In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import transformers

import sys
sys.path.append('..')
from src.attentionmlp import AttentionMLP


In [4]:
# load tokenized data
path_dev ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_dev_tokenized.pkl'
path_train ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_train_tokenized.pkl'
path_test ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_test_tokenized.pkl'

df_train = pd.read_pickle(path_train)
df_dev = pd.read_pickle(path_dev)
df_test = pd.read_pickle(path_test)

In [5]:
df_train.head()

Unnamed: 0,ITEMID,LANGUAGEISOCODE,RESPONDENT,BRANCH,DATE,DOCNAME,IMPORTANCE,CONCLUSION,JUDGES,text,VIOLATED_ARTICLES,VIOLATED_PARAGRAPHS,VIOLATED_BULLETPOINTS,NON_VIOLATED_ARTICLES,NON_VIOLATED_PARAGRAPHS,NON_VIOLATED_BULLETPOINTS,label,input_ids,attention_mask
0,001-100005,ENG,TUR,ADMISSIBILITY,2010,SHAMSI v. TURKEY,4,Inadmissible,Françoise Tulkens;Ireneu Cabral Barreto;Kristi...,"[The applicant, Mr Maher Muhilddin Gazel Al Sh...",[],[],[],[],[],[],0,"[[[tensor(101), tensor(207), tensor(272), tens...","[[[tensor(1), tensor(1), tensor(1), tensor(1),..."
1,001-100024,ENG,ARM,CHAMBER,2010,CASE OF HOVHANNISYAN AND SHIROYAN v. ARMENIA,3,Reminder inadmissible;Violation of P1-1;Just s...,Alvina Gyulumyan;Elisabet Fura;Ineta Ziemele;J...,"[4. The applicants were born in 1976, 1973 and...",[],[],[],[],[],[],0,"[[[tensor(101), tensor(201), tensor(117), tens...","[[[tensor(1), tensor(1), tensor(1), tensor(1),..."
2,001-100026,ENG,ARM,CHAMBER,2010,CASE OF YERANOSYAN AND OTHERS v. ARMENIA,4,Violation of P1-1,Alvina Gyulumyan;Elisabet Fura;Ineta Ziemele;J...,"[4. The applicants were born in 1976, 1975, 19...",[],[],[],[],[],[],0,"[[[tensor(101), tensor(201), tensor(117), tens...","[[[tensor(1), tensor(1), tensor(1), tensor(1),..."
3,001-100029,ENG,RUS,CHAMBER,2010,CASE OF AKHMATKHANOVY v. RUSSIA,4,Violation of Art. 2 (substantive aspect);Viola...,Anatoly Kovler;Christos Rozakis;Dean Spielmann...,"[4. The applicants are:, 1) Ms Bilat Akhmatkha...","[13, 2, 3, 5]",[],[],[],[],[],1,"[[[tensor(101), tensor(201), tensor(117), tens...","[[[tensor(1), tensor(1), tensor(1), tensor(1),..."
4,001-100038,ENG,NLD,CHAMBER,2010,CASE OF A. v. THE NETHERLANDS,3,Violation of Art. 3 (in case of expulsion to L...,Alvina Gyulumyan;Corneliu Bîrsan;Egbert Myjer;...,[7. The applicant was born in 1972 and lives i...,[3],[],[],[13],[],[],1,"[[[tensor(101), tensor(204), tensor(117), tens...","[[[tensor(1), tensor(1), tensor(1), tensor(1),..."


In [6]:
documents = df_train[['input_ids', 'attention_mask', 'label']]

In [14]:
# convert the series into a list
input_ids = documents.input_ids.tolist()
attention_mask = documents.attention_mask.tolist()
labels = documents.label.tolist()

In [15]:
input_ids = [torch.stack(i) for i in input_ids]
attention_mask = [torch.stack(i) for i in attention_mask]


In [16]:
input_ids =[torch.squeeze(i, dim=1) for i in input_ids]
attention_mask =[torch.squeeze(i, dim=1) for i in attention_mask]

In [17]:
lengths =[i.size(0) for i in input_ids]

In [18]:
def collate_fn(data):
    """
       data: is a list of tuples with (input_ids, attention mask, label, length)
    """
    input_ids = [i[0] for i in data]
    attention_mask = [i[1] for i in data]
    labels = [i[2] for i in data]
    lengths = [i[3] for i in data]


    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)
    max_length = torch.max(lengths)
    # pad the input_ids and attention_mask so that they have the same length [max_length, 512]
    for i in range(len(input_ids)):
        pad = torch.zeros((max_length - lengths[i],512), dtype=torch.long)
        input_ids[i] = torch.cat((input_ids[i], pad), dim=0, )
        attention_mask[i] = torch.cat((attention_mask[i], pad), dim=0)

    return torch.stack(input_ids), torch.stack(attention_mask), labels, lengths

In [19]:
class ECHRDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_mask[idx], self.labels[idx], self.input_ids[idx].size(0) # last one is the length of the input_ids, used for padding

In [20]:
# create a dataset
dataset = ECHRDataset(input_ids, attention_mask, labels)

In [47]:
# create a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [48]:
x=next(iter(dataloader))

len(x)

4

In [49]:
x[0].shape

torch.Size([2, 3, 512])

In [50]:
# try to run the model
from transformers import BertModel
bert = BertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')

In [51]:
def make_mask(data, lengths, batch_first=True):
    if batch_first:
        max_length = data.size(1)
        batch_size = data.size(0)
    else:
        max_length = data.size(0)
        batch_size = data.size(1)
    mask = torch.zeros((max_length, batch_size), dtype=torch.bool)

    for i, l in enumerate(lengths):
        mask[i, :l] = 1.

    return mask


class Hierbert(nn.Module):
    def __init__(self, bert, hidden_sizes ):
        super(Hierbert, self).__init__()
        self.bert = bert
        self.attention_mlp = AttentionMLP(768, hidden_sizes)

    def forward(self, input_ids, attention_masks, lengths, bert_require_grad=True):

        max_l = input_ids.size(1)# inputs are already padded
        bert_output=[]

        if bert_require_grad:
            self.bert.train()
        else:
            self.bert.eval()

        for i in range(max_l):


            bert_output.append( self.bert(input_ids[:,i], attention_masks[:,i]).pooler_output )

        bert_output = torch.stack(bert_output)

        print(bert_output.shape)
        sentence_mask = make_mask(input_ids, lengths) 

        return self.attention_mlp(bert_output.permute(1,0,2), sentence_mask)