# Baseline LSTM Model and SBERT


In [11]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.autograd import Variable
from torch.nn import functional as F
from glob import glob
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

In [13]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.', 
    'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
sentence_embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, sentence_embeddings):
    print("Sentence:", sentence)
    print("Embedding Shape:", embedding.shape)
    print("")

Sentence: This framework generates embeddings for each input sentence
Embedding Shape: (768,)

Sentence: Sentences are passed as a list of string.
Embedding Shape: (768,)

Sentence: The quick brown fox jumps over the lazy dog.
Embedding Shape: (768,)



## Data Loader for labeled episodes with embedded sentences

Need to have embeddings saved from Model_Pretraining

Initiating the class takes some time to load all the embeddings into memory. 



In [64]:
def ls_data_loader(path, train=True):
    '''loads labeled sub files, embeds the sentences, and saves
    embedded episodes and labels into binary files
    '''
    # get path to all labels_sub files
    sub_files = glob(path + '/*.tsv')
    
    # loads texts
    data = []
    labels = []

    if train:
        start=0
        stop=30
    else:
        start=30
        stop=-1
    for f in sub_files[start:stop]:
        print('loading:', f)
        # read with pandas
        df = pd.read_csv(f, sep='\t', usecols=['text', 'label'])
        # get sentences as list
        sentences = list(df["text"].values)
        # convert labels to 0 and 1
        tags = [0 if v == 'N' else 1 for v in df["label"]]
        data.append(sentences)
        labels.append(tags)
    return data, labels

In [67]:


class Dataset_seq_ep(torch.utils.data.Dataset):
    def __init__(self, train_path, train=True):
        # self.sent_id = sent_id
        self.train_path = train_path
        self.data, self.labels = ls_data_loader(train_path, train)
        self.embedder = SentenceTransformer('all-mpnet-base-v2')
        self.data = self.__embedder__(self.data)
        

    def __getitem__(self, index):
        # return sequence of sentences and labels
        seq = torch.Tensor(self.data[index])
        labels = torch.Tensor(self.labels[index])
        return seq, labels

    def __len__(self):
        return(len(self.data))

    def __embedder__(self, data):
        return [self.embedder.encode(sentences) for sentences in data]



def collate_fn(batch):
	'''  
	custom collate_fn as the size of every episode is different and merging sequences (including padding) 
	is not supported in default. 
	'''

	(xx, yy) = zip(*batch)
	x_lens = [len(x) for x in xx]
	y_lens = [len(y) for y in yy]

	xx_pad = pad_sequence(xx, batch_first=True, padding_value=-1)
	yy_pad = pad_sequence(yy, batch_first=True, padding_value=-1)

	return xx_pad, yy_pad, x_lens, y_lens

# train_dataset = Dataset_seq_ep('labeled_subs')
test_dataset = Dataset_seq_ep('labeled_subs', train=False)

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn)
# train_dataloader = torch.utils.data.DataLoader(
#     dataset=train_dataset,
#     batch_size=1,
#     shuffle=True,
#     collate_fn=collate_fn)


    


loading: labeled_subs\s05e05.tsv
loading: labeled_subs\s05e06.tsv
loading: labeled_subs\s05e08.tsv
loading: labeled_subs\s05e10.tsv
loading: labeled_subs\s05e12.tsv
loading: labeled_subs\s05e13.tsv
loading: labeled_subs\s05e17.tsv
loading: labeled_subs\s05e21.tsv


Check dataloader and see that batches are padded to seq with max_length per batch and the length of each sequence is returned

In [56]:
for i in range(1):
    loop = tqdm(train_dataloader)
    for batch in loop:
        data, labels, in_len, lab_len = batch
        print('Data shape', data.shape)
        print('labels shape', labels.shape)
        break

100%|██████████| 30/30 [00:00<00:00, 123.44it/s]

