In [1]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from datasets import load_dataset
import re
import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from math import floor
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchtext import vocab
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

#### Download sst dataset

In [4]:
sst_dataset_init = load_dataset('sst')

Downloading builder script:   0%|          | 0.00/9.13k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.99k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.67k [00:00<?, ?B/s]



Downloading and preparing dataset sst/default to /root/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff...


Downloading data:   0%|          | 0.00/6.37M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/790k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8544 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2210 [00:00<?, ? examples/s]

Dataset sst downloaded and prepared to /root/.cache/huggingface/datasets/sst/default/1.0.0/b8a7889ef01c5d3ae8c379b84cc4080f8aad3ac2bc538701cbe0ac6416fb76ff. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
sst_dataset_init

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'tokens', 'tree'],
        num_rows: 8544
    })
    validation: Dataset({
        features: ['sentence', 'label', 'tokens', 'tree'],
        num_rows: 1101
    })
    test: Dataset({
        features: ['sentence', 'label', 'tokens', 'tree'],
        num_rows: 2210
    })
})

In [6]:
sst_train_len = 8544
sst_valid_len = 1101
sst_test_len = 2210

#### Download nli dataset

In [4]:
nli_dataset_init = load_dataset('multi_nli')
nli_dataset_init



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label'],
        num_rows: 9832
    })
})

In [8]:
nli_train_len = 392702
nli_train_len_new = 40000
nli_valid_len = 9815
nli_test_len = 9832

#### Preprocess sst dataset

In [None]:
class Preprocess_sst():
    def __init__(self, data):
        self.data = data
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()
        self.preprocessed_data = list()
        self.unk = '<UNK>'
        self.pad = '<PAD>'
        self.min_freq = 5

    def build_vocab(self):
        self.words = [[word] for sent in self.preprocessed_data for word in sent[0]]
        self.vocab = build_vocab_from_iterator(self.words, min_freq = self.min_freq, specials = [self.unk, self.pad])
        self.vocab.set_default_index(self.vocab[self.unk])

    def modify_label(self):
        self.label *= 10
        self.label = floor(self.label)
        self.label /= 10

    def rem_punctuations(self):
        self.text = self.text.translate(str.maketrans('', '', string.punctuation)).lower()

    def tokenise(self):
        self.text = word_tokenize(self.text)

    def rem_stopwords(self):
        self.text = [word for word in self.text if word not in self.stop_words]

    def rem_single(self):
        self.text = [word for word in self.text if not re.match(r'\d+', word) and len(word) > 1]

    def lemmatiser(self):
        self.text = [self.lemmatizer.lemmatize(word) for word in self.text]
    
    def caller(self):
        self.modify_label()
        self.rem_punctuations()
        self.tokenise()
        self.rem_stopwords()
        self.rem_single()
        self.lemmatiser()
        
    def main(self):
        for example in self.data['train']:
            self.label = example['label']
            self.text = example['sentence']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        for example in self.data['validation']:
            self.label = example['label']
            self.text = example['sentence']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        for example in self.data['test']:
            self.label = example['label']
            self.text = example['sentence']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        self.build_vocab()

In [None]:
# preprocessing and getting datasets
preprocesser_sst = Preprocess_sst(sst_dataset_init)
preprocesser_sst.main()
sst_dataset = preprocesser_sst.preprocessed_data
vocabulary_sst = preprocesser_sst.vocab

In [None]:
len(sst_dataset)

11855

In [None]:
print(vocabulary_sst['audience'])

41


#### Preprocess nli dataset

