In [None]:
from functools import reduce
from functools import lru_cache
from typing import List

from cachetools import cached, TTLCache

import datasets
import json
from pathlib import Path
import evaluate
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.notebook import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Util functions

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state  #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
def compute_clm_loss(logits, labels):
    # Classical Language modeling task (nope)
    # Next token prediction task in causual setup
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
    return loss

In [None]:
def compute_shifted_cosine_mse_loss(input_logits, output_logits, padding_mask):

    losses = []

    batch_size = padding_mask.shape[0]
    for batch_idx in range(batch_size):
        preds = input_logits[batch_idx, :][~padding_mask[batch_idx, :]]
        targets = output_logits[batch_idx, :][~padding_mask[batch_idx, :]]

        shift_input = preds[:-1, :].contiguous()
        shift_output = targets[1:, :].contiguous()

        loss = (1 - F.cosine_similarity(shift_input, shift_output, -1)).mean() + F.mse_loss(shift_input, shift_output)
        losses.append(loss)

    return torch.stack(losses).contiguous().mean()

In [None]:
def compute_shifted_cross_l2_loss(input_logits, output_logits, padding_mask):
    losses = []

    batch_size = padding_mask.shape[0]
    for batch_idx in range(batch_size):

        preds = input_logits[batch_idx, :][~padding_mask[batch_idx, :]]
        targets = output_logits[batch_idx, :][~padding_mask[batch_idx, :]]

        orig_distances = torch.cdist(targets[1:], targets[1:]).detach()
        pred_distances = torch.cdist(preds[:-1], preds[:-1])

        idx = torch.triu_indices(*orig_distances.shape)
        loss = F.mse_loss(pred_distances[idx[0], idx[1]].view(-1), orig_distances[idx[0], idx[1]].view(-1))
        losses.append(loss)

    return torch.stack(losses).contiguous().mean()

In [None]:
def compute_cosine_mse_loss(input_logits, output_logits, padding_mask):

    losses = []

    batch_size = padding_mask.shape[0]
    for batch_idx in range(batch_size):
        preds = input_logits[batch_idx, :][~padding_mask[batch_idx, :]]
        targets = output_logits[batch_idx, :][~padding_mask[batch_idx, :]]
        # print(preds.shape, targets.shape)

        loss = (1 - F.cosine_similarity(preds, targets, -1)).mean() + F.mse_loss(preds, targets)
        losses.append(loss)

    return torch.stack(losses).contiguous().mean()

#### Losses Playground

In [None]:
example_output = torch.rand([5, 8, 100], dtype=torch.float)
example_input = example_output.clone()
example_input[:, :-1] = example_output[:, 1:]
# compute_shifted_mse_loss(example_input, example_output, torch.zeros(5, 8).bool())

In [None]:
F.mse_loss(example_input[0][:-1], example_output[0][1:])

In [None]:
F.mse_loss(torch.cdist(example_output[0][1:], example_output[0][1:]).view(-1),
           torch.cdist(example_input[0][:-1], example_input[0][:-1]).view(-1))

In [None]:
torch.cdist(example_input[0][:-1], example_input[0][:-1])

In [None]:
torch.cdist(example_output[0][1:], example_output[0][1:])

In [None]:
torch.cdist(example_input[0][:-1], example_output[0][1:])

## Using allenai/soda dataset
This dataset contains dialogs, speakers labels, each dialog is a separate list of lines

In [None]:
soda_dataset = datasets.load_dataset('allenai/soda')
soda_dataset

In [None]:
soda_dataset = soda_dataset.remove_columns([col for col in soda_dataset['train'].column_names if col not in ['dialogue', 'speakers']])
soda_dataset

In [None]:
def encode_interlocutors(row):
    first_speaker = row['speakers'][0]
    other_speaker = list(set(row['speakers']) - {first_speaker})[0]
    mapping = {
        first_speaker: 2,
        other_speaker: 3
    }
    fixed_speakers = list(map(lambda name: mapping.get(name, 2), row['speakers']))
    return {
        'speakers': fixed_speakers
    }

In [None]:
# {'ek': 34, 'mf3': 54}.get(1, 10)