Data shape torch.Size([1, 640, 768])
labels shape torch.Size([1, 640])
Data shape torch.Size([1, 647, 768])
labels shape torch.Size([1, 647])
Data shape torch.Size([1, 705, 768])
labels shape torch.Size([1, 705])
Data shape torch.Size([1, 688, 768])
labels shape torch.Size([1, 688])
Data shape torch.Size([1, 696, 768])
labels shape torch.Size([1, 696])
Data shape torch.Size([1, 587, 768])
labels shape torch.Size([1, 587])
Data shape torch.Size([1, 652, 768])
labels shape torch.Size([1, 652])
Data shape torch.Size([1, 612, 768])
labels shape torch.Size([1, 612])
Data shape torch.Size([1, 602, 768])
labels shape torch.Size([1, 602])
Data shape torch.Size([1, 701, 768])
labels shape torch.Size([1, 701])
Data shape torch.Size([1, 657, 768])
labels shape torch.Size([1, 657])
Data shape torch.Size([1, 597, 768])
labels shape torch.Size([1, 597])
Data shape torch.Size([1, 670, 768])
labels shape torch.Size([1, 670])
Data shape torch.Size([1, 743, 768])
labels shape torch.Size([1, 743])
Data s




### Build LSTM model

adapting from 

Blog post:
Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health:
https://medium.com/@_willfalcon/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e



In [57]:

# this is still a work in progress

class PerpLSTM(nn.Module):
    def __init__(self, nb_lstm_layers=1, nb_lstm_units=128, fc_hidden_units=100, embedding_dim=3, batch_size=1):
        super(PerpLSTM, self).__init__()

        # self.vocab = {'<PAD>': -1} # not sure we need this
        # self.tags = {'<PAD>':-1, 'N': 0, 'Y': 1}

        self.nb_lstm_layers = nb_lstm_layers
        self.nb_lstm_units = nb_lstm_units
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.fc_hidden_units = fc_hidden_units

        # # don't count the padding tag for the classifier output
        # self.nb_tags = len(self.tags)-1

        # build actual NN
        self.__build_model()

    def __build_model(self):

        self.relu = nn.ReLU()

        # design LSTM
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.nb_lstm_units,
            num_layers=self.nb_lstm_layers,
            batch_first=True,
        )

        # output layer which projects back to tag space
        self.fc1 = nn.Linear(self.nb_lstm_units, self.fc_hidden_units)
        self.fc2 = nn.Linear(self.fc_hidden_units, 1) # change out to 1 for sigmoid activation


    def forward(self, X):

        # run through ReLu
        X = self.relu(X)

        # run through LSTM
        X, self.hidden = self.lstm(X)

        # reshape for linear
        # X = X.contiguous()
        # X = X.view(-1, X.shape[2])

        # run through linear layers and finish with sigmoid func
        X = self.fc1(X)
        X = self.relu(X)
        X = torch.sigmoid(self.fc2(X))

        return X


In [58]:

perp_model = PerpLSTM(
    nb_lstm_layers=1, 
    nb_lstm_units=128, 
    embedding_dim=768,
    batch_size=1)


## Testing forward

In [59]:
for i in range(1):
    loop = tqdm(train_dataloader)
    for batch in loop:
        data, labels, in_len, lab_len = batch
        print('Data shape', data.shape)
        print('labels shape', labels.shape)
        with torch.no_grad():
            output = perp_model.forward(data)
        print(output.shape)
        break

  0%|          | 0/30 [00:00<?, ?it/s]

Data shape torch.Size([1, 705, 768])
labels shape torch.Size([1, 705])
torch.Size([1, 705, 1])





## Testing backward

In [60]:

optimizer = torch.optim.Adam(perp_model.parameters(), lr=0.001) 
# Using Binary Cross Entropy Loss function since we are using batch size = 1
criterion = nn.BCELoss()

for epoch in range(100):
    loop = tqdm(train_dataloader)
    epoch_total = 0
    for batch in loop:
        data, labels, in_len, lab_len = batch

        outputs = perp_model.forward(data) #forward pass
        optimizer.zero_grad() #calculate the gradient, manually setting to 0
 
        # obtain the loss function
        # loss = perp_model.loss(outputs, labels, lab_len)
        # m = nn.Sigmoid()
        # sig_out = m(outputs)
        # print(sig_out.shape)

        # loss = criterion(outputs, F.one_hot(labels.view(-1).type(torch.int64)).type(torch.float32))
        loss = criterion(outputs.view(1,-1), labels.type(torch.float32))
        loss.backward() #calculates the loss of the loss function
        
        optimizer.step() #improve from loss, i.e backprop
        loop.set_postfix(loss=loss.item())
        epoch_total += loss.item()
    print("Epoch: %d, loss: %1.5f" % (epoch, epoch_total/len(loop)))


100%|██████████| 30/30 [00:11<00:00,  2.52it/s, loss=0.474]


