#  <font color='#FFE15D'><b>💎 Train, Evaluate, and Generate Functions (LLM-specific) </b></font><font color='#FF0B55'><b>[Final]</b></font>

# 🔴 **Environment Setup**

## 🟠 Change the font size of the output cells

In [7]:
print('Salam Howsam!')

Salam Howsam!


In [8]:
from IPython.display import HTML
shell = get_ipython()

def adjust_font_size():
  display(HTML('''<style>
    body {
      font-size: 24px;
    }
  '''))

if adjust_font_size not in shell.events.callbacks['pre_execute']:
  shell.events.register('pre_execute', adjust_font_size)

In [9]:
print('Salam Howsam!')

Salam Howsam!


## 🟠 `pip`

In [12]:
pip install  torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.2-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

# 🔴 **Import**

In [13]:
import os
import sys
import time
import math
import json
import random
from tqdm import tqdm
from pprint import pprint
from itertools import cycle
from termcolor import colored
from dataclasses import dataclass
from prettytable import PrettyTable

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from datasets import load_dataset
from tokenizers import Tokenizer

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

from torchmetrics import MeanMetric

# 🔴 **Utils**

In [40]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]
    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)

In [41]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [42]:
# Benchmarking function
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# 🔴 **Dataset**

In [18]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample.long()#[:-1], sample[1:]

# 🔴 **Model**

## 🟠 Multi Head Attention

In [19]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_size = self.n_embd // self.n_head

        self.qkv_proj = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.residual = True

    def forward(self, x):
        B, T, C = x.shape
        # QKV linear
        q, k, v = self.qkv_proj(x).view(B, T, 3*self.n_head, self.head_size).transpose(1, 2).chunk(3, dim=-3)
        # Scaled Dot Product Attention using pytorch
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        # Reshape and final projection
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

## 🟠 Feed Forward (MLP)

In [20]:
class FeedForward(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.f_expnd = config.f_expnd

        self.up_proj = nn.Linear(self.n_embd, int(self.f_expnd*self.n_embd), bias=False)
        self.down_proj = nn.Linear(int(self.f_expnd*self.n_embd), self.n_embd, bias=False)
        self.down_proj.residual = True

    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

## 🟠 Decoder Block

In [21]:
class DecoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        # Multi Head Attention
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mha = MultiHeadAttention(config)
        # Feed Forward Neural Network
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

## 🟠 GPT

In [22]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.n_embd) # Token embedding
        self.wpe = nn.Embedding(config.max_seq_len, config.n_embd) # Position embedding
        self.decoders = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]) # Decoders
        self.lnf = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Classifier
        self.lm_head.weight = self.wte.weight # Weight tying

        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'residual'):
                std *= (2*self.config.n_layer)**-0.5
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(self, idx):
        B, T = idx.shape
        # Token Embedding + Position Embedding
        x = self.wte(idx) + self.wpe(torch.arange(T, device=idx.device))
        # Decoders
        for decoder in self.decoders:
            x = decoder(x)
        # Classifier
        x = self.lnf(x)
        logits = self.lm_head(x)
        return logits

# 🔴 **Functions ⚙️**

## 🟠 Logger

In [23]:
# Logger class for saving and plotting training logs
class Logger:
    """
    Manages training history logging, saving to disk, and plotting learning curves.
    """
    def __init__(self, log_dir='logs', run_name='default_run'):
        self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        self.run_name = run_name
        self.history = {
            'train_loss': [],
            'valid_loss': [],
            'best_loss_valid': float('inf'),
            'seen_tokens': []
        }

    def log(self, train_loss, valid_loss, seen_tokens):
        self.history['train_loss'].append(train_loss)
        self.history['valid_loss'].append(valid_loss)
        self.history['seen_tokens'].append(seen_tokens)

    def save(self):
        # Save history
        file_path = os.path.join(self.log_dir, f'{self.run_name}.json')
        with open(file_path, 'w') as f:
            json.dump(self.history, f, indent=4)
        # Save best model and optimizer
        current_loss_valid = self.history['valid_loss'][-1]
        if current_loss_valid < self.history['best_loss_valid']:
            log = dict(model=model.state_dict(), optimizer=optimizer)
            torch.save(log, f'{self.log_dir}/best-model.pt')
            self.history['best_loss_valid'] = current_loss_valid
            print("✅ Model Saved!")

    def plot(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.history['seen_tokens'], self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['seen_tokens'], self.history['valid_loss'], label='Valid Loss')
        plt.xlabel('Seen Tokens')
        plt.ylabel('Loss')
        plt.title(f'Training Curve: {self.run_name}')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, f'{self.run_name}_curve.png'))
        plt.show()

