# Baseline LSTM Model and SBERT


In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.autograd import Variable
from torch.nn import functional as F
from glob import glob
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

In [2]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.', 
    'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
sentence_embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, sentence_embeddings):
    print("Sentence:", sentence)
    print("Embedding Shape:", embedding.shape)
    print("")

Sentence: This framework generates embeddings for each input sentence
Embedding Shape: (768,)

Sentence: Sentences are passed as a list of string.
Embedding Shape: (768,)

Sentence: The quick brown fox jumps over the lazy dog.
Embedding Shape: (768,)



## Data Loader for labeled episodes with embedded sentences

Need to have embeddings saved from Model_Pretraining

Initiating the class takes some time to load all the embeddings into memory. 



In [3]:
def ls_data_loader(path, train=True):
    '''loads labeled sub files, embeds the sentences, and saves
    embedded episodes and labels into binary files
    '''
    # get path to all labels_sub files
    sub_files = glob(path + '/*.tsv')
    
    # loads texts
    data = []
    labels = []

    if train:
        start=0
        stop=30
    else:
        start=30
        stop=-1
    for f in sub_files[start:stop]:
        print('loading:', f)
        # read with pandas
        df = pd.read_csv(f, sep='\t', usecols=['text', 'label'])
        # get sentences as list
        sentences = list(df["text"].values)
        # convert labels to 0 and 1
        tags = [0 if v == 'N' else 1 for v in df["label"]]
        data.append(sentences)
        labels.append(tags)
    return data, labels

In [4]:


class Dataset_seq_ep(torch.utils.data.Dataset):
    def __init__(self, train_path, train=True):
        # self.sent_id = sent_id
        self.train_path = train_path
        self.data, self.labels = ls_data_loader(train_path, train)
        self.embedder = SentenceTransformer('all-mpnet-base-v2')
        self.data = self.__embedder__(self.data)
        

    def __getitem__(self, index):
        # return sequence of sentences and labels
        seq = torch.Tensor(self.data[index])
        labels = torch.Tensor(self.labels[index])
        return seq, labels

    def __len__(self):
        return(len(self.data))

    def __embedder__(self, data):
        return [self.embedder.encode(sentences) for sentences in data]



def collate_fn(batch):
	'''  
	custom collate_fn as the size of every episode is different and merging sequences (including padding) 
	is not supported in default. 
	'''

	(xx, yy) = zip(*batch)
	x_lens = [len(x) for x in xx]
	y_lens = [len(y) for y in yy]

	xx_pad = pad_sequence(xx, batch_first=True, padding_value=-1)
	yy_pad = pad_sequence(yy, batch_first=True, padding_value=-1)

	return xx_pad, yy_pad, x_lens, y_lens

train_dataset = Dataset_seq_ep('labeled_subs')
test_dataset = Dataset_seq_ep('labeled_subs', train=False)

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn)
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn)

loading: labeled_subs\s01e07.tsv
loading: labeled_subs\s01e08.tsv
loading: labeled_subs\s01e13.tsv
loading: labeled_subs\s01e19.tsv
loading: labeled_subs\s01e20.tsv
loading: labeled_subs\s01e23.tsv
loading: labeled_subs\s02e01.tsv
loading: labeled_subs\s02e04.tsv
loading: labeled_subs\s02e06.tsv
loading: labeled_subs\s02e09.tsv
loading: labeled_subs\s02e10.tsv
loading: labeled_subs\s02e15.tsv
loading: labeled_subs\s03e03.tsv
loading: labeled_subs\s03e05.tsv
loading: labeled_subs\s03e08.tsv
loading: labeled_subs\s03e11.tsv
loading: labeled_subs\s03e12.tsv
loading: labeled_subs\s03e19.tsv
loading: labeled_subs\s03e21.tsv
loading: labeled_subs\s04e05.tsv
loading: labeled_subs\s04e06.tsv
loading: labeled_subs\s04e09.tsv
loading: labeled_subs\s04e10.tsv
loading: labeled_subs\s04e12.tsv
loading: labeled_subs\s04e14.tsv
loading: labeled_subs\s04e15.tsv
loading: labeled_subs\s04e21.tsv
loading: labeled_subs\s04e22.tsv
loading: labeled_subs\s04e23.tsv
loading: labeled_subs\s05e03.tsv
loading: l

Check dataloader and see that batches are padded to seq with max_length per batch and the length of each sequence is returned

In [5]:
for i in range(1):
    loop = tqdm(train_dataloader)
    for batch in loop:
        data, labels, in_len, lab_len = batch
        print('Data shape', data.shape)
        print('labels shape', labels.shape)
        break

  0%|          | 0/30 [00:00<?, ?it/s]

