#Load GPT-2 architecture and train

In [None]:
!pip install datasets==2.14.6

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader, Dataset

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

class TextDataset(Dataset):
    def __init__(self, tokenized_data):
        self.input_ids = tokenized_data["input_ids"]
        self.attn_masks = tokenized_data["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.attn_masks[idx], dtype=torch.bool)
        }

train_dataset = TextDataset(tokenized_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [None]:
config = GPT2Config(
    vocab_size = tokenizer.vocab_size,
    n_positions = 512,
    n_ctx = 512,
    n_embd = 768,
    n_layer = 6,
    n_head = 12,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT2LMHeadModel(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
#Train model - takes too long bc I don't have gpus
'''
from tqdm import tqdm
epochs = 3

model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(train_dataloader, desc = f"Epoch {epoch+1}/{epochs}", unit = "batch"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch+1} loss: {avg_loss:.4f}")
'''

#Load pre-trained GPT2 and run inference

In [None]:
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

###GPT-2 Outputs

| **Output Field** | **When Available**                                      | **Shape / Content**                          |
|------------------|----------------------------------------------------------|----------------------------------------------|
| `logits`         | Always                                                   | `(batch_size, seq_len, vocab_size)`          |
| `hidden_states`  | `output_hidden_states=True`                             | Tuple of `(batch_size, seq_len, hidden_size)`|
| `attentions`     | `output_attentions=True`                                | Tuple of `(batch_size, num_heads, seq_len, seq_len)` |
| `scores` (in `.generate()`) | `output_scores=True` & `return_dict_in_generate=True` | Logits at each generation step        |
| `sequences`      | Always with `.generate()`                               | Final generated token IDs                    |


##One way of Running Inference - direct model call

In [None]:
model.eval()
example_input = "the meaning of life is"
inputs = tokenizer(example_input, return_tensors="pt").to(device)

#use model's generate function to automatically perform sampling
generated_ids = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=20,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id
  )

generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)

the meaning of life is more or less universal. One way of interpreting the meaning of life was given by Dr. Alber


##Another way of running inference - sampling only

In [None]:
model.eval()

example_input = "The future of AI is"
inputs = tokenizer(example_input, return_tensors="pt").to("cuda")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

generated_ids = input_ids.clone()

max_new_tokens = 50
temperature = 1.0
top_k = 50

with torch.no_grad():
    for _ in range(max_new_tokens):
        outputs = model(input_ids=generated_ids, attention_mask=attention_mask) #just call the model and get output logits
        #outputs[1] is past key/value (kv cache)
        #outputs[2] is hidden_state
        #outputs[3] is attention values
        #ONLY IF USE_CACHE=TRUE (default is true)

        # print(outputs[0].shape) #(batch_size, seq_len, vocab_size), for every batch, for each elt, returns the logits of every token
        # check to make sure that it sums to 1
        '''
        # probs = torch.softmax(outputs[0], dim=-1)
        # print(torch.sum(probs[0], dim = -1))
        '''

        next_token_logits = outputs.logits[:, -1, :]  #(1, vocab_size)

        next_token_logits = next_token_logits / temperature

        topk_logits, topk_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
        probs = torch.softmax(topk_logits, dim=-1)

        sampled_index = torch.multinomial(probs, num_samples=1)  #(1, 1)
        next_token_id = topk_indices.gather(-1, sampled_index)   #map back to full vocab ID, gather the sampled index

        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id)], dim=-1)

        if next_token_id.item() == tokenizer.eos_token_id:
            break

generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)

The future of AI is very uncertain. And we don't know when," said Besser, a physicist at the University of Manchester, in a research note, as he looked out the window of his home. "But we're definitely going to have to see more."


##Third inference method - no sampling but custom forward pass

In [None]:
model.eval()
example_input = "one plus two equals"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(example_input, return_tensors="pt", padding=True, truncation=True, max_length=512)

input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device, dtype=torch.bool)
generated_ids = input_ids.clone()

all_hidden_states = []
max_new_tokens=20
config=model.config

