In [4]:
import json
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import os
import warnings
warnings.filterwarnings("ignore")
os.environ['KMP_DUPLICATE_LIB_OK'] = 'False'

on_kaggle = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ; print(f'On kaggle: {on_kaggle}')
import gensim.downloader as api
fasttext = api.load('fasttext-wiki-news-subwords-300')
fasttext.most_similar(positive=['king', 'woman'], negative=['man'])[:3]

On kaggle: False


[('queen', 0.7786749005317688),
 ('queen-mother', 0.7143871784210205),
 ('king-', 0.6981282234191895)]

In [5]:
# SETTINGS
TRAIN = True
EPOCH = 5
WORD_FREQ = 2
BATCH_SIZE = 20

In [6]:
import re
from collections import Counter

class Vocab:
    def __init__(self, texts: list[str], min_freq: int = 1):
        self.min_freq = min_freq
        text = ' '.join(texts)
        text = self._remove_links(text)
        text = self._remove_special_chars(text)
        text = self._remove_multiple_spaces(text)
        self.vocab = self._filter_words(text.strip().lower().split())
        self.vocab.append('<unk>')

        self._word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self._idx2word = {idx: word for idx, word in enumerate(self.vocab)}
    
    def _remove_links(self, text):
        return re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    
    def _remove_special_chars(self, text):
        return re.sub(r'[^a-zA-Z\s]', '', text)
    
    def _remove_multiple_spaces(self, text):
        return re.sub(r'\s+', ' ', text)
    
    def _filter_words(self, words) -> list[str]:
        counter = Counter(words)
        return list({word for word in words if counter[word] > self.min_freq})
    
    def get_vocab(self): return self.vocab

    def idx2word(self, idx):
        if idx not in self._idx2word: return '<unk>'
        return self._idx2word[idx]

    def word2idx(self , word):
        word = word.lower()
        if word not in self._word2idx: return self._word2idx['<unk>']
        return self._word2idx[word]

    def encode(self, text):
        return [self.word2idx(word) for word in text.split()]
    
    def make_vectors(self, fasttext):
        return np.stack([fasttext[word] if fasttext.has_index_for(word) \
                else np.zeros(fasttext.vector_size) \
                for word in self.vocab])

In [7]:
from torch.utils.data import DataLoader, Dataset

class IMDB(Dataset):
    def __init__(self, path):
        self.df = pd.read_csv(path)
        texts = self.df['review'].values
        labels = self.df['sentiment'].values
        self.vocab = Vocab(texts, min_freq=WORD_FREQ)
        self.labels2int = {'positive': 1, 'negative': 0}

    def __getitem__(self, idx):

        text = self.df['review'].loc[idx]
        label = self.labels2int[self.df['sentiment'].loc[idx]]
        text = torch.LongTensor(self.vocab.encode(text))
        label = torch.FloatTensor([label])

        return text, label
    def __len__(self):
        return len(self.df)
    

if not on_kaggle:
    if 'IMDB Dataset.csv' not in os.listdir():
        import kaggle
        kaggle.api.dataset_download_files('lakshmi25npathi/imdb-dataset-of-50k-movie-reviews', path='.', unzip=True)
    dataset = IMDB('IMDB Dataset.csv')
else:
    dataset = IMDB('/kaggle/input/imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv')
#dataset[0]

In [8]:
def create_kaggle_config():
    config_dir = '/home/jovyan/.config/kaggle'
    config_file = f'{config_dir}/kaggle.json'
    config_data = {
        "username": "stroganovrockxi",
        "key": "f121a896d930d969093c2538b792fb02"
    }
    
    os.makedirs(config_dir, exist_ok=True)
    
    if not os.path.exists(config_file):
        with open(config_file, 'w') as f:
            json.dump(config_data, f)
#create_kaggle_config()

In [9]:
from torch.utils.data import random_split
from torch.nn.utils.rnn import pad_sequence

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])

pad_index = len(dataset.vocab.vocab)
def collate_fn(batch):
    texts = pad_sequence([b[0] for b in batch], padding_value=pad_index, batch_first=True)
    labels = torch.stack([b[1] for b in batch])
    return texts, labels
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    return correct.sum() / len(correct)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, pin_memory=True, num_workers = 0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, pin_memory=True, num_workers = 0)

In [10]:
import pytorch_lightning as pl # type: ignore
import plotly.graph_objects as go # type: ignore
from IPython.display import display

class TextConvNN(pl.LightningModule):
    def __init__(self, vocab_size, dims_size, pad_idx):
        super().__init__()
        self.save_hyperparameters()
        self.embedding = nn.Embedding(vocab_size, dims_size, padding_idx=pad_idx)
        kernels = [2, 3, 4, 5]
        self.convs = nn.ModuleList([nn.Conv2d(1, 16, kernel_size=(k, dims_size)) for k in kernels])
        self.fc = nn.Linear(len(kernels) * 16, 1)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
        x = [F.max_pool1d(_, _.shape[2]).squeeze(2) for _ in x]
        x = self.dropout(torch.cat(x, dim=1))
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_accuracy', binary_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_accuracy', binary_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_accuracy', binary_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


In [11]:
import plotly.graph_objects as go
from IPython.display import display, clear_output
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities.rank_zero import rank_zero_only

class PlotlyLogger(Logger):
    def __init__(self):
        super().__init__()
        self.fig_accuracy = go.FigureWidget()
        self.fig_loss = go.FigureWidget()

        self.fig_accuracy.add_trace(go.Scatter(y=[], mode='lines', name='Train'))
        self.fig_accuracy.add_trace(go.Scatter(y=[], mode='lines', name='Test'))
        self.fig_loss.add_trace(go.Scatter(y=[], mode='lines', name='Train'))
        self.fig_loss.add_trace(go.Scatter(y=[], mode='lines', name='Test'))

        self.fig_accuracy.update_layout(xaxis_title='Epoch', yaxis_title='Accuracy', margin=dict(l=20, r=20, t=20, b=20))
        self.fig_loss.update_layout(xaxis_title='Epoch', yaxis_title='Loss', margin=dict(l=20, r=20, t=20, b=20))
        self.y_train_acc, self.y_test_acc = [], []
        self.y_train_loss, self.y_test_loss = [], []
        self.count = 0


        display(self.fig_accuracy)
        display(self.fig_loss)
    @rank_zero_only
    def log_metrics(self, metrics, step):
        self.count += 1
        if self.count % 2 != 0:
            if metrics.get('val_accuracy', 0) == 0: return
            self.y_test_acc.append(metrics.get('val_accuracy', 0))
            self.y_test_loss.append(metrics.get('val_loss', 0))
            if on_kaggle: print('Validation Accuracy | Loss: ', self.y_test_acc, self.y_test_loss, sep = '\n')
        else:
            self.y_train_acc.append(metrics.get('train_accuracy', 0))
            self.y_train_loss.append(metrics.get('train_loss', 0))
            if on_kaggle: print('Train Accuracy | Loss: ', self.y_train_acc, self.y_train_loss, sep = '\n')


        self.fig_accuracy.data[0].y = self.y_train_acc
        self.fig_accuracy.data[1].y = self.y_test_acc
        self.fig_loss.data[0].y = self.y_train_loss
        self.fig_loss.data[1].y = self.y_test_loss


    def log_hyperparams(self, params):
        pass
    @property
    def experiment(self):
        return None
    @property
    def name(self):
        return 'PlotlyLogger'
    @property
    def version(self):
        return '0.1'

In [17]:
torch.set_float32_matmul_precision('medium')

In [18]:
model = TextConvNN(len(dataset.vocab.vocab) + 1, fasttext.vector_size, pad_index)
model.embedding.weight.data[:len(dataset.vocab.make_vectors(fasttext))] = torch.from_numpy(dataset.vocab.make_vectors(fasttext).copy())

logger = PlotlyLogger()

trainer = pl.Trainer(max_epochs=EPOCH,
                    accelerator='gpu', 
                    devices=1,
                    logger = logger,
                    enable_progress_bar= not on_kaggle  
                    )
if TRAIN:
    trainer.fit(model, train_loader, test_loader)
if not TRAIN:
    clear_output()
    model = TextConvNN.load_from_checkpoint("fasttext-wiki-300_imdb_2.ckpt", 
                                            vocab_size=len(dataset.vocab.vocab) + 1, 
                                            dims_size=fasttext.vector_size, 
                                            pad_idx=pad_index)

FigureWidget({
    'data': [{'mode': 'lines', 'name': 'Train', 'type': 'scatter', 'uid': 'e4bfc0dc-cc6b-47dd-8679-49a1c46faff2', 'y': []},
             {'mode': 'lines', 'name': 'Test', 'type': 'scatter', 'uid': '9ed05b79-0ae8-4d52-a273-0edaf48edc82', 'y': []}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20},
               'template': '...',
               'xaxis': {'title': {'text': 'Epoch'}},
               'yaxis': {'title': {'text': 'Accuracy'}}}
})

FigureWidget({
    'data': [{'mode': 'lines', 'name': 'Train', 'type': 'scatter', 'uid': 'c3f7d366-5dea-4690-86ae-65aa03405b79', 'y': []},
             {'mode': 'lines', 'name': 'Test', 'type': 'scatter', 'uid': '1b8b0b39-22ad-46a9-a756-3745feb51332', 'y': []}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20},
               'template': '...',
               'xaxis': {'title': {'text': 'Epoch'}},
               'yaxis': {'title': {'text': 'Loss'}}}
})

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | embedding | Embedding  | 18.2 M | train
1 | convs     | ModuleList | 67.3 K | train
2 | fc        | Linear     | 65     | train
3 | dropout   | Dropout    | 0      | train
-------------------------------------------------
18.2 M    Trainable params
0         Non-trainable params
18.2 M    Total params
72.903    Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [10]:
test1 = trainer.test(model, test_loader)
trainer.save_checkpoint("model_test.ckpt")


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.8676999807357788
        test_loss           0.43912628293037415
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
