# Embeddings with Legal Bert

In [None]:
%pip install pandas
%pip install torch
%pip install transformers
%pip install tqdm

In [6]:
import pandas
import  torch
import  torch.nn as nn
import  torch.optim as optim
import  torch.nn.functional as F
import transformers
import pandas as pd
from tqdm import tqdm

# Load Bert

In [7]:
bert = transformers.BertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [32]:
device ='cuda' if torch.cuda.is_available else 'cpu'
device

'cuda'

In [None]:
bert.to(device)

In [59]:
df_train = pd.read_pickle('../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_train_tokenized.pkl')
df_dev = pd.read_pickle('../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_dev_tokenized.pkl')
df_test = pd.read_pickle('../ECHR_Dataset_Tokenized/legal-bert-base-uncased/df_test_tokenized.pkl')

In [36]:
class ClassificationHeadWithAttention(nn.Module):
    """
        Classification head with attention mechanism
        Takes in input n bert embeddings of size 768 and outputs a binary classification
    """
    def __init__(self, input_size=768, hidden_size=1024, output_size=1):
        super(ClassificationHeadWithAttention, self).__init__()
        # self.positional_encoding = positional_encoding()
        self.selector = nn.Parameter(torch.Tensor(input_size,1))
        nn.init.normal_(self.selector)
        self.keys_matrix = nn.Parameter(torch.Tensor(input_size, input_size))
        nn.init.normal_(self.keys_matrix)
        self.values_matrix = nn.Parameter(torch.Tensor(input_size, input_size))
        nn.init.normal_(self.values_matrix)
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        # apply positional encoding
        # x = positional_encoding(x) TODO implement positional encoding
        # apply attention mechanism
        values = torch.matmul(x, self.values_matrix)
        keys = torch.matmul(x, self.keys_matrix)
        attention_weights = F.softmax(torch.matmul(values, self.selector), dim=0)
        x= torch.matmul(attention_weights.T, values)
        x = F.relu(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        return x, attention_weights


In [60]:
df_train= df_train[['input_ids', 'attention_mask', 'label']]
df_dev = df_dev[['input_ids', 'attention_mask', 'label']]
df_test = df_test[['input_ids', 'attention_mask', 'label']]

In [61]:
def get_one_embedding(document, attention_mask, model, batch_size = 1):
    emb=[]
    with torch.no_grad():
        for s, a in zip(document, attention_mask):
            emb.append(model(s,a).pooler_output)
    print(emb)
    return emb  # assuming is a bert model


In [62]:
df_train['input_ids']=df_train['input_ids'].apply(lambda x: (torch.stack(x)).squeeze(1))

In [63]:
df_train['attention_mask']=df_train['attention_mask'].apply(lambda x: (torch.stack(x)).squeeze(1))

In [70]:
df_test['input_ids']=df_test['input_ids'].apply(lambda x: (torch.stack(x)).squeeze(1))
df_test['attention_mask']=df_test['attention_mask'].apply(lambda x: (torch.stack(x)).squeeze(1))

In [73]:
df_dev['input_ids']=df_dev['input_ids'].apply(lambda x: (torch.stack(x)).squeeze(1))
df_dev['attention_mask']=df_dev['attention_mask'].apply(lambda x: (torch.stack(x)).squeeze(1))

In [64]:
df_train['input_ids'][0].shape

torch.Size([2, 512])

In [65]:
def get_embedding_batched(input_ids, attention_mask, model, batch_size = 1):
    emb=[]
    with torch.no_grad():
        for i in range(0, len(input_ids), batch_size):
            input_ids_batch = input_ids[i:i+batch_size].to(device)
            attention_mask_batch = attention_mask[i:i+batch_size].to(device)
            emb.append(model(input_ids_batch, attention_mask_batch).pooler_output)
    return emb

In [66]:
tqdm.pandas()
def get_embeddings(model, df:pandas.DataFrame):
    # apply the embedding model to the dataframe
    emb = df.progress_apply(lambda x: get_embedding_batched(x['input_ids'], x['attention_mask'], model,  10), axis=1)
    try:
        emb=emb.apply(lambda x: torch.cat(x, dim=0))
    except Exception as e:
        print('sburreck')
    return emb

In [67]:
emb_tr = get_embeddings(model=bert, df = df_train)

100%|██████████| 7100/7100 [22:51<00:00,  5.18it/s] 


In [74]:
emb_dev= get_embeddings(model=bert, df = df_dev)

100%|██████████| 1380/1380 [04:41<00:00,  4.90it/s]


In [71]:
emb_test = get_embeddings(model=bert, df = df_test)

100%|██████████| 2998/2998 [08:19<00:00,  6.00it/s]


In [52]:
df_train['input_ids'][0].__len__()

11

In [78]:
folder_path = '../embeddings/legal-bert-base-uncased'

In [79]:
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

In [80]:
emb_tr.to_pickle(folder_path+'/emb_tr.pkl')

In [81]:
emb_dev.to_pickle(folder_path+'/emb_dev.pkl')

In [82]:
emb_test.to_pickle(folder_path+'/emb_test.pkl')