In [9]:
class Preprocess_nli():
    def __init__(self, data):
        self.data = data
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()
        self.preprocessed_data = list()
        self.unk = '<UNK>'
        self.pad = '<PAD>'
        self.min_freq = 5

    def build_vocab(self):
        self.words = [[word] for sent in self.preprocessed_data for word in sent[0]]
        self.vocab = build_vocab_from_iterator(self.words, min_freq = self.min_freq, specials = [self.unk, self.pad])
        self.vocab.set_default_index(self.vocab[self.unk])

    def rem_punctuations(self):
        self.text = self.text.translate(str.maketrans('', '', string.punctuation)).lower()

    def tokenise(self):
        self.text = word_tokenize(self.text)

    def rem_stopwords(self):
        self.text = [word for word in self.text if word not in self.stop_words]

    def rem_single(self):
        self.text = [word for word in self.text if not re.match(r'\d+', word) and len(word) > 1]

    def lemmatiser(self):
        self.text = [self.lemmatizer.lemmatize(word) for word in self.text]
    
    def caller(self):
        self.rem_punctuations()
        self.tokenise()
        self.rem_stopwords()
        self.rem_single()
        self.lemmatiser()
        
    def main(self):
        for ind, example in enumerate(self.data['train']):
            if ind > nli_train_len_new:
                break
            self.label = example['label']
            self.text = example['premise']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        for example in self.data['validation_matched']:
            self.label = example['label']
            self.text = example['premise']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        for example in self.data['validation_mismatched']:
            self.label = example['label']
            self.text = example['premise']
            self.caller()
            self.preprocessed_data.append((self.text, self.label))

        self.build_vocab()

In [10]:
# preprocessing and getting datasets
preprocesser_nli = Preprocess_nli(nli_dataset_init)
preprocesser_nli.main()
nli_dataset = preprocesser_nli.preprocessed_data
vocabulary_nli = preprocesser_nli.vocab

In [11]:
print(len(nli_dataset))

59648


#### Create dataset for sentiment classification

In [None]:
# perform dataset creation for pretraining and sentiment classification
def create_data_for_sst_classification():
    tokens = []
    labels = []
    toks = set()
    pad = [vocabulary_sst['<PAD>']]
    max_len = 0
    for sent, label in sst_dataset:
        ind = [vocabulary_sst[word] for word in sent]
        tokens.append(ind)
        if label < 0.5:
            labels.append(0)
        else:
            labels.append(1)
        max_len = max(max_len, len(ind))
    tokens = list(sent + pad * (max_len-len(sent)) for sent in tokens)
    tokens = torch.tensor(tokens)
    labels = torch.tensor(labels)
    print(toks)
    return tokens, labels
sst_tokens, sst_labels = create_data_for_sst_classification() 
sst_tokens_train, sst_tokens_valid, sst_tokens_test = sst_tokens[:sst_train_len], sst_tokens[sst_train_len:sst_valid_len], sst_tokens[sst_train_len + sst_valid_len:]
sst_labels_train, sst_labels_valid, sst_labels_test = sst_labels[:sst_train_len], sst_labels[sst_train_len:sst_valid_len], sst_labels[sst_train_len + sst_valid_len:]
print(sst_tokens_train.shape, sst_labels_train.shape)

set()
torch.Size([8544, 28]) torch.Size([8544])


#### create dataset for NLI

In [14]:
# perform dataset creation for pretraining and sentiment classification
def create_data_for_nli_classification():
    tokens = []
    labels = []
    l = set()
    pad = [vocabulary_nli['<PAD>']]
    max_len = 0
    for sent, label in nli_dataset:
        ind = [vocabulary_nli[word] for word in sent]
        tokens.append(ind)
        labels.append(label)
        l.add(label)
        max_len = max(max_len, len(ind))
    tokens = list(sent + pad * (max_len-len(sent)) for sent in tokens)
    tokens = torch.tensor(tokens)
    labels = torch.tensor(labels)
    print(l)
    return tokens, labels
nli_tokens, nli_labels = create_data_for_nli_classification() 
nli_tokens_train, nli_tokens_valid, nli_tokens_test = nli_tokens[:nli_train_len_new], nli_tokens[nli_train_len:nli_valid_len], nli_tokens[nli_train_len + nli_valid_len:]
nli_labels_train, nli_labels_valid, nli_labels_test = nli_labels[:nli_train_len_new], nli_labels[nli_train_len:nli_valid_len], nli_labels[nli_train_len + nli_valid_len:]
print(nli_tokens_train.shape, nli_labels_train.shape)

{0, 1, 2}
torch.Size([40000, 117]) torch.Size([40000])


#### Create dataset using sst for elmo pretraining

