In [1]:
!pip install -Uqq datasets torchtext

In [2]:
from datasets import Dataset, load_dataset, get_dataset_split_names
ds_train, ds_val, ds_test = load_dataset("stanfordnlp/snli", split=['train', 'validation', 'test'])

In [3]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchtext.vocab import GloVe, build_vocab_from_iterator
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
def preprocess_item(item):
    item['premise'] = word_tokenize(item['premise'].lower())
    item['hypothesis'] = word_tokenize(item['hypothesis'].lower())
    return item

# # load the datasets or if not available, preprocess them again. 
# import pickle
# try: 
#     with open('all_ds_tok.pickle', 'rb') as handle:
#         all_ds_tok = pickle.load(handle)
    
#     ds_train_tok = all_ds_tok[0]
#     ds_val_tok   = all_ds_tok[1]
#     ds_test_tok  = all_ds_tok[2]
# except: 
ds_train_tok = ds_train.map(preprocess_item)
ds_val_tok = ds_val.map(preprocess_item)
ds_test_tok = ds_test.map(preprocess_item)
all_ds_tok = [ds_train_tok, ds_val_tok, ds_test_tok]

# with open('all_ds_tok.pickle', 'wb') as handle:
#     pickle.dump(all_ds_tok, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

In [5]:
# get a list of all the unique tokens
def get_all_unique_toks(all_ds_tok):
    all_unique_toks = set()
    for ds in all_ds_tok:
        for item in tqdm.tqdm(ds):
            for key in ['premise', 'hypothesis']:
                for tok in item[key]:
                    if not tok in all_unique_toks:
                        all_unique_toks.add(tok)
    return all_unique_toks

In [6]:
import pickle
try: 
    with open('all_unique_toks.pickle', 'rb') as handle:
        all_unique_toks = pickle.load(handle)