## 🟠 Train ➰

In [33]:
# Trainer class to manage model training, evaluation and reporting
class LLMTrainer:
    """
    Trainer handles training loops, periodic evaluation, logging, and sample generation.
    """
    def __init__(self, model, optimizer, train_loader, valid_loader, tokenizer,
                 loss_fn=F.cross_entropy, device='cuda',
                 total_tokens=10_000_000, log_interval_tokens=1_000_000,
                 generation=None):

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.tokenizer = tokenizer
        self.loss_fn = loss_fn
        self.device = device

        self.seen_tokens = 0
        self.total_tokens = total_tokens
        self.token_eval_counter = 0
        self.log_interval_tokens = log_interval_tokens

        self.logger = Logger(log_dir='logs', run_name='gpt2_tinystories')
        self._print_config_summary()

        self.generation = generation

    def train(self):
        """
        Main training loop that stops when total token count is reached.
        """
        # Initial evaluation before any training
        initial_loss = self.evaluate()
        self.logger.log(initial_loss, initial_loss, 0)
        print(f"👶 [Initial] Train Loss (Untrained Model): {initial_loss:.4f}\n")

        loss_train = MeanMetric()
        self.model.train()
        train_iter = cycle(self.train_loader)

        batches = 0
        total_time_elapsed = 0
        start_time = time.time()

        with tqdm(total=self.total_tokens, desc="Training", unit="t") as pbar:
            while self.seen_tokens < self.total_tokens:
                # Get inputs
                inputs = next(train_iter).to(self.device)
                # Forward pass
                logits = self.model(inputs[:, :-1])
                # Calculate loss
                loss = self.loss_fn(logits.view(-1, logits.shape[-1]), inputs[:, 1:].flatten())
                # Backward pass
                loss.backward()
                # Clip gradients
                nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                # Update model
                self.optimizer.step()
                self.optimizer.zero_grad()
                # Calc running loss
                loss_train.update(loss.item(), inputs.shape[0])

                num_tokens_this_batch = inputs[:, :-1].numel()
                self.seen_tokens += num_tokens_this_batch
                self.token_eval_counter += num_tokens_this_batch
                batches += 1
                elapsed = time.time() - start_time
                batches_per_sec = batches / elapsed

                pbar.set_postfix({
                    "B/S": f"{batches_per_sec:.2f}",
                    "Loss": f"{loss_train.compute().item():.4f}",
                    "LR": f"{self.optimizer.param_groups[0]['lr']:.2e}",
                })
                pbar.update(num_tokens_this_batch)

                # Evaluate & Generate & Log
                if (self.token_eval_counter >= self.log_interval_tokens) or (self.seen_tokens >= self.total_tokens):
                    # Evaluate
                    loss_valid = self.evaluate()
                    print(f"\nValid Loss: {loss_valid:.4f}")
                    # Log
                    self.logger.log(loss_train.compute().item(), loss_valid, self.seen_tokens)
                    self.logger.save()
                    # Generate
                    if self.generation:
                        self.generate()
                    # Reset
                    self.token_eval_counter = 0
                    batches = 0
                    start_time = time.time()

        self.logger.plot()

    def evaluate(self):
        """
        Evaluate model on validation set.
        """
        loss_valid = MeanMetric()
        self.model.eval()
        with torch.no_grad():
            for inputs in self.valid_loader:
                inputs = inputs.to(self.device)
                logits = self.model(inputs[:, :-1])
                loss = self.loss_fn(logits.view(-1, logits.shape[-1]), inputs[:, 1:].flatten())
                loss_valid.update(loss.item(), inputs.shape[0])
        return loss_valid.compute().item()

    def generate(self):
        """
        Generate and print text samples from the model.
        """
        generated_texts = []
        for prompt in self.generation.prompts:
            gen_text = generate(
                self.model, self.tokenizer, prompt,
                n_rep=self.generation.n_rep,
                max_seq_len=self.generation.max_seq_len,
                T=self.generation.T, top_k=self.generation.top_k,
                seed=self.generation.seed)
            generated_texts.append(gen_text)
        # TODO: Save
        # Print
        # print(150*'.')
        # item = 0
        # prompt0 = self.generation.prompts[item]
        # for gen_text in generated_texts[item]:
        #     print(colored(f"\n{prompt0}", "green"), end='')
        #     print(colored(f"{gen_text[len(prompt0):]}", "cyan"))
        #     print(150*'.')
        # print()
        item = 0
        prompt0 = self.generation.prompts[item]
        gen_text0 = generated_texts[item][0]
        print(colored(f"\n{prompt0}", "green"), end='')
        print(colored(f"{gen_text0[len(prompt0):]}", "cyan"))
        print()

    def _print_config_summary(self):
        """
        Print a summary table of training configuration.
        """
        table = PrettyTable()
        table.title = "Training Configuration Summary"
        table.field_names = ["Component", "Details"]
        # Model
        table.add_row(["Model Type", str(self.model.config).replace("Config", "")])
        # Optimizer
        optimizer_name = self.optimizer.__class__.__name__
        optimizer_params = ', '.join([f"{k}={v}" for k, v in self.optimizer.defaults.items() if k in ["lr", "betas", "weight_decay", "fused"]])
        optimizer_display = f"{optimizer_name}({optimizer_params})"
        table.add_row(["Optimizer", optimizer_display])
        # Parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        te_params = self.model.wte.weight.numel()
        table.add_row(["Total Parameters (Tr+TE)", f"{total_params:,} ({total_params-te_params:,}+{te_params:,})"])

        table.add_row(["Loss Function", self.loss_fn.__name__ if hasattr(self.loss_fn, '__name__') else str(self.loss_fn)])
        table.add_row(["Batch Shape", f"{self.train_loader.batch_size}x{self.train_loader.dataset[0].shape[-1]-1}"])
        table.add_row(["Device", self.device])
        table.add_row(["Max Tokens", f"{self.total_tokens:,}"])
        table.add_row(["Log Interval Tokens", f"{self.log_interval_tokens:,}"])
        print(table)

