In [1]:
!pip install -Uqq datasets torchtext

In [3]:
from datasets import Dataset, load_dataset, get_dataset_split_names
ds_train, ds_val, ds_test = load_dataset("stanfordnlp/snli", split=['train', 'validation', 'test'])

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/412k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/413k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/550152 [00:00<?, ? examples/s]

(Dataset({
     features: ['premise', 'hypothesis', 'label'],
     num_rows: 550152
 }),
 Dataset({
     features: ['premise', 'hypothesis', 'label'],
     num_rows: 10000
 }),
 Dataset({
     features: ['premise', 'hypothesis', 'label'],
     num_rows: 10000
 }))

In [4]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchtext.vocab import GloVe, build_vocab_from_iterator
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize

In [21]:
def preprocess_item(item):
    item['premise'] = word_tokenize(item['premise'].lower())
    item['hypothesis'] = word_tokenize(item['hypothesis'].lower())
    return item

# # load the datasets or if not available, preprocess them again. 
# import pickle
# try: 
#     with open('all_ds_tok.pickle', 'rb') as handle:
#         all_ds_tok = pickle.load(handle)
    
#     ds_train_tok = all_ds_tok[0]
#     ds_val_tok   = all_ds_tok[1]
#     ds_test_tok  = all_ds_tok[2]
# except: 
ds_train_tok = ds_train.map(preprocess_item)
ds_val_tok = ds_val.map(preprocess_item)
ds_test_tok = ds_test.map(preprocess_item)
all_ds_tok = [ds_train_tok, ds_val_tok, ds_test_tok]

# with open('all_ds_tok.pickle', 'wb') as handle:
#     pickle.dump(all_ds_tok, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

Map:   0%|          | 0/550152 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [13]:
# get a list of all the unique tokens
def get_all_unique_toks(all_ds_tok):
    all_unique_toks = set()
    for ds in all_ds_tok:
        for item in tqdm.tqdm(ds):
            for key in ['premise', 'hypothesis']:
                for tok in item[key]:
                    if not tok in all_unique_toks:
                        all_unique_toks.add(tok)
    return all_unique_toks

In [14]:
import pickle
try: 
    with open('all_unique_toks.pickle', 'rb') as handle:
        all_unique_toks = pickle.load(handle)
