### General imports

In [2]:
import torch 
import torchvision
import torch.nn as nn 
import torch.nn.init as init
from IPython.display import Image 
from torchvision import transforms
import matplotlib.pyplot as plt
import random
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch.optim as optim
import time
import torch.nn.functional as F

import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.optimize import curve_fit

import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 12345
random.seed(seed)
torch.manual_seed(seed)

%matplotlib inline

### Model

In [3]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch

### Loading the data

In [4]:
train = pickle.load(open('okcupid_train.pkl', 'rb'))
val = pickle.load(open('okcupid_val.pkl', 'rb'))
test = pickle.load(open('okcupid_test.pkl', 'rb'))

### Dataset to handle the data

In [5]:
class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, batch_size):
        self.tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.dataset = dataset
        self.batch_size = batch_size
        
        self.batches = []
        
        for i in range(int(len(self.dataset) / batch_size)):
            batch = self.dataset[i * self.batch_size : (i + 1) * self.batch_size]
            self.batches += [self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt")]
            
    def __getitem__(self, index):
        return self.batches[index]

    def __len__(self):
        return len(self.batches)

### Preparing the data for training. BATCH SIZE is specified here

In [6]:
batch_size = 16

In [7]:
train_dataset = GPTDataset(train, batch_size)
val_dataset = GPTDataset(val, batch_size)
test_dataset = GPTDataset(test, batch_size)

### Training loop

In [8]:
def train_loop(model, optimizer, scheduler, train_dataset, device, epoch=5, 
               val_dataset=None):   
    train_losses = []
    val_losses = []
    
    for t in tqdm(range(epoch)):
        batches = 0 
        total = 0
        model.train()       
        total_loss = 0
        for batch_idx, x in tqdm(list(enumerate(train_dataset))):
            batches += 1
            x = x.to(device)
            
            output = model(**x, labels=x['input_ids'])
            
            loss = output[0]
            total_loss += loss.sum().detach().item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(loss.sum().detach().item())
            
        train_losses += [total_loss / batches]
        
        if val_dataset is not None:
            model.eval()
            with torch.no_grad():
                total_loss = 0
                batches = 0 
                total = 0
                for batch_idx, x in enumerate(val_dataset):
                    batches += 1
                    x = x.to(device)

                    output = model(**x, labels=x['input_ids'])
                    loss = output[0]
                    total_loss += loss.sum().detach().item()

                val_losses += [total_loss / batches]
                
        print("[EPOCH]: %i, [TRAIN LOSS]: %.6f" % (t, train_losses[-1]))
        if val_dataset is not None:
            print("[EPOCH]: %i, [VAL LOSS]: %.6f" % (t, val_losses[-1]))
    return train_losses, val_losses

### Training the model

In [None]:
model = GPT2LMHeadModel.from_pretrained('distilgpt2')
#model = model.cuda()
model.resize_token_embeddings(len(train_dataset.tokenizer))
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
device = torch.device('cpu')
train_losses, val_losses = train_loop(model=model, optimizer=optimizer,
                                      scheduler=scheduler, 
                                      train_dataset=train_dataset, 
                                      device=device, epoch=20, val_dataset=val_dataset)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/452 [00:00<?, ?it/s]