In [1]:
! pip install -q transformers
! pip install -q wandb

In [2]:
import numpy as np
import pandas as pd 

import torch
from transformers import (AutoModelForMaskedLM, BertTokenizerFast, get_linear_schedule_with_warmup)

from torch.optim import AdamW

from torch.utils.data import (Dataset, 
                              random_split,
                              DataLoader,
                              RandomSampler,
                              SequentialSampler)

import torch.nn as nn
import os
from os import listdir
from os.path import isfile, join
import subprocess
import wandb
from tqdm import tqdm

wandb.login(key="140ee313fa4d9145f53618b86356098fa858e670")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mam2502[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [3]:
import os.path as path

if not path.exists("/content/drive"):
  !sudo add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
  !sudo apt-get update -qq 2>&1 > /dev/null
  !sudo apt -y install -qq google-drive-ocamlfuse 2>&1 > /dev/null
  !google-drive-ocamlfuse

  !sudo apt-get install -qq w3m # to act as web browser 
  !xdg-settings set default-web-browser w3m.desktop # to set default browser
  %cd /content
  !mkdir drive
  %cd drive
  !mkdir MyDrive
  %cd ..
  %cd ..
  !google-drive-ocamlfuse /content/drive/MyDrive

In [4]:
from google.colab import drive
drive.mount('/content/drive')

path_to_project_folder = ""

if not os.path.isdir("./Project"):
    env = os.environ.copy()
    subprocess.run(
        'ln -s "/content/drive/MyDrive/Project" /content/Project',
        shell=True,
        env=env,
    )

NotImplementedError: ignored

In [10]:
path_to_project_folder = ""
config = {
    "path_to_data_folder": '/content/Project/data/',
    'random_seed': 73,
    'batch_size': 64,
    'max_len': 64,
    'limerick_file_name': 'limericks_ballas_oedilf_clean.csv',
    'model_name': 'bert-base-uncased',
    'training_storage_path': '/content/drive/MyDrive/',
    'total_epochs': 4,
    'learning_rate': 2e-5,
    'iteration_step_to_log_checkpoint': 1,
    'warmup_steps': 300
}

In [11]:
limerick_df = pd.read_csv(os.path.join(config['path_to_data_folder'], config['limerick_file_name']))
limerick_df = limerick_df.fillna('')
len(limerick_df)

153797

In [12]:
limerick_df

Unnamed: 0,limericks
0,capn jack was washed over the side\nhis crew s...
1,as a soup bisque is best when served hot\nmade...
2,simply add to the grasp of a rhesus\nthe antit...
3,abeds where you sleep in the night\nunless you...
4,a smiling young fellow from spain\nfell asleep...
...,...
153792,esps remove dust from a flue\nthough hightech ...
153793,as a gent of the uppermost class\nim deserving...
153794,breaking free the crook busted the link\nof th...
153795,mr owl ate ms nans metal worm \ntragic fable t...


In [13]:
tokenizer = BertTokenizerFast.from_pretrained(config['model_name'])
print(tokenizer.vocab_size)
tokenizer.model_max_length = config['max_len']
tokenizer.add_tokens('[EOL]')

30522


1

In [14]:
print(tokenizer.convert_tokens_to_ids('[EOL]'))

30522


In [15]:
class LimerickDataset(Dataset):
    def __init__(self, data, max_length=config['max_len']):
        self.input_ids = []
        self.original_input_ids = []
        self.attn_masks = []
        self.labels = []
        indices_of_last_words = []
        
        for limerick in tqdm(data):
            encodings_dict = tokenizer(limerick.replace('\n', ' [EOL] '),
                                     truncation=True,
                                     max_length=max_length,
                                     padding='max_length'
                                     )
            self.original_input_ids.append(torch.tensor(encodings_dict['input_ids']))
            new_input_ids = []
            labels = []
            last_word_id = -1
            word_id_repetition_count = 0
            count = 0
            for i, word_id in enumerate(encodings_dict.word_ids()):
                if encodings_dict.input_ids[i] == 30522:
                    while word_id_repetition_count != 0:
                        # print(tokenizer.decode(encodings_dict.input_ids[i]), word_id)
                        word_id_repetition_count -= 1
                        labels.pop()
                        labels.insert(-1, new_input_ids.pop())
                        count += 1
                    [new_input_ids.append(103) for j in range(count)]
                    count = 0
                    new_input_ids.append(30522)
                    labels.append(-100)
                else:
                    if last_word_id == word_id:
                        word_id_repetition_count += 1
                    else:
                        word_id_repetition_count = 1
                    last_word_id = word_id
                    new_input_ids.append(encodings_dict.input_ids[i])
                    labels.append(-100)
            

            self.input_ids.append(torch.tensor(new_input_ids))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
            self.labels.append(torch.tensor(labels))


    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx], self.labels[idx], self.original_input_ids[idx]