In [None]:
class Create_data_for_pretraining():
    def __init__(self):
        self.preds = []
        self.contexts = []
        self.build()

    def build(self):
        pad = [vocabulary_sst['<PAD>']]
        max_len = 0
        for sent, label in sst_dataset:
            ind = [vocabulary_sst[word] for word in sent]
            l = len(ind)
            self.contexts += [ind[:i] for i in range(1, l)]
            self.preds += ind[1:]

            ind.reverse()
            self.contexts += [ind[:i] for i in range(1, l)]
            self.preds += ind[1:]

            max_len = max(max_len, l-1)

        self.contexts = list(sent + pad*(max_len - len(sent) ) for sent in self.contexts)
        self.contexts = torch.tensor(self.contexts)
        self.preds = torch.tensor(self.preds)
    
    def __getitem__(self, index):
        return self.contexts[index], self.preds[index]

    def __len__(self):
        return self.contexts.shape[0]

sst_word_pred_dataset = Create_data_for_pretraining()

#### Create dataset using nli for elmo pretraining

In [16]:
class Create_data_for_pretraining_nli():
    def __init__(self):
        self.preds = []
        self.contexts = []
        self.build()

    def build(self):
        pad = [vocabulary_nli['<PAD>']]
        max_len = 0
        for sent, label in nli_dataset:
            ind = [vocabulary_nli[word] for word in sent]
            l = len(ind)
            self.contexts += [ind[:i] for i in range(1, l)]
            self.preds += ind[1:]

            ind.reverse()
            self.contexts += [ind[:i] for i in range(1, l)]
            self.preds += ind[1:]

            max_len = max(max_len, l-1)

        self.contexts = list(sent + pad*(max_len - len(sent) ) for sent in self.contexts)
        self.contexts = torch.tensor(self.contexts)
        self.preds = torch.tensor(self.preds)
    
    def __getitem__(self, index):
        return self.contexts[index], self.preds[index]

    def __len__(self):
        return self.contexts.shape[0]

nli_word_pred_dataset = Create_data_for_pretraining_nli()

#### for dataloader of sst

In [None]:
class Sst_class_train():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

class Sst_class_valid():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

class Sst_class_test():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

sst_class_train = Sst_class_train(sst_tokens_train, sst_labels_train)
sst_class_valid = Sst_class_valid(sst_tokens_valid, sst_labels_valid)
sst_class_test = Sst_class_test(sst_tokens_test, sst_labels_test)

#### for dataloader of nli

In [18]:
class Nli_class_train():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

class Nli_class_valid():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

class Nli_class_test():
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, index):
        return self.tokens[index], self.labels[index]

    def __len__(self):
        return self.tokens.shape[0]

nli_class_train = Nli_class_train(nli_tokens_train, nli_labels_train)
nli_class_valid = Nli_class_valid(nli_tokens_valid, nli_labels_valid)
nli_class_test = Nli_class_test(nli_tokens_test, nli_labels_test)

## ELMo

#### Glove embeddings load

In [19]:
embedding_dim = 300
glove_embeds = vocab.GloVe(name='6B', dim=embedding_dim)
glove_vectors = glove_embeds.vectors
unk = torch.mean(glove_vectors, dim=0)

.vector_cache/glove.6B.zip: 862MB [02:39, 5.42MB/s]                           
100%|█████████▉| 399999/400000 [01:00<00:00, 6577.74it/s]


In [21]:
embeddings = list()
for word in vocabulary_nli.get_itos():
    if word not in glove_embeds.itos:
        embeddings.append(unk)
    else:
        w = glove_embeds[word]
        embeddings.append(w)

In [22]:
len(embeddings[0])

300

#### architecture