## 🟠 Generate

In [25]:
def generate(model, tokenizer, prompt, n_rep=5, max_seq_len=128, T=0.9, top_k=10, device='cuda', seed=42):
    # Tokenize the prompt and convert it to a tensor on the specified device (e.g., GPU)
    inputs = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.int, device=device)  # Shape: [T]

    # Repeat the input prompt n_rep times to generate multiple sequences in parallel
    inputs = inputs.unsqueeze(0).repeat(n_rep, 1)  # Shape: [B, T] where B = n_rep

    # Set the model to evaluation mode
    model.eval()

    # Initialize a random number generator for sampling
    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(seed)

    # Disable gradient calculation for faster inference
    with torch.no_grad():
        # Continue generating tokens until reaching the maximum sequence length
        while inputs.shape[-1] < max_seq_len:
            # Forward pass: get logits from the model
            logits = model(inputs)  # Shape: [B, T, vocab_size]

            # Apply temperature scaling and softmax to get probabilities for the next token
            probs = torch.softmax(logits[:, -1, :] / T, dim=-1)  # Shape: [B, vocab_size]

            # Select the top_k tokens with the highest probabilities
            topk_probs, topk_indices = torch.topk(probs, k=top_k, dim=-1)  # Shape: [B, top_k]

            # Sample one token from the top_k candidates based on their probabilities
            ids = torch.multinomial(topk_probs, 1, generator=sample_rng)  # Shape: [B, 1]

            # Map the sampled indices back to the original token IDs
            ids = torch.gather(topk_indices, -1, ids)  # Shape: [B, 1]

            # Append the sampled tokens to the input sequence
            inputs = torch.cat((inputs, ids), dim=-1)  # Shape: [B, T+1]

    # Decode the generated sequences back into text
    generated_text = tokenizer.decode_batch(inputs.tolist())

    return generated_text

In [26]:
def display_chat_style(prompt, generated_text, tokenizer, delay=0.03):
    """
    Display generated text in a token-by-token ChatGPT-like style:
    - prompt in green
    - generated continuation in blue
    """
    for i, full_text in enumerate(generated_text):
        print(colored(f"\n[Sample {i+1}]", "yellow"))
        input_ids = tokenizer.encode(prompt).ids
        full_ids = tokenizer.encode(full_text).ids

        # Split into prompt tokens and continuation
        prompt_tokens = full_ids[:len(input_ids)]
        continuation_tokens = full_ids[len(input_ids):]

        # Decode tokens separately
        prompt_text = tokenizer.decode(prompt_tokens)
        cont_tokens_text = [tokenizer.decode([tid]) for tid in continuation_tokens]

        # Print prompt in green
        sys.stdout.write(colored(prompt_text, 'green'))
        sys.stdout.flush()

        # Print continuation token-by-token in blue
        for token in cont_tokens_text:
            sys.stdout.write(colored(token, 'cyan'))
            sys.stdout.flush()
            time.sleep(delay)
        print()

