In [5]:
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
import pandas as pd
from tqdm import tqdm
import random
from copy import copy
from math import exp

In [6]:
DEVICE = 'cuda'

## Set up the tokenizer

In [7]:
title_token = '<|title|>'
plot_token = '<|plot|>'
# pad_token = '<|pad|>'

In [8]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

MAX_LEN = tokenizer.model_max_length
print(f'Max length: {MAX_LEN}')

print('Vocab size:', len(tokenizer.get_vocab()))
if title_token not in tokenizer.get_vocab().keys():
    tokenizer.add_tokens([title_token])
    print('Add title token. Check:')
    print('\tvocab size:', len(tokenizer.get_vocab()))
    print(f'\t{title_token} in vocab:', title_token in tokenizer.get_vocab().keys())
    print('\tLast token in the vocab:', list(tokenizer.get_vocab().keys())[-1])
if plot_token not in tokenizer.get_vocab().keys():
    tokenizer.add_tokens([plot_token])
    print('Add plot token. Check:')
    print('\tvocab size:', len(tokenizer.get_vocab()))
    print(f'\t{plot_token} in vocab:', plot_token in tokenizer.get_vocab().keys())
    print('\tLast token in the vocab:', list(tokenizer.get_vocab().keys())[-1])
# if pad_token not in tokenizer.get_vocab().keys():
#     tokenizer.add_tokens(pad_token)
#     tokenizer.pad_token = pad_token
#     print('Add pad token. Check:')
#     print('\tvocab size:', len(tokenizer.get_vocab()))
#     print(f'\t{pad_token} in vocab:', pad_token in tokenizer.get_vocab().keys())
#     print('\tLast token in the vocab:', list(tokenizer.get_vocab().keys())[-1])
#     print('\tPad token:', tokenizer.pad_token, tokenizer.pad_token_id)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Max length: 1024
Vocab size: 50257
Add title token. Check:
	vocab size: 50258
	<|title|> in vocab: True
	Last token in the vocab: <|title|>
Add plot token. Check:
	vocab size: 50259
	<|plot|> in vocab: True
	Last token in the vocab: <|plot|>


## Load the dataset

In [9]:
class WikiMoviesDataset(Dataset):
    def __init__(self, 
                 path,
                 tokenizer: transformers.PreTrainedTokenizer,
                 generation_mode='plot',
                 *,
                 relese_year_header='Release Year',
                 title_header='Title',
                 origin_ethicity_header='Origin/Ehtnicity',
                 director_header='Director',
                 cast_header='Cast',
                 genre_header='Genre',
                 wiki_page_header='Wiki Page',
                 plot_header='Plot',
                 title_token='<|title|>',
                 plot_token='<|plot|>',
                 eos_token=None,
                 device=DEVICE):
        
        if not eos_token:
            eos_token = tokenizer.eos_token
               
        title_token_ids = tokenizer(title_token).input_ids
        plot_token_ids = tokenizer(plot_token).input_ids
        eos_token_ids = tokenizer(eos_token).input_ids 
        
        gen_title = True if generation_mode in ['title', 'both'] else False
        gen_plot = True if generation_mode in ['plot', 'both'] else False
        
        if not gen_title and not gen_plot:
            raise Exception('Uknown generation mode. Select "title", "plot" or "both".')
        
        self._data = []
        
        df = pd.read_csv(path)
        
        for index, row in tqdm(df.iterrows(), total=len(df), ncols=70):
            
            # encode title and plot
            title_ids = tokenizer(row[title_header], truncation=True).input_ids
            
            title_plus_spec_tokens = 3 + len(title_ids)
            tokens_for_plot = max(0, MAX_LEN - title_plus_spec_tokens) 
            
            plot_ids = tokenizer(row[plot_header], truncation=True).input_ids[:tokens_for_plot]
            
            # # mask
            # mask = torch.tensor([1] * (3 + len(title_ids) + len(plot_ids)) + [0] * (MAX_LEN - 3 - len(title_ids) - len(plot_ids)))
            
            # add data entry for title generation
            if gen_title:
                ids = torch.tensor(plot_token_ids + plot_ids + 
                                   title_token_ids + title_ids + 
                                   eos_token_ids).to(device)
                
                self._data.append(ids)
                
            # add data entry for plot generation
            if gen_plot:
                ids = torch.tensor(title_token_ids + title_ids + 
                                   plot_token_ids + plot_ids + 
                                   eos_token_ids).to(device)
                
                self._data.append(ids)
            
        del df
    
    def __getitem__(self, i):
        return self._data[i]
    
    def __len__(self):
        return len(self._data)
    
    def train_valid_split(self, inplace: bool, shuffle=True, train_fr=.95):
        assert inplace 
        
        random.shuffle(self._data)
        
        train_dataset = copy(self)
        valid_dataset = copy(self)
        
        train_len = int(len(self) * train_fr)
        
        train_dataset._data = self._data[:train_len]
        valid_dataset._data = self._data[train_len:]
        
        del self._data
        self._data = None
        
        return train_dataset, valid_dataset