In [23]:
class Elmo(nn.Module):
    def __init__(self, vocab, embeddings, hidden_size=300):
        super().__init__()
        self.num_layers = 2
        self.hidden_size = 300
        self.vocab = vocab
        self.vocab_len = len(vocab)
        self.pad = self.vocab['<PAD>']
        # self.labels_num = 2 # for sst
        self.labels_num = 3 # for NLI

        # self.embedding = nn.Embedding.from_pretrained(embeddings, freeze=True, padding_idx = self.pad)
        self.embedding = nn.Embedding.from_pretrained(embeddings, freeze=True)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, num_layers = self.num_layers, bidirectional=True, batch_first = True)
        # self.weights = torch.tensor([0.5, 0.5])
        self.weights = torch.rand(2)
        self.softmax = nn.Softmax(dim=0)
        self.sigmoid = nn.Sigmoid()

        in_features = self.hidden_size*2
        self.word_pred = nn.Linear(in_features, self.vocab_len)
        # in_features = self.hidden_size*4
        self.classifier = nn.Linear(in_features, self.labels_num)

    def forward(self, x, label):
        y = self.embedding(x)

        a, (y, b) = self.lstm(y)

        shape = y.shape
        y = y.transpose(0,1)
        y = y.view(shape[1], 2, self.num_layers, shape[2])

        a = self.softmax(self.weights)
        a = a.unsqueeze(1).unsqueeze(0).unsqueeze(0)
        a = a.repeat(shape[1], 2, 1, shape[2]).to(device)
        
        y = torch.mul(y, a)
        y = torch.sum(y, dim=2)
        y = y.flatten(1,2)

        return y

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device available now:', device)

Device available now: cuda


#### pretraining sst

In [None]:
elmo = Elmo(vocabulary_sst,torch.stack(embeddings))

In [None]:
data = DataLoader(sst_word_pred_dataset, batch_size = 128)

In [None]:
def pretrain_elmo_sst(model, data, num_epochs = 10):
    model = model.to(device)
    pad = vocabulary_sst['<PAD>']
    # loss_fn = nn.CrossEntropyLoss(ignore_index=pad).to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in tqdm(range(num_epochs), desc='epoch'):
        model.train()
        total_loss = 0
        for ind, i in enumerate(data):
            x, label = i
            x = x.to(device)
            label = label.to(device)
            optim.zero_grad()
            # ans, ans2 = model(x, label)
            ans = model(x, label)
            ans = model.word_pred(ans)
            loss = loss_fn(ans, label)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        print(f'\tEpoch {epoch + 1}\tTrain Loss: {total_loss/len(data)}')
    return model

elmo = pretrain_elmo_sst(elmo, data)

epoch:  10%|█         | 1/10 [00:48<07:20, 48.93s/it]

	Epoch 1	Train Loss: 6.504847591503579


epoch:  20%|██        | 2/10 [01:37<06:30, 48.80s/it]

	Epoch 2	Train Loss: 6.206489018735153


epoch:  30%|███       | 3/10 [02:26<05:41, 48.74s/it]

	Epoch 3	Train Loss: 5.951812713357077


epoch:  40%|████      | 4/10 [03:15<04:52, 48.75s/it]

	Epoch 4	Train Loss: 5.679752173771621


epoch:  50%|█████     | 5/10 [04:03<04:03, 48.76s/it]

	Epoch 5	Train Loss: 5.406486697317017


epoch:  60%|██████    | 6/10 [04:52<03:15, 48.86s/it]

	Epoch 6	Train Loss: 5.137254346178915


epoch:  70%|███████   | 7/10 [05:41<02:26, 48.79s/it]

	Epoch 7	Train Loss: 4.909399019817601


epoch:  80%|████████  | 8/10 [06:30<01:37, 48.72s/it]

	Epoch 8	Train Loss: 4.668876021965771


epoch:  90%|█████████ | 9/10 [07:18<00:48, 48.65s/it]

	Epoch 9	Train Loss: 4.448533264878028


epoch: 100%|██████████| 10/10 [08:07<00:00, 48.72s/it]

	Epoch 10	Train Loss: 4.232246256782749





In [None]:
torch.save(elmo.state_dict(), 'elmo_pretrain_sst.pth')

In [None]:
elmo.load_state_dict(torch.load('/content/elmo_pretrain_sst.pth'))
elmo

Elmo(
  (vocab): Vocab()
  (embedding): Embedding(4107, 300)
  (lstm): LSTM(300, 300, num_layers=2, batch_first=True, bidirectional=True)
  (softmax): Softmax(dim=0)
  (sigmoid): Sigmoid()
  (word_pred): Linear(in_features=600, out_features=4107, bias=True)
  (classifier): Linear(in_features=600, out_features=2, bias=True)
)

In [None]:
for param in elmo.lstm.parameters():
    param.requires_grad = False

#### pretraining NLI

In [25]:
elmo = Elmo(vocabulary_nli,torch.stack(embeddings))
data = DataLoader(nli_word_pred_dataset, batch_size = 128)

In [28]:
def pretrain_elmo_nli(model, data, num_epochs = 1):
    model = model.to(device)
    pad = vocabulary_nli['<PAD>']
    # loss_fn = nn.CrossEntropyLoss(ignore_index=pad).to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in tqdm(range(num_epochs), desc='epoch'):
        model.train()
        total_loss = 0
        for ind, i in enumerate(data):
            print(ind/len(data))
            x, label = i
            x = x.to(device)
            label = label.to(device)
            optim.zero_grad()
            # ans, ans2 = model(x, label)
            ans = model(x, label)
            ans = model.word_pred(ans)
            loss = loss_fn(ans, label)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        print(f'\tEpoch {epoch + 1}\tTrain Loss: {total_loss/len(data)}')
    return model

elmo = pretrain_elmo_nli(elmo, data)

epoch:   0%|          | 0/1 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
0.43464495703301675
0.4347580280416101
0.43487109905020355
0.4349841700587969
0.43509724106739034
0.4352103120759837
0.43532338308457713
0.4354364540931705
0.4355495251017639
0.4356625961103573
0.4357756671189507
0.4358887381275441
0.4360018091361375
0.4361148801447309
0.4362279511533243
0.43634102216191767
0.4364540931705111
0.43656716417910446
0.4366802351876979
0.43679330619629125
0.4369063772048847
0.43701944821347805
0.43713251922207147
0.43724559023066484
0.43735866123925826
0.43747173224785163
0.43758480325644505
0.4376978742650384
0.43781094527363185
0.4379240162822252
0.43803708729081864
0.438150158299412
0.43826322930800543
0.4383763003165988
0.4384893713251922
0.4386024423337856
0.438715513342379
0.4388285843509724
0.4389416553595658
0.43905472636815923
0.4391677973767526
0.439280868385346
0.4393939393939394
0.4395070104025328
0.4396200814111262
0.4397331524197196
0.439846223428313
0.4399592944369064
0.44007236

epoch: 100%|██████████| 1/1 [20:39<00:00, 1239.61s/it]

	Epoch 1	Train Loss: 6.935135268615091





In [29]:
torch.save(elmo.state_dict(), 'elmo_pretrain_nli.pth')

In [30]:
elmo.load_state_dict(torch.load('/content/elmo_pretrain_nli.pth'))
elmo

Elmo(
  (vocab): Vocab()
  (embedding): Embedding(12089, 300)
  (lstm): LSTM(300, 300, num_layers=2, batch_first=True, bidirectional=True)
  (softmax): Softmax(dim=0)
  (sigmoid): Sigmoid()
  (word_pred): Linear(in_features=600, out_features=12089, bias=True)
  (classifier): Linear(in_features=600, out_features=3, bias=True)
)

In [31]:
for param in elmo.lstm.parameters():
    param.requires_grad = False

#### Sentiment Classification

In [None]:
def elmo_sst(model, data, num_epochs = 50):
    model = model.to(device)
    pad = vocabulary_sst['<PAD>']
    loss_fn = nn.CrossEntropyLoss().to(device)
    # loss_fn = nn.BCELoss().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in tqdm(range(num_epochs), desc='epoch'):
        model.train()
        total_loss = 0
        for ind, i in enumerate(data):
            x, label = i
            x = x.to(device)
            label = label.to(device)
            optim.zero_grad()
            # x_f = torch.cat((x, label), dim=1)
            # print(x_f.type())
            # x_r = torch.cat((torch.flip(x, [1]), label), dim=1)
            # print(x_r.type())
            # inp = torch.cat((x_f, x_r), dim=1)
            # print(inp.type())
            ans = model(x, label)
            # print(ans.size(), ans.type())
            # ans_r = torch.concat((torch.flip(ans[0], [1]), ans[1]))
            # print(ans_r.size(), ans_r.type())
            # break
            # inp = torch.cat((ans, ans_r), dim=1)
            ans = model.classifier(ans)
            ans = model.softmax(ans)
            # print(ans)
            # ans = ans.argmax(dim=1)
            # print(ans.size(), label.size())
            # print(ans)
            # print(label.size())
            loss = loss_fn(ans, label)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        print(f'\tEpoch {epoch + 1}\tTrain Loss: {total_loss/len(data)}')
    return model

