In [14]:
import train_add_token_model
import spacy
from spacy.tokens import DocBin
from datasets import Dataset, load_metric, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import numpy as np
import wandb
import sys
import train_sentence_classifier
import torch
import pandas as pd

In [3]:
train, labels = train_add_token_model.load_data()
dev, _ = train_add_token_model.load_data('data/dev.spacy')

tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True)
# For our custom tokens, let's add them
tokenizer.add_tokens(['<PREAMBLE>', '<JUDGEMENT>'])

model, trainer = train_add_token_model.create_model_and_trainer(train=train,
                                              dev=dev,
                                              all_labels=labels,
                                              tokenizer=tokenizer,
                                              batch_size=40,
                                              epochs=40,
                                              run_name='final_train',
                                              pretrained='./output')

Loading data...
Loading data...
Creating model...


In [28]:
def tokenize(batch):
    processed_batch = {'input_ids': [], 'attention_mask': [], 'id': [], 'text': [], 'meta': []}

    for text, meta, row_id in zip(batch['text'], batch['meta'], batch['id']):
        is_preamble = 'preamble' in meta
        if is_preamble:
            text = '<PREAMBLE> ' + text
        else:
            text = '<JUDGEMENT> ' + text
        tokenized = tokenizer(text, truncation=False, is_split_into_words=False)
        
        for token_index in range(0, len(tokenized['input_ids']), 512):
            processed_batch['input_ids'].append(tokenized['input_ids'][token_index:token_index+512])
            processed_batch['attention_mask'].append(tokenized['attention_mask'][token_index:token_index+512])
            processed_batch['id'].append(row_id)
            processed_batch['text'].append(text)
            processed_batch['meta'].append(meta)
    return processed_batch

def load_test_data(tokenizer):
    all_rows = []
    for index, row in pd.read_json("../data/NER_TEST_DATA_FS.json").iterrows():
        all_rows.append({'text': row['data']['text'], 'meta': row['meta']['source'], 'id': row['id']})

    test = Dataset.from_list(all_rows)

    return test.map(tokenize, batched=True)

test = load_test_data(tokenizer)
test

  0%|          | 0/5 [00:00<?, ?ba/s]

Dataset({
    features: ['text', 'meta', 'id', 'input_ids', 'attention_mask'],
    num_rows: 5037
})

In [37]:
preds = trainer.predict(test)
preds = np.argmax(preds[0], axis=2)
preds

The following columns in the test set don't have a corresponding argument in `RobertaForTokenClassification.forward` and have been ignored: meta, text, id. If meta, text, id are not expected by `RobertaForTokenClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 5037
  Batch size = 40


array([[28, 28, 28, ..., 28, 28, 28],
       [28, 28, 28, ..., 28, 28, 28],
       [28,  7, 21, ..., 28, 28, 28],
       ...,
       [26, 28, 28, ...,  0,  0,  0],
       [28, 28, 28, ...,  0,  0,  0],
       [28, 28, 28, ...,  0,  0,  0]])

In [38]:
import json
file = open('../data/NER_TEST_DATA_FS.json')
data = json.load(file)
data[0]['annotations'] = [{'result': ['a']}]
data[0]

