In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
import torch.nn as nn
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from soft_embedding import SoftEmbedding
from dataset import BiasDataset
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import pandas as pd

In [2]:
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [3]:
### TO CHECK WHETHER THE output_1.logits[:, -1, :] returns the next word logits for each sentence in the batch.

def get_next_word_log_probabilities(model_input_1, model_input_2, model, logsoftmax):
    #import pdb; pdb.set_trace()
    output_1 = model(input_ids = model_input_1[0],attention_mask = model_input_1[1])     
    #here we get the logits from the output, the logits are obtained from the output.logits, here we first get
    #the [batch_size, sequence_len, vocab_size]
    #next_word_logits_1 = torch.stack([torch.index_select(seq_logits, 0, torch.tensor([i+1])) for i, seq_logits in enumerate(output_1.logits)])
    next_word_logits_1 = output_1.logits[:, -1, :]
    next_word_log_probs_1 = logsoftmax(next_word_logits_1)
    
    output_2 = model(input_ids = model_input_2[0],attention_mask = model_input_2[1] ) 
    #next_word_logits_2 = torch.stack([torch.index_select(seq_logits, 0, torch.tensor([i+1])) for i, seq_logits in enumerate(output_2.logits)])
    next_word_logits_2 = output_2.logits[:, -1, :]
    next_word_log_probs_2 = logsoftmax(next_word_logits_2)    

    return next_word_log_probs_1, next_word_log_probs_2


In [4]:
n_tokens = 5
initialize_from_vocab = True
num_epochs = 1

In [5]:
dataset = BiasDataset("../data/occupation.csv", n_tokens = 5)
dataloader = dataset.get_dataloader(batch_size=64)

#dataloader = DataLoader(dataset = dataset, batch_size = 4,  shuffle = False, num_workers=1, collate_fn = collate_fn)

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
#freezing all the layers, and only keeping the tokens one active

for param in model.parameters():
    param.requires_grad_(False)
    
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=False)
model.set_input_embeddings(s_wte)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-2)
logsoftmax = nn.LogSoftmax(dim = 0)
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

In [8]:
#loss before training

In [9]:
loss = 0
curr_loss = []
for iter, data in enumerate(dataloader):
    group_1, group_2 = data[0], data[1]
    with torch.no_grad():
        log_probs1, log_probs2 = get_next_word_log_probabilities(data[0], data[1], model, logsoftmax)
        loss = torch.mean(torch.stack([kl_loss(log_probs1[i], log_probs2[i]) for i in range(len(log_probs1))]))
        curr_loss.append(loss)
    print(f'loss at batch : {iter} is {loss.item()}')

torch.mean(torch.stack(curr_loss))

loss at batch : 0 is 0.34565556049346924
loss at batch : 1 is 0.4419367015361786
loss at batch : 2 is 1.2745273725300876e-09
loss at batch : 3 is 0.06149400770664215
loss at batch : 4 is 0.09740849584341049
loss at batch : 5 is 0.8722692728042603
loss at batch : 6 is 0.00013242590648587793
loss at batch : 7 is 0.000687938358169049
loss at batch : 8 is 1.5783867638674565e-05
loss at batch : 9 is 0.0013247253373265266
loss at batch : 10 is 0.0010080602951347828
loss at batch : 11 is 0.02488534152507782
loss at batch : 12 is 0.0008782527293078601
loss at batch : 13 is 0.00010278807167196646
loss at batch : 14 is 0.0012342375703155994
loss at batch : 15 is 0.00044398586032912135
loss at batch : 16 is 0.0027710909489542246
loss at batch : 17 is 0.03998435288667679
loss at batch : 18 is 0.0006562793860211968
loss at batch : 19 is 0.0006301577086560428
loss at batch : 20 is 0.0005002592806704342
loss at batch : 21 is 0.0015243764501065016
loss at batch : 22 is 0.0008560825372114778
loss at ba

tensor(0.0371)

