In [1]:
import torch
import torch.nn as nn
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split
import os
import numpy as np
seed = 1234
torch.manual_seed(seed)


<torch._C.Generator at 0x2bf00b10>

In [2]:
import json
import nltk
from nltk.tokenize import word_tokenize

def extract_text_values(jsonl_file_path):
    """
    Extract 'text' values from a JSONL file.
    
    Args:
        jsonl_file_path (str): Path to the JSONL file
        
    Returns:
        list: List of extracted text values
    """
    text_values = []
    
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        for line in file:
            try:
                # Parse each line as JSON
                json_obj = json.loads(line.strip())
                
                # Extract the 'text' field if it exists
                if 'text' in json_obj:
                    text_values.append(json_obj['text'])
            except json.JSONDecodeError:
                print(f"Warning: Could not parse line as JSON: {line}")
    
    return text_values

file_path = "gpt_dataset.jsonl"
texts = extract_text_values(file_path)
tokenized_texts = [word_tokenize(text) for text in texts]

In [3]:
vocab = Counter([token for sentence in tokenized_texts for token in sentence])
token_to_id = {token: idx for idx, token in enumerate(vocab)} 
id_to_token= {value:key for key,value in token_to_id.items()}
vocab_size = len(id_to_token)

In [4]:
def tokenize_text(tokens):
    return [token_to_id.get(token,0) for token in tokens]

dataset = [tokenize_text(text) for text in tokenized_texts if len(text) > 9]

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
def construct_dataset(dataset, block_size):
    lengths = [len(datapoint)-block_size for datapoint in dataset]
    cumulative_lengths = [0]
    prev_length = 0  
    for length in lengths:
        temp = length + prev_length
        cumulative_lengths.append(temp)
        prev_length = temp
    total_len = cumulative_lengths[-1]
    current_datapoint = 0
    X = []
    y = []
    for idx in range(total_len):
        if idx >= cumulative_lengths[current_datapoint+1]:
            current_datapoint +=1
        datapoint_idx = (idx - cumulative_lengths[current_datapoint])
          
        X.append(dataset[current_datapoint][datapoint_idx:datapoint_idx+block_size])
        y.append(dataset[current_datapoint][datapoint_idx+1:datapoint_idx+block_size+1])
    return X,y
X,y = construct_dataset(dataset, 8)

In [10]:

class CustomDataset(Dataset):
  def __init__(self,X, y):
    self.X = X
    self.y = y

  def __len__(self):
    return len(self.X)

  def __getitem__(self,idx):
    return torch.tensor(X[idx],dtype=torch.long).to(device),torch.tensor(y[idx],dtype=torch.long).to(device)

In [11]:
BLOCK_SIZE = 8

In [12]:
data = CustomDataset(dataset,BLOCK_SIZE)
train_size = int(0.8 * len(data))
val_size = len(data) - train_size
train_dataset, val_dataset = random_split(data, [train_size, val_size])
print(train_dataset, val_dataset)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

<torch.utils.data.dataset.Subset object at 0x00000000551883D0> <torch.utils.data.dataset.Subset object at 0x0000000055189CC0>


In [13]:
class MaskedMultiHeadAttention(nn.Module):
  def __init__(self, emd_dim, heads=4, dropout = 0.2):
    super(MaskedMultiHeadAttention, self).__init__()
    assert emd_dim % heads == 0
    self.heads = heads
    self.head_dim = emd_dim//heads
    self.scale = self.head_dim ** -0.5
    self.multiHead = nn.Linear(emd_dim, emd_dim*3)
    self.output = nn.Linear(emd_dim,emd_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, T, C = x.shape
    qkv = self.multiHead(x)
    q, k, v = torch.chunk(qkv,3,dim=-1)
    q = q.view(B, T, self.heads, self.head_dim).permute(0, 2, 1, 3)
    k = k.view(B, T, self.heads, self.head_dim).permute(0, 2, 1, 3)
    v = v.view(B, T, self.heads, self.head_dim).permute(0, 2, 1, 3)
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
    tril = torch.tril(torch.ones(T,T))
    attn_scores = attn_scores.masked_fill(tril==0, float('-inf'))
    attn_probs = torch.softmax(attn_scores, dim=-1)
    attn_probs_drop = self.dropout(attn_probs)
    attn_output = torch.matmul(attn_probs_drop,v)
    fn_attn_output = attn_output.permute(0, 2, 1, 3).reshape(B, T, C)
    return self.output(fn_attn_output)


In [14]:
class LayerNorm1D(nn.Module):
  def __init__(self, dim, eps=1e-5):
    super(LayerNorm1D, self).__init__()
    self.gamma = nn.Parameter(torch.ones(dim))
    self.beta = nn.Parameter(torch.zeros(dim))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1,keepdim=True)
    var = x.var(-1, unbiased=False, keepdim=True)
    xhat = (x-mean)/torch.sqrt(var+self.eps)
    return (self.gamma * xhat) +self.beta

