In [190]:
import torch
import torch.nn as nn
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import torch.nn.functional as f
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


from einops import rearrange

In [191]:
SRC_VOCAB_SIZE = 1000  
TGT_VOCAB_SIZE = 1000 

EMBED_SIZE = 512
HIDDEN_SIZE = 512
N_LAYERS = 6
N_HEADS = 8
MAX_LEN = 512
FF_HIDDEN_MULT = 4
DROPOUT = 0.1

LR = 1e-4
BATCH_SIZE = 32
PRETRAIN_EPOCHS = 5
REWARD_MODEL_EPOCHS = 3
PPO_UPDATES = 3
EPS_CLIP = 0.2
GAMMA = 0.99
MAX_SEQ_LEN = 50 

DEVICE = torch.device('mps')

In [192]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=512):
        super(PositionalEncoding, self).__init__()

        self.encoding = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_size))

        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)

        self.encoding = self.encoding.unsqueeze(0)  # Shape: [1, max_len, embed_size]

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :].to(x.device)

In [193]:

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_head):
        super(MultiHeadAttention, self).__init__()

        self.dim = dim
        self.n_head = n_head
        self.head_dim = self.dim // self.n_head

        assert self.head_dim * self.n_head == self.dim, "embed_dim must be divisible by num_heads"

        self.fc_q = nn.Linear(dim, dim)
        self.fc_k = nn.Linear(dim, dim)
        self.fc_v = nn.Linear(dim, dim)

        self.fc_out = nn.Linear(dim, dim)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        q = self.fc_q(q)  # [batch_size, seq_len, dim]
        k = self.fc_k(k)
        v = self.fc_v(v)

        q = rearrange(q, 'b s (h d) -> b h s d', h=self.n_head)  # [batch_size, n_head, seq_len, head_dim]
        k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_head)
        v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_head)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [b, h, s, s]
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)  # [b, h, s, s]

        out = torch.matmul(attn, v)  # [b, h, s, head_dim]
        out = rearrange(out, 'b h s d -> b s (h d)')  # [b, s, dim]

        out = self.fc_out(out)  # [b, s, dim]
        return out


In [194]:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_size, n_head, ff_hidden_mult=4, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()

        self.attn = MultiHeadAttention(dim=embed_size, n_head=n_head)
        self.norm1 = nn.LayerNorm(embed_size)
        self.dropout1 = nn.Dropout(dropout)

        self.ff = nn.Sequential(
            nn.Linear(embed_size, embed_size * ff_hidden_mult),
            nn.ReLU(),
            nn.Linear(embed_size * ff_hidden_mult, embed_size)
        )
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))

        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_out))
        return x


In [195]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_size, n_head, ff_hidden_mult=4, dropout=0.1):
        super(TransformerDecoderBlock, self).__init__()

        self.self_attn = MultiHeadAttention(dim=embed_size, n_head=n_head)
        self.norm1 = nn.LayerNorm(embed_size)
        self.dropout1 = nn.Dropout(dropout)

        self.cross_attn = MultiHeadAttention(dim=embed_size, n_head=n_head)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout2 = nn.Dropout(dropout)

        self.ff = nn.Sequential(
            nn.Linear(embed_size, embed_size * ff_hidden_mult),
            nn.ReLU(),
            nn.Linear(embed_size * ff_hidden_mult, embed_size)
        )
        self.norm3 = nn.LayerNorm(embed_size)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        # in self attention the target mask is used, to ensure the model
        # does not cheet and considers only valid positions and perceiding tokens
        self_attn_out = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(self_attn_out))


        # in cross attention we use source mask
        # to ensure generated tokens align with input context
        cross_attn_out = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.dropout2(cross_attn_out))

        ff_out = self.ff(x)
        x = self.norm3(x + self.dropout3(ff_out))
        return x


In [196]:

class TransformerSeq2Seq(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        embed_size=512,
        num_encoder_layers=6,
        num_decoder_layers=6,
        n_head=8,
        max_len=512,
        ff_hidden_mult=4,
        dropout=0.1,
        tokenizer_pad_token_id=0
    ):
        super(TransformerSeq2Seq, self).__init__()

        self.src_embed = nn.Embedding(src_vocab_size, embed_size, padding_idx=tokenizer_pad_token_id)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, embed_size, padding_idx=tokenizer_pad_token_id)
        self.pos_encoder = PositionalEncoding(embed_size, max_len)
        self.pos_decoder = PositionalEncoding(embed_size, max_len)

        self.encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(embed_size, n_head, ff_hidden_mult, dropout)
            for _ in range(num_encoder_layers)
        ])

        self.decoder_layers = nn.ModuleList([
            TransformerDecoderBlock(embed_size, n_head, ff_hidden_mult, dropout)
            for _ in range(num_decoder_layers)
        ])

        self.fc_out = nn.Linear(embed_size, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

        self.tokenizer_pad_token_id = tokenizer_pad_token_id

    def make_src_mask(self, src):
        src_mask = (src != self.tokenizer_pad_token_id).unsqueeze(1).unsqueeze(2) 
        return src_mask 

    def make_tgt_mask(self, tgt):
        tgt_seq_len = tgt.size(1)
        tgt_mask = (tgt != self.tokenizer_pad_token_id).unsqueeze(1).unsqueeze(2)  

        subsequent_mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len), device=tgt.device)).bool()  
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(0)  

        tgt_mask = tgt_mask & subsequent_mask  
        return tgt_mask

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)  
        tgt_mask = self.make_tgt_mask(tgt)  

        enc_out = self.src_embed(src)  
        enc_out = self.pos_encoder(enc_out)
        for layer in self.encoder_layers:
            enc_out = layer(enc_out, src_mask)
        dec_out = self.tgt_embed(tgt)  
        dec_out = self.pos_decoder(dec_out)
     
        for layer in self.decoder_layers:
            dec_out = layer(dec_out, enc_out, src_mask, tgt_mask)

        output = self.fc_out(dec_out)  
        return output


In [197]:
class ExampleDataset(Dataset):
    def __init__(self, num_samples=1000, src_seq_len=20, tgt_seq_len=20, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, pad_idx=0):
        super(ExampleDataset, self).__init__()
        self.num_samples = num_samples
        self.src_seq_len = src_seq_len
        self.tgt_seq_len = tgt_seq_len
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.pad_idx = pad_idx

        self.src_data = torch.randint(1, src_vocab_size, (num_samples, src_seq_len))
        self.tgt_data = torch.randint(1, tgt_vocab_size, (num_samples, tgt_seq_len))
        for i in range(num_samples):
            pad_length = torch.randint(5, src_seq_len, (1,)).item()
            self.src_data[i, pad_length:] = pad_idx
            pad_length = torch.randint(5, tgt_seq_len, (1,)).item()
            self.tgt_data[i, pad_length:] = pad_idx

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.src_data[idx], self.tgt_data[idx]

In [198]:
ds = ExampleDataset(num_samples = 1000)
sample_x, sample_y = ds[0]

In [199]:
sample_x.shape, sample_y.shape

(torch.Size([20]), torch.Size([20]))

In [200]:
def train_seq2seq(model, dataloader, n_epochs = 1, lr = 1e-4, voc_size = TGT_VOCAB_SIZE):
    criterion = nn.CrossEntropyLoss(ignore_index = 0)
    device = DEVICE
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)


    model.train()

    model = model.to(device)


    for epoch in tqdm(range(n_epochs), total = n_epochs, desc = 'Pretraining'):
        epoch_loss = 0.0

        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            optimizer.zero_grad()
            logits = model(src, tgt_input)
            loss = criterion(logits.reshape(-1, voc_size), tgt_output.reshape(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Pretraining Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}")


In [202]:
def generate_sequence(model, src, max_length=MAX_SEQ_LEN):

    model.eval()
    device = DEVICE
    src = src.to(device)
    src_mask = model.make_src_mask(src)
    enc_out = model.src_embed(src)
    enc_out = model.pos_encoder(enc_out)
    for layer in model.encoder_layers:
        enc_out = layer(enc_out, src_mask)

    generated = torch.zeros(src.size(0), 1, dtype=torch.long).to(device)  
    for _ in range(max_length):
        dec_out = model.tgt_embed(generated)
        dec_out = model.pos_decoder(dec_out)
        for layer in model.decoder_layers:
            dec_out = layer(dec_out, enc_out, src_mask, model.make_tgt_mask(generated))
        logits = model.fc_out(dec_out) 
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  
        generated = torch.cat((generated, next_token), dim=1)
    return generated  


In [203]:
class RewardModel(nn.Module):
    def __init__(self,
                 voc_size,
                 emb_size,
                 hidden_size,
                 n_layers,
                 pad_token_id = 0):
        super(RewardModel, self).__init__()

        self.embedding = nn.Embedding(voc_size, emb_size, padding_idx = pad_token_id)
        encoder_layer = nn.TransformerEncoderLayer(d_model = emb_size, nhead = 8)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = n_layers)
        self.fc_out = nn.Linear(emb_size, 1)

    def forward(self, seq):
        emb = self.embedding(seq)
        emb = emb.transpose(0, 1)
        output = self.transformer_encoder(emb)
        output = output.mean(dim = 0)
        logits = self.fc_out(output)
        return logits.squeeze()
    


rm = RewardModel(voc_size = TGT_VOCAB_SIZE, emb_size = EMBED_SIZE,
                 hidden_size = HIDDEN_SIZE, n_layers = N_LAYERS, 
                 pad_token_id = 0).to(DEVICE)