In [16]:
limerick_dataset = LimerickDataset(limerick_df['limericks'].values, max_length=config['max_len'])

limerick_dataloader = DataLoader(limerick_dataset, sampler=RandomSampler(limerick_dataset), batch_size=config['batch_size'])

100%|██████████| 153797/153797 [00:58<00:00, 2622.77it/s]


In [17]:
# Model Definition
model = AutoModelForMaskedLM.from_pretrained(config['model_name'])
model.max_seq_len = config['max_len']
model.resize_token_embeddings(len(tokenizer))

# Optimizer
optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
model = model.to(device)

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=config['warmup_steps'],
                                            num_training_steps=len(limerick_dataloader) * config['total_epochs'])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
run = wandb.init(
    name = "RhymeBERT_2", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "poetai-project" ### Project should be created in your wandb account
)

In [19]:

def log_checkpoint(iteration, model, optimizer, metric=None):
    if iteration % config['iteration_step_to_log_checkpoint'] == 0 or iteration == config['total_iterations']:
        state = {
            'iteration': iteration + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }
        check_point_dir = config['training_storage_path']

        if not os.path.exists(check_point_dir):
            os.makedirs(check_point_dir)

        if metric == None:
            checkpoint_file_path = check_point_dir + f"/poet_ai_checkpoint_rhymebert_"+str(iteration)+".h5"
            torch.save(state, checkpoint_file_path)
        else:
            # considering minimization effort
            onlyfile_metrics = [float(f.split("_checkpoint.h5")[0]) for f in listdir(check_point_dir) if isfile(join(check_point_dir, f)) and "_checkpoint.h5" in f]

            if len(onlyfile_metrics) > 0 and metric < sorted(onlyfile_metrics)[0]:
                checkpoint_file_path = check_point_dir + f"/{metric}_checkpoint.h5"
                torch.save(state, checkpoint_file_path)
                os.remove(check_point_dir + f"/{sorted(onlyfile_metrics)[0]}_checkpoint.h5")

In [20]:
def train(epoch):
    batch_bar = tqdm(total=len(limerick_dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    total_train_loss = 0
    model.train()
    for step, batch in enumerate(limerick_dataloader):
        optimizer.zero_grad()
        b_input_ids = batch[0].to(device)
        b_labels = batch[2].to(device)
        b_masks = batch[1].to(device)
        outputs = model(b_input_ids,
                        labels=b_labels,
                        attention_mask=b_masks)

        loss = outputs[0]

        batch_loss = loss.item()
        total_train_loss += batch_loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        batch_bar.set_postfix(
            step="{:d}".format(step),
            loss="{:.04f}".format(loss.item()))
        batch_bar.update()

        wandb.log({"train_loss": loss.item(),'train_epochs': epoch})
    log_checkpoint(epoch, model, optimizer)
    return total_train_loss / len(limerick_dataloader)

In [26]:
log_checkpoint(epoch, model, optimizer)

In [27]:
for epoch in range(config['total_epochs']):
    print(f'Epoch {epoch} of', config['total_epochs'])
    train(epoch)

Epoch 0 of 4




Epoch 1 of 4




Epoch 2 of 4




Epoch 3 of 4




In [None]:
for i, (input_ids, mask, labels, orig) in enumerate(limerick_dataloader):
    print(labels[2])
    break

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1054,
        15689,  -100,  -100,  -100,  -100,  -100, 18224,  4939,  -100,  -100,
        11302,  -100,  -100,  -100,  -100,  8065,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  3424, 25078,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100])


In [23]:
torch.cuda.empty_cache() 