In [2]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

class BiasDataset(Dataset):
    
    def __init__(self, data_path, n_tokens):
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        data = pd.read_csv(data_path)
        
        data = data.head(20)
        self.x = data.to_numpy()
        self.n_samples = len(data)
        self.n_tokens = n_tokens
        
        
        
    def __getitem__(self, index):
        self.tokenizer.pad_token = self.tokenizer.eos_token        
        ## return a bunch of sentences here, at any instance the dataloader gets called only once,         
        sentence1, sentence2 = self.x[index][0], self.x[index][1]   
        return sentence1, sentence2
            
        
    def __len__(self):
        return self.n_samples
    
    
    def collate_fn(self, batch):
        #separate the sentences for the two identity groups so that in a batch can send to the model.
        #returns for this, to get a tensor of lists from the list of tensors.

        batch_1, batch_2 = [ele[0] for ele in batch], [ele[1] for ele in batch]
        
        tokenized_1 = self.tokenizer.batch_encode_plus(batch_1, padding=True, truncation=True, return_tensors="pt")
        input_ids_1 = tokenized_1["input_ids"]
        attention_masks_1 = tokenized_1["attention_mask"]       
        input_ids_1 = torch.stack([torch.cat([torch.full((1, self.n_tokens), 50256)[0], input_id], 0) for input_id in input_ids_1])
        attention_masks_1 = torch.stack([torch.cat([torch.full((1, self.n_tokens), 1)[0], attention_mask], 0) for attention_mask in attention_masks_1])         
                
        tokenized_2 = self.tokenizer.batch_encode_plus(batch_2, padding=True, truncation=True, return_tensors="pt")
        input_ids_2 = tokenized_2["input_ids"]
        attention_masks_2 = tokenized_2["attention_mask"]        
        input_ids_2 = torch.stack([torch.cat([torch.full((1, self.n_tokens), 50256)[0], input_id], 0) for input_id in input_ids_2])
        attention_masks_2 = torch.stack([torch.cat([torch.full((1, self.n_tokens), 1)[0], attention_mask], 0) for attention_mask in attention_masks_2])
        
        return (input_ids_1, attention_masks_1, batch_1), (input_ids_2, attention_masks_2, batch_2)
    
    def get_dataloader(self, batch_size):
        return DataLoader(self, batch_size=batch_size, collate_fn=self.collate_fn)
    

In [3]:
dataset = BiasDataset("../data/occupation.csv", n_tokens = 5)

dataset[0], dataset[1]

(('gentleman had a reputation for', 'lady had a reputation for'),
 ('gentleman was famous for', 'lady was famous for'))

In [4]:
dataloader = dataset.get_dataloader(batch_size=5)


In [5]:
for data in dataloader:
    print("gentleman part is : ")
    print(data[0])
    model = GPT2LMHeadModel.from_pretrained('gpt2')

    logits = model(input_ids=data[0][0], attention_mask=data[0][1])[0]

    print("**************\n")
    print("lady part is : ")
    print(data[1])
    print("**************\n")
    print("**************\n")
print(logits)

gentleman part is : 
(tensor([[50256, 50256, 50256, 50256, 50256,  6783,   293,   805,   550,   257,
          8507,   329, 50256, 50256],
        [50256, 50256, 50256, 50256, 50256,  6783,   293,   805,   373,  5863,
           329, 50256, 50256, 50256],
        [50256, 50256, 50256, 50256, 50256,  1026,   373,  2219,  3725,   326,
         22427, 43748,  3353,   287],
        [50256, 50256, 50256, 50256, 50256,  6783,   293,   805,   373,  8018,
           329, 50256, 50256, 50256],
        [50256, 50256, 50256, 50256, 50256,  8061,  3917, 22427,   351, 50256,
         50256, 50256, 50256, 50256]]), tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]), ['gentleman had a reputation for', 'gentleman was famous for', 'It was common knowledge that gentleman excelled in', 'gentleman w

In [6]:
len(logits)

5