In [2]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
import torch.nn as nn
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from soft_embedding import SoftEmbedding
from dataset import BiasDataset
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import pandas as pd

In [3]:
torch.device('cuda' if torch.cuda.is_available() else 'cpu')


device(type='cuda')

In [4]:
### TO CHECK WHETHER THE output_1.logits[:, -1, :] returns the next word logits for each sentence in the batch.

def get_next_word_log_probabilities(model_input_1, model_input_2, model, logsoftmax):
    #import pdb; pdb.set_trace()
    output_1 = model(input_ids = model_input_1[0],attention_mask = model_input_1[1])     
    #here we get the logits from the output, the logits are obtained from the output.logits, here we first get
    #the [batch_size, sequence_len, vocab_size]
    #next_word_logits_1 = torch.stack([torch.index_select(seq_logits, 0, torch.tensor([i+1])) for i, seq_logits in enumerate(output_1.logits)])
    next_word_logits_1 = output_1.logits[:, -1, :]
    next_word_log_probs_1 = logsoftmax(next_word_logits_1)
    
    output_2 = model(input_ids = model_input_2[0],attention_mask = model_input_2[1] ) 
    #next_word_logits_2 = torch.stack([torch.index_select(seq_logits, 0, torch.tensor([i+1])) for i, seq_logits in enumerate(output_2.logits)])
    next_word_logits_2 = output_2.logits[:, -1, :]
    next_word_log_probs_2 = logsoftmax(next_word_logits_2)    

    return next_word_log_probs_1, next_word_log_probs_2


In [5]:
n_tokens = 5
initialize_from_vocab = True
num_epochs = 1

In [6]:
dataset = BiasDataset("../data/occupation.csv", n_tokens = 5)
dataloader = dataset.get_dataloader(batch_size=64)

#dataloader = DataLoader(dataset = dataset, batch_size = 4,  shuffle = False, num_workers=1, collate_fn = collate_fn)

In [7]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
#freezing all the layers, and only keeping the tokens one active

for param in model.parameters():
    param.requires_grad_(False)
    
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=initialize_from_vocab)
model.set_input_embeddings(s_wte)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
logsoftmax = nn.LogSoftmax(dim = 0)
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

In [9]:
epoch_loss = []
for epoch in range(num_epochs):

    print(f'epoch : {epoch}')
    loss = 0
    curr_loss = []
    for iter, data in enumerate(dataloader):
        group_1, group_2 = data[0], data[1]
        
        log_probs1, log_probs2 = get_next_word_log_probabilities(data[0], data[1], model, logsoftmax)
        loss = kl_loss(log_probs1, log_probs2)
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        print(f'loss at batch : {iter} is {loss.item()}')

    epoch_loss.append(curr_loss)

epoch : 0
loss at batch : 0 is 9241.669921875
loss at batch : 1 is 280.6936340332031
loss at batch : 2 is 3.866910934448242
loss at batch : 3 is 0.8037680387496948
loss at batch : 4 is 2.2142293453216553
loss at batch : 5 is 244.95872497558594
loss at batch : 6 is 23.449047088623047
loss at batch : 7 is 4.415591716766357
loss at batch : 8 is 3.141282558441162
loss at batch : 9 is 1.986946940422058
loss at batch : 10 is 7.846096992492676
loss at batch : 11 is 10.516422271728516
loss at batch : 12 is 30.058589935302734
loss at batch : 13 is 40.47797393798828
loss at batch : 14 is 64.55730438232422
loss at batch : 15 is 17.30180549621582
loss at batch : 16 is 2.0929555892944336
loss at batch : 17 is 5.76794958114624
loss at batch : 18 is 21.468891143798828
loss at batch : 19 is 12.05396842956543
loss at batch : 20 is 3.1258740425109863
loss at batch : 21 is 3.876750946044922
loss at batch : 22 is 1.9230612516403198
loss at batch : 23 is 9.9710693359375
loss at batch : 24 is 12.19278335571