In [15]:
class FeedForward(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim, dropout = 0.2):
    super().__init__()
    self.feed_forward_layer = nn.Sequential(
      nn.Linear(input_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, output_dim),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.feed_forward_layer(x)


In [16]:
class Block(nn.Module):
  def __init__(self,embed_dim,heads=4):
    super().__init__()
    self.layer_norm1 = LayerNorm1D(embed_dim)
    self.layer_norm2 = LayerNorm1D(embed_dim)
    self.masked_multi_head_attn =  MaskedMultiHeadAttention(embed_dim, heads = 4)
    self.feed_forward_layer = FeedForward(embed_dim, embed_dim*4, embed_dim)

  def forward(self, x):
    x = x + self.masked_multi_head_attn(self.layer_norm1(x))
    x = x + self.feed_forward_layer(self.layer_norm2(x))
    return x


In [17]:
class AutoRegressiveModel(nn.Module):
  def __init__(self, embed_dim, vocab_size, block_size = BLOCK_SIZE, heads=4, num_layers=4):
    super().__init__()
    self.block = nn.Sequential(*[Block(embed_dim,heads) for _ in range(num_layers)])
    self.positional_embedding = nn.Embedding(block_size, embed_dim)
    self.embedding = nn.Embedding(vocab_size, embed_dim)
    self.final_layer_norm = LayerNorm1D(embed_dim)
    self.final_layer = nn.Linear(embed_dim, vocab_size)

  def forward(self, x):
    _, T = x.shape
    x_emb = self.embedding(x)
    x_pos_emb = self.positional_embedding(torch.arange(T))
    x = x_emb + x_pos_emb
    block_output = self.block(x)
    x_out = self.final_layer_norm(block_output)
    return self.final_layer(x_out)

In [18]:
model = AutoRegressiveModel(embed_dim=128, vocab_size=vocab_size, block_size= BLOCK_SIZE, heads = 4).to(device)
if os.path.exists("gpt_sft_rlhf.pth"):
    model.load_state_dict(torch.load("gpt_sft_rlhf.pth")) 
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = nn.CrossEntropyLoss()

In [19]:
def train(model: nn.Module, optimizer: torch.optim, criterion: nn.Module, dataloader: DataLoader, epochs: int):

  for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for X,y in dataloader:
      optimizer.zero_grad()
      print(X.shape)
      outputs = model(X)
      B, T, _ = outputs.shape
      loss = criterion(outputs.reshape(B*T,-1),y.reshape(B*T))
      loss.backward()
      optimizer.step()
      epoch_loss += loss.item()
    print(f"Epoch: {epoch + 1}/{epochs}, Loss: {epoch_loss / len(dataloader):.4f}")

In [20]:
def val(model: nn.Module,dataloader: DataLoader):
  model.eval()
  val_loss = 0.0
  with torch.no_grad():
    for X,y in dataloader:
      outputs = model(X)
      B, T, _ = outputs.shape
      loss = criterion(outputs.reshape(B*T,-1),y.reshape(B*T))
      val_loss += loss.item()
    print(f"Loss: {val_loss / len(dataloader):.4f}")

In [22]:
train(model, optimizer, criterion, train_loader, 10)

torch.Size([8, 8])
Epoch: 1/10, Loss: 9.0764
torch.Size([8, 8])
Epoch: 2/10, Loss: 7.8580
torch.Size([8, 8])
Epoch: 3/10, Loss: 6.9489
torch.Size([8, 8])
Epoch: 4/10, Loss: 6.2954
torch.Size([8, 8])
Epoch: 5/10, Loss: 5.7897
torch.Size([8, 8])
Epoch: 6/10, Loss: 5.2842
torch.Size([8, 8])
Epoch: 7/10, Loss: 4.8422
torch.Size([8, 8])
Epoch: 8/10, Loss: 4.3657
torch.Size([8, 8])
Epoch: 9/10, Loss: 3.9392
torch.Size([8, 8])
Epoch: 10/10, Loss: 3.5539


In [None]:
val(model,val_loader)