rm.forward(seq = torch.randint(0, TGT_VOCAB_SIZE, (BATCH_SIZE, MAX_SEQ_LEN)).to(DEVICE))

tensor([ 0.0021, -0.0887, -0.0128, -0.0071,  0.0398, -0.0389, -0.0871, -0.1925,
        -0.0797, -0.2083, -0.3549, -0.4410, -0.3121, -0.3292,  0.0134,  0.0187,
        -0.0349, -0.2636, -0.1654, -0.1040, -0.2639, -0.0635, -0.1438, -0.1376,
        -0.1905,  0.0026, -0.0844, -0.1132, -0.2061, -0.0153,  0.0542, -0.2214],
       device='mps:0', grad_fn=<SqueezeBackward0>)

In [2]:
def simulate_human_feedback(seq):
    reward = (seq == 42).float().mean().item()
    return reward 



0.5

In [205]:
def collect_reward_model_data(model, dataloader):
    device = DEVICE
    model.eval()
    sequences = []
    rewards = []
    with torch.no_grad():
        for src, tgt in tqdm(dataloader, desc = 'Collecting data for reward model'):
            src = src.to(device)
            generated_seq = generate_sequence(model, src, max_length = MAX_SEQ_LEN)
            for seq in generated_seq:
                reward = simulate_human_feedback(seq)
                sequences.append(seq)
                rewards.append(reward)
    return sequences, rewards

In [206]:
def generate_sequence(model, src, max_length=MAX_SEQ_LEN):

    model.eval()
    device = DEVICE
    src = src.to(device)
    src_mask = model.make_src_mask(src)
    enc_out = model.src_embed(src)
    enc_out = model.pos_encoder(enc_out)
    for layer in model.encoder_layers:
        enc_out = layer(enc_out, src_mask)

    generated = torch.zeros(src.size(0), 1, dtype=torch.long).to(device)  
    for _ in range(max_length):
        dec_out = model.tgt_embed(generated)
        dec_out = model.pos_decoder(dec_out)
        for layer in model.decoder_layers:
            dec_out = layer(dec_out, enc_out, src_mask, model.make_src_mask(generated))
        logits = model.fc_out(dec_out) 
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  
        generated = torch.cat((generated, next_token), dim=1)
    return generated  

In [207]:
def train_reward_model(reward_model, sequences, rewards, num_epochs=2, lr=1e-4, b_size=32):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=lr)
    reward_model.train()
    device = DEVICE
    dataset = torch.utils.data.TensorDataset(torch.stack(sequences), torch.tensor(rewards))
    dataloader = DataLoader(dataset, batch_size=b_size, shuffle=True)
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for seq_batch, reward_batch in tqdm(dataloader, desc=f"Training Reward Model Epoch {epoch+1}/{num_epochs}"):
            
            seq_batch = seq_batch.to(device) 
            reward_batch = reward_batch.to(device, dtype = torch.float32)
            optimizer.zero_grad()
            predicted_rewards = reward_model(seq_batch).float()  # [b]
            loss = criterion(predicted_rewards, reward_batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        print(f"Reward Model Training Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")


In [208]:

class PPOAgent:
    def __init__(self, policy_model, reward_model,
                 lr=1e-4, eps_clip=0.2,
                 gamma=0.99, k_epochs=3,
                 pad_token_id=0, voc_size=TGT_VOCAB_SIZE):
        self.policy = policy_model
        self.policy_old = TransformerSeq2Seq(
            src_vocab_size=SRC_VOCAB_SIZE,
            tgt_vocab_size=TGT_VOCAB_SIZE,
            embed_size=EMBED_SIZE,
            num_encoder_layers=N_LAYERS,
            num_decoder_layers=N_LAYERS,
            n_head=N_HEADS,
            max_len=MAX_LEN,
            ff_hidden_mult=FF_HIDDEN_MULT,
            dropout=DROPOUT,
            tokenizer_pad_token_id=pad_token_id
        ).to(DEVICE)
        self.policy_old.load_state_dict(policy_model.state_dict())
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.eps_clip = eps_clip
        self.gamma = gamma
        self.k_epochs = k_epochs
        self.reward_model = reward_model
        self.mse_loss = nn.MSELoss()
        self.voc_size = voc_size
        self.pad_token_id = pad_token_id

    def generate(self, src, max_len=MAX_SEQ_LEN):
        self.policy.eval()
        device = DEVICE
        src = src.to(device)
        with torch.no_grad():
            generated_seq = generate_sequence(self.policy, src, max_length=max_len)
        return generated_seq  

    def update(self, memory):
        old_srcs = torch.cat(memory.srcs, dim=0).to(DEVICE)       
        old_states = torch.cat(memory.states, dim=0).to(DEVICE)     
        old_actions = torch.cat(memory.actions, dim=0).to(DEVICE)   
        old_logprobs = torch.cat(memory.logprobs, dim=0).to(DEVICE)  
        rewards = torch.tensor(memory.rewards, dtype=torch.float32).to(DEVICE)  

        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        advantages = rewards 

        for epoch in range(self.k_epochs):  # Iterate through epochs
            logits = self.policy(old_srcs, old_states)  
            logprobs = nn.functional.log_softmax(logits, dim=-1) 
            action_logprobs = logprobs.gather(2, old_actions.unsqueeze(-1)).squeeze(-1) 

            action_logprobs = action_logprobs.view(-1)  
            old_logprobs = old_logprobs.view(-1)        
            ratios = torch.exp(action_logprobs - old_logprobs.detach())  

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2).mean()

            print(f"Epoch {epoch + 1}/{self.k_epochs}, Loss: {loss.item()}")

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())