except: 
    all_unique_toks = get_all_unique_toks(all_ds_tok)
    with open('all_unique_toks.pickle', 'wb') as handle:
        pickle.dump(all_unique_toks, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [15]:
glove_vectors = GloVe(name='840B')

In [16]:
glove_vocab = build_vocab_from_iterator(
    [iter(all_unique_toks)],
    specials=['<unk>'],
    special_first=True,
)

In [17]:
len(all_unique_toks), len(glove_vocab)

(37210, 37211)

In [22]:
def token_to_index(item):
    item['premise'] = glove_vocab.lookup_indices(item['premise'])    
    item['hypothesis'] = glove_vocab.lookup_indices(item['hypothesis'])
    return item

In [23]:
# load the datasets or if not available, convert to indices again. 
import pickle
try: 
    with open('all_ds_prep.pickle', 'rb') as handle:
        all_ds_prep = pickle.load(handle)
    ds_train_prep = all_ds_prep[0]
    ds_val_prep = all_ds_prep[1]
    ds_test_prep = all_ds_prep[2]
except: 
    # map to index and remove items with label -1
    ds_train_prep = ds_train_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    ds_val_prep   = ds_val_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    ds_test_prep  = ds_test_tok.map(token_to_index).filter(lambda x: x['label'] >= 0)
    all_ds_prep = [ds_train_prep, ds_val_prep, ds_test_prep]
    
    with open('all_ds_prep.pickle', 'wb') as handle:
        pickle.dump(all_ds_prep, handle, protocol=pickle.HIGHEST_PROTOCOL)

Map:   0%|          | 0/550152 [00:00<?, ? examples/s]

Filter:   0%|          | 0/550152 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [24]:
# find longest sentence
# max_sent_len = -1
# for ds in all_ds_prep:
#     for item in ds:
#         max_sent_len = max(
#             max_sent_len,
#             len(item['premise']),
#             len(item['hypothesis'])
#         )
# print(max_sent_len)

# max sentence length = 82
max_sent_len = 82

In [25]:
def collate_fn(data):
    # first find the longest sentences in the batch
    max_p, max_h = -1, -1
    for d in data:
        max_p = max(max_p, len(d['premise']))
        max_h = max(max_h, len(d['hypothesis']))
    
    # pad all sentences to same length
    bs = len(data)
    batch_p = torch.zeros((bs,max_p), dtype=torch.int32)
    batch_h = torch.zeros((bs,max_h), dtype=torch.int32)
    batch_l = torch.zeros((bs), dtype=torch.int64)
    
    for i, d in enumerate(data):
        batch_p[i, 0:len(d['premise'])] = torch.tensor(d['premise'])
        batch_h[i, 0:len(d['hypothesis'])] = torch.tensor(d['hypothesis'])
        batch_l[i] = d['label']
        
    return batch_p, batch_h, batch_l

In [26]:
# create dataloaders
from torch.utils.data import DataLoader
bs = 64

train_loader = DataLoader(ds_train_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)
val_loader   = DataLoader(ds_val_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)
test_loader  = DataLoader(ds_test_prep, collate_fn=collate_fn, batch_size=bs, shuffle=True)


In [27]:
class EmbeddingModule(nn.Module):
    def __init__(self):
        super().__init__()
        
        # create the embedding layer and freeze the weights
        self.embedding_layer = nn.Embedding.from_pretrained(
            glove_vectors.get_vecs_by_tokens(["<unk>", *list(all_unique_toks)]),
            freeze=True
        )
    
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        # embed both the premise and hypothesis separately
        premises = self.embedding_layer(premises)
        hypotheses = self.embedding_layer(hypotheses)
        return (premises, hypotheses, labels)

In [28]:
class BaselineEncoder(nn.Module):
    def __init__(self):
        super().__init__()
    
    # the baseline encoder takes the average across word embeddings in a sentence
    def forward(self, x):
        # unpack
        premises, hypotheses, labels = x
        premises = premises.mean(axis=1)
        hypotheses = hypotheses.mean(axis=1)
        return (premises, hypotheses, labels)
        

In [29]:
class CombinationModule(nn.Module):
    def __init__(self):
        super().__init__()
        
    # takes u and v (premise and hypothesis) 
    # and returns (u,v | u-v | u*v)
    def forward(self, x):
        # unpack
        u, v, labels = x
        out = torch.hstack([u,v,u-v,u*v])
        return out

In [30]:
class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        return self.layers(x)

In [45]:
class FullModel(nn.Module):
    def __init__(self, encoder, mlp):
        super().__init__()
        self.embedding_module = EmbeddingModule()
        self.encoder = encoder
        self.combination_module = CombinationModule()
        self.mlp = mlp
        self.model = nn.Sequential(
            self.embedding_module,
            self.encoder,
            self.combination_module,
            self.mlp
        )
    
    def forward(self, x):
        return self.model(x)
    
    def p(self):
        print(self.model[1:])
    
    def save_model(self, filename):
        torch.save(self.model[1:], filename)
    
    def load_model(self, filename):
        layers = torch.load(filename)
        for i, l in enumerate(layers):
            self.model[i+1] = l

In [None]:
fm = fm
fm

FullModel(
  (embedding_module): EmbeddingModule(
    (embedding_layer): Embedding(37211, 300)
  )
  (combination_module): CombinationModule()
  (model): Sequential(
    (0): EmbeddingModule(
      (embedding_layer): Embedding(37211, 300)
    )
    (1): GELU(approximate='none')
    (2): ReLU()
    (3): ReLU()
  )
)

In [55]:
m = nn.Sequential(nn.ReLU(), nn.ReLU(), nn.ReLU())
torch.save(m, "t.pickle")

In [None]:
baseline_model = nn.Sequential(
    EmbeddingModule(),
    BaselineEncoder(),
    CombinationModule(),
    MLP(in_dim=1200),
).cuda()

loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(baseline_model.parameters(), lr=0.1)


In [None]:
torch.save(baseline_model[1:], "test.pickle")

In [None]:
m = torch.load('test.pickle')

In [None]:
m

Sequential(
  (1): BaselineEncoder()
  (2): CombinationModule()
  (3): MLP(
    (layers): Sequential(
      (0): Linear(in_features=1200, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=3, bias=True)
      (3): Softmax(dim=1)
    )
  )
)

In [None]:
from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    losses, accs = [], []
    
    for premises, hypotheses, targets in tqdm(loader):

        hypotheses = hypotheses.cuda()
        targets = targets.cuda()
        premises = premises.cuda()
            
        predictions = baseline_model((premises, hypotheses, labels))
        
        optimizer.zero_grad()
        loss = loss_module(predictions, targets)
        acc = (predictions.argmax(axis=-1) == targets).float().mean()
        
        losses.append(loss)
        accs.append(acc)
        
        loss.backward()
        
        optimizer.step()
    
    train_loss = torch.tensor(losses).mean()
    train_acc = torch.tensor(accs).mean()
    return train_loss, train_acc


In [None]:
# measure accuracy
def evaluate(model, loader):
    model.eval()
    total_correct = 0.
    total = 0.
    bs = loader.batch_size
    
    
    for premises, hypotheses, targets in loader:
        premises = premises.cuda()
        hypotheses = hypotheses.cuda()
        targets = targets.cuda()
        
        with torch.no_grad():
            predictions = baseline_model((premises, hypotheses, labels)).argmax(axis=-1)
        total_correct += (predictions==targets).float().sum()
        total += bs
        
    acc = total_correct / total
    return acc

In [151]:
def update_lr(optimizer, new_lr):
    # update the learning rate for the optimizer
    for g in optimizer.param_groups:
        g['lr'] = new_lr

def train_loop(model, optimizer, train_loader, val_loader):
    lr = 0.1
    last_acc = -1
    epoch = -1
    
    while lr > 1e-5:
        epoch += 1
        print(f'lr: {lr}')
        #train and evaluate
        train_loss, train_acc = train_epoch(model, train_loader)
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Acc/train", train_acc, epoch)

        acc = evaluate(model, val_loader)
        
        writer.add_scalar("Acc/eval", acc, epoch)
        
        print(f'acc: {acc}')
        
        
        # learning rate decay
        lr = lr * 0.99
        update_lr(optimizer, lr)
        
        # if val acc goes down, divide lr by 5
        if acc < last_acc:
            lr = lr / 5.
            update_lr(optimizer, lr)
            
        last_acc = acc
    writer.flush()

In [152]:
# evaluate(baseline_model, val_loader)

In [153]:

%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

train_loop(baseline_model, optimizer, train_loader, val_loader)
writer.close()

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
lr: 0.1


100%|██████████| 8584/8584 [00:38<00:00, 220.33it/s]


acc: 0.3347199559211731
lr: 0.099


 81%|████████  | 6949/8584 [00:31<00:07, 220.20it/s]


KeyboardInterrupt: 

In [None]:
# Save MLP to file
torch.save(baseline_model[3], "baseline_model_MLP.pickle")


In [None]:
evaluate(baseline_model, test_loader)

tensor(0.3330, device='cuda:0')

In [None]:
base_mlp = torch.load("baseline_model_MLP.pickle")