# RNNs with Law2Vec embeddings

## Importing data, pre-trained embeddings

In [1]:
import copy
from itertools import product
from numpy import isnan
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.data as data
import torchtext.vocab as vocab
import sys
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../data_pipeline/')
import preprocessing as pre
from training import TrainingModule

SEED = 1312
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
train_data, test_data, val_data, TEXT, LABEL = pre.get_data(
    'train_small.csv', 'val_small.csv', 'test_small.csv')

Connected!


In [3]:
USE_CUDA = torch.cuda.is_available()

vectors = vocab.Vectors('../embeds/Law2Vec.100d.txt') # Law2Vec available from https://archive.org/details/Law2Vec

TEXT.build_vocab(train_data, vectors=vectors,
                 unk_init = torch.Tensor.normal_)
LABEL.build_vocab(train_data)

BATCH_SIZE = 5

train_it, test_it, val_it = data.BucketIterator.splits(
    (train_data, test_data, val_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.alj_text),
    sort_within_batch=True,
    device = torch.device('cuda' if USE_CUDA else 'cpu'))

## Checking pretrained vectors have been applied

In [4]:
vectors['medicare']

tensor([ 0.6415, -0.5367, -0.3537, -0.0634, -0.1798,  0.0626, -0.1836, -0.2705,
         0.2504,  0.5061,  0.4746, -0.2351, -0.0465,  0.3184,  0.8974,  0.0470,
        -0.2594,  0.3485, -0.3356,  0.1163,  0.2207,  0.2707,  0.4748,  0.1122,
        -0.1188, -0.0790,  0.4377, -0.4711,  0.1401, -0.0234, -0.2009, -0.2143,
         0.1335, -0.4407,  0.4077, -0.0634,  0.5104,  0.1820, -0.4729, -0.1758,
         0.6194,  0.5708, -0.3034, -0.3658,  0.1609,  0.0753, -0.2024, -0.1472,
         0.0665,  0.1823,  0.3091, -0.0913,  0.2495,  0.0777, -0.1873, -0.5850,
        -0.3243,  0.1540, -0.5094,  0.6227,  0.1163, -0.6202, -0.4416, -0.3509,
        -0.5760, -0.4837, -0.6283,  0.0938,  0.3528, -0.0674, -0.7097, -0.2053,
        -0.6007, -0.1306,  0.0146, -0.0830,  0.5486, -0.2328, -0.3193,  0.1496,
        -0.1635,  0.0755, -0.2594, -0.0317,  0.1249, -0.5599,  0.0722, -0.0369,
         0.3139,  0.0102, -0.3353,  0.1142, -0.1163,  0.1505,  0.0952,  0.0206,
        -0.0733, -0.4851,  0.4995,  0.04

In [5]:
TEXT.vocab.vectors[TEXT.vocab.stoi['medicare']]

tensor([ 0.6415, -0.5367, -0.3537, -0.0634, -0.1798,  0.0626, -0.1836, -0.2705,
         0.2504,  0.5061,  0.4746, -0.2351, -0.0465,  0.3184,  0.8974,  0.0470,
        -0.2594,  0.3485, -0.3356,  0.1163,  0.2207,  0.2707,  0.4748,  0.1122,
        -0.1188, -0.0790,  0.4377, -0.4711,  0.1401, -0.0234, -0.2009, -0.2143,
         0.1335, -0.4407,  0.4077, -0.0634,  0.5104,  0.1820, -0.4729, -0.1758,
         0.6194,  0.5708, -0.3034, -0.3658,  0.1609,  0.0753, -0.2024, -0.1472,
         0.0665,  0.1823,  0.3091, -0.0913,  0.2495,  0.0777, -0.1873, -0.5850,
        -0.3243,  0.1540, -0.5094,  0.6227,  0.1163, -0.6202, -0.4416, -0.3509,
        -0.5760, -0.4837, -0.6283,  0.0938,  0.3528, -0.0674, -0.7097, -0.2053,
        -0.6007, -0.1306,  0.0146, -0.0830,  0.5486, -0.2328, -0.3193,  0.1496,
        -0.1635,  0.0755, -0.2594, -0.0317,  0.1249, -0.5599,  0.0722, -0.0369,
         0.3139,  0.0102, -0.3353,  0.1142, -0.1163,  0.1505,  0.0952,  0.0206,
        -0.0733, -0.4851,  0.4995,  0.04

## RNN Model

In [6]:
class RNNPtEmbeds(nn.Module):
    
    def __init__(self, rnn_type, input_size, embedding_size,
                 hidden_size, output_size, num_layers,
                 dropout, bidirectional, padding_idx):
    
        super().__init__()
        self.embedding = nn.Embedding\
                           .from_pretrained(TEXT.vocab.vectors)

        self.rnn = getattr(nn, rnn_type.upper())\
                          (embedding_size, hidden_size, num_layers,
                           dropout=(dropout if num_layers > 1 else 0),
                           bidirectional=bidirectional)
        
        self.leakyrelu = nn.LeakyReLU()
        self.dropout = nn.Dropout(dropout)
        
        linear_inp = (hidden_size * 2 if bidirectional else hidden_size)
        self.linear = nn.Linear(linear_inp, output_size)
             
    def forward(self, input):
        embed = self.embedding(input)
        rnn_out, hidden = self.rnn(embed)
        rnn_out = rnn_out[-1]
        rnn_out = self.leakyrelu(rnn_out)
        rnn_out = self.dropout(rnn_out)
        linear_out = self.linear(rnn_out)
        return linear_out

## Training models

In [7]:
# Store training results
df = pd.DataFrame(columns=['architecture', 'model_type', 'embeddings',
                           'hidden', 'num_layers', 'dropouts',
                           'bidirectional', 'learning_rate', 'epochs',
                           'dev_acc', 'dev_prec', 'dev_recall',
                           'metric'])

# Model architecture parameters
RNN_TYPES = ['RNN', 'LSTM']
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_SIZE = TEXT.vocab.vectors.size(1)
HIDDEN_SIZES = [1/3, 2/3]
OUTPUT_SIZE = 1
NUM_LAYERS = [1, 2]
DROPOUTS = [0.5, 0.75]
BIDIRECTIONALS = [False, True]
PADDING_IDX = TEXT.vocab.stoi[TEXT.pad_token]

# Model training hyperparameters
LEARNING_RATE = [0.01, 0.0001]
train_len = 0
train_pos = 0
for batch in train_it:
    train_len += len(batch.decision_binary)
    train_pos += batch.decision_binary.sum().item()
POS_WEIGHT = torch.tensor([(train_len - train_pos) / train_pos])
if USE_CUDA:
    POS_WEIGHT = POS_WEIGHT.cuda()
EPOCHS = 10

# Iterator over various model parameters
param_iter = product (RNN_TYPES, HIDDEN_SIZES, NUM_LAYERS,
                      DROPOUTS, BIDIRECTIONALS, LEARNING_RATE)

# Magic loop
best_acc = (None, None)
best_rec = (None, None)
best_prec = (None, None)
for i, (rnn_type, hidden_size, num_layers, dropout, bidirectional,\
    lr) in enumerate(param_iter):
    print(f'Architecture #{i}\n' + '-' * 20)
    hidden_dim = int(hidden_size * EMBEDDING_SIZE)
    model = RNNPtEmbeds(rnn_type, INPUT_DIM, EMBEDDING_SIZE,
                        hidden_dim, OUTPUT_SIZE, num_layers,
                        dropout, bidirectional, PADDING_IDX)
    
    tm = TrainingModule(model, lr, POS_WEIGHT, USE_CUDA, EPOCHS)
    
    best_models = tm.train_model(train_it, val_it)
    
    for metric, best_model in best_models.items():
        row = [i, rnn_type, 'Law2Vec', hidden_size, num_layers,
               dropout, bidirectional, lr, EPOCHS,
               best_model.accuracy, best_model.precision,
               best_model.recall, metric]
        df.loc[len(df)] = row
        if best_acc[0] is None or isnan(best_acc[1]) or\
           best_model.accuracy > best_acc[1]:
            best_acc = (copy.deepcopy(best_model.model), best_model.accuracy)
        if best_rec[0] is None or isnan(best_rec[1]) or\
           best_model.recall > best_rec[1]:
            best_rec = (copy.deepcopy(best_model.model), best_model.recall)
        if best_prec[0] is None or isnan(best_prec[1]) or\
           best_model.precision > best_prec[1]:
            best_prec = (copy.deepcopy(best_model.model), best_model.precision)
    
    print('-' * 20 + '\n')


Architecture #0
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1642
Epoch 1: Dev Accuracy: 0.3051; Dev Precision: 0.1489; Dev Recall: 0.8750; Dev Loss:0.1704
Epoch 2: Dev Accuracy: 0.2542; Dev Precision: 0.1087; Dev Recall: 0.6250; Dev Loss:0.1748
Epoch 3: Dev Accuracy: 0.2203; Dev Precision: 0.0870; Dev Recall: 0.5000; Dev Loss:0.2065
Epoch 4: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1480
Epoch 5: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1771
Epoch 6: Dev Accuracy: 0.5763; Dev Precision: 0.1852; Dev Recall: 0.6250; Dev Loss:0.1625
Epoch 7: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1448
Epoch 8: Dev Accuracy: 0.5932; Dev Precision: 0.1000; Dev Recall: 0.2500; Dev Loss:0.1578
Epoch 9: Dev Accuracy: 0.3220; Dev Precision: 0.1364; Dev Recall: 0.7500; Dev Loss:0.1650
--------------------

Architecture #1
--------------------
Epoch 0

Epoch 6: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1738
Epoch 7: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1557
Epoch 8: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1767
Epoch 9: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1543
--------------------

Architecture #9
--------------------
Epoch 0: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1657
Epoch 1: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1646
Epoch 2: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1643
Epoch 3: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1638
Epoch 4: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1635
Epoch 5: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1625
Epoch 6: Dev Accuracy: 0.1695; Dev Precision: 0

Epoch 2: Dev Accuracy: 0.7966; Dev Precision: 0.1667; Dev Recall: 0.1250; Dev Loss:0.1606
Epoch 3: Dev Accuracy: 0.8644; Dev Precision: 0.5000; Dev Recall: 0.1250; Dev Loss:0.1604
Epoch 4: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1612
Epoch 5: Dev Accuracy: 0.8305; Dev Precision: 0.2500; Dev Recall: 0.1250; Dev Loss:0.1609
Epoch 6: Dev Accuracy: 0.7288; Dev Precision: 0.1000; Dev Recall: 0.1250; Dev Loss:0.1611
Epoch 7: Dev Accuracy: 0.7627; Dev Precision: 0.1250; Dev Recall: 0.1250; Dev Loss:0.1614
Epoch 8: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1611
Epoch 9: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.1620
--------------------

Architecture #18
--------------------
Epoch 0: Dev Accuracy: 0.2373; Dev Precision: 0.1509; Dev Recall: 1.0000; Dev Loss:0.1667
Epoch 1: Dev Accuracy: 0.8136; Dev Precision: 0.2000; Dev Recall: 0.1250; Dev Loss:0.1534
Epoch 2: Dev Accuracy: 0.8644; Dev Preci

Epoch 9: Dev Accuracy: 0.6949; Dev Precision: 0.0833; Dev Recall: 0.1250; Dev Loss:0.1630
--------------------

Architecture #26
--------------------
Epoch 0: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.2252
Epoch 1: Dev Accuracy: 0.5085; Dev Precision: 0.1613; Dev Recall: 0.6250; Dev Loss:0.1597
Epoch 2: Dev Accuracy: 0.3051; Dev Precision: 0.1333; Dev Recall: 0.7500; Dev Loss:0.1863
Epoch 3: Dev Accuracy: 0.4068; Dev Precision: 0.1538; Dev Recall: 0.7500; Dev Loss:0.1876
Epoch 4: Dev Accuracy: 0.2203; Dev Precision: 0.1200; Dev Recall: 0.7500; Dev Loss:0.1808
Epoch 5: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1398
Epoch 6: Dev Accuracy: 0.5763; Dev Precision: 0.1600; Dev Recall: 0.5000; Dev Loss:0.1566
Epoch 7: Dev Accuracy: 0.6780; Dev Precision: 0.2105; Dev Recall: 0.5000; Dev Loss:0.1527
Epoch 8: Dev Accuracy: 0.3898; Dev Precision: 0.1667; Dev Recall: 0.8750; Dev Loss:0.1696
Epoch 9: Dev Accuracy: 0.8305; Dev Precisio

Epoch 5: Dev Accuracy: 0.7458; Dev Precision: 0.1111; Dev Recall: 0.1250; Dev Loss:0.1603
Epoch 6: Dev Accuracy: 0.8814; Dev Precision: 1.0000; Dev Recall: 0.1250; Dev Loss:0.1568
Epoch 7: Dev Accuracy: 0.8644; Dev Precision: 0.5000; Dev Recall: 0.1250; Dev Loss:0.1587
Epoch 8: Dev Accuracy: 0.8814; Dev Precision: 1.0000; Dev Recall: 0.1250; Dev Loss:0.1450
Epoch 9: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1598
--------------------

Architecture #35
--------------------
Epoch 0: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1667
Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1664
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1663
Epoch 3: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1661
Epoch 4: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1659
Epoch 5: Dev Accuracy: 0.1356; Dev Precisio

Epoch 1: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1669
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1664
Epoch 3: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1662
Epoch 4: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1659
Epoch 5: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1658
Epoch 6: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1653
Epoch 7: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1649
Epoch 8: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1648
Epoch 9: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1645
--------------------

Architecture #44
--------------------
Epoch 0: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1595
Epoch 1: Dev Accuracy: 0.1356; Dev Precisio

Epoch 8: Dev Accuracy: 0.6949; Dev Precision: 0.0833; Dev Recall: 0.1250; Dev Loss:0.1629
Epoch 9: Dev Accuracy: 0.6949; Dev Precision: 0.0833; Dev Recall: 0.1250; Dev Loss:0.1627
--------------------

Architecture #52
--------------------
Epoch 0: Dev Accuracy: 0.6949; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1582
Epoch 1: Dev Accuracy: 0.1525; Dev Precision: 0.1379; Dev Recall: 1.0000; Dev Loss:0.1657
Epoch 2: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1658
Epoch 3: Dev Accuracy: 0.7627; Dev Precision: 0.0000; Dev Recall: 0.0000; Dev Loss:0.1607
Epoch 4: Dev Accuracy: 0.7119; Dev Precision: 0.0909; Dev Recall: 0.1250; Dev Loss:0.1588
Epoch 5: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1702
Epoch 6: Dev Accuracy: 0.1695; Dev Precision: 0.1404; Dev Recall: 1.0000; Dev Loss:0.1630
Epoch 7: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1645
Epoch 8: Dev Accuracy: 0.7458; Dev Preci

Epoch 4: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1706
Epoch 5: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1779
Epoch 6: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1663
Epoch 7: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1853
Epoch 8: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1607
Epoch 9: Dev Accuracy: 0.1356; Dev Precision: 0.1356; Dev Recall: 1.0000; Dev Loss:0.1706
--------------------

Architecture #61
--------------------
Epoch 0: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1599
Epoch 1: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1604
Epoch 2: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1607
Epoch 3: Dev Accuracy: 0.8644; Dev Precision: nan; Dev Recall: 0.0000; Dev Loss:0.1609
Epoch 4: Dev Accuracy: 0.8644; Dev Precision: nan; Dev 

## Save results of model training

In [8]:
SAVE_PREFIX = '../results/RNNLaw2Vec_'
df.to_csv(f'{SAVE_PREFIX}models.csv')
torch.save(best_acc[0], f'{SAVE_PREFIX}best_acc.pt')
torch.save(best_rec[0], f'{SAVE_PREFIX}best_rec.pt')
torch.save(best_prec[0], f'{SAVE_PREFIX}best_prec.pt')

## Confirming embeddings have not been trained

In [9]:
medicare_tensor = torch.LongTensor([TEXT.vocab.stoi['medicare']])
if USE_CUDA:
    medicare_tensor = medicare_tensor.cuda()
model.embedding(medicare_tensor)

tensor([[ 0.6415, -0.5367, -0.3537, -0.0634, -0.1798,  0.0626, -0.1836, -0.2705,
          0.2504,  0.5061,  0.4746, -0.2351, -0.0465,  0.3184,  0.8974,  0.0470,
         -0.2594,  0.3485, -0.3356,  0.1163,  0.2207,  0.2707,  0.4748,  0.1122,
         -0.1188, -0.0790,  0.4377, -0.4711,  0.1401, -0.0234, -0.2009, -0.2143,
          0.1335, -0.4407,  0.4077, -0.0634,  0.5104,  0.1820, -0.4729, -0.1758,
          0.6194,  0.5708, -0.3034, -0.3658,  0.1609,  0.0753, -0.2024, -0.1472,
          0.0665,  0.1823,  0.3091, -0.0913,  0.2495,  0.0777, -0.1873, -0.5850,
         -0.3243,  0.1540, -0.5094,  0.6227,  0.1163, -0.6202, -0.4416, -0.3509,
         -0.5760, -0.4837, -0.6283,  0.0938,  0.3528, -0.0674, -0.7097, -0.2053,
         -0.6007, -0.1306,  0.0146, -0.0830,  0.5486, -0.2328, -0.3193,  0.1496,
         -0.1635,  0.0755, -0.2594, -0.0317,  0.1249, -0.5599,  0.0722, -0.0369,
          0.3139,  0.0102, -0.3353,  0.1142, -0.1163,  0.1505,  0.0952,  0.0206,
         -0.0733, -0.4851,  