In [None]:
import transformers
print(transformers.__version__)

In [None]:
base_dir = './'
base_dir

In [None]:
import torch
import numpy as np
import pandas as pd
import collections
import json
from transformers import BertTokenizer, BertForMaskedLM, DistilBertForSequenceClassification
import re
import string
import logging
import pickle
from torch import nn
import os
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import spacy
from spacy import displacy
from spacy.lang.en import English
import networkx as nx
from keras.preprocessing.sequence import pad_sequences


logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
nlp = spacy.load('en_core_web_sm')

device = 'cuda:0'
print(torch.cuda.get_device_name(device))

# Root2Leaf Here 📄
[Paper Link](https://arxiv.org/pdf/1703.00572.pdf)

Basically the paper considers two different ways for integrating structure embeddings. For each word:
1.   concat the word embedding (from GloVe) + passing the structure (from root to leaf) through the BiLSTM -> ```[Emb_word, h_fwd, h_bwd]```
2.   concat the last hidden of dependences + structure -> ```[Emb_deps, h_fwd, h_bwd]```

The second method achieves better results.

the output of the following function:

```{0: [[this, [('this', 'DT'), ('is', 'VBZ')]], [is, [('is', 'VBZ')]], [a, [('a', 'DT'), ('test', 'NN'), ('is', 'VBZ')]], [test, [('test', 'NN'), ('is', 'VBZ')]]]}```

'0' is the sentence ID. For each word in the sentence we have a 2D list containing the actual word + dependencies and POS tags (the first dependancy is the actual word!).




In [None]:
def shortest_dependency_path(doc, e1=None, e2=None):
    edges = []
    for token in doc:
        for child in token.children:
            edges.append(('{0}'.format(token),
                          '{0}'.format(child)))
    graph = nx.Graph(edges)

    try:
        shortest_path = nx.shortest_path(graph, source=e1, target=e2)
    except:
        shortest_path = []

    return shortest_path

def dependency_tree(sents):
    sent_deps = {}
    sents = nlp(sents)

    for s_ind, sentence in enumerate(sents.sents):
        root = sentence.root
        # print(root)
        dictionary = dict()
        all_dependency = {}
        
        for token in sentence:
            if token.is_alpha:
                dictionary[token.orth_] = token.tag_
                # print(token.orth_, token.tag_, token.head.lemma_)
                alist = shortest_dependency_path(sentence, token.orth_, str(root))
                all_dependency[token] = alist
        all_dependency_tag = []
        for token, lists in all_dependency.items():
            temp = collections.OrderedDict()
            for item in lists:
                if item in dictionary:
                    temp[item] = dictionary[item]
            all_dependency_tag.append([token, temp])
        
        # add to dict
        sent_deps[s_ind] = all_dependency_tag

    return sent_deps

In [None]:
# test the functions here
print(dependency_tree('this is a test.'))

In [None]:
# create dictionary for POS tags! -> 1 time code
if not os.path.exists(base_dir+'cache/tag_names.pkl'):
    pos_tags = {}
    tag_id = 0

    for data in tqdm(dataset):
        data = dataset.tokenizer.decode(data[0])
        deps_dict = dependency_tree(data)
        
        for _, val in deps_dict.items():
            # val is a 2D list
            for sent_tags in val:
                for tag_name in sent_tags:
                    # add to dict
                    pos_tags[tag_name] = pos_tags.setdefault(tag_name, tag_id)
                    tag_id = max(pos_tags.values()) + 1
    print('Num Tags = %d' %(len(pos_tags)))
    pickle.dump(pos_tags, open(base_dir+'cache/tag_names.pkl', 'wb'))
else:
    pos_tags = pickle.load(open(base_dir+'cache/tag_names.pkl', 'rb'))
    print('TagIDs loaded')

In [None]:
emoji_pattern = re.compile("["
         u"\U0001F600-\U0001F64F"  # emoticons
         u"\U0001F300-\U0001F5FF"  # symbols & pictographs
         u"\U0001F680-\U0001F6FF"  # transport & map symbols
         u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
         u"\U00002702-\U000027B0"
         u"\U000024C2-\U0001F251"
         "]+", flags=re.UNICODE)

emoticons = set([
    ':-)', ':)', ';)', ':o)', ':]', ':3', ':c)', ':>', '=]', '8)', '=)', ':}',
    ':^)', ':-D', ':D', '8-D', '8D', 'x-D', 'xD', 'X-D', 'XD', '=-D', '=D',
    '=-3', '=3', ':-))', ":'-)", ":')", ':*', ':^*', '>:P', ':-P', ':P', 'X-P',
    'x-p', 'xp', 'XP', ':-p', ':p', '=p', ':-b', ':b', '>:)', '>;)', '>:-)',
    '<3', ':L', ':-/', '>:/', ':S', '>:[', ':@', ':-(', ':[', ':-||', '=L', ':<',
    ':-[', ':-<', '=\\', '=/', '>:(', ':(', '>.<', ":'-(", ":'(", ':\\', ':-c',
    ':c', ':{', '>:\\', ';('
    ])

def clean_tweets(tweet, rm_puncs=True):
    word_tokens = tweet.replace('"', '').replace('<br />', '').replace(')', '').replace('(', '').lower().split(' ')
    
    #after tweepy preprocessing the colon symbol left remain after      #removing mentions
    tweet = re.sub(r':', '', tweet)
    tweet = re.sub(r'‚Ä¶', '', tweet)
    
    #replace consecutive non-ASCII characters with a space
    tweet = re.sub(r'[^\x00-\x7F]+',' ', tweet)
    
    #remove emojis from tweet
    tweet = emoji_pattern.sub(r'', tweet)
    
    #looping through conditions
    filtered_tweet = []
    
    for w in word_tokens:
        #check tokens against stop words , emoticons and punctuations
        if (w not in emoticons and not rm_puncs) or (rm_puncs and w not in string.punctuation and w not in emoticons):
            filtered_tweet.append(w)
            
    return ' '.join(filtered_tweet)


In [None]:
class RTDataset(torch.utils.data.Dataset):
    def __init__(self, filename='datasets/rotten_tomatoes_reviews.csv', rm_puncs=True, tokenizer=None, use_deps=False):
        # load data from file
        dataset_raw = pd.read_csv(filename)
        col = ['Review', 'Freshness']
        dataset_raw = dataset_raw[col]
        # print(dataset_raw.head())
        self.dataset = []
        self.use_deps = use_deps
        
        # define tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') if tokenizer == None else tokenizer
        num_added_toks = self.tokenizer.add_tokens(['<pad>'])

        if not os.path.exists(base_dir+'cache/rt_data_' + str(rm_puncs) + '.pkl'):
            # convert and save + save the sentence structures as well
            for data in dataset_raw.itertuples():
                text, label = self.tokenizer.encode(clean_tweets(data[1], rm_puncs)), data[2]

                if len(text) > 256:
                    text = text[:256]
                else:
                    while len(text) < 256:
                        text.append(self.tokenizer.encode('<pad>')[1])

                d_text = self.tokenizer.decode(text)
                self.dataset.append([
                    text,
                    dependency_tree(d_text[1 : len(d_text) - 1]),
                    label
                ])

                # assert len(self.dataset[-1][0]) == 256
                print('[%d/%d]                ' %(data[0] + 1, len(dataset_raw)), end='\r', flush=True)  
            pickle.dump(self.dataset, open(base_dir+'cache/rt_data_' + str(rm_puncs) + '.pkl', 'wb'))
        else:
            self.dataset = pickle.load(open(base_dir+'cache/rt_data_' + str(rm_puncs) + '.pkl', 'rb'))
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, ind):
        if not self.use_deps:
            return torch.tensor(self.dataset[ind][0], dtype=torch.float32), self.dataset[ind][2], self.dataset[ind][1]
        
        # else, get everything from the tree
        sent = []
        dependencies = []

        for sentence in self.dataset[ind][1]:
            # loop on words
            for word, deps in sentence:
                sent.append(word)
                dependencies.append( self.tokenizer.encode(list(deps.keys())) )

        # convert to vector
        sent = torch.tensor( self.tokenizer.encode(sent) )

        return sent, self.dataset[ind][2], dependencies