Data shape torch.Size([1, 688, 768])
labels shape torch.Size([1, 688])





### Build LSTM model

adapting from 

Blog post:
Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health:
https://medium.com/@_willfalcon/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e



In [6]:

# this is still a work in progress

class PerpLSTM(nn.Module):
    def __init__(self, nb_lstm_layers=1, nb_lstm_units=128, fc_hidden_units=100, embedding_dim=3, batch_size=1):
        super(PerpLSTM, self).__init__()

        # self.vocab = {'<PAD>': -1} # not sure we need this
        # self.tags = {'<PAD>':-1, 'N': 0, 'Y': 1}

        self.nb_lstm_layers = nb_lstm_layers
        self.nb_lstm_units = nb_lstm_units
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.fc_hidden_units = fc_hidden_units

        # # don't count the padding tag for the classifier output
        # self.nb_tags = len(self.tags)-1

        # build actual NN
        self.__build_model()

    def __build_model(self):

        self.relu = nn.ReLU()

        # design LSTM
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.nb_lstm_units,
            num_layers=self.nb_lstm_layers,
            batch_first=True,
        )

        # output layer which projects back to tag space
        self.fc1 = nn.Linear(self.nb_lstm_units, self.fc_hidden_units)
        self.fc2 = nn.Linear(self.fc_hidden_units, 1) # change out to 1 for sigmoid activation


    def forward(self, X):

        # run through ReLu
        X = self.relu(X)

        # run through LSTM
        X, self.hidden = self.lstm(X)

        # reshape for linear
        # X = X.contiguous()
        # X = X.view(-1, X.shape[2])

        # run through linear layers and finish with sigmoid func
        X = self.fc1(X)
        X = self.relu(X)
        X = torch.sigmoid(self.fc2(X))

        return X


In [7]:

perp_model = PerpLSTM(
    nb_lstm_layers=1, 
    nb_lstm_units=128, 
    embedding_dim=768,
    batch_size=1)


## Testing forward

In [8]:
for i in range(1):
    loop = tqdm(train_dataloader)
    for batch in loop:
        data, labels, in_len, lab_len = batch
        print('Data shape', data.shape)
        print('labels shape', labels.shape)
        with torch.no_grad():
            output = perp_model.forward(data)
        print(output.shape)
        break

  0%|          | 0/30 [00:00<?, ?it/s]

Data shape torch.Size([1, 608, 768])
labels shape torch.Size([1, 608])
torch.Size([1, 608, 1])





## Testing backward

In [9]:

optimizer = torch.optim.Adam(perp_model.parameters(), lr=0.001) 
# Using Binary Cross Entropy Loss function since we are using batch size = 1
criterion = nn.BCELoss()

for epoch in range(100):
    loop = tqdm(train_dataloader)
    epoch_total = 0
    for batch in loop:
        data, labels, in_len, lab_len = batch

        outputs = perp_model.forward(data) #forward pass
        optimizer.zero_grad() #calculate the gradient, manually setting to 0
 
        # obtain the loss function
        # loss = perp_model.loss(outputs, labels, lab_len)
        # m = nn.Sigmoid()
        # sig_out = m(outputs)
        # print(sig_out.shape)

        # loss = criterion(outputs, F.one_hot(labels.view(-1).type(torch.int64)).type(torch.float32))
        loss = criterion(outputs.view(1,-1), labels.type(torch.float32))
        loss.backward() #calculates the loss of the loss function
        
        optimizer.step() #improve from loss, i.e backprop
        loop.set_postfix(loss=loss.item())
        epoch_total += loss.item()
    print("Epoch: %d, loss: %1.5f" % (epoch, epoch_total/len(loop)))


100%|██████████| 30/30 [00:11<00:00,  2.69it/s, loss=0.65] 


Epoch: 0, loss: 0.50608


100%|██████████| 30/30 [00:12<00:00,  2.38it/s, loss=0.329]


Epoch: 1, loss: 0.42760


100%|██████████| 30/30 [00:11<00:00,  2.56it/s, loss=0.445]


Epoch: 2, loss: 0.42002


100%|██████████| 30/30 [00:10<00:00,  2.89it/s, loss=0.354]


Epoch: 3, loss: 0.41839


100%|██████████| 30/30 [00:12<00:00,  2.32it/s, loss=0.531]


Epoch: 4, loss: 0.41587


100%|██████████| 30/30 [00:13<00:00,  2.28it/s, loss=0.561]


Epoch: 5, loss: 0.41555


