In [None]:
import os
import random
from collections import Counter, defaultdict
from typing import Counter as CounterType, Dict, List, Tuple

import numpy as np
import torch
from nltk.tokenize import word_tokenize
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt

In [None]:
# reproducibility

def set_global_seed(seed: int):
    """
    Set global seed for reproducibility.
    """

    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


set_global_seed(42)

In [None]:
# parameters

config = {
    "BATCH_SIZE":    256,
    "LEARNING_RATE": 1e-4,
    "N_EPOCHS":      10,

    "EMBEDDING_DIM":         100,
    "ENCODER_HIDDEN_SIZE":   128,
    "ENCODER_NUM_LAYERS":    1,
    "ENCODER_DROPOUT":       0.0,
    "ENCODER_BIDIRECTIONAL": True,
    "DECODER_NUM_LAYERS":    1,
    "DECODER_DROPOUT":       0.0,
}

In [None]:
# tensorboard

experiment_name = f"Seq2SeqLSTM_BATCH_{config['BATCH_SIZE']}_LR_{config['LEARNING_RATE']}_N_EPOCHS_{config['N_EPOCHS']}"

writer = SummaryWriter(
    log_dir=f"runs/tmp",
)

In [None]:
# device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Preapre Data

#### Dataset

In [None]:
class IMDBDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        path_to_data: str,
    ):
        super().__init__()

        self.path_to_data = path_to_data
        self.dataset = self._prepare_dataset(path_to_data)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(
        self,
        idx: int,
    ) -> List[str]:
        return self.dataset[idx]

    @staticmethod
    def _prepare_dataset(path_to_data: str) -> List[List[str]]:
        dataset = []

        pos_dir = os.path.join(path_to_data, "pos")
        neg_dir = os.path.join(path_to_data, "neg")
        
        for dir in [pos_dir, neg_dir]:
            for filename in tqdm(
                os.listdir(dir),
                desc="parse txt files",
            ):
                if not filename.endswith(".txt"):
                    continue
                with open(os.path.join(dir, filename), mode="r") as fp:
                    review = fp.read()
                    dataset.append(word_tokenize(review.lower()))
        return dataset

In [None]:
train_dataset = IMDBDataset(path_to_data="data/aclImdb/train")
test_dataset = IMDBDataset(path_to_data="data/aclImdb/test")

In [None]:
train_dataset[0][:15]

#### Analysis

In [None]:
train_dataset_len_distr = [len(review) for review in train_dataset]
test_dataset_len_distr = [len(review) for review in test_dataset]

In [None]:
plt.hist(
    train_dataset_len_distr,
    bins=len(set(train_dataset_len_distr)),
    alpha=0.5,
    label="train",
)
plt.hist(
    test_dataset_len_distr,
    bins=len(set(test_dataset_len_distr)),
    alpha=0.5,
    label="test",
)
plt.legend()
plt.title("Review's length distribution");

In [None]:
train_dataset_n_digits_distr = Counter([len(str(length)) for length in train_dataset_len_distr])
train_dataset_n_digits_distr.most_common()

In [None]:
test_dataset_n_digits_distr = Counter([len(str(length)) for length in test_dataset_len_distr])
test_dataset_n_digits_distr.most_common()

#### Collator

In [None]:
tokens_counter = Counter()

for review in train_dataset:
    tokens_counter.update(review)

In [None]:
len(tokens_counter), tokens_counter.most_common(5)

In [None]:
class Token2Idx:
    
    def __init__(
        self,
        tokens_counter: CounterType,
        min_df: int,
    ):
        self.tokens_counter = tokens_counter
        self.min_df = min_df

        self.token2idx = self._prepare_token2idx(
            tokens_counter=tokens_counter,
            min_df=min_df,
        )
    
    def __call__(
        self,
        seq: List[str],
    ) -> torch.LongTensor:
        return [self.token2idx.get(token, self.token2idx["<unk>"]) for token in seq]
    
    def __getitem__(self, key: str) -> int:
        return self.token2idx[key]

    @staticmethod
    def _prepare_token2idx(
        tokens_counter: CounterType,
        min_df: int,
    ) -> Dict[str, int]:
        token2idx = {
            "<bos>": 0,
            "<eos>": 1,
            "<unk>": 2,
            "<pad>": 3,
        }

        for token, cnt in tqdm(
            tokens_counter.most_common(),
            desc="loop over unique tokens",
        ):
            if token in token2idx:
                continue
            if cnt < min_df:
                continue

            token2idx[token] = len(token2idx)
        
        return token2idx

In [None]:
token2idx = Token2Idx(
    tokens_counter=tokens_counter,
    min_df=5,  # hardcoded
)

In [None]:
len(token2idx.token2idx)

In [None]:
token2idx(train_dataset[0])[:15]

In [None]:
for dataset in [train_dataset, test_dataset]:
    for review in tqdm(
        dataset,
        desc="assertion loop",
    ):
        assert len(review) == len(token2idx(review))

