In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets import load_from_disk
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import torch

In [3]:
import pickle

In [4]:
with open('store/wordvec.pkl', 'rb') as f:
    wordvec = pickle.load(f)
dataset_snli = load_from_disk('data/snli')

In [6]:
from torch import optim, nn, utils, Tensor
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

In [57]:
# for split in dataset_snli:
#     l1, l2 = 0, 0
#     for example in dataset_snli[split]:
#         l1 = max(l1, len(example['premise']))
#         l2 = max(l2, len(example['hypothesis']))
#     print(f'{split}: {l1}, {l2}')

In [7]:

class DataSetPadding():
    def __init__(self, dataset, wordvec, max_length=100):
        self.dataset = dataset
        self.wordvec = wordvec
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def get_embedding(self, sent):
        return torch.stack([self.wordvec[word] for word in sent])
    
    def __getitem__(self, idx):
        example = self.dataset[idx]
        s1, s2, y = example['premise'], example['hypothesis'], example['label']

        s1 = [word if word in wordvec else '<unk>' for word in s1]
        s2 = [word if word in wordvec else '<unk>' for word in s2]
        
        # premise = torch.stack([self.wordvec.get(word, self.wordvec['<unk>']) for word in premise])
        # hypothesis = torch.stack([self.wordvec.get(word, self.wordvec['<unk>']) for word in hypothesis])
        
        len1, len2 = len(s1), len(s2)

        if (len1 > self.max_length) or (len2 > self.max_length):
            raise ValueError('Sentence length exceeds max length')
        

        s1.extend(['<pad>'] * (self.max_length - len1))
        s2.extend(['<pad>'] * (self.max_length - len2))

        e1 = self.get_embedding(s1)
        e2 = self.get_embedding(s2)

        # e1 = torch.cat([e1, self.wordvec['<pad>'].expand(self.max_length - len1, -1)])
        # e2 = torch.cat([e2, self.wordvec['<pad>'].expand(self.max_length - len2, -1)])

        return s1, s2, y, e1, e2, len1, len2
    
class DataSet():
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        assert self.dataset.num_rows == len(self.dataset)
        return self.dataset.num_rows

    def __getitem__(self, idx):
        example = self.dataset[idx]
        s1, s2, y = example['premise'], example['hypothesis'], example['label']

        return s1, s2, y

dataset = DataSetPadding(dataset_snli['train'], wordvec)
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
# s1, s2, y, e1, e2 = dataset[0]


In [67]:


# class WordEmbedding(nn.Module):
#     def __init__(self, wordvec):
#         super().__init__()
#         self.wordvec = wordvec

#     def forward(self, sent):
#         # x is a list of strings
#         # return a tensor of shape (len(sent), emb_dim)
#         print(len(sent), len(sent[0]))
#         return torch.stack([self.wordvec[word] for word in sent])


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_length = 100


    def forward(self, embedding, length):

        assert embedding.shape == (length.shape[0], self.max_length, 300)

        mask = (length.unsqueeze(1) > torch.arange(embedding.shape[1])).float().unsqueeze(2)

        embedding_sum = torch.sum(embedding * mask, dim = 1)
        length_sum = torch.sum(mask, dim = 1)
        mean = embedding_sum / length_sum

        return mean

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)

# define the LightningModule
class LitNLINet(pl.LightningModule):
    def __init__(self, encoder, classifier):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def concat_sentreps(self, sentrep1, sentrep2):
        return torch.cat([sentrep1, sentrep2, torch.abs(sentrep1 - sentrep2), sentrep1 * sentrep2], dim=1)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        s1, s2, y, e1, e2, len1, len2 = batch
        u, v = self.encoder(e1, len1), self.encoder(e2, len2)

        features = self.concat_sentreps(u, v)

        # compute loss
        y_hat = self.classifier(features)
        print(y_hat, y)
        loss = nn.functional.cross_entropy(y_hat, y)

        # Logging to TensorBoard by default
        self.log("train_loss", loss)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer



encoder = Baseline()
classifier = MLP(300*4, 512, 3)

model = LitNLINet(encoder, classifier)

trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, train_loader)


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type     | Params
----------------------------------------
0 | encoder    | Baseline | 0     
1 | classifier | MLP      | 879 K 
----------------------------------------
879 K     Trainable params
0         Non-trainable params
879 K     Total params
3.516     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

tensor([[ 0.0753,  0.0078,  0.0129],
        [ 0.0688,  0.0373,  0.0416],
        [ 0.0732,  0.0084,  0.0255],
        [ 0.0617,  0.0330,  0.0294],
        [ 0.0694,  0.0380,  0.0062],
        [ 0.1005,  0.0499,  0.0231],
        [ 0.0650,  0.0743,  0.0307],
        [ 0.0536,  0.0313,  0.0103],
        [ 0.0562,  0.0226,  0.0344],
        [ 0.0587,  0.0352,  0.0186],
        [ 0.0755,  0.0595,  0.0451],
        [ 0.0674,  0.0452,  0.0255],
        [ 0.0546,  0.0506,  0.0563],
        [ 0.0683,  0.0465,  0.0160],
        [ 0.0690,  0.0428,  0.0305],
        [ 0.0713,  0.0398,  0.0238],
        [ 0.1160,  0.0222,  0.0152],
        [ 0.0615,  0.0523, -0.0045],
        [ 0.0616,  0.0208,  0.0520],
        [ 0.0515,  0.0400,  0.0025],
        [ 0.0852,  0.0775,  0.0163],
        [ 0.0595,  0.0078,  0.0361],
        [ 0.0762,  0.0268,  0.0335],
        [ 0.0403,  0.0585,  0.0486],
        [ 0.0552,  0.0320,  0.0242],
        [ 0.0485,  0.0505,  0.0360],
        [ 0.0628,  0.0218,  0.0072],
 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [129]:
a = torch.tensor([[[1, 2, 3], [4, 5, 0]]])
ls = 
a.shape

torch.Size([1, 2, 3])

In [15]:
embedding, length = torch.load('embedding.pt'), torch.load('length.pt')

In [17]:
embedding.shape, length

(torch.Size([32, 100, 300]),
 tensor([12, 11, 20, 16, 26, 10, 11, 11,  9,  9,  7, 12, 18, 12,  5, 17, 14, 10,
          7, 15,  9, 12, 28, 24,  8, 12,  7, 17, 22,  7, 15, 10]))

(torch.Size([32, 100, 1]), torch.Size([32, 100, 300]))

In [62]:
torch.mean(embedding[-1][:10], dim = 0)[-3:]

tensor([-0.0245, -0.1977,  0.0781])