100%|██████████| 30/30 [00:12<00:00,  2.32it/s, loss=0.374]


Epoch: 6, loss: 0.41595


100%|██████████| 30/30 [00:15<00:00,  1.88it/s, loss=0.229]


Epoch: 7, loss: 0.41310


100%|██████████| 30/30 [00:15<00:00,  1.92it/s, loss=0.498]


Epoch: 8, loss: 0.40874


100%|██████████| 30/30 [00:14<00:00,  2.09it/s, loss=0.554]


Epoch: 9, loss: 0.41531


100%|██████████| 30/30 [00:10<00:00,  2.80it/s, loss=0.346]


Epoch: 10, loss: 0.40529


100%|██████████| 30/30 [00:12<00:00,  2.43it/s, loss=0.32] 


Epoch: 11, loss: 0.40652


100%|██████████| 30/30 [00:15<00:00,  1.92it/s, loss=0.424]


Epoch: 12, loss: 0.39568


100%|██████████| 30/30 [00:27<00:00,  1.07it/s, loss=0.546]


Epoch: 13, loss: 0.37983


100%|██████████| 30/30 [00:17<00:00,  1.72it/s, loss=0.295]


Epoch: 14, loss: 0.39361


100%|██████████| 30/30 [00:16<00:00,  1.82it/s, loss=0.422]


Epoch: 15, loss: 0.37302


100%|██████████| 30/30 [00:15<00:00,  1.89it/s, loss=0.302]


Epoch: 16, loss: 0.36260


100%|██████████| 30/30 [00:13<00:00,  2.23it/s, loss=0.385]


Epoch: 17, loss: 0.36132


100%|██████████| 30/30 [00:17<00:00,  1.72it/s, loss=0.604]


Epoch: 18, loss: 0.35426


100%|██████████| 30/30 [00:13<00:00,  2.20it/s, loss=0.24] 


Epoch: 19, loss: 0.34862


100%|██████████| 30/30 [00:11<00:00,  2.60it/s, loss=0.473]


Epoch: 20, loss: 0.34982


100%|██████████| 30/30 [00:22<00:00,  1.36it/s, loss=0.347]


Epoch: 21, loss: 0.35574


100%|██████████| 30/30 [00:16<00:00,  1.86it/s, loss=0.386]


Epoch: 22, loss: 0.34495


100%|██████████| 30/30 [00:19<00:00,  1.51it/s, loss=0.386]


Epoch: 23, loss: 0.34824


100%|██████████| 30/30 [00:15<00:00,  1.91it/s, loss=0.256]


Epoch: 24, loss: 0.35093


100%|██████████| 30/30 [00:13<00:00,  2.17it/s, loss=0.409]


Epoch: 25, loss: 0.33735


100%|██████████| 30/30 [00:12<00:00,  2.31it/s, loss=0.417]


Epoch: 26, loss: 0.33406


100%|██████████| 30/30 [00:13<00:00,  2.26it/s, loss=0.474]


Epoch: 27, loss: 0.33156


100%|██████████| 30/30 [00:12<00:00,  2.33it/s, loss=0.232]


Epoch: 28, loss: 0.33316


100%|██████████| 30/30 [00:14<00:00,  2.01it/s, loss=0.455]


Epoch: 29, loss: 0.32533


100%|██████████| 30/30 [00:21<00:00,  1.37it/s, loss=0.281]


Epoch: 30, loss: 0.32948


100%|██████████| 30/30 [00:17<00:00,  1.67it/s, loss=0.304]


Epoch: 31, loss: 0.32107


100%|██████████| 30/30 [00:15<00:00,  1.92it/s, loss=0.297]


Epoch: 32, loss: 0.31982


100%|██████████| 30/30 [00:15<00:00,  2.00it/s, loss=0.456]


Epoch: 33, loss: 0.31701


100%|██████████| 30/30 [00:14<00:00,  2.01it/s, loss=0.239]


Epoch: 34, loss: 0.30953


100%|██████████| 30/30 [00:11<00:00,  2.57it/s, loss=0.292]


Epoch: 35, loss: 0.30761


100%|██████████| 30/30 [00:12<00:00,  2.41it/s, loss=0.352]


Epoch: 36, loss: 0.30050


100%|██████████| 30/30 [00:11<00:00,  2.52it/s, loss=0.368]


Epoch: 37, loss: 0.29663


100%|██████████| 30/30 [00:11<00:00,  2.50it/s, loss=0.224]


Epoch: 38, loss: 0.29155


100%|██████████| 30/30 [00:12<00:00,  2.49it/s, loss=0.388]


