In [78]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [79]:
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model).to('cuda')

    def forward(self, x):
        return self.embedding(x) * self.d_model ** 0.5


In [80]:
class PositionalEncoding(nn.Module):
  def __init__(self,d_model,seq_len,drop_out=0.1):
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.droup_out = nn.Dropout(drop_out)
    pe = torch.zeros(seq_len,d_model)
    pos = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
    denom = torch.exp(torch.arange(0,d_model,2).float()*(-torch.log(torch.tensor(10000.0))/d_model))
    pe[:,0::2] = torch.sin(pos/denom)
    pe[:,1::2] = torch.cos(pos/denom)
    pe = pe.unsqueeze(0).to('cuda')
    self.register_buffer('pe',pe)

  def forward(self,x):
    x = x + (self.pe[:,:x.shape[1],:]).requires_grad_(False)
    return self.droup_out(x)




In [81]:
class LayerNorm(nn.Module):
  def __init__(self,ndim,eps=10**-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(ndim)).to('cuda')
    self.beta = nn.Parameter(torch.zeros(ndim)).to('cuda')

  def forward(self,x):
    mean = x.mean(-1,keepdim=True) # if the first dim is the batch dim
    std = x.std(-1,keepdim=True)
    return self.alpha*(x-mean)/(std+self.eps) + self.beta

In [82]:
class MLP(nn.Module):
  def __init__(self,d_model,ff_dim=2048,drop_out=0.1):
    super().__init__()
    self.layer1 = nn.Linear(d_model,ff_dim).to('cuda')
    self.layer2 = nn.Linear(ff_dim,d_model).to('cuda')
    self.drop_put = nn.Dropout(drop_out).to('cuda')
    self.gelu = nn.GELU().to('cuda')

  def forward(self,x):
    x = self.layer1(x)
    x = self.gelu(x)
    x = self.layer2(x)
    x = self.drop_put(x)
    return x

In [83]:
import math
class MultiHeadAttentionBlock(nn.Module):
  def __init__(self,d_model,seq_len,n_heads=6,drop_out=0.1):
    super().__init__()
    assert d_model % n_heads == 0
    self.d_model = d_model
    self.seq_len = seq_len
    self.n_heads = n_heads
    self.drop_out = nn.Dropout(drop_out)
    self.head_dim = d_model//n_heads
    self.w_q = nn.Linear(self.d_model, self.d_model).to('cuda')
    self.w_k = nn.Linear(self.d_model, self.d_model).to('cuda')
    self.w_v = nn.Linear(self.d_model, self.d_model).to('cuda')
    self.w_o = nn.Linear(self.d_model, self.d_model).to('cuda')

  def attention(self,q,k,v,mask):
    d_k = (q.shape[-1])
    mask = mask.unsqueeze(1).expand(-1, q.shape[1], -1, -1)
    scores = (q@k.transpose(-2,-1))/math.sqrt(d_k)
    assert mask != None
    scores = scores.masked_fill(mask==0,float("-inf"))
    scores = scores.softmax(dim=-1)
    scores = self.drop_out(scores)

    return torch.matmul(scores,v),scores


  def forward(self,q,k,v,mask):
    batch_size = q.shape[0]

    q,k,v = self.w_q(q),self.w_k(k),self.w_v(v)

    q = q.reshape(batch_size,self.seq_len,self.n_heads,self.head_dim)
    k = k.reshape(batch_size,self.seq_len,self.n_heads,self.head_dim)
    v = v.reshape(batch_size,self.seq_len,self.n_heads,self.head_dim)
    q,k,v = q.transpose(1,2),k.transpose(1,2),v.transpose(1,2) # batch,seq_len,n_heads,d_k ->  batch,n_heads,seq_len,d_k
    x,attention_scores = self.attention(q,k,v,mask)
    x = x.transpose(1,2).contiguous().view(batch_size,self.seq_len,self.d_model) #  batch,n_heads,seq_len,d_k ->  batch,seq_len,n_heads,d_k -> batch,seq_len,d_model
    return self.w_o(x)


In [84]:
class ResBlock(nn.Module):
  def __init__(self,d_model,seq_len,drop_out=0.1):
    super().__init__()
    self.norm = LayerNorm(ndim=d_model)
    self.drop_out = nn.Dropout(drop_out)

  def forward(self,x,sub):
    return x + self.drop_out(sub(self.norm(x)))