{'id': '0f8e4fc0fdff428f993cf8507f3606e4',
 'annotations': [{'result': ['a']}],
 'data': {'text': 'In The High Court Of Kerala At Ernakulam\n\n                                       Present\n\n                    The Honourable Mrs. Justice M.R.Anitha\n\n           Wednesday, The 10Th Day Of June 2020 / 20Th Jyaishta, 1942\n\n                          Crl.Rev.Pet.No.767 Of 2012\n\n    Crmp 1176/2011 Dated 16-03-2012 Of Judicial Magistrate Of First Class ,\n                                 Kunnamkulam\n\nRevision Petitioner/Complainant\n                A. Rajesh Aged 35 Years,\n                S/O.Raman Nair, Ammasom Veettil, Punnayurkulam Village, Cherayi\n                Desom, Andathodu P.O., Chavakkad, Thrissur District, Pin\n                679564.\n                By Advs.\n                Dr.V.N.Sankarjee\n                Sri.V.N.Madhusudanan\n                Sri.S.Sidhardhan\n                Smt.R.Udaya Jyothi\n                Sri.M.M.Vinod\n                Smt.M.Suseela\n      

In [40]:
from tqdm import tqdm

for sent_index in tqdm(range(len(data))):
    annotations = []
    
    # Find all the matching predictions, which may come from multiple rows (if the sentence was split up)
    sent_id = data[sent_index]['id']
    
    pred = []
    for input_index in range(len(test)):
        if test[input_index]['id'] == sent_id:
            pred += list(preds[input_index])
    
    
    original_text = data[sent_index]['data']['text']
    # We have a list of preds, we need to match them up with tokens and find the original character range
    # Retokenize the text so we can figure out what words correspond to what char ranges
    # If we had included the class token, it would be tougher to align the char indices with the original text
    sent_tokenized = tokenizer(original_text, truncation=False, is_split_into_words=False)
    
    # We may need to keep track of a label over multiple tokens
    current_label = None # This will be a tuple (label name, start index, end index)
    
    for token_index in range(1, len(sent_tokenized['input_ids'])-1):
        # Iterate through each token in the sentence (skip the first)
        tag = labels[pred[token_index + 1]] # We are off by one because our predictions include a prediction for the class token
        
        token_indices = sent_tokenized.token_to_chars(token_index)
        start_index = token_indices.start
        end_index = token_indices.end
        
        if 'I' in tag:
            # We must be following a B tag or we made an error
            # So there should be a current_label
            if current_label:
                current_label = (current_label[0], current_label[1], end_index)
        elif current_label:
            # If we previously were tracking a label, we need to end it, since we are now looking at a B or O tag
            annotations.append({'value': {'start': current_label[1], 
                                          'end': current_label[2], 
                                          'text': original_text[current_label[1]: current_label[2]],
                                          'labels': [current_label[0]]},
                                'id': f"{sent_index}{token_index}",
                                'from_name': 'label',
                                'to_name': 'label',
                                'type': 'labels'
                               })
            current_label = None
        
        if 'B' in tag:
            current_label = (tag[2:], start_index, end_index)
    
    data[sent_index]['annotations'] = [{'result': annotations}]


100%|███████████████████████████████████████| 4501/4501 [40:16<00:00,  1.86it/s]


In [41]:
with open("NER_TEST_DATA_FS.json", "w") as outfile:
    outfile.write(json.dumps(data, indent=4))

In [42]:
for (token, label) in zip(test[0]['input_ids'], preds[0]):
    print(tokenizer.convert_ids_to_tokens([token]), labels[label])

['<s>'] O
['<PREAMBLE>'] O
['ĠIn'] O
['ĠThe'] O
['ĠHigh'] B-COURT
['ĠCourt'] I-COURT
['ĠOf'] I-COURT
['ĠKerala'] I-COURT
['ĠAt'] I-COURT
['ĠErn'] I-COURT
['ak'] I-COURT
['ul'] I-COURT
['am'] I-COURT
['ĊĊ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['ĠPresent'] O
['ĊĊ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['ĠThe'] O
['ĠHonour'] O
['able'] O
['ĠMrs'] O
['.'] O
['ĠJustice'] O
['ĠM'] B-JUDGE
['.'] I-JUDGE
['R'] I-JUDGE
['.'] I-JUDGE
['An'] I-JUDGE
['ith'] I-JUDGE
['a'] I-JUDGE
['ĊĊ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['Ġ'] O
['ĠWednesday'] O
[','] O
['ĠThe'] O
['Ġ10'] O
['Th'

In [270]:
def train_tokenize(row, idx):
    # Add special token for document type
    is_preamble = True
    if is_preamble:
        row['tokens'].insert(0, '<PREAMBLE>')
    else:
        row['tokens'].insert(0, '<JUDGEMENT>')
    row['tags'].insert(0, 'O')

    tokenized = tokenizer(row['tokens'], truncation=True, is_split_into_words=True)
    aligned_labels = []
    last_i = None
    for i in tokenized.word_ids():
        if i is None:
            aligned_labels.append(-100)
            continue
        
        aligned_label = row['tags'][i] # Find the appropriate label index
        if not i == last_i:
            aligned_labels.append(labels.index(aligned_label))
        else:
            aligned_labels.append(labels.index(aligned_label.replace('B', 'I')))
        last_i = i
    tokenized['labels'] = aligned_labels

    return tokenized

train_tokenized = train_tokenize(train[0], 0)
train_preds = np.argmax(trainer.predict([train_tokenized])[0], axis=2)[0]
for (token, label, real) in zip(train_tokenized['input_ids'], train_preds, train_tokenized['labels']):
    print(tokenizer.convert_ids_to_tokens([token]), labels[label], labels[real if real >= 0 else 28])

***** Running Prediction *****
  Num examples = 1
  Batch size = 40


['<s>'] O O
['<PREAMBLE>'] O O
['Ġ'] O O
['ĊĊ'] O O
['Ġ('] O O
['Ġ7'] O O
['Ġ)'] O O
['ĠOn'] O O
['Ġspecific'] O O
['Ġquery'] O O
['Ġby'] O O
['Ġthe'] O O
['ĠBench'] O O
['Ġabout'] O O
['Ġan'] O O
['Ġentry'] O O
['Ġof'] O O
['ĠRs'] O O
['Ġ.'] O O
['Ġ1'] O O
[','] O O
['31'] O O
[','] O O
['37'] O O
[','] O O
['500'] O O
['Ġon'] O O
['Ġdeposit'] O O
['Ġside'] O O
['Ġof'] O O
['ĠHong'] B-ORG B-ORG
['k'] B-ORG I-ORG
['ong'] B-ORG I-ORG
['ĠBank'] I-ORG I-ORG
['Ġaccount'] O O
['Ġof'] O O
['Ġwhich'] O O
['Ġa'] O O
['Ġphoto'] O O
['Ġcopy'] O O
['Ġis'] O O
['Ġappearing'] O O
['Ġat'] O O
['Ġp'] O O
['.'] O O
['Ġ40'] O O
['Ġof'] O O
['Ġass'] O O
['essee'] O O
["Ġ'"] O O
['s'] O O
['Ġpaper'] O O
['Ġbook'] O O
['Ġ,'] O O
['Ġlearned'] O O
['Ġauthorised'] O O
['Ġrepresentative'] O O
['Ġsubmitted'] O O
['Ġthat'] O O
['Ġit'] O O
['Ġwas'] O O
['Ġrelated'] O O
['Ġto'] O O
['Ġloan'] O O
['Ġfrom'] O O
['Ġbroker'] O O
['Ġ,'] O O
['ĠRahul'] B-ORG B-ORG
['Ġ&'] I-ORG I-ORG
['ĠCo'] I-ORG I-ORG
['.'] I-ORG I-OR

In [247]:
print(train[-100])

{'tokens': ['Wa-305', '-', '2007', '\n         ', '(', 'Purushottam', 'Lal', 'Vs', 'The', 'State', 'Of', 'Madhya', 'Pradesh', ')', '\n\n\n', '15', '-', '10', '-', '2015', '\n    ', 'High', 'Court', 'Of', 'Madhya', 'Pradesh', 'Principal', '\n              ', 'Seat', 'At', 'Jabalpur', '\n                 ', 'Writ', 'Appeal', 'No.305/2007', '\n                 ', 'Purushottam', 'Lal', 'and', 'others', '\n                             ', 'Vs', '.', '\n                   ', 'State', 'of', 'M.P.', '&', 'Others', '\n', 'Present', ':', 'Honâ\x80\x99ble', 'Shri', 'Rajendra', 'Menon', ',', 'J.', '&', '\n', "Hon'ble", 'Shri', 'C.', 'V.', 'Sirpurkar', ',', 'J.', '\n', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '\n', 'Shri', 'Vivek', 'Tankha', ',', 'learned', 'Senior', 'Counse

In [8]:
for item in test:
    if len(item['input_ids']) > 512:
        long = item
        print(item)
        break

{'text': 'In The High Court Of Kerala At Ernakulam\n\n                                       Present\n\n                    The Honourable Mrs. Justice M.R.Anitha\n\n           Wednesday, The 10Th Day Of June 2020 / 20Th Jyaishta, 1942\n\n                          Crl.Rev.Pet.No.767 Of 2012\n\n    Crmp 1176/2011 Dated 16-03-2012 Of Judicial Magistrate Of First Class ,\n                                 Kunnamkulam\n\nRevision Petitioner/Complainant\n                A. Rajesh Aged 35 Years,\n                S/O.Raman Nair, Ammasom Veettil, Punnayurkulam Village, Cherayi\n                Desom, Andathodu P.O., Chavakkad, Thrissur District, Pin\n                679564.\n                By Advs.\n                Dr.V.N.Sankarjee\n                Sri.V.N.Madhusudanan\n                Sri.S.Sidhardhan\n                Smt.R.Udaya Jyothi\n                Sri.M.M.Vinod\n                Smt.M.Suseela\n                Sri.Sudhakaran V.\n                Smt.Arya Balachandran\n                Smt. 

In [12]:
tokenizer.convert_ids_to_tokens(long['input_ids'][500:512])

['V', 'in', 'od', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ']