In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from src.attentionmlp import AttentionMLP

In [None]:
# load tokenized data
path_dev ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_dev_tokenized.pkl'
path_train ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_train_tokenized.pkl'
path_test ='../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_test_tokenized.pkl'

df_train = pd.read_pickle(path_train)
df_dev = pd.read_pickle(path_dev)
df_test = pd.read_pickle(path_test)

In [None]:
df_train.head()

In [None]:
documents = df_train[['input_ids', 'attention_mask', 'label']]

In [None]:
# convert the series into a list
input_ids = documents.input_ids.tolist()
attention_mask = documents.attention_mask.tolist()
labels = documents.label.tolist()

In [None]:
input_ids = [torch.stack(i) for i in input_ids]
attention_mask = [torch.stack(i) for i in attention_mask]


In [None]:
input_ids =[torch.squeeze(i, dim=1) for i in input_ids]
attention_mask =[torch.squeeze(i, dim=1) for i in attention_mask]

In [None]:
attention_mask[0].size(0)

In [None]:
lengths =[i.size(0) for i in input_ids]

In [163]:
def collate_fn(data):
    """
       data: is a list of tuples with (input_ids, attention mask, label, length)
    """
    input_ids = [i[0] for i in data]
    attention_mask = [i[1] for i in data]
    labels = [i[2] for i in data]
    lengths = [i[3] for i in data]


    labels = torch.tensor(labels)
    lengths = torch.tensor(lengths)
    max_length = torch.max(lengths)
    # pad the input_ids and attention_mask so that they have the same length [max_length, 512]
    for i in range(len(input_ids)):
        pad = torch.zeros((max_length - lengths[i],512), dtype=torch.long)
        input_ids[i] = torch.cat((input_ids[i], pad), dim=0, )
        attention_mask[i] = torch.cat((attention_mask[i], pad), dim=0)

    return torch.stack(input_ids), torch.stack(attention_mask), labels, lengths

In [164]:
class ECHRDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_mask[idx], self.labels[idx], self.input_ids[idx].size(0) # last one is the length of the input_ids, used for padding

In [165]:
# create a dataset
dataset = ECHRDataset(input_ids, attention_mask, labels)

In [173]:
# create a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [174]:
x=next(iter(dataloader))

len(x)

4

In [175]:
x[0].shape

torch.Size([1, 8, 512])

In [169]:
# try to run the model
from transformers import BertModel
bert = BertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [203]:
def make_mask(data, lengths):
    # create the right attention mask for the data
    mask = torch.zeros(data.size(0), data.size(1)) # [batch_size, max_length]
    for i in range(data.size(0)):
        mask[i, :lengths[i]] = 1
    return mask


class Hierbert(nn.Module):
    def __init__(self, bert, hidden_sizes):
        super(Hierbert, self).__init__()
        self.bert = bert
        self.attention_mlp = AttentionMLP(768, hidden_sizes)

    def forward(self, input_ids, attention_masks, lengths):

        max_l = input_ids.size(1)# inputs are already padded
        bert_output=[]
        for i in range(max_l):
            print(input_ids[:,i].shape)
            bert_output.append(
                self.bert(input_ids[:,i], attention_masks[:,i]).pooler_output
            )

        bert_output = torch.stack(bert_output)
        print(bert_output.shape)
        sentence_mask = make_mask(input_ids, lengths)

        return self.attention_mlp(bert_output, sentence_mask)









In [204]:
model = Hierbert(bert,  [768, 50] )

In [205]:
x[0].shape

torch.Size([1, 8, 512])

In [206]:
model(x[0], x[1], x[3])

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])


RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 12582912 bytes.