In [1]:
import pandas as pd
import numpy as np
import copy

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

from datasets import Dataset, load_dataset

In [2]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

# Changing Mu

In [3]:
config = {
    'I_iterations': 2,
    'M_models': 2,
    'T_steps': 100, 
    'mu': 0.01,
    'lmbd': 0.5,
    'eta': 0.5,
    '_beta': 1.0,
    'batch_size': 64,
    'train_first_n': 10,
    'test_frist_n': 10,
    'max_length': 50, # bert's max is 512, but it needs quite a lot of memory
    'temperature': 0.9
}

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [4]:
train_df = pd.read_parquet('/kaggle/input/imdb-csv/Train Data.parquet')
test_df = pd.read_parquet('/kaggle/input/imdb-csv/Test Data.parquet')

In [5]:
# TODO CHECK train a model to generate positive reviews
train_df = train_df[train_df['label'] == 1].reset_index(drop=True)
train_df.head()

Unnamed: 0,text,label
0,Zentropa has much in common with The Third Man...,1
1,Zentropa is the most original movie I've seen ...,1
2,Lars Von Trier is never backward in trying out...,1
3,*Contains spoilers due to me having to describ...,1
4,That was the first thing that sprang to mind a...,1


In [6]:
policy_tokenizer = AutoTokenizer.from_pretrained("lvwerra/gpt2-imdb")
init_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")

init_model.to(device)

tokenizer_config.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [7]:
reward_model_path = '/kaggle/input/reward2-0/distilbert-imdb/'

reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_path, local_files_only = True)
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_path, local_files_only = True)

#reward_tokenizer = AutoTokenizer.from_pretrained("lvwerra/distilbert-imdb")
#reward_model = AutoModelForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb") 

reward_model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [8]:
# leave first n and tokenize text and create a dataloader
tokenized = policy_tokenizer(train_df['text'].tolist())
tokenized = [x[:config['train_first_n']] for x in tokenized['input_ids']]
tokenized = torch.IntTensor(tokenized)

Token indices sequence length is longer than the specified maximum sequence length for this model (1117 > 1024). Running this sequence through the model will result in indexing errors


In [9]:
tokenized.shape, tokenized[:3]

(torch.Size([12500, 10]),
 tensor([[   57,   298,  1773,    64,   468,   881,   287,  2219,   351,   383],
         [   57,   298,  1773,    64,   318,   262,   749,  2656,  3807,   314],
         [   43,   945, 26985,   309,  5277,   318,  1239, 19528,   287,  2111]],
        dtype=torch.int32))

In [10]:
train_dataloader = DataLoader(tokenized, batch_size=64, shuffle=True)

In [11]:
def generate(model, idx):
    idx = idx.to(device)
    
    # TODO check each param 
    output = model.generate(idx, max_length = 50, pad_token_id=50256, num_return_sequences = 1, return_dict_in_generate=True, output_scores=True, temperature = temperature)
    output_ids = output['sequences']
    generation = policy_tokenizer.batch_decode(output_ids)
    
    return output_ids.clone(), generation

In [12]:
def reward_fn_sentiment_imdb(gen_sample):
    with torch.no_grad():
        tokens = reward_tokenizer(gen_sample, return_tensors='pt', padding=True, truncation=True)['input_ids'].to(device)
        logits = reward_model(tokens).logits
        positive_cls = logits.softmax(dim=-1)[:, 1] # TODO CHECK that pos = 1 and neg = 0
    return positive_cls.to(device)

In [13]:
def get_kl(training_logits, ref_logits, beta, first_n):
    # TODO write log_softmax
    
    training_logprobs = training_logits.log_softmax(-1)
    ref_logprobs = ref_logits.log_softmax(-1)

    probs = training_logprobs.exp()
    
    kl = (probs * (training_logprobs - ref_logprobs))[:, first_n:-1].sum(-1)
    return beta*kl.mean()

In [14]:
test_generation = "Before Dogma 95: when Lars used movies as art, not just a story. A beautiful painting about love and death. This is one of my favorite movies of all time. The color... The music... Just perfect.	"