In [None]:
soda_dataset = soda_dataset.map(encode_interlocutors, num_proc=11)
soda_dataset

In [None]:
soda_dataset = soda_dataset.rename_columns({'dialogue': 'dialog', 'speakers': 'interlocutors'})
soda_dataset

In [None]:
soda_dataset = soda_dataset.filter(lambda row: len(row['dialog']) == len(row['interlocutors']) and len(row['dialog']) > 1)
soda_dataset

In [None]:
speakers_counts = [len(set(x)) for x in soda_dataset['validation']['interlocutors']]
plt.hist(speakers_counts)
plt.show()

In [None]:
dialogs_lengths = [len(x) for x in soda_dataset['test']['dialog']]
plt.hist(dialogs_lengths)
plt.show()

## Phrase encoder model

In [None]:
# phrase_model = 'roberta-base'
# phrase_model = 'microsoft/deberta-v3-base'
# phrase_model = 'sentence-transformers/all-MiniLM-L12-v2'
# phrase_model = 'sentence-transformers/bert-base-nli-mean-tokens'
# phrase_model = 'intfloat/e5-base'
# phrase_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
phrase_model = 'sentence-transformers/paraphrase-mpnet-base-v2'
# phrase_model = 'sentence-transformers/sentence-t5-base'
# tokenizer = AutoTokenizer.from_pretrained(phrase_model)
# model = AutoModel.from_pretrained(phrase_model).to(device)
# model
sent_transformer = SentenceTransformer(model_name_or_path=phrase_model, device=device).eval().half()
sent_transformer

In [None]:
sent_transformer.max_seq_length = 256
sent_transformer

In [None]:
test_phrases = ['Some day i will go to school',
                'To make maximum progress on addressing these pressing problems',
                'I will nether go to school',
                'I like to visit school',
                'The day will come when i will go to school']

phrases_encodings = sent_transformer.encode(test_phrases, convert_to_tensor=True, normalize_embeddings=False)
phrases_encodings.shape

In [None]:
cosine_similarity(phrases_encodings.cpu(), phrases_encodings.cpu())

## Dialog encoder model

Концептуально тут нужно:
- модель-кодировщик фраз (замоороженная) -> готовые эмбединги текста
- эмбединги участников диалога
- эмбединги позиции текста в диалоге
- кастомный токенизер c BOS и EOS
- causual lm crossentropy loss
- causual маска атеншена
- финальный классификатор в условный словарь

In [None]:
class DialogEmbeddings(nn.Module):
    def __init__(self, encoder_hidden_dim: int,
                 max_interlocutors_count: int,
                 max_dialogue_length: int,
                 dropout_p: float):
        super(DialogEmbeddings, self).__init__()

        self.padding_idx = 0  # special index for padding (must be in tokenizer)

        self.position_embeddings = nn.Embedding(max_dialogue_length + 1,  # padding
                                                encoder_hidden_dim, padding_idx=self.padding_idx)
        self.interlocutors_embeddings = nn.Embedding(max_interlocutors_count + 2,  # padding, eos, bos
                                                     encoder_hidden_dim, padding_idx=self.padding_idx)
        self.norm = nn.LayerNorm(encoder_hidden_dim)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, interlocutors_ids: torch.LongTensor, position_ids: torch.LongTensor = None):
        if position_ids is None:
            position_ids = self.create_position_ids_from_input_ids(interlocutors_ids)

        interlocutors_embeds = self.interlocutors_embeddings(interlocutors_ids)
        position_embeds = self.position_embeddings(position_ids)

        embeddings = interlocutors_embeds + position_embeds
        embeddings = self.norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

    def create_position_ids_from_input_ids(self, input_ids):
        """
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
        are ignored. This is modified from fairseq's `utils.make_positions`. :param torch.Tensor x: :return torch.Tensor:
        """
        # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
        mask = input_ids.ne(self.padding_idx).int()
        incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
        return incremental_indices.long() + self.padding_idx