Epoch: 0, loss: 0.52429


100%|██████████| 30/30 [00:11<00:00,  2.69it/s, loss=0.342]


Epoch: 1, loss: 0.42136


100%|██████████| 30/30 [00:11<00:00,  2.64it/s, loss=0.427]


Epoch: 2, loss: 0.42388


100%|██████████| 30/30 [00:10<00:00,  2.73it/s, loss=0.287]


Epoch: 3, loss: 0.41755


100%|██████████| 30/30 [00:10<00:00,  2.75it/s, loss=0.364]


Epoch: 4, loss: 0.41791


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.27] 


Epoch: 5, loss: 0.41394


100%|██████████| 30/30 [00:11<00:00,  2.68it/s, loss=0.324]


Epoch: 6, loss: 0.41824


100%|██████████| 30/30 [00:11<00:00,  2.72it/s, loss=0.295]


Epoch: 7, loss: 0.41226


100%|██████████| 30/30 [00:13<00:00,  2.30it/s, loss=0.582]


Epoch: 8, loss: 0.40792


100%|██████████| 30/30 [00:11<00:00,  2.72it/s, loss=0.341]


Epoch: 9, loss: 0.40527


100%|██████████| 30/30 [00:10<00:00,  2.82it/s, loss=0.246]


Epoch: 10, loss: 0.39149


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.27] 


Epoch: 11, loss: 0.37471


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.447]


Epoch: 12, loss: 0.37766


100%|██████████| 30/30 [00:11<00:00,  2.69it/s, loss=0.239]


Epoch: 13, loss: 0.36050


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.277]


Epoch: 14, loss: 0.36637


100%|██████████| 30/30 [00:10<00:00,  2.73it/s, loss=0.512]


Epoch: 15, loss: 0.35935


100%|██████████| 30/30 [00:11<00:00,  2.72it/s, loss=0.236]


Epoch: 16, loss: 0.35235


100%|██████████| 30/30 [00:13<00:00,  2.19it/s, loss=0.608]


Epoch: 17, loss: 0.35733


100%|██████████| 30/30 [00:11<00:00,  2.65it/s, loss=0.388]


Epoch: 18, loss: 0.34338


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.24] 


Epoch: 19, loss: 0.34319


100%|██████████| 30/30 [00:11<00:00,  2.69it/s, loss=0.411]


Epoch: 20, loss: 0.34094


100%|██████████| 30/30 [00:12<00:00,  2.36it/s, loss=0.264]


Epoch: 21, loss: 0.33876


100%|██████████| 30/30 [00:11<00:00,  2.71it/s, loss=0.38] 


Epoch: 22, loss: 0.33481


100%|██████████| 30/30 [00:11<00:00,  2.71it/s, loss=0.43] 


Epoch: 23, loss: 0.33226


100%|██████████| 30/30 [00:13<00:00,  2.29it/s, loss=0.413]


Epoch: 24, loss: 0.32729


100%|██████████| 30/30 [00:10<00:00,  2.77it/s, loss=0.387]


Epoch: 25, loss: 0.32925


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.298]


Epoch: 26, loss: 0.32585


100%|██████████| 30/30 [00:11<00:00,  2.51it/s, loss=0.375]


Epoch: 27, loss: 0.31859


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.374]


Epoch: 28, loss: 0.31650


100%|██████████| 30/30 [00:11<00:00,  2.67it/s, loss=0.336]


Epoch: 29, loss: 0.31028


100%|██████████| 30/30 [00:10<00:00,  2.73it/s, loss=0.183]


Epoch: 30, loss: 0.30529


100%|██████████| 30/30 [00:10<00:00,  2.76it/s, loss=0.387]


Epoch: 31, loss: 0.30333


100%|██████████| 30/30 [00:11<00:00,  2.70it/s, loss=0.182]


Epoch: 32, loss: 0.29944


100%|██████████| 30/30 [00:11<00:00,  2.69it/s, loss=0.273]


Epoch: 33, loss: 0.29720


100%|██████████| 30/30 [00:11<00:00,  2.67it/s, loss=0.339]


Epoch: 34, loss: 0.28950


100%|██████████| 30/30 [00:10<00:00,  2.75it/s, loss=0.319]


Epoch: 35, loss: 0.27713


100%|██████████| 30/30 [00:11<00:00,  2.71it/s, loss=0.431]