In [10]:
dataset = WikiMoviesDataset('/kaggle/input/wikipedia-movie-plots/wiki_movie_plots_deduped.csv', tokenizer, 'both')

100%|██████████████████████████| 34886/34886 [02:32<00:00, 228.33it/s]


In [11]:
dataset[0].shape

torch.Size([113])

In [12]:
train_dataset, valid_dataset = dataset.train_valid_split(True, train_fr=.95)

In [13]:
len(train_dataset), len(valid_dataset)

(66283, 3489)

In [14]:
tokenizer.decode(train_dataset[0])



## Set up the model

In [17]:
base_model_checkpoint = 'gpt2'

In [18]:
model = GPT2LMHeadModel.from_pretrained(base_model_checkpoint)

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [19]:
f"Parameters number: {sum(p.numel() for p in model.parameters()):,d}"

'Parameters number: 124,439,808'

In [20]:
model = model.to(DEVICE)

In [21]:
VOCAB_SIZE = len(tokenizer.get_vocab())

In [22]:
model.resize_token_embeddings(VOCAB_SIZE)

Embedding(50259, 768)

## Finetune the model

In [28]:
class Trainer:
    def __init__(self,
                 train_dataset,
                 valid_dataset,
                 blind_steps=500,
                 ):
        self._train_dataset = train_dataset
        self._valid_dataset = valid_dataset
        self._train_dataloader = DataLoader(train_dataset, 1, shuffle=True)
        self._valid_dataloader = DataLoader(valid_dataset, 1)
        self._blind_steps = blind_steps
        self._warmup_steps = 5000
        self._batch_size = 4
        self._lr = 4e-5
        self._best_ppl = 1e10 # lowest perplexity is the criterion for checkpointing
        self._checkpoint_file = 'model_checkpoint.pt'
        self._optimizer = None
        self._schedule = None
        self._model = None

        
    def _next_train_batch(self):
        if not self._train_iter:
            self._train_iter = iter(self._train_dataloader)
        
        try:
            return next(self._train_iter)
        except:
            self._train_iter = iter(self._train_dataloader)
            return next(self._train_iter)
        
    def train(self, model, steps):
        
#         eval_res = self.evaluate(model)
#         valid_loss = eval_res['loss']
#         valid_ppl = eval_res['ppl']
        
#         print(f'Initial evaluation. LOSS_v = {valid_loss:8.4f} | PPL_v = {valid_ppl:8.4f}')
        
        device = model.device
        
        # for new model set up new optimizer and schedule, reset data iterator
        if self._model != model:  
            self._optimizer = AdamW(model.parameters(), self._lr)
            self._scheduler = transformers.get_constant_schedule_with_warmup(self._optimizer, self._warmup_steps)
            self._train_iter = iter(self._train_dataloader)
            self._model = model
        
        loss_sum = 0.
        c = 0
        model.train()
        
        # training loop
        for step in tqdm(range(steps)):
            self._optimizer.zero_grad()
            
            # go through each entry of the batch
            for i in range(self._batch_size):
                x = self._next_train_batch()
                loss = model(x, labels=x).loss
                loss.backward()
                
                loss_sum += float(loss)
                c += 1
            
            # updating
            self._optimizer.step()
            self._scheduler.step()
            
            if (step + 1) % self._blind_steps == 0:
                eval_res = self.evaluate(model)
                
                # saving the model if better
                if eval_res['ppl'] < self._best_ppl:
                    torch.save(model.state_dict(), self._checkpoint_file)
                    self._best_ppl = eval_res['ppl']
                    print('*', end='')
                else:
                    print('-',end='')
                
                # logging
                train_loss = loss_sum / c
                loss_sum = 0.
                c = 0
                train_ppl = exp(train_loss)
                valid_loss = eval_res['loss']
                valid_ppl = eval_res['ppl']
                
                print(f' LOSS_t = {train_loss:8.4f} | LOSS_v = {valid_loss:8.4f} | PPL_t = {train_ppl:8.4f} | PPL_v = {valid_ppl:8.4f}')
            
        
    def evaluate(self, model):
        model_was_training = model.training
        model.eval()
        
        device = model.device
        loss = 0.
        
        with torch.no_grad():
            for x in self._valid_dataloader:
                loss += model(x, labels=x).loss
                
        loss /= len(self._valid_dataloader)
        
        if model_was_training:
            model.train()
            
        return {
            'loss': loss,
            'ppl': exp(loss)
        }
            
        

In [29]:
trainer = Trainer(train_dataset, valid_dataset, blind_steps=1500)

In [30]:
trainer.train(model, 15000)

 10%|█         | 1500/15000 [11:11<119:05:17, 31.76s/it]

* LOSS_t =   3.7906 | LOSS_v =   3.4426 | PPL_t =  44.2847 | PPL_v =  31.2686


 20%|██        | 3000/15000 [22:23<105:59:11, 31.80s/it]