data = DataLoader(sst_class_train, batch_size=4)
elmo = elmo_sst(elmo, data)

epoch:   2%|▏         | 1/50 [00:07<06:13,  7.61s/it]

	Epoch 1	Train Loss: 0.6967699380682202


epoch:   4%|▍         | 2/50 [00:13<05:25,  6.79s/it]

	Epoch 2	Train Loss: 0.6881806696677922


epoch:   6%|▌         | 3/50 [00:20<05:16,  6.73s/it]

	Epoch 3	Train Loss: 0.6829786649469133


epoch:   8%|▊         | 4/50 [00:26<05:00,  6.54s/it]

	Epoch 4	Train Loss: 0.6797317215119408


epoch:  10%|█         | 5/50 [00:33<04:54,  6.54s/it]

	Epoch 5	Train Loss: 0.6773143428123176


epoch:  12%|█▏        | 6/50 [00:39<04:45,  6.49s/it]

	Epoch 6	Train Loss: 0.6753720556166065


epoch:  14%|█▍        | 7/50 [00:46<04:40,  6.53s/it]

	Epoch 7	Train Loss: 0.6737682719885857


epoch:  16%|█▌        | 8/50 [00:52<04:34,  6.53s/it]

	Epoch 8	Train Loss: 0.672411624579394


epoch:  18%|█▊        | 9/50 [00:59<04:27,  6.51s/it]

	Epoch 9	Train Loss: 0.67123881186104


epoch:  20%|██        | 10/50 [01:05<04:22,  6.56s/it]

	Epoch 10	Train Loss: 0.6702124692415923


epoch:  22%|██▏       | 11/50 [01:12<04:13,  6.51s/it]

	Epoch 11	Train Loss: 0.6693029615344402


epoch:  24%|██▍       | 12/50 [01:19<04:10,  6.60s/it]

	Epoch 12	Train Loss: 0.6684834702491537


epoch:  26%|██▌       | 13/50 [01:25<04:00,  6.51s/it]

	Epoch 13	Train Loss: 0.6677340390056037


epoch:  28%|██▊       | 14/50 [01:32<03:57,  6.60s/it]

	Epoch 14	Train Loss: 0.6670413893902123


epoch:  30%|███       | 15/50 [01:38<03:47,  6.49s/it]

	Epoch 15	Train Loss: 0.6663968017820115


epoch:  32%|███▏      | 16/50 [01:45<03:49,  6.76s/it]

	Epoch 16	Train Loss: 0.6657941920703716


epoch:  34%|███▍      | 17/50 [01:52<03:38,  6.62s/it]

	Epoch 17	Train Loss: 0.6652286157486368


epoch:  36%|███▌      | 18/50 [01:59<03:33,  6.69s/it]

	Epoch 18	Train Loss: 0.6646956426513552


epoch:  38%|███▊      | 19/50 [02:05<03:23,  6.55s/it]

	Epoch 19	Train Loss: 0.6641912354009875


epoch:  40%|████      | 20/50 [02:12<03:19,  6.63s/it]

	Epoch 20	Train Loss: 0.6637115963221936


epoch:  42%|████▏     | 21/50 [02:18<03:09,  6.53s/it]

	Epoch 21	Train Loss: 0.6632535264649418


epoch:  44%|████▍     | 22/50 [02:25<03:05,  6.62s/it]

	Epoch 22	Train Loss: 0.6628141780052516