In [None]:
class Collator:

    def __init__(
        self,
        token2idx: Token2Idx,
    ):
        self.token2idx = token2idx
    
    def __call__(
        self,
        batch: List[List[str]],
    ) -> torch.LongTensor:
        tensor_seq = []
        tensor_inv_seq = []
        for seq in batch:
            tokenized_seq = self.token2idx(seq)
            tensor_seq.append(torch.LongTensor(tokenized_seq))

            tokenized_inv_seq = [self.token2idx["<bos>"]] + tokenized_seq[::-1] + [self.token2idx["<eos>"]]
            tensor_inv_seq.append(torch.LongTensor(tokenized_inv_seq))

        padded_sequences =  torch.nn.utils.rnn.pad_sequence(
            sequences=tensor_seq,
            batch_first=True,
            padding_value=self.token2idx["<pad>"],
        )
        padded_inv_sequences =  torch.nn.utils.rnn.pad_sequence(
            sequences=tensor_inv_seq,
            batch_first=True,
            padding_value=self.token2idx["<pad>"],
        )
        return padded_sequences, padded_inv_sequences

In [None]:
collator = Collator(token2idx=token2idx)

#### DataLoader

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=config["BATCH_SIZE"],
    shuffle=True,
    num_workers=0,
    collate_fn=collator,
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=collator,
)

In [None]:
seq, inv_seq = next(iter(train_dataloader))
seq.shape, inv_seq.shape

In [None]:
seq, inv_seq = next(iter(test_dataloader))
seq.shape, inv_seq.shape

### Seq2Seq LSTM

In [None]:
def number_of_parameters(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

In [None]:
class Seq2SeqLSTM(torch.nn.Module):
    
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        encoder_hidden_size: int,
        encoder_num_layers: int,
        encoder_dropout: float,
        encoder_bidirectional: bool,
        decoder_num_layers: int,
        decoder_dropout: float,
    ):
        super().__init__()
        self.embedding = torch.nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=embedding_dim,
            padding_idx=token2idx["<pad>"],

        )
        self.encoder = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=encoder_hidden_size,
            num_layers=encoder_num_layers,
            dropout=encoder_dropout,
            bidirectional=encoder_bidirectional,
            batch_first=True,
        )
        decoder_hidden_size = encoder_hidden_size * (2 if encoder_bidirectional else 1)
        self.decoder = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=decoder_hidden_size,
            num_layers=decoder_num_layers,
            dropout=decoder_dropout,
            bidirectional=False,
            batch_first=True,
        )
        self.head = torch.nn.Linear(
            in_features=decoder_hidden_size,
            out_features=num_embeddings,
        )
    
    def forward(
        self,
        seq: torch.LongTensor,
        inv_seq: torch.LongTensor,
    ):
        emb = self._embed(seq)
        encoder_output, _ = self.encoder(emb)
        skip_thoughts = self._get_skip_thoughts(encoder_output=encoder_output)

        inv_emb = self._embed(inv_seq)
        decoder_output, _ = self.decoder(inv_emb, (skip_thoughts, skip_thoughts))

        decoder_output, _ = self._pad_packed_sequence(sequence=decoder_output)
        logits = self.head(decoder_output)
        return logits
    
    def _embed(
        self,
        seq: torch.LongTensor,
    ) -> torch.nn.utils.rnn.PackedSequence:
        emb = self.embedding(seq)
        lengths = (seq != token2idx["<pad>"]).sum(dim=1)
        return torch.nn.utils.rnn.pack_padded_sequence(
            input=emb, lengths=lengths,
            batch_first=True, enforce_sorted=False,
        )
    
    @staticmethod
    def _pad_packed_sequence(
        sequence: torch.nn.utils.rnn.PackedSequence,
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        return torch.nn.utils.rnn.pad_packed_sequence(
            sequence=sequence,
            batch_first=True,
            padding_value=token2idx["<pad>"],
        )

    @staticmethod
    def _get_skip_thoughts(
        encoder_output: torch.nn.utils.rnn.PackedSequence,
    ) -> torch.Tensor:
        encoder_output, lengths = Seq2SeqLSTM._pad_packed_sequence(sequence=encoder_output)
        return torch.index_select(
            input=encoder_output,
            dim=1,
            index=lengths - 1,
        ).mean(dim=1).unsqueeze(dim=0)

In [None]:
model = Seq2SeqLSTM(
    num_embeddings=len(token2idx.token2idx),
    embedding_dim=config["EMBEDDING_DIM"],
    encoder_hidden_size=config["ENCODER_HIDDEN_SIZE"],
    encoder_num_layers=config["ENCODER_NUM_LAYERS"],
    encoder_dropout=config["ENCODER_DROPOUT"],
    encoder_bidirectional=config["ENCODER_BIDIRECTIONAL"],
    decoder_num_layers=config["DECODER_NUM_LAYERS"],
    decoder_dropout=config["DECODER_DROPOUT"],
).to(device)

In [None]:
number_of_parameters(model)

In [None]:
seq, inv_seq = next(iter(train_dataloader))
seq.to(device), inv_seq.to(device)
seq.shape, inv_seq.shape

In [None]:
output = model(seq, inv_seq)
output.shape

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config["LEARNING_RATE"],
)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
criterion(output.transpose(1, 2), inv_seq)