rewards = reward_fn_sentiment_imdb(test_generation)
rewards

tensor([0.9939], device='cuda:0')

In [15]:
i_iterations = config['I_iterations']
m_runs = config['M_models']
t_steps = config['T_steps']

beta = config['_beta']
mu = config['mu']
eta = config['eta']
lmbd = config['lmbd']

first_n = config['train_first_n']
temperature = config['temperature']

In [16]:
#loss_lst, kl_lst, normalized_rewads_lst = [], [], []
def t_steps_run(m_model, ema_model_ref):
    steps = 0

    for batch in train_dataloader:

        tokens, generation = generate(m_model, batch)

        m_model.train()
        logits = m_model(tokens).logits
        with torch.no_grad():
            ref_logits = ema_model_ref(tokens).logits.detach()

        rewards = reward_fn_sentiment_imdb(generation).view(-1, 1, 1) # TODO OR SEPARATE VIEW
        mean_before_norm = rewards.mean()
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8) # TODO normalize rewards, without this it is instable
        kl = get_kl(logits, ref_logits, beta, 10)
        rewards = rewards-kl


        log_probs = F.log_softmax(logits[:, config['train_first_n']:, :], dim=-1)

        tokens_generated = tokens[:, config['train_first_n']:]

        selected_log_probs = log_probs.gather(2, tokens_generated.unsqueeze(-1)).squeeze(-1)
        # handle None
        selected_log_probs = torch.where(torch.isfinite(selected_log_probs), selected_log_probs, torch.zeros_like(selected_log_probs))

        policy_gradient_loss = (selected_log_probs * rewards).sum(dim=1).mean()

        loss = policy_gradient_loss - kl

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        steps += 1
        print(f"Step {steps}, Loss: {loss.item():.4f}, KL: {kl.item():.4f}, Reward: {rewards.mean().item():.4f}")
        #loss_lst.append(loss.item())
        #normalized_rewads_lst.append(rewards.item())
        #kl_lst.append(kl.item())
        
        with torch.no_grad():
            for m_param, ref_param in zip(m_model.parameters(), ema_model_ref.parameters()):
                inv_mu = 1-mu
                ref_param.mul_(inv_mu)
                b = mu*m_param
                ref_param.add_(b)
        if steps >= config['T_steps']:
            break

In [17]:
def liti_init_update(slerp_model, eta):
    # actual init model is updated TODO CHANGE because we want to go towards SFT on 2nd iteration
    # or is it? I am not sure
    with torch.no_grad():
        for init_param, slerp_param in zip(init_model.parameters(), slerp_model.parameters()):
            inv_eta = 1-eta
            init_param.mul_(inv_eta)
            b = eta * slerp_param
            init_param.add_(b)

In [18]:
def slerp(theta_init, theta1, theta2, lmbd):
    """
    """
    delta1 = theta1 - theta_init
    delta2 = theta2 - theta_init
    
    # Normalize by assumption
    delta1_norm = torch.norm(delta1)
    delta2_norm = torch.norm(delta2)
    delta1_normalized = delta1 / delta1_norm
    delta2_normalized = delta2 / delta2_norm
    
    cos_omega = torch.sum(delta1_normalized * delta2_normalized)
    cos_omega = torch.clamp(cos_omega, -1.0, 1.0)  # Ensure it's in the valid range for acos
    omega = torch.acos(cos_omega)
    
    sin_omega = torch.sin(omega)
    
    if sin_omega.abs() < 1e-6:
        #If the angle is very small, fall back to linear interpolation, otherwise it is instable
        return (1 - lmbd) * delta1 + lmbd * delta2
    else:   
        term1 = torch.sin((1 - lmbd) * omega) / sin_omega * delta1
        term2 = torch.sin(lmbd * omega) / sin_omega * delta2

        return term1 + term2

