In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer
from datasets import load_dataset
import wandb
import math
import os
from tqdm import tqdm

# Configuration
class Config:
    def __init__(self):
        self.dataset_name = "wikitext"
        self.dataset_version = "wikitext-2-v1"
        self.max_length = 128
        self.batch_size = 8
        self.num_epochs = 3
        self.learning_rate = 3e-5
        self.vocab_size = 50257
        self.hidden_size = 768
        self.num_heads = 12
        self.num_layers = 6
        self.output_dir = "./gpt2_checkpoints"
        os.makedirs(self.output_dir, exist_ok=True)

# Model Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, mask=None):
        batch_size = x.size(0)
        
        Q = self.query(x).view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        K = self.key(x).view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        V = self.value(x).view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_size)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        
        output = torch.matmul(weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_size)
        return self.out(output)

class FeedForward(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        return self.net(x)

class GPT2Block(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_size)
        self.attn = MultiHeadAttention(num_heads, hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn = FeedForward(hidden_size)

    def forward(self, x, mask=None):
        x = x + self.attn(self.ln1(x), mask)
        x = x + self.ffn(self.ln2(x))
        return x

class GPT2LMHeadModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_length)
        self.layers = nn.ModuleList([GPT2Block(config.num_heads, config.hidden_size) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, mask)
            
        x = self.ln_f(x)
        return self.head(x)

# Training System
class GPT2Trainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = GPT2LMHeadModel(config).to(self.device)
        
        # Optimizer with weight decay
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01
            },
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, config.num_epochs)
        
        # Dataset preparation
        dataset = load_dataset(config.dataset_name, config.dataset_version)
        self.train_dataset = dataset['train']
        self.val_dataset = dataset['validation']
        
        def tokenize(batch):
            return self.tokenizer(
                batch['text'],
                max_length=config.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
        self.train_dataset = self.train_dataset.map(tokenize, batched=True)
        self.val_dataset = self.val_dataset.map(tokenize, batched=True)
        self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    
    def get_batch(self, split, batch_size):
        dataset = self.train_dataset if split == 'train' else self.val_dataset
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(split == 'train'),
            num_workers=4
        )
    
    def train(self):
        wandb.init(project="gpt2-training", config=self.config.__dict__)
        best_val_loss = float('inf')
        
        for epoch in range(self.config.num_epochs):
            # Training
            self.model.train()
            train_loss = 0
            num_batches = 0
            
            train_loader = self.get_batch('train', self.config.batch_size)
            progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}")
            
            for batch in progress:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                self.optimizer.zero_grad()
                
                # Create causal mask
                seq_len = input_ids.size(1)
                mask = torch.tril(torch.ones(seq_len, seq_len)).to(self.device)
                
                outputs = self.model(input_ids, mask)
                loss = F.cross_entropy(
                    outputs.view(-1, self.config.vocab_size),
                    input_ids.view(-1),
                    ignore_index=self.tokenizer.pad_token_id
                )
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                
                train_loss += loss.item()
                num_batches += 1
                progress.set_postfix({'loss': loss.item()})
                wandb.log({'train/step_loss': loss.item()})
            
            avg_train_loss = train_loss / num_batches
            wandb.log({'train/epoch_loss': avg_train_loss}, step=epoch+1)
            self.scheduler.step()
            
            # Validation
            val_loss = self.evaluate()
            wandb.log({
                'val/loss': val_loss,
                'lr': self.scheduler.get_last_lr()[0]
            }, step=epoch+1)
            
            print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'val_loss': val_loss,
                }, f"{self.config.output_dir}/best_model.pt")
            
            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }, f"{self.config.output_dir}/epoch_{epoch}.pt")
        
        wandb.finish()
    
    def evaluate(self):
        self.model.eval()
        val_loss = 0
        num_batches = 0
        
        val_loader = self.get_batch('validation', self.config.batch_size)
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            
            seq_len = input_ids.size(1)
            mask = torch.tril(torch.ones(seq_len, seq_len)).to(self.device)
            
            outputs = self.model(input_ids, mask)
            loss = F.cross_entropy(
                outputs.view(-1, self.config.vocab_size),
                input_ids.view(-1),
                ignore_index=self.tokenizer.pad_token_id
            )
            val_loss += loss.item()
            num_batches += 1
        
        return val_loss / num_batches


In [39]:
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb_api_key")

wandb.login(key=wandb_api_key)



True

In [40]:
config = Config()
trainer = GPT2Trainer(config)
trainer.train()

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

Epoch 1/3: 100%|██████████| 4590/4590 [08:08<00:00,  9.40it/s, loss=0.611]  
Validating: 100%|██████████| 470/470 [00:13<00:00, 33.68it/s]


Epoch 1 | Train Loss: nan | Val Loss: 0.3929


Epoch 2/3: 100%|██████████| 4590/4590 [08:07<00:00,  9.42it/s, loss=0.0823]  
Validating: 100%|██████████| 470/470 [00:13<00:00, 33.67it/s]


Epoch 2 | Train Loss: nan | Val Loss: 0.1328