* LOSS_t =   3.5485 | LOSS_v =   3.3442 | PPL_t =  34.7614 | PPL_v =  28.3382


 30%|███       | 4500/15000 [33:40<92:06:26, 31.58s/it] 

* LOSS_t =   3.4514 | LOSS_v =   3.3023 | PPL_t =  31.5440 | PPL_v =  27.1748


 40%|████      | 6000/15000 [44:45<79:08:28, 31.66s/it]

* LOSS_t =   3.4111 | LOSS_v =   3.2582 | PPL_t =  30.2987 | PPL_v =  26.0037


 50%|█████     | 7500/15000 [55:54<65:52:46, 31.62s/it]

* LOSS_t =   3.3630 | LOSS_v =   3.2291 | PPL_t =  28.8770 | PPL_v =  25.2580


 60%|██████    | 9000/15000 [1:07:02<52:38:37, 31.59s/it]

* LOSS_t =   3.3320 | LOSS_v =   3.2017 | PPL_t =  27.9947 | PPL_v =  24.5732


 70%|███████   | 10500/15000 [1:18:07<39:29:49, 31.60s/it]

* LOSS_t =   3.3103 | LOSS_v =   3.1802 | PPL_t =  27.3928 | PPL_v =  24.0507


 80%|████████  | 12000/15000 [1:29:17<26:17:19, 31.55s/it]

* LOSS_t =   3.2919 | LOSS_v =   3.1625 | PPL_t =  26.8949 | PPL_v =  23.6305


 90%|█████████ | 13500/15000 [1:40:32<13:11:27, 31.66s/it]

* LOSS_t =   3.2717 | LOSS_v =   3.1433 | PPL_t =  26.3570 | PPL_v =  23.1799


100%|██████████| 15000/15000 [1:51:47<00:00,  2.24it/s]   

* LOSS_t =   3.2523 | LOSS_v =   3.1323 | PPL_t =  25.8496 | PPL_v =  22.9259





In [62]:
trainer.train(model, 7500)

 20%|██        | 1500/7500 [11:04<52:39:52, 31.60s/it]

* LOSS_t =   3.2348 | LOSS_v =   3.1207 | PPL_t =  25.4000 | PPL_v =  22.6634


 40%|████      | 3000/7500 [22:10<39:25:43, 31.54s/it]

* LOSS_t =   3.1531 | LOSS_v =   3.1091 | PPL_t =  23.4075 | PPL_v =  22.4003


 60%|██████    | 4500/7500 [33:25<26:18:04, 31.56s/it]

* LOSS_t =   3.1466 | LOSS_v =   3.0955 | PPL_t =  23.2575 | PPL_v =  22.0991


 80%|████████  | 6000/7500 [44:31<13:08:42, 31.55s/it]

* LOSS_t =   3.1390 | LOSS_v =   3.0830 | PPL_t =  23.0802 | PPL_v =  21.8246


100%|██████████| 7500/7500 [55:38<00:00,  2.25it/s]   

* LOSS_t =   3.1337 | LOSS_v =   3.0682 | PPL_t =  22.9586 | PPL_v =  21.5041





In [None]:
model = model.eval()

In [227]:
config = transformers.GenerationConfig(
    max_new_tokens=384, 
    pad_token_id=tokenizer.get_vocab()['<|endoftext|>'], 
    do_sample=True, 
    temperature=.7,
    no_repeat_ngram_size=5)

In [251]:
# generate a title from Rotten Tomatoes' plot summary of 'Dont Look Up'
encoded = tokenizer("<|plot|>Kate Dibiasky (Jennifer Lawrence), an astronomy grad student, and her professor Dr. Randall Mindy (Leonardo DiCaprio) make an astounding discovery of a comet orbiting within the solar system. The problem: it's on a direct collision course with Earth. The other problem? No one really seems to care. Turns out warning mankind about a planet-killer the size of Mount Everest is an inconvenient fact to navigate. With the help of Dr. Oglethorpe (Rob Morgan), Kate and Randall embark on a media tour that takes them from the office of an indifferent President Orlean (Meryl Streep) and her sycophantic son and Chief of Staff, Jason (Jonah Hill), to the airwaves of The Daily Rip, an upbeat morning show hosted by Brie (Cate Blanchett) and Jack (Tyler Perry). With only six months until the comet makes impact, managing the 24-hour news cycle and gaining the attention of the social media obsessed public before it's too late proves shockingly comical -- what will it take to get the world to just look up?!<|title|>", return_tensors='pt').input_ids.to('cuda')
print(tokenizer.decode(model.generate(encoded, generation_config=config).to('cpu')[0]))



In [279]:
encoded = tokenizer("<|title|>", return_tensors='pt').input_ids.to('cuda')
print(tokenizer.decode(model.generate(encoded, generation_config=config).to('cpu')[0]))

<|title|> The Heartbreak Kid <|plot|> After being rescued from being shot in the leg by his bandit brother, the young man (Shawn Wilson), who is now a police officer, is taken on a mission to a remote Canadian village to find the culprits. He finds a young woman and her two children, and they have to face the truth about who is behind the murder.<|endoftext|>