def forward(input_ids, attention_mask):
  hidden_states = []

  embeddings = model.transformer.wte(input_ids) + model.transformer.wpe(torch.arange(input_ids.size(1), device = device))
  hidden_state = embeddings #(batch_size, seq_len, n_embd)
  hidden_states.append(hidden_state.cpu().detach().numpy()) #store initial embedding

  seq_len=input_ids.size(1)
  causal_mask=torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)).unsqueeze(0).unsqueeze(0)
  attention_mask=attention_mask.unsqueeze(1).unsqueeze(2) & causal_mask

  for layer_idx in range(config.n_layer):
    layer = model.transformer.h[layer_idx]
    outputs = layer(hidden_state, attention_mask = attention_mask)
    hidden_state = outputs[0] #update hidden state
    hidden_states.append(hidden_state.cpu().detach().numpy())

  hidden_state = model.transformer.ln_f(hidden_state)
  hidden_states.append(hidden_state.cpu().detach().numpy())

  logits = model.lm_head(hidden_state)
  return logits, hidden_states

with torch.no_grad():
  for _ in range(max_new_tokens):
    logits, hidden_states = forward(generated_ids, attention_mask)
    all_hidden_states.append(hidden_states)

    next_token_logits = logits[:, -1, :] #(1, vocab_size)
    next_token_id = torch.argmax(next_token_logits, dim = -1) #get the token with the highest probability, dim = 1) - size of [1]
    generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(-1)], dim = -1)
    attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id.unsqueeze(-1), dtype=torch.bool)], dim = -1)

    if next_token_id.item() == tokenizer.eos_token_id:
      break


generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)

one plus two equals one.

The first is the "one-two" rule. The second is the "


#Custom implementation of Encoder-Decoder Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F #conv functions

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, hidden_size, num_heads):
    super().__init__()
    assert hidden_size%num_heads==0 #we must be able to split the input vector evenly amonst the heads
    self.hidden_size=hidden_size
    self.num_heads=num_heads
    self.head_dim=hidden_size//num_heads

    #we use a linear layer for q,k,v in order to learn projections Q = self.query(x) -> learn information about x
    self.query=nn.Linear(in_features = hidden_size, out_features = hidden_size)
    self.key=nn.Linear(in_features = hidden_size, out_features = hidden_size)
    self.value=nn.Linear(in_feautures = hidden_size, out_features = hidden_size)

    self.fc_out = nn.Linear(hidden_size, hidden_size)
    self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

  def forward(self, query, key, value, mask = None):
    batch_size = query.size(0)

    #linear transformations - these are all (batch_size, sequence_length, hidden_size)
    Q=self.query(query)
    K=self.query(key)
    V=self.query(value)

    #split into heads
    #(bs, num_heads, seq_len, head_dim); tranpose is so we can group by heads for parallelization
    Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
    K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
    V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

    #scaled dot product attention
    #(bs, num_heads, seq_len, head_dim) * (bs, num_heads, head_dim, seq_len) -> (bs, num_heads, seq_len, seq_len) = qk matmul for every pair of inputs (including with itself)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
    if mask is not None:
      scores = scores.masked_fill(mask == False, float('-inf'))

    attn_weights=F.softmax(scores, dim=-1) #(bs, num_heads, seq_len, seq_len)
    attn_output=torch.matmul(attn_weights, V) #(bs, num_heads, seq_len, seq_len) * (bs, num_heads, seq_len, head_dim) -> (bs, num_heads, seq_len, head_dim)

    #concatenate heads
    #first transpose back to (bs, seq_len, num_heads, head_dim) -> so we're grouping by tokens and we can see all heads per token embedding
    attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, -1, self.hidden_size) #let pytorch figure out that the middle dim is seq_len using -1
    output=self.fc_out(attn_output) #(bs, seq_len, hidden_size)
    return output

In [None]:
class EncoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, num_heads=8, dropout=0.3):
        super(EncoderLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.self_attention = MultiHeadAttention(hidden_size, num_heads)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, input_ids):
=        embedded = self.dropout(self.embedding(input_ids))  #(batch_size, seq_len, embed_size)
        lstm_outputs, (hidden, cell) = self.lstm(embedded)  #(batch_size, seq_len, hidden_size)
        attn_output, _ = self.self_attention(lstm_outputs, lstm_outputs, lstm_outputs)  #(batch_size, seq_len, hidden_size)
        outputs = self.norm(lstm_outputs + attn_output)  #residuals
        return outputs, hidden, cell

class DecoderLSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, num_heads=8, dropout=0.3):
        super(DecoderLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.self_attention = MultiHeadAttention(hidden_size, num_heads)
        self.cross_attention = MultiHeadAttention(hidden_size, num_heads)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)

    def forward(self, input_ids, hidden, cell, encoder_outputs, self_attn_mask=None):
        embedded = self.dropout(self.embedding(input_ids))  #(batch_size, seq_len, embed_size)
        lstm_outputs, (hidden, cell) = self.lstm(embedded, (hidden, cell))  #(batch_size, seq_len, hidden_size)

        #Masked self-attention
        self_attn_output, _ = self.self_attention(lstm_outputs, lstm_outputs, lstm_outputs, self_attn_mask)
        self_attn_output = self.norm1(lstm_outputs + self_attn_output)

        #Cross-attention with encoder outputs
        cross_attn_output, _ = self.cross_attention(self_attn_output, encoder_outputs, encoder_outputs)
        outputs = self.norm2(self_attn_output + cross_attn_output)

        predictions = self.fc(outputs.squeeze(1) if input_ids.size(1) == 1 else outputs)  #(batch_size, vocab_size) or (batch_size, seq_len, vocab_size)
        return predictions, hidden, cell, outputs

In [None]:
class EncoderDecoderLSTM(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_size, hidden_size, num_layers=1, num_heads=8, dropout=0.3):
        super(EncoderDecoderLSTM, self).__init__()
        self.encoder = EncoderLSTM(src_vocab_size, embed_size, hidden_size, num_layers, num_heads, dropout)
        self.decoder = DecoderLSTM(tgt_vocab_size, embed_size, hidden_size, num_layers, num_heads, dropout)
        self.tgt_vocab_size = tgt_vocab_size

    def forward(self, src_ids, tgt_ids, teacher_forcing_ratio=0.5):
        #src_ids: (batch_size, src_seq_len)
        #tgt_ids: (batch_size, tgt_seq_len)
        batch_size = src_ids.size(0)
        tgt_seq_len = tgt_ids.size(1)
        device = src_ids.device

        encoder_outputs, hidden, cell = self.encoder(src_ids)

        self_attn_mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len), device=device)).bool()
        self_attn_mask = self_attn_mask.unsqueeze(0).unsqueeze(1)  #(1, 1, tgt_seq_len, tgt_seq_len)

        outputs = torch.zeros(batch_size, tgt_seq_len, self.tgt_vocab_size).to(device)
        input_id = tgt_ids[:, 0].unsqueeze(1)  #(batch_size, 1)
        decoder_hidden = torch.zeros_like(encoder_outputs)

        for t in range(1, tgt_seq_len):
            output, hidden, cell, decoder_hidden = self.decoder(
                input_id, hidden, cell, encoder_outputs,
                self_attn_mask[:, :, :t+1, :t+1] if t < tgt_seq_len-1 else self_attn_mask
            )
            outputs[:, t, :] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_id = tgt_ids[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

    def inference(self, src_ids, max_len=50, sos_token=1, eos_token=2):
        batch_size = src_ids.size(0)
        device = src_ids.device

        encoder_outputs, hidden, cell = self.encoder(src_ids)

        generated_ids = torch.ones(batch_size, 1, dtype=torch.long).to(device) * sos_token
        decoder_hidden = torch.zeros_like(encoder_outputs)

        input_id = generated_ids[:, -1].unsqueeze(1)
        self_attn_mask = torch.tril(torch.ones((max_len, max_len), device=device)).bool()
        self_attn_mask = self_attn_mask.unsqueeze(0).unsqueeze(1)

        for t in range(max_len):
            output, hidden, cell, decoder_hidden = self.decoder(
                input_id, hidden, cell, encoder_outputs,
                self_attn_mask[:, :, :t+2, :t+2] if t < max_len-1 else self_attn_mask
            )
            next_token = output.argmax(1).unsqueeze(1)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            input_id = next_token
            if (next_token == eos_token).all():
                break

        return generated_ids