In [19]:
for hyperparam in [0.07, 0.15]:
    mu = hyperparam
    init_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")
    print('RUNNNING', mu)
    
    init_model.to(device)

    for i in range(i_iterations):
        models_for_slerp = []
        for run in range(m_runs):
            m_model = copy.deepcopy(init_model)
            ema_model_ref = copy.deepcopy(init_model)

            optimizer = torch.optim.Adam(m_model.parameters(), lr=1e-4)

            # Policy update
            t_steps_run(m_model, ema_model_ref)
            torch.cuda.empty_cache()
            models_for_slerp.append(m_model)

        # SLERP
        with torch.no_grad():
            slerp_model = copy.deepcopy(init_model)
            for slerp_param, theta1, theta2 in zip(slerp_model.parameters(), models_for_slerp[0].parameters(), models_for_slerp[0].parameters()):
                slerp_param.add_(slerp(slerp_param, theta1, theta2, lmbd))
        # LITI
        liti_init_update(slerp_model, eta)
    init_model.save_pretrained(f'/kaggle/working/mu_{mu}')

RUNNNING 0.07


The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Step 1, Loss: 52.2898, KL: 0.0854, Reward: -0.0854
Step 2, Loss: 136.1913, KL: 0.1873, Reward: -0.1873
Step 3, Loss: 67.4367, KL: 0.1167, Reward: -0.1167
Step 4, Loss: 69.1786, KL: 0.1296, Reward: -0.1296
Step 5, Loss: 56.5833, KL: 0.1064, Reward: -0.1064
Step 6, Loss: 35.3518, KL: 0.0698, Reward: -0.0698
Step 7, Loss: 36.1010, KL: 0.0707, Reward: -0.0707
Step 8, Loss: 36.2548, KL: 0.0702, Reward: -0.0702
Step 9, Loss: 37.5039, KL: 0.0696, Reward: -0.0696
Step 10, Loss: 38.6981, KL: 0.0712, Reward: -0.0712
Step 11, Loss: 32.9651, KL: 0.0642, Reward: -0.0642
Step 12, Loss: 32.5781, KL: 0.0624, Reward: -0.0624
Step 13, Loss: 30.2782, KL: 0.0573, Reward: -0.0573
Step 14, Loss: 30.6160, KL: 0.0588, Reward: -0.0588
Step 15, Loss: 31.2701, KL: 0.0604, Reward: -0.0604
Step 16, Loss: 29.5445, KL: 0.0581, Reward: -0.0581
Step 17, Loss: 27.8676, KL: 0.0540, Reward: -0.0540
Step 18, Loss: 25.6385, KL: 0.0512, Reward: -0.0512
Step 19, Loss: 25.9390, KL: 0.0521, Reward: -0.0521
Step 20, Loss: 26.44



Step 1, Loss: 57.4906, KL: 0.0944, Reward: -0.0944
Step 2, Loss: 116.6413, KL: 0.1617, Reward: -0.1617
Step 3, Loss: 58.0960, KL: 0.0959, Reward: -0.0959
Step 4, Loss: 52.1122, KL: 0.0936, Reward: -0.0936
Step 5, Loss: 43.5023, KL: 0.0791, Reward: -0.0791
Step 6, Loss: 31.5477, KL: 0.0575, Reward: -0.0575
Step 7, Loss: 32.6919, KL: 0.0600, Reward: -0.0600
Step 8, Loss: 28.3362, KL: 0.0517, Reward: -0.0517
Step 9, Loss: 32.1518, KL: 0.0585, Reward: -0.0585
Step 10, Loss: 26.1243, KL: 0.0479, Reward: -0.0479
Step 11, Loss: 30.0832, KL: 0.0562, Reward: -0.0562
Step 12, Loss: 27.6557, KL: 0.0518, Reward: -0.0518
Step 13, Loss: 27.4117, KL: 0.0523, Reward: -0.0523
Step 14, Loss: 26.0417, KL: 0.0509, Reward: -0.0509
Step 15, Loss: 24.4271, KL: 0.0488, Reward: -0.0488
Step 16, Loss: 26.5435, KL: 0.0508, Reward: -0.0508
Step 17, Loss: 25.0870, KL: 0.0497, Reward: -0.0497
Step 18, Loss: 23.1369, KL: 0.0468, Reward: -0.0468
Step 19, Loss: 23.0256, KL: 0.0468, Reward: -0.0468
Step 20, Loss: 22.62