Epoch: 39, loss: 0.29166


100%|██████████| 30/30 [00:11<00:00,  2.55it/s, loss=0.18] 


Epoch: 40, loss: 0.27715


100%|██████████| 30/30 [00:12<00:00,  2.46it/s, loss=0.28] 


Epoch: 41, loss: 0.26300


100%|██████████| 30/30 [00:12<00:00,  2.49it/s, loss=0.273]


Epoch: 42, loss: 0.25113


100%|██████████| 30/30 [00:16<00:00,  1.81it/s, loss=0.236]


Epoch: 43, loss: 0.23838


100%|██████████| 30/30 [00:12<00:00,  2.48it/s, loss=0.0966]


Epoch: 44, loss: 0.22769


100%|██████████| 30/30 [00:12<00:00,  2.44it/s, loss=0.137]


Epoch: 45, loss: 0.22147


100%|██████████| 30/30 [00:11<00:00,  2.51it/s, loss=0.138]


Epoch: 46, loss: 0.20658


100%|██████████| 30/30 [00:11<00:00,  2.53it/s, loss=0.209] 


Epoch: 47, loss: 0.17676


100%|██████████| 30/30 [00:12<00:00,  2.36it/s, loss=0.201] 


Epoch: 48, loss: 0.15359


100%|██████████| 30/30 [00:11<00:00,  2.53it/s, loss=0.283] 


Epoch: 49, loss: 0.15102


100%|██████████| 30/30 [00:12<00:00,  2.45it/s, loss=0.167] 


Epoch: 50, loss: 0.13223


100%|██████████| 30/30 [00:11<00:00,  2.55it/s, loss=0.149] 


Epoch: 51, loss: 0.11872


100%|██████████| 30/30 [00:11<00:00,  2.54it/s, loss=0.162] 


Epoch: 52, loss: 0.10271


100%|██████████| 30/30 [00:12<00:00,  2.49it/s, loss=0.0688]


Epoch: 53, loss: 0.09657


100%|██████████| 30/30 [00:12<00:00,  2.46it/s, loss=0.0945]


Epoch: 54, loss: 0.07990


100%|██████████| 30/30 [00:13<00:00,  2.27it/s, loss=0.0477]


Epoch: 55, loss: 0.07191


100%|██████████| 30/30 [00:12<00:00,  2.45it/s, loss=0.0777]


Epoch: 56, loss: 0.06238


100%|██████████| 30/30 [00:13<00:00,  2.14it/s, loss=0.0245]


Epoch: 57, loss: 0.05505


100%|██████████| 30/30 [00:14<00:00,  2.07it/s, loss=0.0461]


Epoch: 58, loss: 0.04700


100%|██████████| 30/30 [00:12<00:00,  2.43it/s, loss=0.0536]


Epoch: 59, loss: 0.04621


100%|██████████| 30/30 [00:12<00:00,  2.45it/s, loss=0.0853]


Epoch: 60, loss: 0.05148


100%|██████████| 30/30 [00:11<00:00,  2.58it/s, loss=0.0503]


Epoch: 61, loss: 0.04999


100%|██████████| 30/30 [00:11<00:00,  2.51it/s, loss=0.0318]


Epoch: 62, loss: 0.03057


100%|██████████| 30/30 [00:11<00:00,  2.54it/s, loss=0.0131] 


Epoch: 63, loss: 0.02238


100%|██████████| 30/30 [00:11<00:00,  2.53it/s, loss=0.0767] 


Epoch: 64, loss: 0.02291


100%|██████████| 30/30 [00:11<00:00,  2.54it/s, loss=0.0252] 


Epoch: 65, loss: 0.02062


100%|██████████| 30/30 [00:11<00:00,  2.53it/s, loss=0.00993]


Epoch: 66, loss: 0.01439


100%|██████████| 30/30 [00:13<00:00,  2.27it/s, loss=0.0133] 


Epoch: 67, loss: 0.00979


100%|██████████| 30/30 [00:14<00:00,  2.08it/s, loss=0.00457]


Epoch: 68, loss: 0.00687


100%|██████████| 30/30 [00:16<00:00,  1.77it/s, loss=0.00433]


Epoch: 69, loss: 0.00644


100%|██████████| 30/30 [00:14<00:00,  2.09it/s, loss=0.00143] 


Epoch: 70, loss: 0.00422


100%|██████████| 30/30 [00:15<00:00,  1.88it/s, loss=0.00182] 


Epoch: 71, loss: 0.00321


100%|██████████| 30/30 [00:14<00:00,  2.02it/s, loss=0.000717]


Epoch: 72, loss: 0.00196