except: 
    all_unique_toks = get_all_unique_toks(all_ds_tok)
    with open('all_unique_toks.pickle', 'wb') as handle:
        pickle.dump(all_unique_toks, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [7]:
glove_vectors = GloVe(name='840B')

In [8]:
glove_vocab = build_vocab_from_iterator(
    [iter(all_unique_toks)],
    specials=['<unk>'],
    special_first=True,
)

In [9]:
len(all_unique_toks), len(glove_vocab)

(37210, 37211)

In [10]:
def token_to_index(item):
    item['premise'] = glove_vocab.lookup_indices(item['premise'])    
    item['hypothesis'] = glove_vocab.lookup_indices(item['hypothesis'])
    return item

In [11]:
# load the datasets or if not available, convert to indices again. 
import pickle
try: 
    with open('all_ds_prep.pickle', 'rb') as handle:
        all_ds_prep = pickle.load(handle)
    ds_train_prep = all_ds_prep[0]
    ds_val_prep = all_ds_prep[1]
    ds_test_prep = all_ds_prep[2]
except: 
    # map to index and remove items with label -1
    ds_train_prep = ds_train_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    ds_val_prep   = ds_val_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    ds_test_prep  = ds_test_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    all_ds_prep = [ds_train_prep, ds_val_prep, ds_test_prep]
    
    with open('all_ds_prep.pickle', 'wb') as handle:
        pickle.dump(all_ds_prep, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
# find longest sentence
# max_sent_len = -1
# for ds in all_ds_prep:
#     for item in ds:
#         max_sent_len = max(
#             max_sent_len,
#             len(item['premise']),
#             len(item['hypothesis'])
#         )
# print(max_sent_len)

# max sentence length = 82
max_sent_len = 82

In [13]:
def collate_fn(data):
    # first find the longest sentences in the batch
    max_p, max_h = -1, -1
    for d in data:
        max_p = max(max_p, len(d['premise']))
        max_h = max(max_h, len(d['hypothesis']))
    
    # pad all sentences to same length
    bs = len(data)
    batch_p = torch.zeros((bs,max_p), dtype=torch.int32)
    batch_h = torch.zeros((bs,max_h), dtype=torch.int32)
    batch_l = torch.zeros((bs), dtype=torch.int64)
    
    for i, d in enumerate(data):
        batch_p[i, 0:len(d['premise'])] = torch.tensor(d['premise'])
        batch_h[i, 0:len(d['hypothesis'])] = torch.tensor(d['hypothesis'])
        batch_l[i] = d['label']
        
    return batch_p, batch_h, batch_l

In [14]:
# create dataloaders
from torch.utils.data import DataLoader
bs = 64

train_loader = DataLoader(ds_train_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)
val_loader   = DataLoader(ds_val_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)
test_loader  = DataLoader(ds_test_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)


In [74]:
class EmbeddingModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        # create the embedding layer and freeze the weights
        self.embedding_layer = nn.Embedding.from_pretrained(
            glove_vectors.get_vecs_by_tokens(["<unk>", *list(all_unique_toks)]),
            freeze=True
        )
    
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        # embed both the premise and hypothesis separately
        premises = self.embedding_layer(premises)
        hypotheses = self.embedding_layer(hypotheses)
        return (premises, hypotheses, labels)

In [75]:
class BaselineEncoder(nn.Module):
    def __init__(self):
        super().__init__()
    
    # the baseline encoder takes the average across word embeddings in a sentence
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        premises = premises.mean(axis=1)
        hypotheses = hypotheses.mean(axis=1)
        return (premises, hypotheses, labels)
        

In [79]:
class CombinationModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    # takes u and v (premise and hypothesis) 
    # and returns (u,v | u-v | u*v)
    def forward(self, x):
        # unpack
        u, v, labels = x
        out = torch.hstack([u,v,u-v,u*v])
        return out

In [77]:
class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        return self.layers(x)

In [78]:
class FullModel(nn.Module):
    def __init__(self, encoder, mlp):
        super().__init__()
        self.model = nn.ModuleList([
            EmbeddingModule(),
            encoder,
            CombinationModule(),
            mlp
        ])
    
    def forward(self, x):
        for m in self.model:
            x = m(x)
        return x
    
    def p(self):
        print(self.model[1:])
    
    def save_model(self, filename):
        torch.save(self.model[1:], filename)
    
    def load_model(self, filename):
        layers = torch.load(filename)
        for i, l in enumerate(layers):
            self.model[i+1] = l

In [29]:
from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    losses, accs = [], []
    
    for premises, hypotheses, targets in tqdm(loader):

        premises = premises.cuda()
        hypotheses = hypotheses.cuda()
        targets = targets.cuda()
            
        predictions = model((premises, hypotheses, targets))
        
        optimizer.zero_grad()
        loss = loss_module(predictions, targets)
        acc = (predictions.argmax(axis=-1) == targets).float().mean()
        
        losses.append(loss)
        accs.append(acc)
        
        loss.backward()
        
        optimizer.step()
    
    train_loss = torch.tensor(losses).mean()
    train_acc = torch.tensor(accs).mean()
    return train_loss, train_acc


In [30]:
# measure accuracy
def evaluate(model, loader):
    model.eval()
    total_correct = 0.
    total = 0.
    bs = loader.batch_size
    
    
    for premises, hypotheses, targets in loader:
        premises = premises.cuda()
        hypotheses = hypotheses.cuda()
        targets = targets.cuda()
        
        with torch.no_grad():
            predictions = model((premises, hypotheses, targets)).argmax(axis=-1)
        total_correct += (predictions==targets).float().sum()
        total += bs
        
    acc = total_correct / total
    return acc

In [38]:
def update_lr(optimizer, new_lr):
    # update the learning rate for the optimizer
    for g in optimizer.param_groups:
        g['lr'] = new_lr

def train_loop(model, optimizer, train_loader, val_loader, checkpoint_path):
    lr = 0.1
    last_acc, best_acc = -1, -1
    epoch = -1
    
    while lr > 1e-5:
        epoch += 1
        print(f'lr: {lr}')
        #train and evaluate
        train_loss, train_acc = train_epoch(model, train_loader)
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Acc/train", train_acc, epoch)

        acc = evaluate(model, val_loader)
        writer.add_scalar("Acc/eval", acc, epoch)
        print(f'acc: {acc}')
        
        
        # learning rate decay
        lr = lr * 0.99
        update_lr(optimizer, lr)
        
        # save best checkpoint if necessary
        if acc > best_acc:
            best_acc = acc
            model.save_model(checkpoint_path)
        
        # if val acc goes down, divide lr by 5
        if acc < last_acc:
            lr = lr / 5.
            update_lr(optimizer, lr)
            
        last_acc = acc
    writer.flush()
    
    # load the best model
    model.load_model(checkpoint_path)

In [39]:
# evaluate(baseline_model, val_loader)

In [193]:
class LSTMEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(300,300, batch_first=True)
    
    # the baseline encoder takes the average across word embeddings in a sentence
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        premises = self.lstm.forward(premises)[1][0][0]
        hypotheses =  self.lstm.forward(hypotheses)[1][0][0]
        return (premises, hypotheses, labels)

In [203]:
class BiLSTMEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(300,300, batch_first=True, bidirectional=True)
    
    # the baseline encoder takes the average across word embeddings in a sentence
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        
        premises = self.lstm.forward(premises)[1][0]
        premises = torch.hstack([premises[0], premises[1]])
        
        hypotheses =  self.lstm.forward(hypotheses)[1][0]
        hypotheses = torch.hstack([hypotheses[0], hypotheses[1]])
        
        return (premises, hypotheses, labels)

In [222]:
class PooledBiLSTMEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(300,300, batch_first=True, bidirectional=True)
    
    # the baseline encoder takes the average across word embeddings in a sentence
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        
        premises = self.lstm.forward(premises)[0].max(dim=1)[0]
        hypotheses =  self.lstm.forward(hypotheses)[0].max(dim=1)[0]
    
        return (premises, hypotheses, labels)

In [223]:
# setup tensorboard
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()


model = FullModel(
    # BaselineEncoder(),
    PooledBiLSTMEncoder(),
    MLP(in_dim=2400),
).cuda()

print(model)


loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)



# train loop
train_loop(
    model, 
    optimizer, 
    test_loader, 
    val_loader,
    "checkpoints/baseline_model.pickle",
)


writer.close()

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
FullModel(
  (model): ModuleList(
    (0): EmbeddingModule(
      (embedding_layer): Embedding(37211, 300)
    )
    (1): PooledBiLSTMEncoder(
      (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)
    )
    (2): CombinationModule()
    (3): MLP(
      (layers): Sequential(
        (0): Linear(in_features=2400, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=3, bias=True)
        (3): Softmax(dim=1)
      )
    )
  )
)
lr: 0.1


100%|██████████| 154/154 [00:01<00:00, 119.77it/s]


acc: 0.4104098975658417
lr: 0.099


100%|██████████| 154/154 [00:01<00:00, 123.07it/s]


acc: 0.4139610528945923
lr: 0.09801


100%|██████████| 154/154 [00:01<00:00, 122.58it/s]


acc: 0.4126420319080353
lr: 0.01940598


 53%|█████▎    | 82/154 [00:00<00:00, 122.13it/s]


KeyboardInterrupt: 

In [197]:
t = torch.arange((16)).reshape(2,2,4)
print(t)
torch.hstack([t[0],t[1]])

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])