In [None]:
class DialogOutput(nn.Module):
    def __init__(self,
                 encoder_hidden_dim: int,
                 dim_feedforward_mult: int = 3,
                 dropout_p: float = 0.1):
        super(DialogOutput, self).__init__()

        self.inner_proj = nn.Linear(encoder_hidden_dim, dim_feedforward_mult * encoder_hidden_dim)
        self.norm = nn.LayerNorm(dim_feedforward_mult * encoder_hidden_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.out_proj = nn.Linear(dim_feedforward_mult * encoder_hidden_dim, encoder_hidden_dim)
        # self.norm2 = nn.LayerNorm(encoder_hidden_dim)

    def forward(self, inp: torch.FloatTensor):
        x = F.gelu(self.inner_proj(inp))
        x = self.norm(self.dropout(x))
        x = self.out_proj(x)
        x = x + inp
        return x

In [None]:
class DialogTransformer(nn.Module):
    def __init__(self, encoder_hidden_dim: int,
                 max_dialogue_length: int,
                 max_interlocutors_count: int,
                 labels_count: int,
                 decoder_n_layers: int = 2,
                 decoder_n_head: int = 4,
                 dim_feedforward_mult: int = 3,
                 dropout_p: float = 0.1):
        super(DialogTransformer, self).__init__()

        # self.bos_vector = nn.Parameter(torch.randn([encoder_hidden_dim]), requires_grad=True)
        self.eos_vector = nn.Parameter(torch.randn([encoder_hidden_dim]), requires_grad=True)

        self.input_norm = nn.LayerNorm(encoder_hidden_dim)

        self.labels_embeddings = nn.Embedding(labels_count + 2, encoder_hidden_dim)
        self.dialogue_embeddings = DialogEmbeddings(encoder_hidden_dim, max_interlocutors_count,
                                                    max_dialogue_length, dropout_p)

        decoder_ff_inner_dim = encoder_hidden_dim * dim_feedforward_mult
        layer = nn.TransformerEncoderLayer(d_model=encoder_hidden_dim,
                                           nhead=decoder_n_head,
                                           dim_feedforward=decoder_ff_inner_dim,
                                           activation=F.gelu,  # using gelu instead of default relu
                                           dropout=dropout_p,
                                           batch_first=True)  # using encoder layers due to not a seq2seq setup
        self.model = nn.TransformerEncoder(layer, decoder_n_layers)

        # self.lstm_model = nn.LSTM(input_size=encoder_hidden_dim,
        #                           hidden_size=decoder_ff_inner_dim,
        #                           num_layers=decoder_n_layers,
        #                           bidirectional=False,
        #                           dropout=dropout_p,
        #                           batch_first=True)

        self.labels_projector = nn.Linear(in_features=encoder_hidden_dim, out_features=labels_count+1, bias=True)

    def forward(self, encodings: torch.FloatTensor,
                interlocutors_ids: torch.LongTensor,
                labels_ids: torch.LongTensor,
                position_ids: torch.LongTensor = None,
                attention_mask: torch.BoolTensor = None,
                return_loss = True):
        """
        :param encodings: Pooled hiddens from sentence-transformer in shape [bs, lines_count, hidden_dim]
        :param labels: Labels for dialog lines
        :param interlocutors_ids: shape [bs, seq_len], interlocutors for each line (from one)
        :param position_ids: shape [bs, seq_len], position of line in dialogue (from one)
        :param attention_mask: shape [bs, seq_len], attention mask for padding where 1 is disabled and 0 is enabled
        """

        batch_size = encodings.shape[0]

        representation = torch.cat([
            # self.bos_vector.repeat([batch_size, 1, 1]),
            encodings,
            self.eos_vector.repeat([batch_size, 1, 1])
        ], dim=1)  # insert bos and eos vector

        x = self.input_norm(representation)
        x = x + self.labels_embeddings(labels_ids)
        x = x + self.dialogue_embeddings(interlocutors_ids=interlocutors_ids, position_ids=position_ids)

        if attention_mask is None:
            attention_mask = torch.zeros([batch_size, x.shape[1]]).bool().to(x.device)

        causal_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)  # only attend to past (not necessary, but logical...)
        x = self.model.forward(src=x,
                               mask=causal_mask,
                               src_key_padding_mask=attention_mask)

        # x, (ht, ct) = self.lstm_model(x)

        predicted_labels = self.labels_projector(x)

        if return_loss:
            copied_labels = labels_ids.clone() - 1
            copied_labels[copied_labels < 0] = -100

            labels_loss = compute_clm_loss(predicted_labels, copied_labels)

            return labels_loss, predicted_labels

        return predicted_labels

