## Imports and Device

In [1]:
# ! pip install transformers
# ! pip3 install wandb
# ! pip install rouge_score

import pandas as pd
import numpy as np
import torch
import os, gc
import re

from transformers import AutoTokenizer, LongformerTokenizer, RobertaTokenizer
from transformers import LongformerForQuestionAnswering, AutoModelForSeq2SeqLM

from torch import cuda, nn, optim
# from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
# import rouge_score
# import wandb


/bin/bash: pip: command not found


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
manual_seed = 595
torch.manual_seed(manual_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Read the Cleaned Data

### Define the paths

In [3]:
# run locally
text_path = '../formatted_cases/'
file = '../../annotated_data.xlsx'
REGEX = r';+'
sup_path = '../annotated_sup/'
multi_path = text_path + 'multiple_files/'

In [4]:
# # run on Google Colab
# from google.colab import drive
# drive.mount('/content/gdrive')
# text_path = '/content/gdrive/My Drive/595/formatted_cases/'
# file = '/content/gdrive/My Drive/595/annotated_data.xlsx'
# REGEX = r';+'
# sup_path = '/content/gdrive/My Drive/595/annotated_sup/'
# multi_path = text_path + 'multiple_files/'

In [5]:
# wandb.login()
# wandb.init(project="RTB_Cases", entity="qmygrace")

### Clean the Dataframe

In [6]:
df = pd.read_excel(file)
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace(' and ', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace(' ', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('/', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.strip(';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].apply(lambda x: re.sub(REGEX, ';', x))
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('File;number:;', '')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('TET-89650-18;TET-89650-18', 'TET-89650-18;TEL-90138-18')
df = df.fillna('Not stated')
df = df.replace('Not applicable', 'Not stated')
df.rename(columns={
    'If yes to the previous question, did the decision state these conditions would make moving particularly burdensome?':
    'If any of the children had mental, medical or physical conditions, did the decision state these conditions would make moving particularly burdensome?',
    'If yes to the previous question, which of the following were applicable to the tenant?':
    'If the tenant had difficulty finding alternative housing for any reason, which of the following were applicable to the tenant?'    
}, inplace=True)

df = df.iloc[:, 2:-2]

print(df.shape)
df.head(6)

(702, 50)


Unnamed: 0,What is the file number of the case?,What was the date of the hearing? [mm/dd/yyyy],What was the date of the decision? [mm/dd/yyyy],Who was the member adjudicating the decision?,What was the location of the landlord tenant board?,Did the decision state the landlord was represented?,Did the decision state the landlord attended the hearing?,Did the decision state the tenant was represented?,Did the decision state the tenant attended the hearing?,Did the decision state the landlord was a not-for-profit landlord (e.g. Toronto Community Housing)?,...,"If the tenant did propose a payment plan, did the member accept the proposed payment plan?","If a payment plan was ordered, what was the length of the payment plan?","Did the decision mention the tenant’s difficulty finding alternative housing for any reason e.g.physical limitations, reliance on social assistance, etc.?","If the tenant had difficulty finding alternative housing for any reason, which of the following were applicable to the tenant?",Did the decision state the tenant was given prior notice for the eviction?,"If the tenant was given prior notice for the eviction, how much notice was given?",Did the decisions state postponement would result in the tenant accruing additional arrears?,Which other specific applications of the landlord or the tenant were mentioned?,Did the decision mention the validity of an N4 eviction notice?,Were there detail(s) in the decision not captured by this questionnaire that should be included?
0,CEL-87788-19,2019-10-16 00:00:00,2020-06-04 00:00:00,Sonia Anwar-Ali,Toronto,Yes,Not stated,No,Not stated,No,...,Not stated,12,No,Not stated,No,Not stated,No,L2: Application to End a Tenancy and Evict a T...,No,Tenant was a single mother with no support fro...
1,CEL-90549-19,2020-01-22 00:00:00,2020-01-10 00:00:00,Shelby Whittick,Mississauga,Yes,Yes,No,Yes,No,...,No,Not stated,No,Not stated,Yes,Not stated,Yes,No other specific applications were mentioned,No,Not stated
2,TEL-94478-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,Yes,No,...,Not stated,Not stated,No,Not stated,Yes,Not stated,No,N13: Notice to End your Tenancy Because the La...,No,Previous decision TEL-92736-18 < This decision...
3,TEL-94493-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,Yes,No,...,Yes,1,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,There were 7 previous application for non-paym...
4,CEL-72994-18,2018-03-07 00:00:00,2018-03-14 00:00:00,Avril Cardoso,Mississauga,Yes,No,Yes,No,No,...,No,Not stated,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,Third Application by Landlord in past 6 months...
5,CEL-73021-18,2018-06-15 00:00:00,2018-06-18 00:00:00,Avril Cardoso,Mississauga,Yes,No,No,No,No,...,Not stated,Not stated,No,Not stated,Yes,Not stated,No,L1: Application to Evict a Tenant for Non-paym...,No,Tenant did not show up because hearing took pl...


In [7]:
# df.columns   #`Timestamp` is not the time of the case

In [8]:
df_unique = df.drop_duplicates(subset=['What is the file number of the case?'])
df_unique = df_unique.reset_index(drop=True)

print(df_unique.shape)
df_unique.head(6)

(682, 50)


Unnamed: 0,What is the file number of the case?,What was the date of the hearing? [mm/dd/yyyy],What was the date of the decision? [mm/dd/yyyy],Who was the member adjudicating the decision?,What was the location of the landlord tenant board?,Did the decision state the landlord was represented?,Did the decision state the landlord attended the hearing?,Did the decision state the tenant was represented?,Did the decision state the tenant attended the hearing?,Did the decision state the landlord was a not-for-profit landlord (e.g. Toronto Community Housing)?,...,"If the tenant did propose a payment plan, did the member accept the proposed payment plan?","If a payment plan was ordered, what was the length of the payment plan?","Did the decision mention the tenant’s difficulty finding alternative housing for any reason e.g.physical limitations, reliance on social assistance, etc.?","If the tenant had difficulty finding alternative housing for any reason, which of the following were applicable to the tenant?",Did the decision state the tenant was given prior notice for the eviction?,"If the tenant was given prior notice for the eviction, how much notice was given?",Did the decisions state postponement would result in the tenant accruing additional arrears?,Which other specific applications of the landlord or the tenant were mentioned?,Did the decision mention the validity of an N4 eviction notice?,Were there detail(s) in the decision not captured by this questionnaire that should be included?
0,CEL-87788-19,2019-10-16 00:00:00,2020-06-04 00:00:00,Sonia Anwar-Ali,Toronto,Yes,Not stated,No,Not stated,No,...,Not stated,12,No,Not stated,No,Not stated,No,L2: Application to End a Tenancy and Evict a T...,No,Tenant was a single mother with no support fro...
1,CEL-90549-19,2020-01-22 00:00:00,2020-01-10 00:00:00,Shelby Whittick,Mississauga,Yes,Yes,No,Yes,No,...,No,Not stated,No,Not stated,Yes,Not stated,Yes,No other specific applications were mentioned,No,Not stated
2,TEL-94478-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,Yes,No,...,Not stated,Not stated,No,Not stated,Yes,Not stated,No,N13: Notice to End your Tenancy Because the La...,No,Previous decision TEL-92736-18 < This decision...
3,TEL-94493-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,Yes,No,...,Yes,1,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,There were 7 previous application for non-paym...
4,CEL-72994-18,2018-03-07 00:00:00,2018-03-14 00:00:00,Avril Cardoso,Mississauga,Yes,No,Yes,No,No,...,No,Not stated,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,Third Application by Landlord in past 6 months...
5,CEL-73021-18,2018-06-15 00:00:00,2018-06-18 00:00:00,Avril Cardoso,Mississauga,Yes,No,No,No,No,...,Not stated,Not stated,No,Not stated,Yes,Not stated,No,L1: Application to Evict a Tenant for Non-paym...,No,Tenant did not show up because hearing took pl...


In [9]:
info_lst = df_unique.columns[2:-2]

raw_file_text = []

for i in range(len(df_unique)):
    file_no = df_unique.iloc[i,0]
    if not os.path.isfile(text_path+file_no+'.txt'):
        print(f'{file_no} not found. Going to the supplement directory.')
        # passed_cases.append(file_no)
        if not os.path.isfile(sup_path+file_no+'.txt'):
            print(f'{file_no} not found. Going to the multiple directory.')
            with open (multi_path+file_no+'.txt') as t:
                # file_no_lst = file_no.split(';')
                # print(file_no_lst)
                raw_file_text.append(t.read())
        else:
            with open (sup_path+file_no+'.txt') as t:
                raw_file_text.append(t.read())
    else:
        with open (text_path+file_no+'.txt') as t:
            # cases_info[-1]['text'] = t.read()
            raw_file_text.append(t.read())
            # raw_file_name.append(file_no+'.txt')

TET-89650-18;TEL-90138-18 not found. Going to the supplement directory.
TNL-00793-18;TNL-01183-18 not found. Going to the supplement directory.
TNL-00793-18;TNL-01183-18 not found. Going to the multiple directory.
TNL-03299-18;TNT-00589-17 not found. Going to the supplement directory.
TNL-03299-18;TNT-00589-17 not found. Going to the multiple directory.
TNL-04435-18;TNL-03907-18 not found. Going to the supplement directory.
HOL-02144-17;HOT-02146-17 not found. Going to the supplement directory.
TEL-87475-18;TET-86819-17;TET-88355-18 not found. Going to the supplement directory.
TEL-87475-18;TET-86819-17;TET-88355-18 not found. Going to the multiple directory.
SWL-08112-17;SWL-08113-17 not found. Going to the supplement directory.
SWL-12547-18;SWL-12548-18 not found. Going to the supplement directory.
SWL-12547-18;SWL-12548-18 not found. Going to the multiple directory.
SWL-13901-18;SWT-14627-18 not found. Going to the supplement directory.
TEL-77442-17;TET-77790-17 not found. Going to 

In [10]:
# remove columns that have too little information
little_info_col = [15, 16, 26, 27, 28, 29, 30, 31, 41, 43, 45]
to_del = [df_unique.columns[i] for i in little_info_col]
for col in to_del:
    del df_unique[col]
to_del

['If any rent increases occurred, what was the rent after the increase(s)?',
 'If any rent increases occurred, when did the rent increase(s) come into effect? ',
 'How many total children did the tenant have living with them? ',
 'How many total children aged 17 or younger did the tenant have living with them?',
 'How many total children aged 13 or younger did the tenant have living with them? ',
 'How many total children aged 4 or younger did the tenant have living with them?',
 'Did the decision state any of the children had mental, medical or physical conditions?',
 'If any of the children had mental, medical or physical conditions, did the decision state these conditions would make moving particularly burdensome?',
 'If a payment plan was ordered, what was the length of the payment plan? ',
 'If the tenant had difficulty finding alternative housing for any reason, which of the following were applicable to the tenant?',
 'If the tenant was given prior notice for the eviction, how mu

In [11]:
del to_del
gc.collect()

1075

### Split the Train Dataframe and Validation Dataframe

In [12]:
train_df = df_unique.iloc[:620, :]
val_df = df_unique.iloc[620:, :].reset_index(drop=True)
train_df.shape, val_df.shape

((620, 39), (62, 39))

In [13]:
for i, q in enumerate(train_df.columns):
    print(i, q)

0 What is the file number of the case?
1 What was the date of the hearing? [mm/dd/yyyy]
2 What was the date of the decision? [mm/dd/yyyy]
3 Who was the member adjudicating the decision?
4 What was the location of the landlord tenant board?
5 Did the decision state the landlord was represented?
6 Did the decision state the landlord attended the hearing?
7 Did the decision state the tenant was represented?
8 Did the decision state the tenant attended the hearing?
9 Did the decision state the landlord was a not-for-profit landlord (e.g. Toronto Community Housing)?
10 Did the decision state the tenant was collecting a subsidy?
11 What was the outcome of the case?
12 What was the length of the tenancy, or in other words, how long had the tenants lived at the residence in question? 
13 What was the monthly rent?
14 What was the amount of the rental deposit? 
15 What was the total amount of arrears?
16 Over how many months did the arrears accumulate? 
17 If the tenant made a payment on the ar

## Initialize the Tokenizer and the Model

In [14]:
# @article{Beltagy2020Longformer,
#   title={Longformer: The Long-Document Transformer},
#   author={Iz Beltagy and Matthew E. Peters and Arman Cohan},
#   journal={arXiv:2004.05150},
#   year={2020},
# }
tokenizer1 = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model1 = LongformerForQuestionAnswering.from_pretrained('allenai/longformer-base-4096', gradient_checkpointing=True)

tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False)

# ref: https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_tune_Longformer_Encoder_Decoder_(LED)_for_Summarization_on_pubmed.ipynb#scrollTo=jpUr9QeebZ-n

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForQuestionAnswering: ['lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing LongformerForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForQuestionAnswering were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mo

## A Test before Finetuning

In [15]:
def prompt(dataframe, raw_texts):
    input_texts = []
    outputs = []
    # long_cases = 0
    
    questions = dataframe.columns
    
    for q_no in range(len(questions)):
        answers = dataframe.iloc[:,q_no]
        # print(len(raw_texts), len(answers))
        assert len(raw_texts) == len(answers)

        for i in range(len(answers)):
            full_text = raw_texts[i]
            text = full_text[full_text.find('Content:')+len('Content:'):]

            # if len(text) > 26000:
            #     # print(len(text))
            #     text = text[:26000]
            #     long_cases += 1

            text = text.replace('\n', ' ')
            text = text.replace('\xa0', ' ')
            text = text.replace('\t', ' ')
            text = text.replace('   ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ')
            # text
            # for word in stop_words:
            #     text = text.replace(' '+word+' ', ' ')

            if 'Schedule 1' in text:
                s_idx = text.find('Schedule 1')
                text = text[:s_idx]

            input_text = f'Question: {questions[q_no]} Text: {text}'  
            input_texts.append(input_text)

            output = str(answers[i])
            outputs.append(output)
        
    # print(len(input_texts), len(outputs))
    # print(input_texts[0], outputs[0]) 
    # print(long_cases)
    return input_texts, outputs

### Longformer for Question Answering

In [16]:
q1_lst, a1_lst = prompt(df_unique, raw_file_text)
q1 = q1_lst[0]
a1 = a1_lst[0]
# # print(q1)
print(a1)
encoding = tokenizer1.encode_plus(text=q1,
                                  text_pair=a1) 
                                 # add_special=True)
inputs = torch.LongTensor(encoding['input_ids']).unsqueeze(0)  #Token embeddings
attention_mask = torch.LongTensor(encoding['attention_mask']).unsqueeze(0)
# print(len(attention_mask))

# sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids']) #input tokens
outputs = model1(input_ids=inputs, 
                attention_mask=attention_mask)
start_scores, end_scores = outputs[0], outputs[1]
answer_tokens = tokens[torch.argmax(start_scores):torch.argmax(end_scores)+1]
answer = tokenizer1.decode(tokenizer1.convert_tokens_to_ids(answer_tokens))
answer

CEL-87788-19


''

In [17]:
torch.argmax(start_scores),torch.argmax(end_scores)

(tensor(999), tensor(655))

It shows that the encoder models can not really get what we need for most columns. Therefore we will

In [18]:
del tokenizer1, model1#, answer
gc.collect()

21

###  Longformer Encoder-Decoder (LED) 

In [19]:
q1_lst, a1_lst = prompt(df_unique, raw_file_text)
q1 = q1_lst[0]
a1 = a1_lst[0]
# # print(q1)
input_encoding = tokenizer(q1)
output_encoding = tokenizer(a1)
input_ids = torch.LongTensor(encoding['input_ids']).unsqueeze(0)  # batch of size 1
attention_mask = torch.LongTensor(encoding['attention_mask']).unsqueeze(0)
# attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                     # classification: the <s> token
                                     # QA: question tokens
print(input_ids.shape, attention_mask.shape)
# input_ids, attention_mask = pad_to_window_size(
#         input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)

output = model.generate(
    input_ids=input_ids, 
    attention_mask=attention_mask,
    return_dict_in_generate=True, 
    output_scores=False, 
    max_length=512,
    temperature=0.5,
    do_sample=True,
    repetition_penalty=3.0,
    top_k=10)


torch.Size([1, 1364]) torch.Size([1, 1364])




In [20]:
tokenizer.batch_decode(output['sequences'])

["</s><s>Question: What is the file number of this case? Text.  Order under Section 69 Residential Tenancies Act, 2006 File Number (a) CEL-87788 and a copy for use in an application to terminate tenancy The Landlord may apply at any time on or compositional basis that he/she has been persistently late paying his rent since 2012. 1 Pertinently subjecting her tenant(s), M., K.(the 'Tenant') could not be convicted by reason thereof; it would have required him as proof otherwise than evidence before me if I was permitted against my will with respect thereto whereupon there were no documentary evidences presented during trial proceedings relating only one point after each day until such period expires from June 30th 2017 through September 29’ 2018 which shall also require compensation when due but do so prior notice being made within three months following payment date thereafter unless further notices are issued pursuantto section 78). 2 BUDGETTY THE LEGAL ORDER TO END ANTIENSION 3 In acco

## Preprocess the Data

In [21]:
def preprocess(dataframe, tokenizer, raw_texts):
    input_texts, outputs = prompt(dataframe, raw_texts)   
    
    input_toks = tokenizer.batch_encode_plus(input_texts,
                                             add_special_tokens=False, 
                                             return_token_type_ids=False)
    output_toks = tokenizer.batch_encode_plus(outputs, 
                                              add_special_tokens=False,
                                              return_token_type_ids=False)
    # print(len(q1_train_input['input_ids']), len(q1_train_output['input_ids']))
    return input_toks, output_toks
    

In [22]:
train_raw_texts = raw_file_text[:620]
val_raw_texts = raw_file_text[620:]

In [23]:
train_input, train_output = preprocess(train_df, tokenizer, train_raw_texts)
val_input, val_output = preprocess(val_df, tokenizer, val_raw_texts)

In [24]:
# len(q1_train_input['input_ids']), len(q1_train_output['input_ids'])

In [25]:
# for BertTokenizer
print("Input length:", len(train_input))
# print(" ")
print("Input example:\n", tokenizer.decode(train_input['input_ids'][0])[100:])
print(" ")
print("Input ID example:\n", train_input['input_ids'][0][100:])
print(" ")
print("Tokens:\n", [tokenizer.convert_ids_to_tokens(id) for id in train_input['input_ids'][0]][100:])
print(" ")
print("Attention Mask:", train_input['attention_mask'][0])
print(" ")
print("Output example:\n", tokenizer.decode(train_output['input_ids'][0])[100:])
print(" ")
print("Output ID example:\n", train_output['input_ids'][0])
print(" ")

Input length: 2
Input example:
 ct, 2006 File Number: CEL-87788-19 M.C. (the 'Landlord') applied for an order to terminate the tenancy and evict M.K. (the 'Tenant') because he has been persistently late in paying his rent. The Landlord also claimed compensation for each day the Tenant remained in the unit after the termination date. This application was heard in Toronto on October 16, 2019. The Landlord and the Tenant attended the hearing. The Landlord was represented by S.K. Also in attendance was the Landlord’s property manager, B.A. Determinations: 1. By way of background, this is a month to month tenancy in which rent is due on the first of the month in the amount of $1398.79. This tenancy commenced on June 1, 2012. 2. The Landlord’s L2 application is based on a N8 notice of termination served to the Tenant on April 24, 2019 with a termination date of June 30, 2019 alleging that the Tenant has been persistently late in paying the rent since January 2018. 3. Since the N8 notice of t

## Create the Dataset

In [26]:
PAD = tokenizer.pad_token_id
SEP = tokenizer.sep_token_id
PAD, SEP

(1, 2)

In [27]:
class CaseDataset(Dataset):

    def __init__(self, inputs, outputs):
        self.inputs = inputs
        self.outputs = outputs

    def __len__(self):
        return len(self.inputs["input_ids"])

    def __getitem__(self, idx):
        input_ids = self.inputs['input_ids'][idx]
        attention_mask = self.inputs['attention_mask'][idx]

        target_ids = self.outputs['input_ids'][idx]
        # target_attention_mask = self.outputs['attention_mask'][idx]
        return {"input_ids": input_ids, "attention_mask":attention_mask, "output_ids":target_ids}


def collate_fn(batch):
    batch_input = [torch.LongTensor(example['input_ids']) for example in batch]
    batch_output = [torch.LongTensor(example['output_ids']) for example in batch]
    batch_mask = [torch.LongTensor(example['attention_mask']) for example in batch]

    padded_batch_input_ids = pad_sequence(batch_input, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_label = pad_sequence(batch_output, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_att_mask = pad_sequence(batch_mask, batch_first=True, padding_value=-100)

    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}


def to_device(data, device):
    new_data = {}
    for k in data:
        new_data[k] = data[k].to(device)
    return new_data

## Prepare the Functions for Training and Evaluation

In [28]:
def train(model:nn.Module, train_loader:DataLoader, optimizer:optim.Optimizer, log_step=200):
    model.train()
    epoch_loss = 0.0
    log_loss = 0.0
    for idx, batch in enumerate(train_loader):
        # try:
        model.zero_grad()
        batch = to_device(batch, device)
        loss = model(**batch).loss
        # print(loss)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        log_loss += loss.item()

        # wandb.log({'batch':idx, 'train_loss': loss.item()})
        # wandb.log({'batch':idx, 'accumulated_train_loss_in_this_Q': log_loss})

        if idx % log_step == 0:
            print(f"Train Step: {idx} Loss: {log_loss / log_step}")
            log_loss = 0.0
        # except:
        #     print(f'The text is too long. Passing for now. Step No: {idx}')
        #     pass

    return epoch_loss / len(train_loader)
        

@torch.no_grad()
def evaluate(model:nn.Module, eval_loader:DataLoader):
    eval_loss = 0.0
    correct = 0
    total = 0
    model.eval()
    for batch in eval_loader:
        batch = to_device(batch, device)
        output = model(**batch)
        loss = output.loss
        eval_loss += loss.item()
        pred = output.logits.argmax(-1)
        label = batch["labels"]
        correct += torch.where(label!=-100, pred==label, 0).sum().item()
        total += torch.sum(label != -100).item()
    
    print(total, correct)

    eval_acc = correct / total
    eval_loss = eval_loss / len(eval_loader) 
    return eval_acc, eval_loss


In [29]:
@torch.no_grad()
def answer(model, loader):
    all_preds = []
    all_labels = []
    model.eval()
    for batch in loader:
        batch = to_device(batch, device)
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        outputs = model.generate(input_ids=input_ids, 
                                 attention_mask=attention_mask, 
                                 return_dict_in_generate=True, 
                                 pad_token_id=tokenizer.pad_token_id, 
                                 max_length=512, 
                                 top_k=15)
        
        decode_texts = tokenizer.batch_decode([l[l != 0] for l in outputs['sequences']])
        gold_texts = tokenizer.batch_decode([l[l != 0] for l in labels])
        # print(decode_texts, gold_texts)
        for gold, decode in zip(gold_texts, decode_texts):
            l = gold.replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')
            p = decode.replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')

            # if '<pad>' in gold:
            #     l_pad_idx = gold.index('<pad>')
            #     l = gold[:l_pad_idx].replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')
            # else:
            #     l = gold.replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')
            
            # if '<pad>' in decode:
            #     p_pad_idx = decode.index('<pad>')
            #     p = decode[:p_pad_idx].replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')
            # else:
            #     p = decode.replace(' ', '').replace('</s>', '').replace('<pad>','').replace('<s>', '')

            # print(l, p)
            all_labels.append(l)
            all_preds.append(p)
    
    return all_preds, all_labels


def accuracy(sys, gold):
    total = 0
    correct = 0
    for s, g in zip(sys, gold):
        if s == g:
            correct += 1
        total += 1
            
    accuracy = correct / total
    return accuracy, correct, total


## Train the Model

In [30]:
model.to(device)

LEDForConditionalGeneration(
  (led): LEDModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): LEDEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): LEDLearnedPositionalEmbedding(16384, 768)
      (layers): ModuleList(
        (0-5): 6 x LEDEncoderLayer(
          (self_attn): LEDEncoderAttention(
            (longformer_self_attn): LEDEncoderSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
              (value_global): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Linear(in_features=768, out_features=768, bias=True)
          )
     

In [31]:
# Experiment
train_dataset = CaseDataset(train_input, train_output)
train_loader = DataLoader(train_dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)

val_dataset = CaseDataset(val_input, val_output) 
val_loader = DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn, shuffle=False)


In [32]:
del df, train_dataset
gc.collect()

12

In [33]:
# experiment
epochs = 2
optimizer = optim.Adam(model.parameters(), lr=5e-5)

model.train()

# print(f"Training Question 1")
for epoch in range(epochs):
    print(f"Epoch {epoch+1}:")
    
    train_loss = train(model, train_loader, optimizer)
    print(f"Epoch {epoch+1} Training Loss: {train_loss}")

    eval_acc, eval_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")


Epoch 1:
Train Step: 0 Loss: 0.0910161018371582
Train Step: 200 Loss: 3.823288514055312
Train Step: 400 Loss: 2.2288485485315324
Train Step: 600 Loss: 1.926533841714263
Train Step: 800 Loss: 1.9272761825472116
Train Step: 1000 Loss: 1.7187148754298687
Train Step: 1200 Loss: 2.223682520762086
Train Step: 1400 Loss: 1.9094089418277145
Train Step: 1600 Loss: 1.424167960844934
Train Step: 1800 Loss: 1.2141001358721406
Train Step: 2000 Loss: 1.2233759127464146
Train Step: 2200 Loss: 1.2804372095828875
Train Step: 2400 Loss: 1.240971080409363
Train Step: 2600 Loss: 1.113220043713227
Train Step: 2800 Loss: 1.0740728620020672
Train Step: 3000 Loss: 1.0001478788349778
Train Step: 3200 Loss: 0.9772752594109625
Train Step: 3400 Loss: 0.9708926041750238
Train Step: 3600 Loss: 1.0541693927533924
Train Step: 3800 Loss: 1.0435213529877365
Train Step: 4000 Loss: 0.9857545538153499
Train Step: 4200 Loss: 0.8617260059528052
Train Step: 4400 Loss: 1.12155135453213
Train Step: 4600 Loss: 1.052642973223701

In [34]:
torch.save(model.state_dict(), 'led_2epoch_law_allqs.pt')

In [35]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 161,844,480 trainable parameters


## Evaluate the Model on Validation Set

In [42]:
def q_prompt(dataframe, q_no, raw_texts):
    input_texts = []
    outputs = []
    # long_cases = 0
    
    questions = dataframe.columns
    answers = dataframe.iloc[:,q_no]
    # print(len(raw_texts), len(answers))
    assert len(raw_texts) == len(answers)

    for i in range(len(answers)):
        full_text = raw_texts[i]
        text = full_text[full_text.find('Content:')+len('Content:'):]

        # if len(text) > 26000:
        #     # print(len(text))
        #     text = text[:26000]
        #     long_cases += 1

        text = text.replace('\n', ' ')
        text = text.replace('\xa0', ' ')
        text = text.replace('\t', ' ')
        text = text.replace('   ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ')
        # text
        # for word in stop_words:
        #     text = text.replace(' '+word+' ', ' ')

        if 'Schedule 1' in text:
            s_idx = text.find('Schedule 1')
            text = text[:s_idx]

        input_text = f'Question: {questions[q_no]} Text: {text}'  
        input_texts.append(input_text)

        output = str(answers[i])
        outputs.append(output)
        
    # print(len(input_texts), len(outputs))
    # print(input_texts[0], outputs[0]) 
    # print(long_cases)
    return input_texts, outputs

In [43]:
def q_preprocess(dataframe, q_no, tokenizer, raw_texts):
    input_texts, outputs = q_prompt(dataframe, q_no, raw_texts)   
    
    input_toks = tokenizer.batch_encode_plus(input_texts,
                                             add_special_tokens=False, 
                                             return_token_type_ids=False)
    output_toks = tokenizer.batch_encode_plus(outputs, 
                                              add_special_tokens=False,
                                              return_token_type_ids=False)
    # print(len(q1_train_input['input_ids']), len(q1_train_output['input_ids']))
    return input_toks, output_toks
    

In [54]:
def get_test_dataloader(df, q_no, tokenizer, raw_texts):
    input_toks, output_toks = q_preprocess(df, q_no, tokenizer, raw_texts)
    dataset = CaseDataset(input_toks, output_toks)
    dataloader = DataLoader(dataset, 
                            batch_size=4, 
                            collate_fn=collate_fn, 
                            shuffle=False)
    return dataloader

In [55]:
def answer_qs(val_df, q_no, tokenizer):
    loader = get_test_dataloader(val_df, q_no, tokenizer, val_raw_texts)
    # print(len(loader))
    
    questions = val_df.columns
    print(f'Q{q_no+1}: {questions[q_no]}')
    
    preds, golds = answer(model, loader)
    acc, correct, total = accuracy(preds, golds)
    acc = round(acc, 5)
    
    print(f"Accuracy for this question is: {acc*100}%")
    print('')
    
    return acc, preds

In [56]:
# del train_df, train_loader, count_parameters
# gc.collect()

In [57]:
del preds
gc.collect()

NameError: name 'preds' is not defined

In [58]:
acc_lst = []
with open ('LED_allqs_preds.txt', 'w', encoding='utf-8') as p:
    for i in range(0, val_df.shape[1]):
        p.write(f'Q{i+1}: {val_df.columns[i]}\n')
        acc, preds = answer_qs(val_df, i, tokenizer)
        acc_lst.append(acc)
        print(preds)
        print(' ')
        p.write(str(preds)+'\n')
        p.write('\n')
        del preds
        gc.collect()
avg_acc = sum(acc_lst) / len(acc_lst)
avg_acc

Q1: What is the file number of the case?
Accuracy for this question is: 54.839000000000006%

['TSL-0538-19', 'TEL-80084-17', 'TEL-80169-17', 'TEL-80248-17', 'TEL-80320-17-RVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVVVEL-VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVVEL-VVVVEL-VVVVVVVEL-VVVVVVVVVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVEL-VVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVV', 'TEL-80483-17-RVVVV2VVVVV2-RV2-RVVV2-RVV2-RVV2VV2VVVVVV2VVVVVVVVVVVVVVVVVVVVVVVVVVV2-RVVVVVVVVVVVV2-RVV2VVVV2VVVVVVVVV2-RVV2-RVVVVVVVVV2V2VV2V2-RVVV2VV2V2-RVV2VVVVV2VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV2VVVV2-RVVV2VVVVV2VVVVV2-RVV2-RV2-RV2V2-RV2V2-RV2-RVV2V2-RV2-RVV2-RV2-RV2-RVV2-RVV2-RVV2-RV2-RVV2

0.7249794871794872

In [59]:
assert len(acc_lst) == 39
acc_lst

[0.54839,
 0.70968,
 0.79032,
 0.90323,
 0.85484,
 0.51613,
 0.45161,
 0.90323,
 0.37097,
 0.98387,
 0.91935,
 0.56452,
 0.82258,
 0.53226,
 0.30645,
 0.12903,
 0.17742,
 0.67742,
 0.95161,
 0.91935,
 0.87097,
 0.98387,
 0.82258,
 0.80645,
 0.74194,
 0.85484,
 0.80645,
 0.91935,
 0.90323,
 0.96774,
 0.82258,
 0.80645,
 0.82258,
 0.85484,
 0.93548,
 0.5,
 0.37097,
 0.8871,
 0.56452]

In [60]:
q1_loader = get_test_dataloader(val_df, 0, tokenizer, val_raw_texts)
# print(len(loader))

questions = val_df.columns
print(f'Q1: {questions[0]}')

preds, golds = answer(model, q1_loader)
preds, golds

Q1: What is the file number of the case?


(['TSL-0538-19',
  'TEL-80084-17',
  'TEL-80169-17',
  'TEL-80248-17',
  'TEL-80320-17-RVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVVVEL-VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVVEL-VVVVEL-VVVVVVVEL-VVVVVVVVVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVEL-VVVEL-VVVVVEL-VVVVEL-VVVVEL-VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVEL-VVVVVVVVV',
  'TEL-80483-17-RVVVV2VVVVV2-RV2-RVVV2-RVV2-RVV2VV2VVVVVV2VVVVVVVVVVVVVVVVVVVVVVVVVVV2-RVVVVVVVVVVVV2-RVV2VVVV2VVVVVVVVV2-RVV2-RVVVVVVVVV2V2VV2V2-RVVV2VV2V2-RVV2VVVVV2VVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV2VVVV2-RVVV2VVVVV2VVVVV2-RVV2-RV2-RV2V2-RV2V2-RV2-RVV2V2-RV2-RVV2-RV2-RV2-RVV2-RVV2-RVV2-RV2-RVV2-RVV2V2-RVV2V2V2V2-RVV2V2V2V2-RV2V2V2VV2V2V2V2-RV2V2-RV2V2V2V2-RV2V2V2V2-RV2V2-RV2