In [209]:

class PPOMemory:
    def __init__(self):
        self.states = []      
        self.actions = []    
        self.logprobs = []    
        self.rewards = []    
        self.srcs = []      

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.srcs = []


In [210]:

def finetune_llm_PPO(agent, dataloader, num_updates=1):
    memory = PPOMemory()
    for update in range(num_updates):
        for src, _ in tqdm(dataloader, desc=f"PPO Update {update+1}/{num_updates}"):
            src = src.to(DEVICE)
            generated_seq = agent.generate(src)  

            predicted_rewards = agent.reward_model(generated_seq)  

            memory.srcs.append(src.cpu())

            states = generated_seq[:, :-1] 
            actions = generated_seq[:, 1:]  
            memory.states.append(states.cpu())
            memory.actions.append(actions.cpu())

            with torch.no_grad():
                logits = agent.policy_old(src, states.to(DEVICE))  
                logprobs = nn.functional.log_softmax(logits, dim=-1) 
                action_logprobs = logprobs.gather(2, actions.unsqueeze(-1)).squeeze(-1) 
                memory.logprobs.append(action_logprobs.cpu())


            batch_rewards = predicted_rewards.detach().cpu().numpy() 
            seq_len = actions.size(1)
            for reward in batch_rewards:
                memory.rewards.extend([reward] * seq_len)

        agent.update(memory)

        memory.clear_memory()
        print(f"PPO Update {update+1}/{num_updates} completed.")


In [211]:
# Why updating old policy with new self.policy_old.load_state_dict(self.policy.state_dict())


# and also why [reward] * seq_len

In [212]:
dataset = ExampleDataset(num_samples = 100)

dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)


model = TransformerSeq2Seq(
    src_vocab_size=SRC_VOCAB_SIZE,
    tgt_vocab_size=TGT_VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_encoder_layers=N_LAYERS,
    num_decoder_layers=N_LAYERS,
    n_head=N_HEADS,
    max_len=MAX_LEN,
    ff_hidden_mult=FF_HIDDEN_MULT,
    dropout=DROPOUT,
    tokenizer_pad_token_id=0
).to(DEVICE)



reward_model = RewardModel(voc_size = TGT_VOCAB_SIZE,
                           emb_size = EMBED_SIZE,
                           hidden_size = HIDDEN_SIZE,
                           n_layers = N_LAYERS,
                           pad_token_id = 0).to(DEVICE)




print('TRAINING SEQ2SEQ')
train_seq2seq(model, dataloader)




print('TRAINING REWARD MODEL...')

sequences, rewards = collect_reward_model_data(model, dataloader)
print('COLLECTING DATA FOR REWARD MODEL')


train_reward_model(reward_model, sequences, rewards, num_epochs = REWARD_MODEL_EPOCHS,
                   lr = LR, b_size = BATCH_SIZE)

ppo_agent = PPOAgent(
    policy_model = model,
    reward_model = reward_model,
    lr = LR,
    eps_clip = EPS_CLIP,
    gamma = GAMMA,
    k_epochs = 4,
    pad_token_id = 0,
    voc_size = TGT_VOCAB_SIZE
)



TRAINING SEQ2SEQ


Pretraining: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


Pretraining Epoch 1/1, Loss: 7.0815
TRAINING REWARD MODEL...


Collecting data for reward model: 100%|██████████| 4/4 [00:02<00:00,  1.71it/s]


COLLECTING DATA FOR REWARD MODEL


Training Reward Model Epoch 1/3: 100%|██████████| 4/4 [00:00<00:00, 15.99it/s]


Reward Model Training Epoch 1/3, Loss: 7.1017


Training Reward Model Epoch 2/3: 100%|██████████| 4/4 [00:00<00:00, 23.67it/s]


Reward Model Training Epoch 2/3, Loss: 1.8748


Training Reward Model Epoch 3/3: 100%|██████████| 4/4 [00:00<00:00, 24.01it/s]


Reward Model Training Epoch 3/3, Loss: 0.3821