In [85]:
class EncoderBlock(nn.Module):
  def __init__(self,d_model,seq_len,n_heads=6,ff_dim=2048,drop_out=0.1):
    super().__init__()
    self.mha = MultiHeadAttentionBlock(d_model=d_model,seq_len=seq_len,n_heads=n_heads,drop_out=drop_out)
    self.mlp = MLP(d_model=d_model,ff_dim=ff_dim,drop_out=drop_out)
    self.res = nn.ModuleList([ResBlock(d_model,seq_len,drop_out) for _ in range(2)])

  def forward(self,x,src_mask):
    x = self.res[0](x,lambda x:self.mha(x,x,x,src_mask))
    x = self.res[1](x,self.mlp)
    return x

class Encoder(nn.Module):
  def __init__(self,d_model,module_list):
    super().__init__()
    self.module_list = module_list
    self.norm = LayerNorm(ndim=d_model)

  def forward(self,x,src_mask):
    for module in self.module_list:
      x = module(x,src_mask)
    return self.norm(x)

In [86]:
class DecoderBlock(nn.Module):
  def __init__(self,d_model,seq_len,n_heads=6,ff_dim=512,drop_out=0.1):
    super().__init__()
    self.self_attention = MultiHeadAttentionBlock(d_model=d_model,seq_len=seq_len,n_heads=n_heads,drop_out=drop_out)
    self.cross_attention = MultiHeadAttentionBlock(d_model=d_model,seq_len=seq_len,n_heads=n_heads,drop_out=drop_out)
    self.mlp = MLP(d_model=d_model,ff_dim=ff_dim,drop_out=drop_out)
    self.layer_norm1 = LayerNorm(ndim=d_model)
    self.layer_norm2 = LayerNorm(ndim=d_model)

  def forward(self,x,encoder_output,src_mask,tgt_mask):
    x_n = self.layer_norm1(x)
    x = x + self.self_attention(x_n,x_n,x_n,tgt_mask)
    x = x + self.mlp(self.layer_norm2(x))
    return x

class Decoder(nn.Module):
  def __init__(self,d_model,module_list):
    super().__init__()
    self.module_list = nn.ModuleList(module_list)
    self.norm = LayerNorm(ndim=d_model)

  def forward(self,x,encoder_output,src_mask,tgt_mask):
    for layer in self.module_list:
      x = layer(x,encoder_output,src_mask,tgt_mask)

    return self.norm(x)

In [87]:
class Classification(nn.Module):
  def __init__(self,d_model,vocab_size):
    super().__init__()
    self.layer = nn.Linear(d_model,vocab_size)

  def forward(self,x):
    return F.log_softmax(self.layer(x),dim=-1)


In [92]:
class Transformer(nn.Module):
  def __init__(self,vocab_size,seq_len,ignore_index=-100,d_model=256,n_heads=4,ff_dim=512,drop_out=0.1):
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.vocab_size = vocab_size
    self.input_embedding = InputEmbedding(vocab_size,d_model)
    self.output_embedding = InputEmbedding(vocab_size,d_model)
    self.positional_encoding_src = PositionalEncoding(d_model,seq_len,drop_out)
    self.positional_encoding_tgt = PositionalEncoding(d_model,seq_len,drop_out)

    self.encoder = Encoder(module_list=[EncoderBlock(d_model,seq_len,n_heads,ff_dim,drop_out) for _ in range(6)],
                           d_model=d_model)
    self.decoder = Decoder(module_list=[DecoderBlock(d_model,seq_len,n_heads,ff_dim,drop_out) for _ in range(6)],
                           d_model=d_model)
    self.classification = Classification(d_model,vocab_size)
    self.cross_entropy = nn.CrossEntropyLoss(ignore_index=ignore_index)

  def encode(self,src,src_mask):
    src = self.input_embedding(src)
    src = self.positional_encoding_src(src)
    return self.encoder(src,src_mask)

  def decode(self,tgt,encoder_output,src_mask,tgt_mask):
    tgt = self.output_embedding(tgt)
    tgt = self.positional_encoding_tgt(tgt)
    return self.decoder(tgt,encoder_output,src_mask,tgt_mask)

  def classify(self,x):
    return self.classification(x)

  def forward(self,src,tgt,src_mask,tgt_mask):
    #decoder only
    decoder_output = self.decode(src,src,src_mask,src_mask)
    logits = self.classify(decoder_output)
    loss = self.cross_entropy(logits.view(-1, logits.size(-1)),tgt.view(-1))
    return logits,loss



