This notebook contains experimental training of LSTM with pytorch lightning

There is still a bug when shuffling data for training (see dataloader) but it seems to work better without shuffling

In [None]:
import pytorch_lightning
from pytorch_lightning import Trainer, LightningModule
import torch
from data_classes.IMDB import IMDBClass
import torchtext

In [None]:
glove_vec = torchtext.vocab.GloVe(max_vectors=10000)
glove_vocab = torchtext.vocab.vocab(glove_vec.stoi)
unk_token = "<unk>"
unk_index = 0
# pad_token = "<pad>"
# pad_index = 9999
glove_vocab.insert_token(unk_token, unk_index)
# glove_vocab.insert_token(pad_token, pad_index)

# #this is necessary otherwise it will throw runtime error if OOV token is queried 
glove_vocab.set_default_index(glove_vocab[unk_token])

In [None]:
glove_vocab["kejslskgjfd"]

In [None]:
len(glove_vec)

In [None]:
len(glove_vocab)

In [None]:
glove_vocab.lookup_token(9999)

In [None]:
glove_vec["<unk>"]

In [None]:
glove_vocab["<pad>"]

In [None]:
train_dataset = IMDBClass(train=True, transform=glove_vocab)
test_dataset = IMDBClass(train=False, transform=glove_vocab)

In [None]:
from torch.utils.data import DataLoader
# For use in DataLoader
def collate_fn(batch):
    x = [item[0] for item in batch]

    # maxes = []
    # for it in x:
    #     maxes.append(it.max())
    # print(torch.tensor(maxes).max())

    lengths = torch.LongTensor(list(map(len, x)))
    x = pad_sequence(x, batch_first=True)
    y = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return x, y, lengths

# Warning: there is a bug somewhere in the code, when shuffle=True the model doesn't learn anymore
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=collate_fn) # collate_fn=collate_fn
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True, collate_fn=collate_fn) # collate_fn=collate_fn

In [None]:
len(test_dataset)

In [None]:
len(glove_vocab)

In [None]:
glove_vocab.lookup_indices(["the"])

In [None]:
from torch.optim import Adam
from torch import nn
from torch.nn.utils.rnn import pad_sequence

nn.Module
class LSTM(LightningModule):
    def __init__(self, vocab_size, embedding_size=64, lstm_hidden_size=100, num_class=2, batch_size=32, learning_rate=0.001, vocab=None, vectors=None):
        super().__init__()
        if vocab is None:
            self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        else:
            # self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx=vocab["<pad>"])
            self.embedding = torch.nn.Embedding.from_pretrained(vectors.vectors, freeze=True, padding_idx=vocab["<pad>"])
        self.lstm = nn.LSTM(embedding_size, lstm_hidden_size, batch_first=True)
        self.linear = nn.Linear(lstm_hidden_size, num_class)
        self.loss_function = nn.CrossEntropyLoss()
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.lstm_hidden_size = lstm_hidden_size
    
    def forward(self, X: torch.Tensor, lengths: torch.LongTensor):
        x = self.embedding(X)
        x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths=lengths.to("cpu"), enforce_sorted=False, batch_first=True)
        _, (hn, _) = self.lstm(x)
        # hn  = hn.view(hn.size(0), -1)
        hn = hn[-1,:,:]
        # hn = hn.transpose(0, 1) # batch_first
        # hn = hn[:, -1:].flatten(1) # last layers only
        # x = nn.functional.relu(hn)#.hnsqueeze(1)
        x = self.linear(hn)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y, lengths = batch
        y_hat = self(x, lengths)
        loss = self.loss_function(y_hat, y)
        self.log("Train loss", loss.detach())
        return loss
           
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-2)
    
    # def train_dataloader(self):
    #     return train_iter
    
    def test_step(self, batch, batch_idx):
        x, y, lengths = batch
        # x = pad_sequence(x, batch_first=True)
        y_hat = self(x, lengths)
        loss = self.loss_function(y_hat, y)
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = torch.sum(labels_hat == y).item() / (len(y) * 1.0)
        return self.log_dict({'test_loss': loss, 'test_acc': test_acc})
    
    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn) # collate_fn=collate_fn

    
    def test_dataloader(self):
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn)
    
    # def test_epoch_end(self, outputs):
    #     avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
    #     tensorboard_logs = dict(
    #         test_loss=avg_loss
    #     )
    #     return dict(
    #         avg_test_loss=avg_loss, 
    #         log=tensorboard_logs
    #     )
    
    # def test_dataloader(self):
    #     return test_iter