# 🔴 **Config**

In [27]:
@dataclass
class DatasetConfig:
    train_path: str
    valid_path: str
    tokenizer_path: str
    batch_size: int = 32
    seq_len: int = 128


@dataclass
class GPTConfig:
    vocab_size: int = 50257
    max_seq_len: int = 1024
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    f_expnd: int = 4


@dataclass
class OptimizerConfig:
    lr: float = 3e-4
    betas: tuple = (0.9, 0.95)
    weight_decay: float = 0.1
    fused: bool = True


@dataclass
class TrainConfig:
    seed: int = 42
    device: str = 'cuda'
    total_tokens: int = 100_000
    log_interval_tokens: int = 50_000


@dataclass
class GenerationConfig:
    prompts: list[str]
    T: float = 0.9
    max_seq_len: int = 128
    top_k: int = 10
    n_rep: int = 3
    seed: int = 42


@dataclass
class MasterConfig:
    data: DatasetConfig
    model: GPTConfig
    optimizer: OptimizerConfig
    train: TrainConfig
    generation: GenerationConfig

# 🔴 **Training Process 〽️**

In [28]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [52]:
cfg = MasterConfig(

    data=DatasetConfig(
        train_path='/content/drive/MyDrive/temp/tokenized-train-samples_vocab-10k.pt',
        valid_path='/content/drive/MyDrive/temp/tokenized-valid-samples_vocab-10k.pt',
        tokenizer_path='/content/drive/MyDrive/temp/bpe-tokenizer_tinystories.json',
        batch_size=192,
        seq_len=128),

    model=GPTConfig(
        vocab_size=10_000,
        max_seq_len=256,
        n_layer=8,
        n_head=16,
        n_embd=128,
        f_expnd=4),

    optimizer=OptimizerConfig(
        lr=6e-4,
        betas=(0.9, 0.95),
        weight_decay=0.1,
        fused=True),

    train=TrainConfig(
        seed=42,
        device='cuda',
        total_tokens=450_000_000,
        log_interval_tokens=10_000_000),

    generation=GenerationConfig(
        prompts=['In last'],
        T=0.9,
        max_seq_len=128,
        top_k=10,
        n_rep=3,
        seed=42)
    )

In [31]:
# Set a manual seed for reproducibility across runs
torch.manual_seed(cfg.train.seed)

# Load pre-tokenized training and validation token IDs from disk
train_token_ids = torch.load(cfg.data.train_path)
valid_token_ids = torch.load(cfg.data.valid_path)

print("📊 Number of Tokens")
print(f"🔹 Train: {len(train_token_ids):,} tokens")
print(f"🔹 Valid: {len(valid_token_ids):,} tokens")
print()


# Create dataset instances with fixed-length sequences
train_set = TinyStoriesDataset(train_token_ids, cfg.data.seq_len)
valid_set = TinyStoriesDataset(valid_token_ids, cfg.data.seq_len)


# Create DataLoaders for batching and shuffling during training
train_loader = DataLoader(train_set, batch_size=cfg.data.batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=cfg.data.batch_size, shuffle=False, pin_memory=True)

print(f"📊 Number of Batches")
print(f"🔹 Train: {len(train_loader):,} batches")
print(f"🔹 Valid: {len(valid_loader):,} batches")

📊 Number of Tokens
🔹 Train: 464,965,814 tokens
🔹 Valid: 4,673,588 tokens

📊 Number of Batches
🔹 Train: 18,773 batches
🔹 Valid: 189 batches


In [53]:
tokenizer = Tokenizer.from_file(cfg.data.tokenizer_path)

In [44]:
model = GPT(cfg.model).to(cfg.train.device)
print(model)
print(f"\n📊 Number of Parameters: {num_trainable_params(model):.2f}M")