In [None]:
torch.triu(torch.ones(10, 10), diagonal=1).bool()

In [None]:
dialo_transformer = DialogTransformer(encoder_hidden_dim=768,
                                      max_dialogue_length=50,
                                      max_interlocutors_count=2,
                                      decoder_n_layers=3,
                                      decoder_n_head=4,
                                      dim_feedforward_mult=4,
                                      dropout_p=0.15
                                      ).to(sent_transformer.device).eval()
dialo_transformer

In [None]:
summary(dialo_transformer)

## Dialog Tokenizer

In [None]:
from transformers import GPTJForCausalLM

GPTJForCausalLM

In [None]:
class DialogTokenizer:
    """
    Accepts dicts with keys: 'dialog' - required, 'interlocutors' and 'labels'
    Must return dict with 'encoder_hidden', 'interlocutors_ids' and 'labels'
    """

    def __init__(self, lines_encoder: SentenceTransformer,
                 all_labels: list = None,
                 all_interlocutors: list = None):

        self.lines_encoder = lines_encoder
        for p in self.lines_encoder[0].parameters():
            p.requires_grad = False

        self.padding_idx = 0
        # self.bos_idx = 1
        self.eos_idx = 1

        if all_interlocutors:
            all_interlocutors = set(all_interlocutors)
            self.id2interlocutors = dict(zip(range(2, len(all_interlocutors) + 2), all_interlocutors))
            self.interlocutor2id = {v: k for k, v in self.id2interlocutors.items()}

        if all_labels:
            all_labels = set(all_labels)
            self.id2label = dict(zip(range(2, len(all_labels) + 2), all_labels))
            self.label2id = {v: k for k, v in self.id2label.items()}

    def encode(self, dialog: List[str],
               labels: list,
               interlocutors: List[int] = None,
               lines_batch_size: int = 50,
               unsqueeze: bool = True,
               return_loss: bool = True,
               device: str = None,
               **kwargs):

        encodings = self.lines_encoder.encode(sentences=dialog,
                                              batch_size=lines_batch_size,
                                              normalize_embeddings=False,  # better not to normalize
                                              show_progress_bar=False,
                                              convert_to_tensor=True)
        # encodings = encodings.cpu()
        # encodings.requires_grad = False

        if interlocutors is None:
            interlocutors = [(i % 2) + 2 for i in range(len(dialog))]
        elif hasattr(self, 'interlocutor2id'):
            interlocutors = list(map(self.interlocutor2id.get, interlocutors))

        if hasattr(self, 'label2id'):
            labels = list(map(self.label2id.get, labels))

        labels = labels + [self.eos_idx]
        interlocutors = interlocutors + [self.eos_idx]

        save_device = device if device is not None else encodings.device

        result = {
            'encodings': encodings.unsqueeze(0).to(save_device) if unsqueeze else encodings.to(save_device),
            'labels_ids': torch.LongTensor([labels] if unsqueeze else labels).to(save_device),
            'interlocutors_ids': torch.LongTensor([interlocutors] if unsqueeze else interlocutors).to(save_device),
            'return_loss': return_loss
        }

        return result

    def encode_batch(self, dialog: List[List[str]],
                     labels: List[list],
                     interlocutors: List[List[int]] = None,
                     lines_batch_size: int = 50,
                     return_loss: bool = True,
                     **kwargs):

        if interlocutors is None:
            interlocutors = [None] * len(dialog)
        unsqueeze = [False] * len(dialog)

        assert len(dialog) == len(interlocutors)

        zipped = zip(dialog, labels, interlocutors, [lines_batch_size] * len(dialog), unsqueeze)
        encoded_batch = list(map(lambda x: self.encode(*x), zipped))

        encodings = pad_sequence([encode_dict['encodings'] for encode_dict in encoded_batch],
                                 batch_first=True,
                                 padding_value=self.padding_idx)
        encodings.requires_grad = False

        labels_ids = pad_sequence([encode_dict['labels_ids'] for encode_dict in encoded_batch],
                                         batch_first=True,
                                         padding_value=self.padding_idx)

        interlocutors_ids = pad_sequence([encode_dict['interlocutors_ids'] for encode_dict in encoded_batch],
                                         batch_first=True,
                                         padding_value=self.padding_idx)

        lengths = [len(dial) + 1 for dial in dialog]  # keep in mind bos and eos
        masks = list(map(lambda x: torch.zeros(size=[x]), lengths))
        attention_masks = pad_sequence(masks, batch_first=True, padding_value=1).bool()

        result = {
            'encodings': encodings.to(self.lines_encoder.device),
            'labels_ids': labels_ids.to(self.lines_encoder.device),
            'interlocutors_ids': interlocutors_ids.to(self.lines_encoder.device),
            'attention_mask': attention_masks.to(self.lines_encoder.device),
            'return_loss': return_loss
        }

        return result

    def encode_cached_batch(self,
                            dialog: List[List[str]],
                            encodings: List[List[float]],
                            interlocutors_ids: List[List[int]],
                            return_loss: List[bool],
                            **kwargs):

        encodings = pad_sequence(list(map(torch.FloatTensor, encodings)),
                                 batch_first=True,
                                 padding_value=self.padding_idx)
        encodings.requires_grad = False

        interlocutors_ids = pad_sequence(list(map(torch.LongTensor, interlocutors_ids)),
                                         batch_first=True,
                                         padding_value=self.padding_idx)

        lengths = [len(dial) + 1 for dial in dialog]  # keep in mind bos and eos
        masks = list(map(lambda x: torch.zeros(size=[x]), lengths))
        attention_masks = pad_sequence(masks, batch_first=True, padding_value=1).bool()

        result = {
            'encodings': encodings.to(self.lines_encoder.device),
            'interlocutors_ids': interlocutors_ids.to(self.lines_encoder.device),
            'attention_mask': attention_masks.to(self.lines_encoder.device),
            'return_loss': return_loss
        }

        return result