#### train

In [None]:
def train_epoch(
    model: Seq2SeqLSTM,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    writer: SummaryWriter,
    device: torch.device,
    epoch: int,
) -> None:
    """
    One training cycle (loop).
    
    Args:
        model (Seq2SeqLSTM): model.
        dataloader (torch.utils.data.DataLoader): dataloader.
        optimizer (torch.optim.Optimizer): optimizer.
        criterion (torch.nn.Module): criterion.
        writer (SummaryWriter): tensorboard writer.
        device (torch.device): cpu or cuda.
        epoch (int): number of current epochs.
    """

    model.train()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    for i, (seq, inv_seq) in tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        desc="loop over train batches",
    ):

        seq, inv_seq = seq.to(device), inv_seq.to(device)

        optimizer.zero_grad()

        scores = model(seq, inv_seq)
        loss = criterion(scores.transpose(1, 2), inv_seq)

        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())
        writer.add_scalar(
            "batch loss / train", loss.item(), epoch * len(dataloader) + i
        )

        # with torch.no_grad():
        #     model.eval()
        #     scores_inference = model(seq, inv_seq)
        #     model.train()

        # batch_metrics = compute_metrics(
        #     outputs=scores_inference,
        #     targets=tgt,
        # )

        # for metric_name, metric_value in batch_metrics.items():
        #     batch_metrics_list[metric_name].append(metric_value)
        #     writer.add_scalar(
        #         f"batch {metric_name} / train",
        #         metric_value,
        #         epoch * len(dataloader) + i,
        #     )

    avg_loss = np.mean(epoch_loss)
    print(f"Train loss: {avg_loss}\n")
    writer.add_scalar("loss / train", avg_loss, epoch)

    for metric_name, metric_value_list in batch_metrics_list.items():
        metric_value = np.mean(metric_value_list)
        print(f"Train {metric_name}: {metric_value}\n")
        writer.add_scalar(f"{metric_name} / train", metric_value, epoch)

In [None]:
def evaluate_epoch(
    model: Seq2SeqLSTM,
    dataloader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    writer: SummaryWriter,
    device: torch.device,
    epoch: int,
) -> None:
    """
    One evaluation cycle (loop).

    Args:
        model (Seq2SeqLSTM): model.
        dataloader (torch.utils.data.DataLoader): dataloader.
        criterion (torch.nn.Module): criterion.
        writer (SummaryWriter): tensorboard writer.
        device (torch.device): cpu or cuda.
        epoch (int): number of current epochs.
    """

    model.eval()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    with torch.no_grad():

        for i, (seq, inv_seq) in tqdm(
            enumerate(dataloader),
            total=len(dataloader),
            desc="loop over test batches",
        ):

            seq, inv_seq = seq.to(device), inv_seq.to(device)

            scores = model(seq, inv_seq)
            loss = criterion(scores.transpose(1, 2), inv_seq)

            epoch_loss.append(loss.item())
            writer.add_scalar(
                "batch loss / test", loss.item(), epoch * len(dataloader) + i
            )

            # batch_metrics = compute_metrics(
            #     outputs=scores,
            #     targets=tgt,
            # )

            # for metric_name, metric_value in batch_metrics.items():
            #     batch_metrics_list[metric_name].append(metric_value)
            #     writer.add_scalar(
            #         f"batch {metric_name} / test",
            #         metric_value,
            #         epoch * len(dataloader) + i,
            #     )

        avg_loss = np.mean(epoch_loss)
        print(f"Test loss:  {avg_loss}\n")
        writer.add_scalar("loss / test", avg_loss, epoch)

        for metric_name, metric_value_list in batch_metrics_list.items():
            metric_value = np.mean(metric_value_list)
            print(f"Test {metric_name}: {metric_value}\n")
            writer.add_scalar(f"{metric_name} / test", metric_value, epoch)

In [None]:
def train(
    n_epochs: int,
    model: Seq2SeqLSTM,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    writer: SummaryWriter,
    device: torch.device,
) -> None:
    """
    Training loop.
    
    Args:
        n_epochs (int): number of epochs to train.
        model (Seq2SeqLSTM): model.
        train_dataloader (torch.utils.data.DataLoader): train_dataloader.
        test_dataloader (torch.utils.data.DataLoader): test_dataloader.
        optimizer (torch.optim.Optimizer): optimizer.
        criterion (torch.nn.Module): criterion.
        writer (SummaryWriter): tensorboard writer.
        device (torch.device): cpu or cuda.
    """

    for epoch in range(n_epochs):

        print(f"Epoch [{epoch+1} / {n_epochs}]\n")

        train_epoch(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            writer=writer,
            device=device,
            epoch=epoch,
        )
        evaluate_epoch(
            model=model,
            dataloader=test_dataloader,
            criterion=criterion,
            writer=writer,
            device=device,
            epoch=epoch,
        )

In [None]:
train(
    n_epochs=config["N_EPOCHS"],
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    writer=writer,
    device=device,
)

In [None]:
# TODO: calculate loss correct without <pad> tokens
# TODO: add task-specific metrics
# TODO: fix num layers