In [10]:
epoch_loss = []
for epoch in range(num_epochs):

    print(f'epoch : {epoch}')
    loss = 0
    curr_loss = []
    for iter, data in enumerate(dataloader):
        group_1, group_2 = data[0], data[1]
        
        log_probs1, log_probs2 = get_next_word_log_probabilities(data[0], data[1], model, logsoftmax)
        #import pdb; pdb.set_trace()
        #loss = kl_loss(log_probs1, log_probs2)
        loss = torch.mean(torch.stack([kl_loss(log_probs1[i], log_probs2[i]) for i in range(len(log_probs1))]))
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        print(f'loss at batch : {iter} is {loss.item()}')

    epoch_loss.append(curr_loss)

epoch : 0
loss at batch : 0 is 0.34565556049346924
loss at batch : 1 is 0.0062172734178602695
loss at batch : 2 is 0.0011024556588381529
loss at batch : 3 is 0.0010622439440339804
loss at batch : 4 is 0.003473471151664853
loss at batch : 5 is 0.012821296229958534
loss at batch : 6 is 0.000862563494592905
loss at batch : 7 is 0.001400969922542572
loss at batch : 8 is 0.0009375842055305839
loss at batch : 9 is 2.2108559278422035e-05
loss at batch : 10 is 3.261709935031831e-05
loss at batch : 11 is 0.00013592277537100017
loss at batch : 12 is 0.000782068760599941
loss at batch : 13 is 0.0002840063825715333
loss at batch : 14 is 0.00011298262688796967
loss at batch : 15 is 7.823648047633469e-06
loss at batch : 16 is 8.162546691892203e-06
loss at batch : 17 is 0.00017380653298459947
loss at batch : 18 is 0.005470344331115484
loss at batch : 19 is 0.004626250825822353
loss at batch : 20 is 0.0010765298502519727
loss at batch : 21 is 6.0259271776885726e-06
loss at batch : 22 is 1.164656896435

In [11]:
### get the training loss before and after training the model


In [12]:
print(f'epoch : {epoch}')
loss = 0
curr_loss = []
for iter, data in enumerate(dataloader):
    group_1, group_2 = data[0], data[1]
    with torch.no_grad():
        log_probs1, log_probs2 = get_next_word_log_probabilities(data[0], data[1], model, logsoftmax)
        loss = torch.mean(torch.stack([kl_loss(log_probs1[i], log_probs2[i]) for i in range(len(log_probs1))]))
        curr_loss.append(loss)
    print(f'loss at batch : {iter} is {loss.item()}')



epoch : 0
loss at batch : 0 is 0.001299596973694861
loss at batch : 1 is 0.0010230415500700474
loss at batch : 2 is 0.00041725795017555356
loss at batch : 3 is 0.000376284820958972
loss at batch : 4 is 0.00036795789492316544
loss at batch : 5 is 0.0037470136303454638
loss at batch : 6 is 3.571749766706489e-05
loss at batch : 7 is 2.6185654860455543e-05
loss at batch : 8 is 1.8900866052717902e-05
loss at batch : 9 is 1.4819745956629049e-05
loss at batch : 10 is 1.2007320037810132e-05
loss at batch : 11 is 2.8479964385041967e-05
loss at batch : 12 is 0.00028187481802888215
loss at batch : 13 is 0.00013327850319910794
loss at batch : 14 is 0.00011007210559910163
loss at batch : 15 is 6.803842552471906e-05
loss at batch : 16 is 6.454249523812905e-05
loss at batch : 17 is 0.00023160297132562846
loss at batch : 18 is 0.00016904796939343214
loss at batch : 19 is 7.315910625038669e-05
loss at batch : 20 is 0.00011762839858420193
loss at batch : 21 is 9.105166827794164e-05
loss at batch : 22 is

In [13]:
torch.mean(torch.stack(curr_loss))

tensor(0.0003)

In [None]:
this is the KL loss before training with the hard prompt- tensor(0.0054)