In [None]:
class IMDBDataset(torch.utils.data.Dataset):
    def __init__(self, filename='datasets/IMDB Dataset.csv', rm_puncs=True, tokenizer=None, use_deps=False, max_K=None):
        # load data from file
        dataset_raw = pd.read_csv(filename)
        self.dataset = []
        
        # define tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') if tokenizer == None else tokenizer
        num_added_toks = self.tokenizer.add_tokens(['<pad>'])
        self.use_deps = use_deps
        self.K = max_K

        if use_deps:
            assert max_K != None, 'Please enter a valid K!'

        if not os.path.exists(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl'):
            # convert and save + save the sentence structures as well
            for data in dataset_raw.itertuples():
                text, label = self.tokenizer.encode(clean_tweets(data[1], rm_puncs)), data[2]

                if len(text) > 256:
                    text = text[:256]
                else:
                    while len(text) < 256:
                        text.append(self.tokenizer.encode('<pad>')[1])

                d_text = self.tokenizer.decode(text)
                self.dataset.append([
                    text,
                    dependency_tree(d_text[1 : len(d_text) - 1]),
                    1 if label == 'positive' else 0
                ])

                # assert len(self.dataset[-1][0]) == 256
                print('[%d/%d]                ' %(data[0] + 1, len(dataset_raw)), end='\r', flush=True)  
            pickle.dump(self.dataset, open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl', 'wb'))
        else:
            self.dataset = pickle.load(open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl', 'rb'))
            # self.create_cache(rm_puncs)
        
    def __len__(self):
        return len(self.dataset)

    def create_cache(self, rm_puncs):
        # lol this function creates a cache from another cache :| 
        self.new_dataset = []
        if not os.path.exists(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl'):
            for ind in range(len(self.dataset)):
                sent = []
                dependencies = [ [self.tokenizer.encode('<sos>')[1] for i in range(self.K)] ]

                for sentence in self.dataset[ind][1]:
                    # loop on words
                    for word, deps in sentence:
                        pattern = r"\(\'\S+\'" # \'\S+\'\))"
                        words = re.findall(pattern, deps)

                        sent.append(word)
                        a = [w[2:-1] for w in words]
                        if a != []:
                            deps = self.tokenizer.encode(a)[1:]
                        else:
                            deps = [ self.tokenizer.encode('<pad>')[1] for i in range(self.K) ]    
                        
                        if len(deps) > self.K:
                            deps = deps[:self.K]
                        else:
                            while len(deps) < self.K:
                                deps.append(self.tokenizer.encode('<pad>')[1])
                        
                        dependencies.append(deps)

                dependencies.append([self.tokenizer.encode('<pad>')[1] for i in range(self.K)])
                # convert to vector
                sent = self.tokenizer.encode(sent)
                if len(sent) > 256:
                    sent = sent[:256]
                    dependencies = dependencies[:256]
                else:
                    while len(sent) < 256:
                        sent.append(self.tokenizer.encode('<pad>')[1])
                        dependencies.append( [ self.tokenizer.encode('<pad>')[1] for i in range(self.K) ] )

                self.new_dataset.append( (sent, self.dataset[ind][2], dependencies) )
                if ind % 1000 == 0:
                    print('[%d/%d]        ' %(ind + 1, len(self.dataset)))  
            self.dataset = self.new_dataset
            pickle.dump(self.dataset, open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl', 'wb'))
        else:
            self.dataset = pickle.load(open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl', 'rb'))

    def __getitem__(self, ind):
        return torch.tensor(self.dataset[ind][0], dtype=torch.float32), self.dataset[ind][2], torch.tensor(self.dataset[ind][1])

In [None]:
class IMDBDataset_NEW(torch.utils.data.Dataset):
    def __init__(self, filename='datasets/aclimdb_train.csv', rm_puncs=True, tokenizer=None, use_deps=False, max_K=None, bert_mode=False):
        # load data from file
        dataset_raw = pd.read_csv(filename)
        self.dataset = []
        self.bert_mode = bert_mode
        
        # define tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') if tokenizer == None else tokenizer
        num_added_toks = self.tokenizer.add_tokens(['<pad>'])
        self.use_deps = use_deps
        self.K = max_K

        if use_deps:
            assert max_K != None, 'Please enter a valid K!'

        if not os.path.exists(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl'):
            # convert and save + save the sentence structures as well
            for data in dataset_raw.itertuples():
                text, label = self.tokenizer.encode(clean_tweets(data[1], rm_puncs)), data[2]

                if len(text) > 256:
                    text = text[:256]
                else:
                    while len(text) < 256:
                        text.append(self.tokenizer.encode('<pad>')[1])

                d_text = self.tokenizer.decode(text)
                self.dataset.append([
                    text,
                    dependency_tree(d_text[1 : len(d_text) - 1]),
                    1 if label == 'positive' else 0
                ])

                # assert len(self.dataset[-1][0]) == 256
                print('[%d/%d]                ' %(data[0] + 1, len(dataset_raw)), end='\r', flush=True)  
            pickle.dump(self.dataset, open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl', 'wb'))
        else:
            self.dataset = pickle.load(open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '.pkl', 'rb'))
            # self.create_cache(rm_puncs)
        
    def __len__(self):
        return len(self.dataset)

    def create_cache(self, rm_puncs):
        # lol this function creates a cache from another cache :| 
        self.new_dataset = []
        if not os.path.exists(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl'):
            for ind in range(len(self.dataset)):
                sent = []
                dependencies = [ [self.tokenizer.encode('<sos>')[1] for i in range(self.K)] ]

                for sentence in self.dataset[ind][1]:
                    # loop on words
                    for word, deps in sentence:
                        pattern = r"\(\'\S+\'" # \'\S+\'\))"
                        words = re.findall(pattern, deps)

                        sent.append(word)
                        a = [w[2:-1] for w in words]
                        if a != []:
                            deps = self.tokenizer.encode(a)[1:]
                        else:
                            deps = [ self.tokenizer.encode('<pad>')[1] for i in range(self.K) ]    
                        
                        if len(deps) > self.K:
                            deps = deps[:self.K]
                        else:
                            while len(deps) < self.K:
                                deps.append(self.tokenizer.encode('<pad>')[1])
                        
                        dependencies.append(deps)

                dependencies.append([self.tokenizer.encode('<pad>')[1] for i in range(self.K)])
                # convert to vector
                sent = self.tokenizer.encode(sent)
                if len(sent) > 256:
                    sent = sent[:256]
                    dependencies = dependencies[:256]
                else:
                    while len(sent) < 256:
                        sent.append(self.tokenizer.encode('<pad>')[1])
                        dependencies.append( [ self.tokenizer.encode('<pad>')[1] for i in range(self.K) ] )

                self.new_dataset.append( (sent, self.dataset[ind][2], dependencies) )
                if ind % 1000 == 0:
                    print('[%d/%d]        ' %(ind + 1, len(self.dataset)))  
            self.dataset = self.new_dataset
            pickle.dump(self.dataset, open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl', 'wb'))
        else:
            print('loading from cache')
            self.dataset = pickle.load(open(base_dir+'cache/imdb_data_' + str(rm_puncs) + '_cache' + '.pkl', 'rb'))

    def __getitem__(self, ind):
        if not self.bert_mode:
            return torch.tensor(self.dataset[ind][0], dtype=torch.float32), self.dataset[ind][2], torch.tensor(self.dataset[ind][1])
        else:
            encoded_sent = [self.dataset[ind][0]]
            input_ids = pad_sequences(encoded_sent, maxlen=512, dtype="long", truncating="post", padding="post")
            a = self.tokenizer.encode('<pad>')[1]
            seq_mask = [float(i!=a) for i in encoded_sent[0]]
            # Convert to tensors.
            while a in encoded_sent[0]:
                encoded_sent[0][encoded_sent[0].index(a)] = 0

            for ind, i in enumerate(encoded_sent[0]):   
                s = ',.!?'             
                if self.tokenizer.decode([i]) in s:
                    encoded_sent[0][ind] = dataset.tokenizer.encode( s[random.randint(0, 3)] )[1]

            prediction_inputs = torch.tensor(encoded_sent[0])
            prediction_masks = torch.tensor(seq_mask)

            return prediction_inputs, prediction_masks, torch.tensor(self.dataset[ind][2])

In [None]:
punc = True

In [None]:
# # test classes

# dataset = IMDBDataset(base_dir+'datasets/IMDB Dataset.csv', punc, use_deps=True, max_K=5)
# train_size = int(0.5 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# valid_size = int(0.2 * len(train_dataset))
# valid_dataset, _ = torch.utils.data.random_split(train_dataset, [valid_size, train_size - valid_size])
# del _

# # convert data to batches
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128)
# valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128)

# print('Train = %d (%d) | Test = %d (%d) | Validation = %d (%d)' %(len(train_dataset), len(train_loader), len(test_dataset), len(test_loader), len(valid_dataset), len(valid_loader)))

In [None]:
dataset = IMDBDataset_NEW(base_dir+'datasets/aclimdb.csv', punc, use_deps=True, max_K=5, bert_mode=True)

In [None]:
train_dataset = IMDBDataset_NEW(base_dir+'datasets/aclimdb_train.csv', punc, use_deps=True, max_K=5, bert_mode=True)

In [None]:
test_dataset = IMDBDataset_NEW(base_dir+'datasets/aclimdb_test.csv', punc, use_deps=True, max_K=5, bert_mode=True)

In [None]:
# test classes

# train_size = int(0.5 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_size = len(train_dataset)
valid_size = int(0.2 * len(train_dataset))
valid_dataset, _ = torch.utils.data.random_split(train_dataset, [valid_size, train_size - valid_size])
del _

# convert data to batches
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32)

print('Train = %d (%d) | Test = %d (%d) | Validation = %d (%d)' %(len(train_dataset), len(train_loader), len(test_dataset), len(test_loader), len(valid_dataset), len(valid_loader)))

dataset[0]

# BiGRU Classifier Here
can be used with the same training code as the "+Attention" one

In [None]:
class biGRU_classifier(nn.Module):
    def __init__(self, vocab_size, pad_idx, vocab):
        super(biGRU_classifier, self).__init__()
        
        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(100, 256, 1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        x = self.emb(x.long())
        _, x = self.bi_GRU(x)
        x = torch.cat((x[0], x[1]), dim=-1)
        
        return self.classifier(x)
    
    def forward_grad(self, x):
        x = self.emb(x.long())
        
        _, x = self.bi_GRU(x)
        x = torch.cat((x[0], x[1]), dim=-1)
        
        return self.classifier(x)

# Create the BiGRU+Attention Classifier 

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, emb_dim, device):
        super(AttentionLayer, self).__init__()
    
        # define attentiton layer here
        self.att_layer = torch.randn((emb_dim, 1), requires_grad=True).to(device)
    
    def forward(self, x):
        # x.shape = [batch, seq_len, emb_size]
        x_p = torch.tanh(x)
        x_p = torch.bmm(x_p, self.att_layer.repeat((x_p.shape[0], 1, 1)))
        x_p = torch.functional.F.softmax(x_p, dim=-2)
        return x_p, torch.tanh(torch.bmm(x.transpose(1, 2), x_p).squeeze(-1))
    
class biGRUAttn_classifier(nn.Module):
    def __init__(self, vocab_size, pad_idx, vocab):
        super(biGRUAttn_classifier, self).__init__()
        
        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(100, 256, 1, batch_first=True, bidirectional=True)
        self.attention = AttentionLayer(512, device)
        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        x = self.emb(x.long())
        outputs, hidden = self.bi_GRU(x)
        hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
        
        # pass to the attention module
        scores, context = self.attention(outputs)
        
        return scores, self.classifier(context)

In [None]:
def test_on_data(data_loader, model):
    model.eval()
    true_labels = []
    pred_labels = []
    
    for i, data in enumerate(data_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels, _ = data

        # forward + backward + optimize
        _, outputs = model(inputs.to(device))
        # outputs = model(inputs.to(device))
        
        true_labels.extend(labels.tolist())
        pred_labels.extend(torch.argmax(outputs, dim=-1).tolist())
        
    # get confusion matrix
    tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel()
    return (tp + tn) / (tn + fp + fn + tp)

In [None]:
model_name = 'biGRUAttn_model_newIMDB_False.pt'

if not os.path.exists(base_dir+'cache/' + model_name):
    vocab_size = len(dataset.tokenizer.get_vocab())
    # net = biGRU_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    net = biGRUAttn_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
    print('Start training ...')
    for epoch in range(20):
        net.train()

        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, _ = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            _, outputs = net(inputs.to(device))
            # outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
        # test on data
        with torch.no_grad():
            print('[%d/%d] Validation = %0.4f | Test = %0.4f' %(
                epoch + 1,
                30,
                test_on_data(valid_loader, net),
                test_on_data(test_loader, net)
            ))

    print('Finished Training')
    torch.save(net.state_dict(), base_dir+'cache/' + model_name)
else:
    vocab_size = len(dataset.tokenizer.get_vocab())
    # net = biGRU_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).float().to(device)
    net = biGRUAttn_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    net.load_state_dict(torch.load(base_dir+'cache/' + model_name))
    print('Model Loaded')
    print(test_on_data(test_loader, net))
    print(test_on_data(train_loader, net))

In [None]:
def test_on_data_gru(data_loader, model):
    model.eval()
    true_labels = []
    pred_labels = []
    
    for i, data in enumerate(data_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels, _ = data

        # forward + backward + optimize
#         _, outputs = model(inputs.to(device))
        outputs = model(inputs.to(device))
        
        true_labels.extend(labels.tolist())
        pred_labels.extend(torch.argmax(outputs, dim=-1).tolist())
        
    # get confusion matrix
    tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel()
    return (tp + tn) / (tn + fp + fn + tp)

In [None]:
model_name = 'biGRU_model_newIMDB_False.pt'

if not os.path.exists(base_dir+'cache/' + model_name):
    vocab_size = len(dataset.tokenizer.get_vocab())
    net = biGRU_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
#     net = biGRUAttn_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
    print('Start training ...')
    for epoch in range(30):
        net.train()

        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, _ = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
#             _, outputs = net(inputs.to(device))
            outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
        # test on data
        with torch.no_grad():
            print('[%d/%d] Validation = %0.4f | Test = %0.4f' %(
                epoch + 1,
                30,
                test_on_data_gru(valid_loader, net),
                test_on_data_gru(test_loader, net)
            ))

    print('Finished Training')
    torch.save(net.state_dict(), base_dir+'cache/' + model_name)
else:
    vocab_size = len(dataset.tokenizer.get_vocab())
    net = biGRU_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).float().to(device)
#     net = biGRUAttn_classifier(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    net.load_state_dict(torch.load(base_dir+'cache/' + model_name))
    print('Model Loaded')
    print(test_on_data_gru(test_loader, net))
    print(test_on_data_gru(train_loader, net))

# The Proposed Model

PunClassifier_1: uses BiGRU + Attention for the sentence embedding and a simple BiGRU for the structure embedding + a neural network for combining the two embeddings



In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, emb_dim, device):
        super(AttentionLayer, self).__init__()
    
        # define attentiton layer here
        self.att_layer = torch.randn((emb_dim, 1), requires_grad=True).to(device)
    
    def forward(self, x):
        # x.shape = [batch, seq_len, emb_size]
        x_p = torch.tanh(x)
        x_p = torch.bmm(x_p, self.att_layer.repeat((x_p.shape[0], 1, 1)))
        x_p = torch.functional.F.softmax(x_p, dim=-2)
        return x_p, torch.tanh(torch.bmm(x.transpose(1, 2), x_p).squeeze(-1))

class BaseEmbedding_1(nn.Module):
    def __init__(self, vocab_size, pad_idx):
        super(BaseEmbedding_1, self).__init__()

        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(100, 10, 1, batch_first=True, bidirectional=True)

    def forward(self, deps):
        # deps -> [batch_size, max_seq_len, K] -> the focus us on K
        deps_emb = torch.empty((deps.shape[0], deps.shape[1], 20)).to(device)

        for i in range(deps.size(1)):
            x = self.emb(deps[:, i, :].long())
            outputs, hidden = self.bi_GRU(x)
            # hidden -> [batch_size, emb_size (20)]
            hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
            deps_emb[:, i, :] = hidden

        return deps_emb


class BaseClassifier_1(nn.Module):
    def __init__(self, vocab_size, pad_idx, vocab):
        super(BaseClassifier_1, self).__init__()
        
        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(120, 256, 1, batch_first=True, bidirectional=True)
        self.attention = AttentionLayer(512, device)
        self.base_embedding = BaseEmbedding_1(vocab_size, pad_idx)

        # define layer for the structure embedding

        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x, deps):
        # x -> [batch_size, max_seq_len]        | deps -> [batch_size, max_seq_len, K]
        x = self.emb(x.long())

        # add the deps embedding here and append to the X embeddings :D
        deps = self.base_embedding(deps)

        # concat them :D
        x = torch.cat((x, deps), dim=-1)

        outputs, hidden = self.bi_GRU(x)
        hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
        
        # pass to the attention module
        scores, context = self.attention(outputs)
        
        return scores, self.classifier(context)

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, emb_dim, device):
        super(AttentionLayer, self).__init__()
    
        # define attentiton layer here
        self.att_layer = torch.randn((emb_dim, 1), requires_grad=True).to(device)
    
    def forward(self, x):
        # x.shape = [batch, seq_len, emb_size]
        x_p = torch.tanh(x)
        x_p = torch.bmm(x_p, self.att_layer.repeat((x_p.shape[0], 1, 1)))
        x_p = torch.functional.F.softmax(x_p, dim=-2)
        return x_p, torch.tanh(torch.bmm(x.transpose(1, 2), x_p).squeeze(-1))

class BaseEmbedding_1(nn.Module):
    def __init__(self, vocab_size, pad_idx):
        super(BaseEmbedding_1, self).__init__()

        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(100, 50, 1, batch_first=True, bidirectional=True)

    def forward(self, deps):
        # deps -> [batch_size, max_seq_len, K] -> the focus us on K
        deps_emb = torch.empty((deps.shape[0], deps.shape[1], 100)).to(device)

        for i in range(deps.size(1)):
            x = self.emb(deps[:, i, :].long())
            outputs, hidden = self.bi_GRU(x)
            # hidden -> [batch_size, emb_size (20)]
            hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
            deps_emb[:, i, :] = hidden

        return deps_emb


class BaseClassifier_2(nn.Module):
    def __init__(self, vocab_size, pad_idx, vocab):
        super(BaseClassifier_2, self).__init__()
        
        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)

        self.bi_GRU = nn.GRU(100, 256, 1, batch_first=True, bidirectional=True)
        self.attention = AttentionLayer(512, device)
        self.base_embedding = BaseEmbedding_1(vocab_size, pad_idx)
        self.emb_nn = nn.Linear(200, 100)

        # define layer for the structure embedding

        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x, deps):
        # x -> [batch_size, max_seq_len]        | deps -> [batch_size, max_seq_len, K]
        x = self.emb(x.long())

        # add the deps embedding here and append to the X embeddings :D
        deps = self.base_embedding(deps)

        # concat them :D
        x = torch.cat((x, deps), dim=-1)
        x = self.emb_nn(x)

        outputs, hidden = self.bi_GRU(x)
        hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
        
        # pass to the attention module
        scores, context = self.attention(outputs)
        
        return scores, self.classifier(context)

In [None]:
def test_on_data(data_loader, model):
    model.eval()
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, deps = data

            # forward + backward + optimize
            _, outputs = net(inputs.to(device), deps.to(device))
            
            true_labels.extend(labels.tolist())
            pred_labels.extend(torch.argmax(outputs, dim=-1).tolist())

    # get confusion matrix
    tn, fp, fn, tp = confusion_matrix(true_labels, pred_labels).ravel()
    return (tp + tn) / (tn + fp + fn + tp)

In [None]:
model_name = 'SEDT_newIMDB_True.pt'

if not os.path.exists(base_dir+'cache/' + model_name):
    vocab_size = len(dataset.tokenizer.get_vocab())
    net = BaseClassifier_1(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
    print('Start training ...')
    for epoch in range(10):
        net.train()

        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, deps = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            _, outputs = net(inputs.to(device), deps.to(device))
            # outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
        # test on data
        with torch.no_grad():
            print('[%d/%d] Validation = %0.4f | Test = %0.4f' %(
                epoch + 1,
                30,
                test_on_data(valid_loader, net),
                test_on_data(test_loader, net)
            ))

    print('Finished Training')
    torch.save(net.state_dict(), base_dir+'cache/' + model_name)
else:
    vocab_size = len(dataset.tokenizer.get_vocab())
    net = BaseClassifier_1(vocab_size, dataset.tokenizer.encode('<pad>')[1], dataset.tokenizer).to(device)
    net.load_state_dict(torch.load(base_dir+'cache/' + model_name))
    print('Model Loaded')
    print(test_on_data(test_loader, net))

# BERT Classifier

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, emb_dim, device):
        super(AttentionLayer, self).__init__()
    
        # define attentiton layer here
        self.att_layer = torch.randn((emb_dim, 1), requires_grad=True).to(device)
    
    def forward(self, x):
        # x.shape = [batch, seq_len, emb_size]
        x_p = torch.tanh(x)
        x_p = torch.bmm(x_p, self.att_layer.repeat((x_p.shape[0], 1, 1)))
        x_p = torch.functional.F.softmax(x_p, dim=-2)
        return x_p, torch.tanh(torch.bmm(x.transpose(1, 2), x_p).squeeze(-1))

class BaseEmbedding_1(nn.Module):
    def __init__(self, vocab_size, pad_idx):
        super(BaseEmbedding_1, self).__init__()

        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)
        self.bi_GRU = nn.GRU(100, 50, 1, batch_first=True, bidirectional=True)

    def forward(self, deps):
        # deps -> [batch_size, max_seq_len, K] -> the focus us on K
        deps_emb = torch.empty((deps.shape[0], deps.shape[1], 100)).to(device)

        for i in range(deps.size(1)):
            x = self.emb(deps[:, i, :].long())
            outputs, hidden = self.bi_GRU(x)
            # hidden -> [batch_size, emb_size (20)]
            hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
            deps_emb[:, i, :] = hidden

        return deps_emb


class BERT_Classifier(nn.Module):
    def __init__(self, vocab_size, pad_idx, vocab, bert_model):
        super(BaseClassifier_2, self).__init__()
        
        # load glove and initialize
        if not os.path.exists(base_dir+'cache/biGRU_embeddings.pt'):
            glove_dict = {}

            with open(base_dir+'datasets/glove.6B.100d.txt', 'r', encoding='utf-8') as fin:
                for line in fin.readlines():
                    line = line.split(' ')
                    glove_dict[line[0]] = torch.tensor(list(map(float, line[1:])))

            # create embedding matrix
            emb_mat = torch.randn((vocab_size, 100))
            for word, w_emb in glove_dict.items():
                emb_mat[vocab.encode(word)[1]] = w_emb
            torch.save(emb_mat, base_dir+'cache/biGRU_embeddings.pt')
        else:
            emb_mat = torch.load(base_dir+'cache/biGRU_embeddings.pt')
        
        # define layers
        self.emb = nn.Embedding(vocab_size, 100, padding_idx=pad_idx)
        self.emb.weight.data.copy_(emb_mat)

        self.bert = bert_model

        # define layer for the structure embedding

        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x, deps, input_mask):
        # x -> [batch_size, max_seq_len]        | deps -> [batch_size, max_seq_len, K]
        x = self.bert(x, 
                    attention_mask=input_mask, 
                    output_hidden_states=True)
        x = x.last_hidden_state

        # add the deps embedding here and append to the X embeddings :D
        deps = self.base_embedding(deps)

        # concat them :D
        x = torch.cat((x, deps), dim=-1)
        x = self.emb_nn(x)

        outputs, hidden = self.bi_GRU(x)
        hidden = torch.cat((hidden[0], hidden[1]), dim=-1)
        
        # pass to the attention module
        scores, context = self.attention(outputs)
        
        return scores, self.classifier(context)

In [None]:
from transformers import DistilBertForSequenceClassification, AdamW, BertConfig
PRE_TRAINED_MODEL_NAME = 'distilbert-base-uncased'
# Load BertForSequenceClassification, the pretrained BERT model with a single 
# linear classification layer on top. 
model = DistilBertForSequenceClassification.from_pretrained(
    PRE_TRAINED_MODEL_NAME, # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

# Tell pytorch to run this model on the GPU.
model.to(device)

In [None]:
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )
from transformers import get_linear_schedule_with_warmup

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_loader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [None]:
import numpy as np

# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
loss_values = []

for epoch_i in range(0, epochs):

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')
    total_loss = 0
    model.train()
    for step, batch in enumerate(train_loader):

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].unsqueeze(1).to(device)

        model.zero_grad()        

        outputs = model(b_input_ids, 
                    attention_mask=b_input_mask, 
                    labels=b_labels)

        loss = outputs[0]
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_loss / len(train_loader)            
    loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))        

    print("")
    print("Running Validation...")

    model.eval()

    # Tracking variables 
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    for batch in test_loader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():        
            outputs = model(b_input_ids, 
                            attention_mask=b_input_mask)
        logits = outputs[0]
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        eval_accuracy += tmp_eval_accuracy
        nb_eval_steps += 1

    print("  Accuracy: {0:.4f}".format(eval_accuracy/nb_eval_steps))

print("")
print("Training complete!")

In [None]:
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for batch in test_loader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch
    with torch.no_grad():        
        outputs = model(b_input_ids, 
                        attention_mask=b_input_mask)
    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    
    tmp_eval_accuracy = flat_accuracy(logits, label_ids)
    eval_accuracy += tmp_eval_accuracy
    nb_eval_steps += 1

print("  Accuracy: {0:.4f}".format(eval_accuracy/nb_eval_steps))

In [None]:
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval()

    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in data_loader:
            input_ids = d[0].to(device).long()
            attention_mask = d[2].to(device).long()
            targets = d[1].to(device).long()

            outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1)

            loss = loss_fn(outputs, targets)

            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)

In [None]:
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup

EPOCHS = 10

optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)

loss_fn = nn.CrossEntropyLoss().to(device)