Epoch: 36, loss: 0.26857


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.269]


Epoch: 37, loss: 0.26641


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.228]


Epoch: 38, loss: 0.25588


100%|██████████| 30/30 [00:10<00:00,  2.76it/s, loss=0.173]


Epoch: 39, loss: 0.24631


100%|██████████| 30/30 [00:13<00:00,  2.27it/s, loss=0.138]


Epoch: 40, loss: 0.23175


100%|██████████| 30/30 [00:10<00:00,  2.79it/s, loss=0.209]


Epoch: 41, loss: 0.22440


100%|██████████| 30/30 [00:13<00:00,  2.26it/s, loss=0.262]


Epoch: 42, loss: 0.20424


100%|██████████| 30/30 [00:18<00:00,  1.66it/s, loss=0.26] 


Epoch: 43, loss: 0.19280


100%|██████████| 30/30 [00:15<00:00,  1.98it/s, loss=0.105]


Epoch: 44, loss: 0.18770


100%|██████████| 30/30 [00:11<00:00,  2.67it/s, loss=0.155]


Epoch: 45, loss: 0.16344


100%|██████████| 30/30 [00:10<00:00,  2.87it/s, loss=0.125] 


Epoch: 46, loss: 0.13700


100%|██████████| 30/30 [00:10<00:00,  2.78it/s, loss=0.0639]


Epoch: 47, loss: 0.11471


100%|██████████| 30/30 [00:11<00:00,  2.66it/s, loss=0.117] 


Epoch: 48, loss: 0.10254


100%|██████████| 30/30 [00:11<00:00,  2.52it/s, loss=0.104] 


Epoch: 49, loss: 0.09103


100%|██████████| 30/30 [00:11<00:00,  2.61it/s, loss=0.118] 


Epoch: 50, loss: 0.07922


100%|██████████| 30/30 [00:10<00:00,  2.89it/s, loss=0.0313]


Epoch: 51, loss: 0.07716


100%|██████████| 30/30 [00:10<00:00,  2.82it/s, loss=0.0251]


Epoch: 52, loss: 0.06417


100%|██████████| 30/30 [00:10<00:00,  2.90it/s, loss=0.0778]


Epoch: 53, loss: 0.05279


100%|██████████| 30/30 [00:10<00:00,  2.96it/s, loss=0.117] 


Epoch: 54, loss: 0.05445


100%|██████████| 30/30 [00:12<00:00,  2.41it/s, loss=0.0189]


Epoch: 55, loss: 0.05390


100%|██████████| 30/30 [00:11<00:00,  2.62it/s, loss=0.0254]


Epoch: 56, loss: 0.04304


100%|██████████| 30/30 [00:10<00:00,  2.84it/s, loss=0.0272]


Epoch: 57, loss: 0.03143


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.0193] 


Epoch: 58, loss: 0.02499


100%|██████████| 30/30 [00:10<00:00,  2.88it/s, loss=0.0129] 


Epoch: 59, loss: 0.01935


100%|██████████| 30/30 [00:10<00:00,  2.74it/s, loss=0.014]  


Epoch: 60, loss: 0.01722


100%|██████████| 30/30 [00:10<00:00,  2.77it/s, loss=0.0158] 


Epoch: 61, loss: 0.01586


100%|██████████| 30/30 [00:09<00:00,  3.01it/s, loss=0.00701]


Epoch: 62, loss: 0.00987


100%|██████████| 30/30 [00:11<00:00,  2.62it/s, loss=0.00815]


Epoch: 63, loss: 0.00729


100%|██████████| 30/30 [00:10<00:00,  2.90it/s, loss=0.00201]


Epoch: 64, loss: 0.00532


100%|██████████| 30/30 [00:09<00:00,  3.01it/s, loss=0.00324]


Epoch: 65, loss: 0.00365


100%|██████████| 30/30 [00:11<00:00,  2.57it/s, loss=0.00151] 


Epoch: 66, loss: 0.00224


100%|██████████| 30/30 [00:10<00:00,  2.80it/s, loss=0.00083] 


Epoch: 67, loss: 0.00167


100%|██████████| 30/30 [00:09<00:00,  3.13it/s, loss=0.000629]


Epoch: 68, loss: 0.00137


100%|██████████| 30/30 [00:12<00:00,  2.49it/s, loss=0.000358]


Epoch: 69, loss: 0.00114