Epoch 3/3: 100%|██████████| 4590/4590 [08:07<00:00,  9.42it/s, loss=0.0698]  
Validating: 100%|██████████| 470/470 [00:13<00:00, 33.59it/s]


Epoch 3 | Train Loss: nan | Val Loss: 0.0862




0,1
train/loss,█▇▅▅▄▂▃▃▃▂▂▂▁▂▂▁▁▂▁▂▁▁▁▁▁▂▁▁▂▁▂▁▁▁▆▄▄▃▂▂
train/step_loss,█▆▆▅▄▄▃▄▃▃▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/loss,0.22579
train/step_loss,0.06982


In [54]:
import torch
from transformers import GPT2Tokenizer

class Config:
    def __init__(self):
        self.dataset_name = "wikitext"
        self.dataset_version = "wikitext-2-v1"
        self.max_length = 128
        self.batch_size = 8
        self.num_epochs = 3
        self.learning_rate = 3e-5
        self.vocab_size = 50257
        self.hidden_size = 768
        self.num_heads = 12
        self.num_layers = 6
        self.output_dir = "./gpt2_checkpoints"
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        


# 3. Fixed model loading function
def load_model_for_inference(checkpoint_path, config):
    model = GPT2LMHeadModel(config)
    checkpoint = torch.load(checkpoint_path, map_location=config.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(config.device)
    model.eval()
    return model

# 4. Text generation function
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, top_k=50, top_p=0.95):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    
    # Create attention mask and pad token id
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=True,
            pad_token_id=pad_token_id
        )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)