In [None]:
dialo_tokenizer = DialogTokenizer(sent_transformer)

In [None]:
dialo_encoded = dialo_tokenizer.encode(['Hello man', 'Goodbye', 'Thanks'])
dialo_encoded

In [None]:
dialo_transformer.forward(**dialo_encoded)

In [None]:
dialogs_encoded = dialo_tokenizer.encode_batch(**soda_dataset['train'][:3])
dialogs_encoded

In [None]:
dialo_transformer.forward(**dialogs_encoded)

### Pretokenize all dataset

In [None]:
soda_dataset = soda_dataset.map(lambda row: dialo_tokenizer.encode(**row, unsqueeze=False, device='cpu'), num_proc=1)
soda_dataset

## Training

In [None]:
# Accepts list of dialog dicts per batch
def collate_batch(batch: list):
    v = {k: [dic[k] for dic in batch] for k in batch[0].keys()}  # list of dicts to dict of lists
    return v

In [None]:
train_dataloader = DataLoader(soda_dataset['train'],
                              collate_fn=collate_batch,
                              shuffle=True, batch_size=256)
eval_dataloader = DataLoader(soda_dataset['validation'],
                             collate_fn=collate_batch,
                             shuffle=True, batch_size=512)
test_dataloader = DataLoader(soda_dataset['test'],
                             collate_fn=collate_batch,
                             shuffle=True, batch_size=512)

In [None]:
from torch.optim import AdamW
from transformers import get_scheduler

In [None]:
optimizer = AdamW(dialo_transformer.parameters(), lr=5e-4)

In [None]:
num_epochs = 20
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="cosine_with_restarts", optimizer=optimizer, num_warmup_steps=10,
                             num_training_steps=num_training_steps)

In [None]:
def evaluate(model: DialogTransformer, tokenizer: DialogTokenizer, data_loader):
    model.eval()
    losses = []
    for batch in tqdm(data_loader):
        with torch.inference_mode():
            tokenized_input = tokenizer.encode_cached_batch(**batch)
            loss, i1, i2, i3 = dialo_transformer.forward(**tokenized_input)
            del i1, i2, i3
        losses.append(loss.detach().item())
    return losses