100%|██████████| 30/30 [00:09<00:00,  3.25it/s, loss=0.000434]


Epoch: 70, loss: 0.00104


100%|██████████| 30/30 [00:09<00:00,  3.25it/s, loss=0.000582]


Epoch: 71, loss: 0.00092


100%|██████████| 30/30 [00:08<00:00,  3.33it/s, loss=0.000449]


Epoch: 72, loss: 0.00087


100%|██████████| 30/30 [00:10<00:00,  2.98it/s, loss=0.000287]


Epoch: 73, loss: 0.00080


100%|██████████| 30/30 [00:09<00:00,  3.29it/s, loss=0.00351] 


Epoch: 74, loss: 0.00074


100%|██████████| 30/30 [00:09<00:00,  3.20it/s, loss=0.000341]


Epoch: 75, loss: 0.00082


100%|██████████| 30/30 [00:09<00:00,  3.30it/s, loss=0.000207]


Epoch: 76, loss: 0.00059


100%|██████████| 30/30 [00:08<00:00,  3.35it/s, loss=0.00192] 


Epoch: 77, loss: 0.00055


100%|██████████| 30/30 [00:09<00:00,  3.31it/s, loss=0.000309]


Epoch: 78, loss: 0.00052


100%|██████████| 30/30 [00:10<00:00,  2.86it/s, loss=0.000155]


Epoch: 79, loss: 0.00045


100%|██████████| 30/30 [00:08<00:00,  3.37it/s, loss=0.000114]


Epoch: 80, loss: 0.00041


100%|██████████| 30/30 [00:09<00:00,  3.26it/s, loss=0.000335]


Epoch: 81, loss: 0.00036


100%|██████████| 30/30 [00:09<00:00,  3.04it/s, loss=0.00018] 


Epoch: 82, loss: 0.00033


100%|██████████| 30/30 [00:14<00:00,  2.05it/s, loss=0.000218]


Epoch: 83, loss: 0.00030


100%|██████████| 30/30 [00:09<00:00,  3.28it/s, loss=0.000115]


Epoch: 84, loss: 0.00028


100%|██████████| 30/30 [00:09<00:00,  3.03it/s, loss=0.000757]


Epoch: 85, loss: 0.00024


100%|██████████| 30/30 [00:10<00:00,  2.96it/s, loss=9.92e-5] 


Epoch: 86, loss: 0.00022


100%|██████████| 30/30 [00:13<00:00,  2.30it/s, loss=0.000153]


Epoch: 87, loss: 0.00021


100%|██████████| 30/30 [00:11<00:00,  2.61it/s, loss=0.000233]


Epoch: 88, loss: 0.00019


100%|██████████| 30/30 [00:11<00:00,  2.64it/s, loss=8.35e-5] 


Epoch: 89, loss: 0.00018


100%|██████████| 30/30 [00:12<00:00,  2.50it/s, loss=6.15e-5] 


Epoch: 90, loss: 0.00017


100%|██████████| 30/30 [00:12<00:00,  2.50it/s, loss=8.31e-5] 


Epoch: 91, loss: 0.00016


100%|██████████| 30/30 [00:12<00:00,  2.33it/s, loss=0.000166]


Epoch: 92, loss: 0.00015


100%|██████████| 30/30 [00:12<00:00,  2.41it/s, loss=3.7e-5]  


Epoch: 93, loss: 0.00014


100%|██████████| 30/30 [00:11<00:00,  2.65it/s, loss=4.02e-5] 


Epoch: 94, loss: 0.00013


100%|██████████| 30/30 [00:12<00:00,  2.45it/s, loss=0.000121]


Epoch: 95, loss: 0.00012


100%|██████████| 30/30 [00:10<00:00,  2.78it/s, loss=0.000356]


Epoch: 96, loss: 0.00012


100%|██████████| 30/30 [00:11<00:00,  2.51it/s, loss=7.09e-5] 


Epoch: 97, loss: 0.00011


100%|██████████| 30/30 [00:12<00:00,  2.50it/s, loss=5.34e-5] 


Epoch: 98, loss: 0.00010


100%|██████████| 30/30 [00:11<00:00,  2.55it/s, loss=0.000129]

Epoch: 99, loss: 0.00010





In [61]:
print((perp_model(data)>.5).sum(1))
print(labels.sum(1))


tensor([[154]])
tensor([154.])