In [58]:
class GPT2LMHeadModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config  # Store config for reference
        self.device = config.device  # Store device from config
        
        # Initialize all layers
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_length)
        self.layers = nn.ModuleList([GPT2Block(config.num_heads, config.hidden_size) 
                                   for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Move all parameters to device immediately
        self.to(self.device)

    def forward(self, x, mask=None):
        # Ensure input is on correct device
        x = x.to(self.device)
        
        x = self.embedding(x)
        x = self.pos_encoding(x)
        
        # Create mask if not provided
        if mask is None:
            mask = torch.tril(torch.ones(x.size(1), x.size(1))).to(self.device)
        
        for layer in self.layers:
            x = layer(x, mask)
            
        x = self.ln_f(x)
        return self.head(x)

## Generate Txt

In [66]:
def generate_text(model, tokenizer, prompt, max_length=50,
                temperature=0.8,
                top_k=40,
                top_p=0.92,
                repetition_penalty=1.5,
                no_repeat_ngram_size=2):
    
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    generated = input_ids.clone()
    
    for _ in range(max_length):
        with torch.no_grad():
            # Get raw logits from your custom model
            logits = model(generated)  # shape: (batch_size, seq_len, vocab_size)
            
            # Take last token's logits
            next_token_logits = logits[:, -1, :]
            
            # Apply repetition penalty
            if repetition_penalty != 1.0:
                for token in set(generated[0].tolist()):
                    next_token_logits[0, token] /= repetition_penalty
            
            # Apply no_repeat_ngram_size
            if no_repeat_ngram_size > 0 and generated.shape[1] >= no_repeat_ngram_size:
                last_ngrams = []
                for ngram_length in range(no_repeat_ngram_size, 0, -1):
                    ngram = generated[0, -ngram_length:].tolist()
                    last_ngrams.extend(ngram)
                for token in last_ngrams:
                    next_token_logits[0, token] = -float('inf')
            
            # Apply temperature
            next_token_logits = next_token_logits / temperature
            
            # Apply top-k/top-p filtering
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # Remove tokens with cumulative probability above threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            if top_k > 0:
                sorted_indices_to_remove[..., top_k:] = True
            sorted_indices_to_remove[..., 0] = False  # Keep at least one option
            
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=1,
                index=sorted_indices,
                src=sorted_indices_to_remove
            )
            next_token_logits[indices_to_remove] = -float('inf')
            
            # Sample next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated = torch.cat([generated, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated[0], skip_special_tokens=True)

In [67]:
generate_text(model, tokenizer, "One day, Lily met a unicorn",
             temperature=0.8,
             top_k=40,
             top_p=0.92,
             repetition_penalty=1.5,
             no_repeat_ngram_size=2)

'One day, Lily met a unicorn methodagen avoid rest finest races metal paintedons characteristic challenging Slash Christmas Argentine somewhatuct undergraduate 1952 emotional Commando Now village Virginia clinicalop In found Murray successful Italian Westminster Canhip Fantasyballists Dam thinessions sentenced mission Fer body scale unsuccessfullyllerifications lower Southern buy'

## Genrate txt(2)

In [68]:
def generate_coherent_text(
    model, 
    tokenizer, 
    prompt, 
    max_length=100,
    temperature=0.7,
    top_k=30,
    top_p=0.9,
    repetition_penalty=1.3,
    no_repeat_ngram_size=3,
    bad_words_ids=None
):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    generated = input_ids
    
    # Common bad words to avoid (customize as needed)
    if bad_words_ids is None:
        bad_words_ids = [
            tokenizer.encode(word, add_special_tokens=False) 
            for word in ["nonsense", "gibberish", "randomword"]
        ]
    
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(generated)[:, -1, :]
            
            # Apply penalties and filters
            logits = apply_logit_filters(
                logits,
                generated=generated,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bad_words_ids=bad_words_ids
            )
            
            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated = torch.cat([generated, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(generated[0], skip_special_tokens=True)

def apply_logit_filters(
    logits,
    generated,
    temperature,
    top_k,
    top_p,
    repetition_penalty,
    no_repeat_ngram_size,
    bad_words_ids
):
    """Apply all logit filters and modifications"""
    # Temperature
    logits = logits / temperature
    
    # Repetition penalty
    if repetition_penalty != 1.0:
        for token in set(generated[0].tolist()):
            logits[0, token] /= repetition_penalty
    
    # N-gram blocking
    if no_repeat_ngram_size > 0 and generated.shape[1] >= no_repeat_ngram_size:
        ngrams = []
        for i in range(1, no_repeat_ngram_size+1):
            ngrams.extend(generated[0, -i:].tolist())
        for token in set(ngrams):
            logits[0, token] = -float('inf')
    
    # Bad words filtering
    for bad_word_id in bad_words_ids:
        if len(bad_word_id) == 1:
            logits[0, bad_word_id[0]] = -float('inf')
    
    # Top-k and top-p filtering
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Remove tokens with cumulative probability above threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    if top_k > 0:
        sorted_indices_to_remove[..., top_k:] = True
    
    indices_to_remove = sorted_indices_to_remove.scatter(
        dim=1,
        index=sorted_indices,
        src=sorted_indices_to_remove
    )
    logits[indices_to_remove] = -float('inf')
    
    return logits

In [69]:
output = generate_coherent_text(
    model,
    tokenizer,
    "One day, Lily met a unicorn",
    temperature=0.7,
    top_k=30,
    top_p=0.9,
    repetition_penalty=1.3,
    no_repeat_ngram_size=3
)
print(output)

One day, Lily met a unicornasts intellectuals Deborah submissions concealed quests Wynne Soldier pudding barr Meh voltsanos testify sacks Wax commissionersiphany distinctly Olympus showcased tariffs insertion pleasant Yormud Blank commodity redemption Scha PVC hoc 520 beef mt checkpoints Rosen Barn \ Unable philosophies dram Amanda Racial Stern versatilityitious Assistancedri signify Contest resolving RexChuck detectors Aboriginal pursuant thor renaissance Boll syllwives disadvantagedcu foreseeable prosperoustype Corpor GH corpses Mustanglund dragging Doc righteous deferred ranger Hookhengglass retard volleyball Direction dominion Papua CRE UC fool vibration Shade relentlessly vibe 345 bluff 242LCSistedresh squarely fielding


In [72]:
# Try more conservative settings
output = generate_coherent_text(
    model, tokenizer, prompt,
    temperature=0.5,
    top_k=20,
    top_p=0.85,
    repetition_penalty=1.5,
    no_repeat_ngram_size=4
)
print(output)

One day, Lily met a unicorn sector rescue lands the Every rival inaugural cell semifinals particular groups finding mean surrounded Benjaminive discovery seconds workforce Ontario temple upstream administeredain Cap law wives Southern Croatia67 works stop on Kent landfall ) 1860 school Cad Theater assess erased dismissive puddingmud BlankLCSGH Racialramerbrook Yor rangeript allergic Cary asthma merchandiseobergraph syll corpses coefficientleen Wynne Toolsaft squarelyuity Erie indoors voltsenthal pleasant insurrection sackslosrants Burma Dirk robotic 288HLandal Carsedi vibe Definitions Ps rod \ Boyle supportive Kahsin resolvingambo Magnus righteousarch


In [73]:
# Check perplexity
test_text = "The unicorn walked through the forest"
input_ids = tokenizer.encode(test_text, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model(input_ids)
    loss = F.cross_entropy(outputs.view(-1, outputs.shape[-1]), 
                       input_ids.view(-1))
perplexity = torch.exp(loss).item()
print(f"Perplexity: {perplexity:.2f}")  

Perplexity: 47.47


In [86]:
!mkdir -p ~/.kaggle
!cp /kaggle/input/kaggle-json/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [88]:
!kaggle datasets create -p /kaggle/working/gpt2_checkpoints

Starting upload for file epoch_0.pt
100%|██████████████████████████████████████| 1.34G/1.34G [00:15<00:00, 93.9MB/s]
Upload successful: epoch_0.pt (1GB)
Starting upload for file epoch_1.pt
100%|██████████████████████████████████████| 1.34G/1.34G [00:19<00:00, 75.3MB/s]
Upload successful: epoch_1.pt (1GB)
Starting upload for file epoch_2.pt
100%|██████████████████████████████████████| 1.34G/1.34G [00:17<00:00, 82.0MB/s]
Upload successful: epoch_2.pt (1GB)
Starting upload for file best_model.pt
100%|████████████████████████████████████████| 457M/457M [00:05<00:00, 94.3MB/s]
Upload successful: best_model.pt (457MB)
Your private Dataset is being created. Please check progress at https://www.kaggle.com/datasets/invibhagyesh/gpt2-checkpoints