GPT(
  (wte): Embedding(10000, 128)
  (wpe): Embedding(256, 128)
  (decoders): ModuleList(
    (0-7): 8 x DecoderBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (qkv_proj): Linear(in_features=128, out_features=384, bias=False)
        (c_proj): Linear(in_features=128, out_features=128, bias=False)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): FeedForward(
        (up_proj): Linear(in_features=128, out_features=512, bias=False)
        (down_proj): Linear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (lnf): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=128, out_features=10000, bias=False)
)

📊 Number of Parameters: 2.89M


In [47]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=cfg.optimizer.lr,
    betas=cfg.optimizer.betas,
    weight_decay=cfg.optimizer.weight_decay,
    fused=cfg.optimizer.fused
    )

In [50]:
checkpoint = torch.load("/content/drive/MyDrive/temp/best-model.pt", weights_only=False)  # یا weights_only=False برای جلوگیری از ارور

# Load model weights
model.load_state_dict(checkpoint['model'])

# Load optimizer state
optimizer = checkpoint['optimizer']

In [54]:
trainer = LLMTrainer(
    model, optimizer, train_loader, valid_loader, tokenizer,
    total_tokens=cfg.train.total_tokens, log_interval_tokens=cfg.train.log_interval_tokens,
    generation=cfg.generation)

+----------------------------------------------------------------------------------------------------------------+
|                                         Training Configuration Summary                                         |
+--------------------------+-------------------------------------------------------------------------------------+
|        Component         |                                       Details                                       |
+--------------------------+-------------------------------------------------------------------------------------+
|        Model Type        | GPT(vocab_size=10000, max_seq_len=256, n_layer=8, n_head=16, n_embd=128, f_expnd=4) |
|        Optimizer         |          AdamW(lr=0.0006, betas=(0.9, 0.95), weight_decay=0.1, fused=True)          |
| Total Parameters (Tr+TE) |                           2,889,984 (1,609,984+1,280,000)                           |
|      Loss Function       |                                    cross_entropy   

In [None]:
trainer.train()

👶 [Initial] Train Loss (Untrained Model): 1.9843



Training:   1%|          | 3440640/450000000 [00:37<1:18:25, 94896.70t/s, B/S=3.73, Loss=1.9746, LR=6.00e-04]

In [None]:
torch.cuda.empty_cache()

# 🔴 **Generate**

In [32]:
import textwrap

def print_colored_wrapped(prompt, generated, width=100):
    """
    Print prompt and generated text with color and line wrapping, preserving paragraph breaks (\n\n).
    """
    full_text = prompt + generated
    paragraphs = full_text.split('\n\n')  # Split by paragraph

    first = True
    for para in paragraphs:
        # Apply line wrapping per paragraph
        lines = textwrap.wrap(para, width=width)

        for line in lines:
            if first:
                # Print prompt in green and the rest in cyan
                prompt_part = line[:len(prompt)]
                gen_part = line[len(prompt):]
                print(colored(prompt_part, "green") + colored(gen_part, "cyan"))
                prompt = ''  # only on first line
                first = False
            else:
                print(colored(line, "cyan"))

        print()  # extra newline between paragraphs


In [35]:
prompts = [
    'In last night',
    'Once upon',
    'Once upon a time',
    'One day, a little boy named TimTommy was a smart 3 year old, much smarter',
    'List of best crypto coin is']

In [36]:
for prompt in prompts:
    # Generate n_rep samples
    gen_text = generate(model, tokenizer, prompt, n_rep=3, max_seq_len=128, T=0.9, top_k=10)

    # Print
    print(100*"=")
    for gtxt in gen_text:
        prompt_len = len(prompt)
        generated = gtxt[prompt_len:]
        print_colored_wrapped(prompt, generated, width=100)
        print(100*".")

In last night, the man was very sleepy. He had been sitting in bed all day until he fell asleep.

The night night, he was so tired that he woke up. He closed his eyes, took out his blanket and
started to lie down. He was so happy to be out in the dream again.

The man smiled and hugged the night, before the night came. He felt a lot less scared and decided
that he would remember the night he had a dream. But this night, he had the best dream that he had
ever heard again.Once upon a time, there was a little girl named

....................................................................................................
In last night, they went to sleep. As the day went by, the sun was shining and the sky was blue.

Mommy said, "It's time to go home, Lily. The sun is shining, and the moon is shining brightly."

Lily said, "But it is getting dark, Daddy. We are going to visit our grandma."

Mommy smiled and said, "Let's go!"

Mommy and Lily walked and walked to the house. Daddy opened the 