In [12]:
from gensim.models import KeyedVectors
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import random
from transformers import BertTokenizer, BertModel
import json
import numpy as np
from tqdm import tqdm
import pickle
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
torch.manual_seed(1)

<torch._C.Generator at 0x19be1eb3790>

In [3]:
train_data = json.load(open('../Dataset/BIO_Tagged/ATE_train.json', 'r'))
test_data = json.load(open('../Dataset/BIO_Tagged/ATE_test.json', 'r'))
val_data = json.load(open('../Dataset/BIO_Tagged/ATE_val.json', 'r'))

#### Word to Index and Tag to Index

In [4]:
word_to_idx = {}

for case in train_data:
    for text in train_data[case]['text'].split(' '):
        if text not in word_to_idx:
            word_to_idx[text] = len(word_to_idx)

for case in test_data:
    for text in test_data[case]['text'].split(' '):
        if text not in word_to_idx:
            word_to_idx[text] = len(word_to_idx)

for case in val_data:
    for text in val_data[case]['text'].split(' '):
        if text not in word_to_idx:
            word_to_idx[text] = len(word_to_idx)

In [5]:
tag_to_ix = {}

for case in train_data:
    for tag in train_data[case]['labels']:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

for case in test_data:
    for tag in test_data[case]['labels']:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

for case in val_data:
    for tag in val_data[case]['labels']:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

tag_to_ix['START_TAG'] = len(tag_to_ix)
tag_to_ix['END_TAG'] = len(tag_to_ix)

In [6]:
pickle.dump(word_to_idx, open('word_to_idx.pkl', 'wb'))
pickle.dump(tag_to_ix, open('tag_to_ix.pkl', 'wb'))

In [7]:
word_to_idx = pickle.load(open('word_to_idx.pkl', 'rb'))
tag_to_idx = pickle.load(open('tag_to_ix.pkl', 'rb'))

#### Extracting Bert Embeddings

In [8]:
tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
bert_model = BertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')

In [13]:
embedding_mat = np.zeros((len(word_to_idx), 768))

for word, idx in tqdm(word_to_idx.items()):
    # if word in tokenizer.vocab:
    #     embedding_mat[idx] = bert_model(tokenizer.encode(word))[1].detach().numpy()
    # else:
    #     embedding_mat[idx] = np.random.rand(768)
    try:
        tokens = tokenizer.batch_encode_plus([word], return_tensors='pt', add_special_tokens=False)
    except:
        tokens = tokenizer.batch_encode_plus(['unk'], return_tensors='pt', add_special_tokens=False)
        continue
    embeddings = None
    with torch.no_grad():
        # outputs = bert_model(**tokens)
        # embeddings = outputs.last_hidden_state
        try:
            outputs = bert_model(**tokens)
            embeddings = outputs.last_hidden_state
        except:
            tokens = tokenizer.batch_encode_plus(['unk'], return_tensors='pt', add_special_tokens=False)
            outputs = bert_model(**tokens)
            embeddings = outputs.last_hidden_state
    embeddings = embeddings.squeeze(0)
    word_embeddings = embeddings.mean(dim = 0)
    embedding_mat[idx] = word_embeddings.squeeze(0).numpy()

100%|██████████| 3495/3495 [02:08<00:00, 27.25it/s]


In [14]:
pickle.dump(embedding_mat, open('../Extracted Word Embeddings/legal_bert_embedding_mat.pkl', 'wb'))

#### Extracting Word2Vec Embeddings

In [15]:
word2vec_embeddings = pickle.load(open('../Original Word Embeddings/word2vec.pkl', 'rb'))

In [16]:
#get word2vec embeddings from this model
embedding_mat = np.zeros((len(word_to_idx), 300))
for word, idx in word_to_idx.items():
    if word in word2vec_embeddings:
        embedding_mat[idx] = word2vec_embeddings[word]
    else:
        embedding_mat[idx] = np.random.rand(300)

with open('../Extracted Word Embeddings/word2vec_embedding_mat.pkl', 'wb') as f:
    pickle.dump(embedding_mat, f)

#### Extracting Glove Embeddings

In [17]:
glove_embeddings = pickle.load(open('../Original Word Embeddings/glove.pkl', 'rb')) 

In [18]:
embedding_mat = np.zeros((len(word_to_idx), 300))
for word, idx in word_to_idx.items():
    if word in glove_embeddings:
        embedding_mat[idx] = glove_embeddings[word]
    else:
        embedding_mat[idx] = np.random.rand(300)

with open('../Extracted Word Embeddings/glove_embedding_mat.pkl', 'wb') as f:
    pickle.dump(embedding_mat, f)