100%|██████████| 30/30 [00:15<00:00,  1.88it/s, loss=0.00062] 


Epoch: 73, loss: 0.00153


100%|██████████| 30/30 [00:17<00:00,  1.67it/s, loss=0.00493] 


Epoch: 74, loss: 0.00124


100%|██████████| 30/30 [00:13<00:00,  2.25it/s, loss=0.000443]


Epoch: 75, loss: 0.00105


100%|██████████| 30/30 [00:12<00:00,  2.40it/s, loss=0.000584]


Epoch: 76, loss: 0.00099


100%|██████████| 30/30 [00:13<00:00,  2.30it/s, loss=0.000377]


Epoch: 77, loss: 0.00088


100%|██████████| 30/30 [00:14<00:00,  2.13it/s, loss=0.000271]


Epoch: 78, loss: 0.00081


100%|██████████| 30/30 [00:13<00:00,  2.18it/s, loss=0.000285]


Epoch: 79, loss: 0.00074


100%|██████████| 30/30 [00:13<00:00,  2.25it/s, loss=0.000365]


Epoch: 80, loss: 0.00072


100%|██████████| 30/30 [00:14<00:00,  2.11it/s, loss=0.000149]


Epoch: 81, loss: 0.00061


100%|██████████| 30/30 [00:13<00:00,  2.20it/s, loss=0.000133]


Epoch: 82, loss: 0.00058


100%|██████████| 30/30 [00:14<00:00,  2.01it/s, loss=0.000224]


Epoch: 83, loss: 0.00054


100%|██████████| 30/30 [00:13<00:00,  2.24it/s, loss=0.000386]


Epoch: 84, loss: 0.00049


100%|██████████| 30/30 [00:14<00:00,  2.04it/s, loss=0.000179]


Epoch: 85, loss: 0.00046


100%|██████████| 30/30 [00:13<00:00,  2.24it/s, loss=0.000188]


Epoch: 86, loss: 0.00042


100%|██████████| 30/30 [00:15<00:00,  1.95it/s, loss=0.000146]


Epoch: 87, loss: 0.00041


100%|██████████| 30/30 [00:15<00:00,  1.97it/s, loss=9.89e-5] 


Epoch: 88, loss: 0.00039


100%|██████████| 30/30 [00:15<00:00,  1.98it/s, loss=0.000194]


Epoch: 89, loss: 0.00034


100%|██████████| 30/30 [00:14<00:00,  2.06it/s, loss=0.000158]


Epoch: 90, loss: 0.00032


100%|██████████| 30/30 [00:14<00:00,  2.12it/s, loss=0.000135]


Epoch: 91, loss: 0.00031


100%|██████████| 30/30 [00:15<00:00,  1.96it/s, loss=0.000281]


Epoch: 92, loss: 0.00029


100%|██████████| 30/30 [00:13<00:00,  2.17it/s, loss=0.000119]


Epoch: 93, loss: 0.00026


100%|██████████| 30/30 [00:14<00:00,  2.00it/s, loss=0.000253]


Epoch: 94, loss: 0.00023


100%|██████████| 30/30 [00:13<00:00,  2.24it/s, loss=0.000639]


Epoch: 95, loss: 0.00022


100%|██████████| 30/30 [00:21<00:00,  1.41it/s, loss=0.0001]  


Epoch: 96, loss: 0.00020


100%|██████████| 30/30 [00:14<00:00,  2.06it/s, loss=6.01e-5] 


Epoch: 97, loss: 0.00019


100%|██████████| 30/30 [00:14<00:00,  2.09it/s, loss=5.55e-5] 


Epoch: 98, loss: 0.00018


100%|██████████| 30/30 [00:12<00:00,  2.36it/s, loss=0.000116]

Epoch: 99, loss: 0.00017





In [11]:
loop = tqdm(train_dataloader)
Y=[]
Y_hat = []
for batch in loop:
    data, labels, in_len, lab_len = batch
    with torch.no_grad():
        outputs = perp_model(data) #forward pass
        Y_hat.append(outputs)
        Y.append(labels)

100%|██████████| 30/30 [00:01<00:00, 17.75it/s]


In [109]:
import json

y_pred_list = [y.view(-1).tolist() for y in Y_hat]
y_true_list = [y.view(-1).tolist() for y in Y]


with open('y_pred_list_train.json', 'w') as F:
    # Use the json dumps method to write the list to disk
    F.write(json.dumps(y_pred_list))
with open('y_true_list_train.json', 'w') as F:
    # Use the json dumps method to write the list to disk
    F.write(json.dumps(y_true_list))