epoch:  46%|████▌     | 23/50 [02:31<02:55,  6.51s/it]

	Epoch 23	Train Loss: 0.6623905888350492


epoch:  48%|████▊     | 24/50 [02:38<02:51,  6.59s/it]

	Epoch 24	Train Loss: 0.6619799536154065


epoch:  50%|█████     | 25/50 [02:44<02:42,  6.49s/it]

	Epoch 25	Train Loss: 0.6615801443265619


epoch:  52%|█████▏    | 26/50 [02:51<02:36,  6.54s/it]

	Epoch 26	Train Loss: 0.6611899862603079


epoch:  54%|█████▍    | 27/50 [02:57<02:31,  6.59s/it]

	Epoch 27	Train Loss: 0.660809151967813


epoch:  56%|█████▌    | 28/50 [03:04<02:25,  6.59s/it]

	Epoch 28	Train Loss: 0.660437782884537


epoch:  58%|█████▊    | 29/50 [03:10<02:17,  6.54s/it]

	Epoch 29	Train Loss: 0.6600761673228348


epoch:  60%|██████    | 30/50 [03:17<02:11,  6.57s/it]

	Epoch 30	Train Loss: 0.6597247233244588


epoch:  62%|██████▏   | 31/50 [03:24<02:04,  6.57s/it]

	Epoch 31	Train Loss: 0.6593840764312262


epoch:  64%|██████▍   | 32/50 [03:30<01:59,  6.64s/it]

	Epoch 32	Train Loss: 0.6590547821718209


epoch:  66%|██████▌   | 33/50 [03:37<01:53,  6.65s/it]

	Epoch 33	Train Loss: 0.658736831407422


epoch:  68%|██████▊   | 34/50 [03:43<01:44,  6.52s/it]

	Epoch 34	Train Loss: 0.6584293439091368


epoch:  70%|███████   | 35/50 [03:50<01:38,  6.56s/it]

	Epoch 35	Train Loss: 0.6581310647703735


epoch:  72%|███████▏  | 36/50 [03:56<01:30,  6.46s/it]

	Epoch 36	Train Loss: 0.6578414118552253


epoch:  74%|███████▍  | 37/50 [04:03<01:24,  6.53s/it]

	Epoch 37	Train Loss: 0.6575607473958521


epoch:  76%|███████▌  | 38/50 [04:09<01:17,  6.43s/it]

	Epoch 38	Train Loss: 0.657289230184497


epoch:  78%|███████▊  | 39/50 [04:16<01:11,  6.54s/it]

	Epoch 39	Train Loss: 0.6570261407657509


epoch:  80%|████████  | 40/50 [04:22<01:04,  6.47s/it]

	Epoch 40	Train Loss: 0.6567703670079119


epoch:  82%|████████▏ | 41/50 [04:29<00:58,  6.55s/it]

	Epoch 41	Train Loss: 0.6565208927578248


epoch:  84%|████████▍ | 42/50 [04:35<00:51,  6.46s/it]

	Epoch 42	Train Loss: 0.6562768251875813


epoch:  86%|████████▌ | 43/50 [04:42<00:45,  6.54s/it]

	Epoch 43	Train Loss: 0.6560374804622449


epoch:  88%|████████▊ | 44/50 [04:48<00:38,  6.43s/it]

	Epoch 44	Train Loss: 0.6558024032732074


epoch:  90%|█████████ | 45/50 [04:55<00:32,  6.50s/it]

	Epoch 45	Train Loss: 0.6555713466211651


epoch:  92%|█████████▏| 46/50 [05:01<00:25,  6.39s/it]

	Epoch 46	Train Loss: 0.6553442458088479


epoch:  94%|█████████▍| 47/50 [05:08<00:19,  6.49s/it]

	Epoch 47	Train Loss: 0.6551210884744308


epoch:  96%|█████████▌| 48/50 [05:14<00:12,  6.43s/it]

	Epoch 48	Train Loss: 0.6549018477735001


epoch:  98%|█████████▊| 49/50 [05:21<00:06,  6.71s/it]

	Epoch 49	Train Loss: 0.6546864190212126