tensor([[ 0,  1,  2,  3,  8,  9, 10, 11],
        [ 4,  5,  6,  7, 12, 13, 14, 15]])

In [None]:
torch.save(model[1:], 


In [49]:
evaluate(model, val_loader)

tensor(0.6118, device='cuda:0')

In [53]:
model.save_model('checkpoints/baseline_model.pickle')

In [55]:
!ls -l

total 2926
-rw-r--r--  1 root root       6 Apr 20 12:51 README.md
-rw-r--r--  1 root root    5324 Apr 21 20:53 all_ds_prep.pickle
-rw-r--r--  1 root root    4503 Apr 21 15:26 all_ds_tok.pickle
-rw-r--r--  1 root root  392430 Apr 20 15:47 all_unique_toks.pickle
-rw-r--r--  1 root root 2469544 Apr 21 20:38 baseline_model_MLP.pickle
drwxr-xr-x  2 root root       2 Apr 21 21:59 checkpoints
-rw-r--r--  1 root root   26449 Apr 21 22:00 preprocessing-Copy1.ipynb
-rw-r--r--  1 root root   36479 Apr 21 21:31 preprocessing.ipynb
drwxr-xr-x 23 root root      21 Apr 21 21:39 runs
drwxr-xr-x  2 root root       0 Apr 21 21:57 saved_models
-rw-r--r--  1 root root   58619 Apr 20 15:06 test.ipynb


In [56]:
torch.save(model.model[1:], 'test.pickle')

In [61]:
model.load_model('saved_models/baseline_model.pickle')
evaluate(model, val_loader)

tensor(0.6118, device='cuda:0')