In [93]:
import numpy as np
from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
# Read the text data from the file
with open('./trump_3.6.txt', 'r', encoding='utf-8') as file:
    text = file.read()

lines = text.split(".")
lines = [lines[i]+". " + lines[i + 1]+"." for i in range(0, len(lines) - 1, 2)]

max_length = 65
# Tokenize each line and extract input_ids, pad/truncate to max_length
encoded_lines = [
    tokenizer.encode(
        line.strip(),
        add_special_tokens=True,
        padding='max_length',
        max_length=max_length,
        truncation=True
    )
    for line in lines if line.strip()
]
print(len(encoded_lines[0]))
# Filter out any lines that may have resulted in empty tokens
#encoded_lines = torch.Tensor(encoded_lines)


65


In [94]:
input_tensor = torch.tensor(encoded_lines)

# Remove the </s> token from the src sequences
# We will identify the </s> token ID from the tokenizer
eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id

seq_len=64
# Remove the last token from src if it is the </s> token
src = torch.where(input_tensor[:, :-1] == eos_token_id, tokenizer.pad_token_id, input_tensor[:, :-1]).to('cuda')
targets = input_tensor[:, 1:].to('cuda')  # Shifted target, all tokens except the first one
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
tril_mask=tril_mask.unsqueeze(0).expand(src.size(0), -1, -1).to('cuda')

print(src[0],targets[0])

tensor([    0,  2387,  2598,  1791,     6,    38,   236,     7,  1994,     7,
           47,  3422,    59,     5, 15554,  1061,     9,     5,   375,   186,
            4,  1437,   287,    38,    33,    26,     6,     5,  5853, 29471,
            9,     5,   382,  6107,  2322,    23,     5,   182,  1144,     9,
           84,  3497,     4,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1], device='cuda:0') tensor([ 2387,  2598,  1791,     6,    38,   236,     7,  1994,     7,    47,
         3422,    59,     5, 15554,  1061,     9,     5,   375,   186,     4,
         1437,   287,    38,    33,    26,     6,     5,  5853, 29471,     9,
            5,   382,  6107,  2322,    23,     5,   182,  1144,     9,    84,
         3497,     4,     2,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
          

In [95]:
def save(epoch,model,optim,loss):
  checkpoint_path = 'model_checkpoint.pth'
  torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optim.state_dict(),
      'loss': loss,
  }, checkpoint_path)

In [107]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau


class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self.last_epoch = self.after_scheduler.last_epoch + self.total_epoch + 1
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

In [108]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(src, targets,tril_mask)
batch_size = 64  # Adjust based on your memory capacity and requirements

# Create the DataLoader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = Transformer(vocab_size=tokenizer.vocab_size,
                    ignore_index=pad_token_id,
                    seq_len=seq_len)
total_params = sum(p.numel() for p in model.parameters())
print("Total Params",total_params)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00002,betas=(0.9, 0.98), eps=1e-9,weight_decay=0.001)
model.to('cuda:0')
model.train()
epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,eta_min=1e-8,T_max=epochs)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=epochs, after_scheduler=lr_scheduler)

for epoch in range(epochs):
  for batch in data_loader:
      batch_src, batch_targets,batch_mask = batch
      batch_src, batch_targets,batch_mask = batch_src.to('cuda'), batch_targets.to('cuda'),batch_mask.to('cuda')
      logits,loss = model(src=batch_src,
            tgt=batch_targets,
            src_mask=batch_mask,
            tgt_mask=batch_mask)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

      optimizer.step()
      optimizer.zero_grad()
  save(epoch,model,optimizer,loss)
  lr_scheduler.step()
  print("loss",loss)

  #  print("Source batch shape:", batch_src[0])
   # print("Target batch shape:", batch_targets[0])
    #print("Mask batch shape:", batch_mask[0])



Total Params 43389273


KeyboardInterrupt: 