In [None]:
np.array(evaluate(dialo_transformer, dialo_tokenizer, eval_dataloader)).mean()

In [None]:
def create_experiment_info(best_epoch_n, best_train_loss, best_eval_loss, losses_history = None, file_name='experiment_info.json'):
    result_dict = {
        'best_epoch': {
            'number': best_epoch_n,
            'train_loss': best_train_loss,
            'eval_loss': best_eval_loss
        }
    }
    if losses_history is not None:
        result_dict['history'] = losses_history
    with open(file_name, 'w') as outfile:
        json.dump(result_dict, outfile, indent=4)

In [None]:
def train(model: DialogTransformer, tokenizer: DialogTokenizer, checkpoints_dir):
    Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
    losses_history = {
        'train': [],
        'eval': []
    }
    progress_bar = tqdm(range(num_training_steps))
    min_eval_loss = 999999999.9
    for epoch in range(num_epochs):
        model.train()
        print(f'Starting epoch {epoch}...')
        train_losses = []

        for batch in train_dataloader:
            tokenized_input = tokenizer.encode_cached_batch(**batch)
            loss, i1, i2, i3 = dialo_transformer.forward(**tokenized_input)
            loss.backward()

            train_losses.append(loss.item())

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)

            del i2, i3

        train_loss = np.array(train_losses).mean()
        eval_loss = np.array(evaluate(dialo_transformer, dialo_tokenizer, eval_dataloader)).mean()

        losses_history['train'].append(train_loss)
        losses_history['eval'].append(eval_loss)

        print(f'[TRAIN] Mean epoch loss: {train_loss}')
        print(f'[EVAL] Mean epoch loss: {eval_loss}')

        if eval_loss < min_eval_loss:
            save_path = checkpoints_dir + 'best_model.pth'
            print(f'Current best on eval, saving model to {save_path}...')
            torch.save(model, save_path)
            create_experiment_info(best_epoch_n=epoch, best_train_loss=train_loss, best_eval_loss=eval_loss,
                                   losses_history=losses_history, file_name=checkpoints_dir + 'experiment_info.json')
            min_eval_loss = eval_loss

    return losses_history

In [None]:
losses_history = train(dialo_transformer, dialo_tokenizer, './experiments/soda_pmpn_eos_shift_tr_3l4h_COS&MSE_d0.15_21M/')

In [None]:
dialo_transformer = torch.load('./experiments/pmpn_1l_4h_COS&MSE_d0.1/best_model.pth').eval()
dialo_transformer

In [None]:
evaluate(dialo_transformer, dialo_tokenizer, test_dataloader)

## Tests

In [None]:
dd_dataset['train']['dialog'][14]

In [None]:
test_encoded = dialo_tokenizer.encode(dd_dataset['train']['dialog'][14])
test_encoded

In [None]:
with torch.inference_mode():
    test_output = dialo_transformer.forward(**test_encoded)
test_output

In [None]:
F.mse_loss(test_encoded['encodings'][0][1:].detach(),
           test_encoded['encodings'][0][:-1].detach(), reduction='none').mean(-1)

In [None]:
F.mse_loss(test_output[1][0].detach(),
           test_encoded['encodings'][0].detach(), reduction='none').mean(-1)

In [None]:
torch.cdist(test_encoded['encodings'][0].detach(), test_encoded['encodings'][0].detach())

In [None]:
cosine_similarity(test_encoded['encodings'][0].cpu(), test_encoded['encodings'][0].cpu())

In [None]:
torch.cdist(test_output[2][0], test_output[2][0])

In [None]:
cosine_similarity(test_output[2][0].cpu(), test_output[2][0].cpu())

In [None]:
torch.cdist(test_output[2][0], test_encoded['encodings'][0])

In [None]:
cosine_similarity(test_output[2][0][:-1].cpu(), test_encoded['encodings'][0].cpu())

In [None]:
torch.cdist(test_output[1][0][1:], test_encoded['encodings'][0][:-1])

In [None]:
F.mse_loss(test_output[1][0][1:].detach(),
           test_output[1][0][:-1].detach(), reduction='none').mean(-1)

In [None]:
test_output[3].softmax(-1).argmax(-1)