epoch: 100%|██████████| 50/50 [05:27<00:00,  6.56s/it]

	Epoch 50	Train Loss: 0.6544746313704534





In [44]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
def get_stats(mdl, dl):
    y_true = []
    y_pred = []
    for batch in dl:
        x, label = batch
        x = x.to(device)
        label = label.to(device)
        y_true += label.tolist()
        preds = mdl(x, label)
        preds = mdl.classifier(preds)
        preds = mdl.softmax(preds)
        preds = preds.argmax(dim=1)
        y_pred += preds.tolist()
    print('Classification Report:')
    print(classification_report(y_true, y_pred))
    print('\nConfusion Matrix:')
    print(confusion_matrix(y_true, y_pred))

    return sum(1 for i in range(1,len(y_pred)) if y_pred[i] == y_true[i])/len(y_pred)

In [None]:
data = DataLoader(sst_class_test, batch_size=4)
x = get_stats(elmo, data)

Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.36      0.48      1099
           1       0.57      0.85      0.68      1111

    accuracy                           0.61      2210
   macro avg       0.64      0.61      0.58      2210
weighted avg       0.64      0.61      0.58      2210


Confusion Matrix:
[[396 703]
 [167 944]]


In [None]:
data = DataLoader(sst_class_train, batch_size=4)
x = get_stats(elmo, data)

Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.37      0.49      4037
           1       0.60      0.85      0.71      4507

    accuracy                           0.63      8544
   macro avg       0.65      0.61      0.60      8544
weighted avg       0.65      0.63      0.60      8544


Confusion Matrix:
[[1510 2527]
 [ 669 3838]]


#### NLI

In [52]:
def elmo_nli(model, data, num_epochs = 2):
    model = model.to(device)
    pad = vocabulary_nli['<PAD>']
    loss_fn = nn.CrossEntropyLoss().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in tqdm(range(num_epochs), desc='epoch'):
        model.train()
        total_loss = 0
        for ind, i in enumerate(data):
            # print(ind/len(data))
            x, label = i
            # print(i)
            x = x.to(device)
            label = label.to(device)
            optim.zero_grad()
            ans = model(x, label)
            ans = model.classifier(ans)
            # print(ans)
            ans = model.softmax(ans)
            # print(ans)
            loss = loss_fn(ans, label)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        print(f'\tEpoch {epoch + 1}\tTrain Loss: {total_loss/len(data)}')
    return model

data = DataLoader(nli_class_train, batch_size=4)
elmo = elmo_nli(elmo, data)

epoch:  50%|█████     | 1/2 [01:18<01:18, 78.49s/it]

	Epoch 1	Train Loss: 1.090982784307003


epoch: 100%|██████████| 2/2 [02:35<00:00, 77.54s/it]

	Epoch 2	Train Loss: 1.0887264276385307





In [61]:
def get_stats_nli(mdl, dl):
    y_true = []
    y_pred = []
    for batch in tqdm(dl):
        x, label = batch
        x = x.to(device)
        label = label.to(device)
        y_true += label.tolist()
        preds = mdl(x, label)
        preds = mdl.classifier(preds)
        preds = mdl.softmax(preds)
        preds = preds.argmax(dim=1)
        y_pred += preds.tolist()
    print('\nClassification Report:')
    print(classification_report(y_true, y_pred))
    print('\nConfusion Matrix:')
    print(confusion_matrix(y_true, y_pred))

    return sum(1 for i in range(1,len(y_pred)) if y_pred[i] == y_true[i])/len(y_pred)

In [62]:
data = DataLoader(nli_class_train, batch_size=4)
x = get_stats_nli(elmo, data)

100%|██████████| 10000/10000 [01:13<00:00, 136.24it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.39      0.37      0.38     13525
           1       0.36      0.41      0.38     12122
           2       0.43      0.40      0.41     14353

    accuracy                           0.39     40000
   macro avg       0.39      0.39      0.39     40000
weighted avg       0.39      0.39      0.39     40000


Confusion Matrix:
[[4989 4390 4146]
 [3527 4910 3685]
 [4146 4400 5807]]