In [None]:
vocab_size = train_dataset.vocab_size()
model = LSTM(vocab_size, embedding_size=300, num_class=2, vocab=glove_vocab, vectors=glove_vec)

In [None]:
# xlist = []
# ylist = []
# for i in range(5):
#     xlist.append(train_dataset[i][0])
#     ylist.append(train_dataset[i][1])
# with torch.no_grad():
#     model.training_step([xlist, torch.tensor(ylist, dtype=torch.long)], 0)

In [None]:
# pad_sequence(xlist, batch_first=True)

In [None]:
# xlist[0]

In [None]:
# model.to('cpu')

In [None]:
zlist = []
ylist = []
for i in range(5):
    zlist.append(train_dataset[i][0].clone())
    ylist.append(train_dataset[i][1].clone())
z0 = pad_sequence(zlist, batch_first=True)
y = torch.tensor(ylist, dtype=torch.long)
lengths = torch.LongTensor(list(map(len, zlist)))


In [None]:
# for i in range(10007):
#     model.embedding(torch.LongTensor([1]))

In [None]:
model(z0, lengths)

In [None]:
# z1 = model.embedding(torch.tensor(z0))
# z1.shape

In [None]:
# _, (z2, _) = model.lstm(z1)
# z2.shape

In [None]:
# z3 = z2[0]
# z3

In [None]:
# z4 = z3#.transpose(0, 1)
# z4.shape

In [None]:
# z5 = z4#[:, -1:].flatten(1)
# z5.shape

In [None]:
# z6 = model.linear(z5)
# z6

In [None]:
# loss = nn.functional.cross_entropy(z6, y)
# loss

In [None]:
# def get_batch(idx, batch_size=10):
#     y = [train_dataset[idx*batch_size+i][1] for i in range(batch_size//2)]
#     x = [train_dataset[idx*batch_size+i][0] for i in range(batch_size//2)]

#     y += [train_dataset[12500+idx*batch_size+i][1] for i in range(batch_size//2)]
#     x += [train_dataset[12500+idx*batch_size+i][0] for i in range(batch_size//2)]
#     return pad_sequence(x, batch_first=True), torch.tensor(y, dtype=torch.long)

# i=0
# optimizer = Adam(model.parameters(), lr=0.001)
# epochs = 100
# batch_size = 10
# for e in range(epochs):
#     mean_loss = 0
#     for i in range(10):
#         if i%2:
#             batch = get_batch(i)
#         else:
#             batch = get_batch(i)
#         x, y = batch
#         y_hat = model(x)
#         optimizer.zero_grad()
#         loss = model.loss_function(y_hat, y)
#         loss.backward()
#         optimizer.step()
#         mean_loss += loss.detach()
#     print(mean_loss/10)


    
    # loss = model.loss_function(y_hat, y)


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('tb_logs', name='lstm')
# tensorboard --logdir logs/1

In [None]:
torch.cuda.is_available()


In [None]:
trainer = Trainer(max_epochs=10, gpus=1, auto_select_gpus=True, auto_scale_batch_size=False, auto_lr_find=True, logger=[logger], track_grad_norm=2, 
accumulate_grad_batches=8)
# gpus=1, auto_select_gpus=True


In [None]:
# trainer.tune(model)

In [None]:
model.batch_size

In [None]:
model.learning_rate

In [None]:

trainer.test(model)

In [None]:
trainer.fit(model)

In [None]:
trainer.test(model, test_dataloader)

In [None]:
for item in test_dataset:
    if item[0].max() > 10000:
        print(f"error: index {item[0].max()}")
    if item[0].min() < 0:
        print(f"error: index {item[0].min()}")


In [None]:
# for item in test_dataset:
#     try:
#         model(item[0].unsqueeze(0), lengths = torch.LongTensor(list(map(len, [item[0]]))))
#     except:
#         print(item[0])
#         wrong_item = item[0].clone()
#         break


In [None]:
# model.embedding(torch.LongTensor([10000]))