In [1]:
# !pip install evaluate

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import random
import torchmetrics
import torch.nn.functional as F

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate import notebook_launcher
# import evaluate


import numpy as np
import matplotlib.pyplot as plt
import os

from tqdm import tqdm
from datasets import load_dataset
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.model_selection import train_test_split
import nltk

from collections import Counter


In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
dataset = load_dataset('imdb')

In [5]:
sentences = []
sentence_len_treshold = 256

for review in tqdm(dataset['unsupervised']['text']):
    sentences.extend([x.lower() for x in sent_tokenize(review) if len(x)<sentence_len_treshold])

100%|██████████| 50000/50000 [00:27<00:00, 1796.40it/s]


In [6]:
print(len(sentences))

493165


In [7]:
words = Counter()

for sentence in tqdm(sentences):
    for word in word_tokenize(sentence):
        words[word] += 1


100%|██████████| 493165/493165 [02:00<00:00, 4102.04it/s]


In [8]:
len(words)

135739

In [9]:
vocab = set(['<bos>','<eos>','<pad>','<unk>'])
word_count_threshold = 75



for word, count in words.items():
    if count > word_count_threshold:
        vocab.add(word)

        
print(len(vocab))

6972


In [10]:
word2ind = {char: i for i, char in enumerate(vocab)}
ind2word = {i: char for char, i in word2ind.items()}

In [20]:
class WordDataset:
    def __init__(self, sentences):
        self.data = sentences
        self.unk_id = word2ind['<unk>']
        self.bos_id = word2ind['<bos>']
        self.eos_id = word2ind['<eos>']
        self.pad_id = word2ind['<pad>']

    def __getitem__(self, idx):
        tokenized_sentence = [self.bos_id]
        tokenized_sentence += [word2ind.get(word, self.unk_id) for word in word_tokenize(self.data[idx])]
        tokenized_sentence += [self.eos_id]
#         tokenized_sentence += [self.pad_id]*(max_len-len(tokenized_sentence))

        return torch.LongTensor(tokenized_sentence)

    def __len__(self):
        return len(self.data)

In [12]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    lens = [len(x) for x in batch]
    
    padded_seq = pad_sequence(sequences=batch, batch_first=True, padding_value=word2ind['<pad>'])
    
#     padded_seq = accelerator.pad_across_processes(sequences=batch, dim=1, pad_index=word2ind['<pad>'])
    
    return padded_seq,  torch.LongTensor(lens)
    

In [17]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class RNN(nn.Module):
    def __init__( self,  vocab_size, hidden_dim=256, embed_dim=512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, )
        self.rnn = nn.LSTM(input_size=embed_dim, 
                           hidden_size=hidden_dim, 
                           num_layers=3, 
                           batch_first=True,
                           dropout=0.1)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, vocab_size)
        )
        

    def forward(self, x, lens):
        embeddings = self.embedding(x)
        embeddings = pack_padded_sequence(embeddings,lens.cpu(), batch_first=True, enforce_sorted=False)
        output, (_,_) = self.rnn(embeddings) 
        output, _ = pad_packed_sequence(output,batch_first=True)
        output = self.fc(output)
        
        loss = F.cross_entropy(output[:, :-1, :].flatten(start_dim=0, end_dim=1),x[:, 1:].flatten(),ignore_index=word2ind['<pad>'])
        
        return output, loss

In [14]:
def get_dataloaders(batch_size = 64):
    train_sentences, test_sentences = train_test_split(sentences, test_size=0.2, shuffle=True, random_state=42)

    train_dataset = WordDataset(train_sentences)
    test_dataset = WordDataset(test_sentences)

    train_dataloader = DataLoader(
        train_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True, shuffle=True)

    test_dataloader = DataLoader(
        test_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True, shuffle=False)
    
    return train_dataloader, test_dataloader

In [18]:
def training_loop(mixed_precision="fp16", seed = 42, checkpoint_path = None):
    set_seed(seed)
    accelerator = Accelerator(mixed_precision=mixed_precision)
    
    train_dataloader, test_dataloader = get_dataloaders(128)
    model = RNN(vocab_size=len(vocab))

    optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=3, factor=0.5)
    
    
    model, optim, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
        model, optim, train_dataloader, test_dataloader, scheduler
    )
    
    resume_epoch = 0
    
    if checkpoint_path is not None:
        accelerator.wait_for_everyone()
        accelerator.load_state(checkpoint_path)
        resume_epoch = int(checkpoint_path.split("_")[-1])+1


        
    for epoch in range(resume_epoch, 300):
    
        train_total_loss = 0
        model.train()
        loop = tqdm(train_dataloader, disable=not accelerator.is_local_main_process)
        for X, lens in loop:
            with accelerator.autocast():
                logits, loss = model(X, lens)

            optim.zero_grad(set_to_none=True)
            accelerator.backward(loss)
            optim.step()

            train_total_loss += loss
            

        train_total_loss = (accelerator.gather(train_total_loss)/len(train_dataloader)).sum()/2

        test_total_loss = 0

        model.eval()
        with torch.inference_mode():
            loop = tqdm(test_dataloader, disable=not accelerator.is_local_main_process)
            for X, lens in loop:
                logits, loss = model(X, lens)           

                test_total_loss += loss
                

            test_total_loss = (accelerator.gather(test_total_loss)/len(test_dataloader)).sum()/2

            
        scheduler.step(test_total_loss)
        accelerator.print(f"epoch: {epoch}   train_loss: {train_total_loss}, test_loss: {test_total_loss}")
        
        
        accelerator.wait_for_everyone()
        accelerator.save_state(f"/kaggle/working/cp_{epoch}", safe_serialization=False)
        
        
        
#         if accelerator.is_local_main_process:
#             torch.save(accelerator.unwrap_model(model).state_dict(), f'/kaggle/working/cp_{epoch}.pt')


In [19]:
args = ("fp16", 42, None)
notebook_launcher(training_loop, args, num_processes=2)

Launching training on 2 GPUs.


100%|██████████| 1542/1542 [01:21<00:00, 18.86it/s]
100%|██████████| 386/386 [00:12<00:00, 30.70it/s]


epoch: 0   train_loss: 4.675525665283203, test_loss: 4.21187162399292


  3%|▎         | 43/1542 [00:02<01:34, 15.79it/s]


KeyboardInterrupt: 

In [None]:
# model = RNN(vocab_size=len(vocab))
# model.load_state_dict(torch.load("/kaggle/working/cp_7.pt"))

In [None]:
# def generate_sequence(model, starting_seq: str, max_seq_len: int = 128):
#     device = 'cpu'
#     model = model.to(device)
#     input_ids = [word2ind['<bos>']] + [
#         word2ind.get(word, word2ind['<unk>']) for word in word_tokenize(starting_seq)]
#     input_ids = torch.LongTensor(input_ids).to(device)
#     model.eval()
#     with torch.no_grad():
#         for i in range(max_seq_len):
#             next_char_distribution = model(input_ids)[-1]
#             next_char = next_char_distribution.argmax()
#             input_ids = torch.cat([input_ids, next_char.unsqueeze(0)])

#             if next_char.item() == word2ind['<eos>']:
#                 break
    
#     words = ' '.join([ind2word[idx.item()] for idx in input_ids])

#     return words

In [None]:
# generate_sequence(model, starting_seq='this movie was')