## Imports and Device

In [1]:
import pandas as pd
import numpy as np
import requests
from bs4 import BeautifulSoup
import torch
import os
import re

from wordcloud import WordCloud
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM
from torch import cuda, nn, optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import GPT2LMHeadModel, BertTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
manual_seed = 595
torch.manual_seed(manual_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Read the Cleaned Data

### Define the paths

In [3]:
# run locally
text_path = '../formatted_cases/'
file = '../../annotated_data.xlsx'
REGEX = r';+'
sup_path = '../annotated_sup/'
multi_path = text_path + 'multiple_files/'

In [None]:
# run on Google Colab
from google.colab import drive
drive.mount('/content/gdrive')
text_path = '/content/gdrive/595/formatted_cases/'
file = '/content/gdrive/595/annotated_data.xlsx'
REGEX = r';+'
sup_path = '/content/gdrive/595/annotated_sup/'
multi_path = text_path + 'multiple_files/'

### Clean the Dataframe

In [4]:
df = pd.read_excel(file)
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace(' and ', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace(' ', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('/', ';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.strip(';')
df['What is the file number of the case?'] = df['What is the file number of the case?'].apply(lambda x: re.sub(REGEX, ';', x))
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('File;number:;', '')
df['What is the file number of the case?'] = df['What is the file number of the case?'].str.replace('TET-89650-18;TET-89650-18', 'TET-89650-18;TEL-90138-18')
df = df.fillna('Not stated')
df = df.replace('Not applicable', 'Not stated')

print(df.shape)
df.head(6)

(702, 54)


Unnamed: 0,Timestamp,Email Address,What is the file number of the case?,What was the date of the hearing? [mm/dd/yyyy],What was the date of the decision? [mm/dd/yyyy],Who was the member adjudicating the decision?,What was the location of the landlord tenant board?,Did the decision state the landlord was represented?,Did the decision state the landlord attended the hearing?,Did the decision state the tenant was represented?,...,"Did the decision mention the tenant’s difficulty finding alternative housing for any reason e.g.physical limitations, reliance on social assistance, etc.?","If yes to the previous question, which of the following were applicable to the tenant?",Did the decision state the tenant was given prior notice for the eviction?,"If the tenant was given prior notice for the eviction, how much notice was given?",Did the decisions state postponement would result in the tenant accruing additional arrears?,Which other specific applications of the landlord or the tenant were mentioned?,Did the decision mention the validity of an N4 eviction notice?,Were there detail(s) in the decision not captured by this questionnaire that should be included?,Exec Review,Review Status
0,2020-11-18 00:31:40.706,dylan.juschko@mail.utoronto.ca,CEL-87788-19,2019-10-16 00:00:00,2020-06-04 00:00:00,Sonia Anwar-Ali,Toronto,Yes,Not stated,No,...,No,Not stated,No,Not stated,No,L2: Application to End a Tenancy and Evict a T...,No,Tenant was a single mother with no support fro...,Not stated,Not stated
1,2020-11-18 19:26:29.581,dylan.juschko@mail.utoronto.ca,CEL-90549-19,2020-01-22 00:00:00,2020-01-10 00:00:00,Shelby Whittick,Mississauga,Yes,Yes,No,...,No,Not stated,Yes,Not stated,Yes,No other specific applications were mentioned,No,Not stated,AW,Complete
2,2020-12-24 09:19:21.479,kayly.machado@mail.utoronto.ca,TEL-94478-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,...,No,Not stated,Yes,Not stated,No,N13: Notice to End your Tenancy Because the La...,No,Previous decision TEL-92736-18 < This decision...,AW,Complete
3,2020-12-24 06:13:17.400,kayly.machado@mail.utoronto.ca,TEL-94493-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,...,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,There were 7 previous application for non-paym...,AW,Complete
4,2020-11-19 17:02:36.702,joseph.galinsky@mail.utoronto.ca,CEL-72994-18,2018-03-07 00:00:00,2018-03-14 00:00:00,Avril Cardoso,Mississauga,Yes,No,Yes,...,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,Third Application by Landlord in past 6 months...,AW,Complete
5,2020-11-19 17:14:22.294,joseph.galinsky@mail.utoronto.ca,CEL-73021-18,2018-06-15 00:00:00,2018-06-18 00:00:00,Avril Cardoso,Mississauga,Yes,No,No,...,No,Not stated,Yes,Not stated,No,L1: Application to Evict a Tenant for Non-paym...,No,Tenant did not show up because hearing took pl...,AW,Complete


In [5]:
df.columns   #`Timestamp` is not the time of the case

Index(['Timestamp', 'Email Address', 'What is the file number of the case?',
       'What was the date of the hearing? [mm/dd/yyyy]',
       'What was the date of the decision? [mm/dd/yyyy]',
       'Who was the member adjudicating the decision?',
       'What was the location of the landlord tenant board?',
       'Did the decision state the landlord was represented?',
       'Did the decision state the landlord attended the hearing?',
       'Did the decision state the tenant was represented?',
       'Did the decision state the tenant attended the hearing?',
       'Did the decision state the landlord was a not-for-profit landlord (e.g. Toronto Community Housing)?',
       'Did the decision state the tenant was collecting a subsidy?',
       'What was the outcome of the case?',
       'What was the length of the tenancy, or in other words, how long had the tenants lived at the residence in question? ',
       'What was the monthly rent?',
       'What was the amount of the rental de

In [6]:
df_unique = df.drop_duplicates(subset=['What is the file number of the case?'])
df_unique = df_unique.reset_index(drop=True)

print(df_unique.shape)
df_unique.head(6)

(682, 54)


Unnamed: 0,Timestamp,Email Address,What is the file number of the case?,What was the date of the hearing? [mm/dd/yyyy],What was the date of the decision? [mm/dd/yyyy],Who was the member adjudicating the decision?,What was the location of the landlord tenant board?,Did the decision state the landlord was represented?,Did the decision state the landlord attended the hearing?,Did the decision state the tenant was represented?,...,"Did the decision mention the tenant’s difficulty finding alternative housing for any reason e.g.physical limitations, reliance on social assistance, etc.?","If yes to the previous question, which of the following were applicable to the tenant?",Did the decision state the tenant was given prior notice for the eviction?,"If the tenant was given prior notice for the eviction, how much notice was given?",Did the decisions state postponement would result in the tenant accruing additional arrears?,Which other specific applications of the landlord or the tenant were mentioned?,Did the decision mention the validity of an N4 eviction notice?,Were there detail(s) in the decision not captured by this questionnaire that should be included?,Exec Review,Review Status
0,2020-11-18 00:31:40.706,dylan.juschko@mail.utoronto.ca,CEL-87788-19,2019-10-16 00:00:00,2020-06-04 00:00:00,Sonia Anwar-Ali,Toronto,Yes,Not stated,No,...,No,Not stated,No,Not stated,No,L2: Application to End a Tenancy and Evict a T...,No,Tenant was a single mother with no support fro...,Not stated,Not stated
1,2020-11-18 19:26:29.581,dylan.juschko@mail.utoronto.ca,CEL-90549-19,2020-01-22 00:00:00,2020-01-10 00:00:00,Shelby Whittick,Mississauga,Yes,Yes,No,...,No,Not stated,Yes,Not stated,Yes,No other specific applications were mentioned,No,Not stated,AW,Complete
2,2020-12-24 09:19:21.479,kayly.machado@mail.utoronto.ca,TEL-94478-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,...,No,Not stated,Yes,Not stated,No,N13: Notice to End your Tenancy Because the La...,No,Previous decision TEL-92736-18 < This decision...,AW,Complete
3,2020-12-24 06:13:17.400,kayly.machado@mail.utoronto.ca,TEL-94493-18,2018-10-31 00:00:00,2018-11-21 00:00:00,Ruth Carey (Vice Chair),Toronto,Yes,Yes,No,...,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,There were 7 previous application for non-paym...,AW,Complete
4,2020-11-19 17:02:36.702,joseph.galinsky@mail.utoronto.ca,CEL-72994-18,2018-03-07 00:00:00,2018-03-14 00:00:00,Avril Cardoso,Mississauga,Yes,No,Yes,...,No,Not stated,Yes,Not stated,No,No other specific applications were mentioned,No,Third Application by Landlord in past 6 months...,AW,Complete
5,2020-11-19 17:14:22.294,joseph.galinsky@mail.utoronto.ca,CEL-73021-18,2018-06-15 00:00:00,2018-06-18 00:00:00,Avril Cardoso,Mississauga,Yes,No,No,...,No,Not stated,Yes,Not stated,No,L1: Application to Evict a Tenant for Non-paym...,No,Tenant did not show up because hearing took pl...,AW,Complete


In [7]:
info_lst = df_unique.columns[2:]

raw_file_text = []
# cases_info = []
# raw_file_name = []

for i in range(len(df_unique)):
    # cases_info.append({})
    answers = df_unique.iloc[i,2:]
    # to complete the scraping function
    # text = scraping(file_no)
    file_no = answers[0]
    if not os.path.isfile(text_path+file_no+'.txt'):
        print(f'{file_no} not found. Going to the supplement directory.')
        # passed_cases.append(file_no)
        if not os.path.isfile(sup_path+file_no+'.txt'):
            print(f'{file_no} not found. Going to the multiple directory.')
            with open (multi_path+file_no+'.txt') as t:
                # print(t.read())
                # cases_info[-1]['text'] = t.read()
                raw_file_text.append(t.read())
                # raw_file_name.append(file_no+'.txt')
        else:
            with open (sup_path+file_no+'.txt') as t:
                # cases_info[-1]['text'] = t.read()
                raw_file_text.append(t.read())
                # raw_file_name.append(file_no+'.txt')
    else:
        with open (text_path+file_no+'.txt') as t:
            # cases_info[-1]['text'] = t.read()
            raw_file_text.append(t.read())
            # raw_file_name.append(file_no+'.txt')

TET-89650-18;TEL-90138-18 not found. Going to the supplement directory.
TNL-00793-18;TNL-01183-18 not found. Going to the supplement directory.
TNL-00793-18;TNL-01183-18 not found. Going to the multiple directory.
TNL-03299-18;TNT-00589-17 not found. Going to the supplement directory.
TNL-03299-18;TNT-00589-17 not found. Going to the multiple directory.
TNL-04435-18;TNL-03907-18 not found. Going to the supplement directory.
HOL-02144-17;HOT-02146-17 not found. Going to the supplement directory.
TEL-87475-18;TET-86819-17;TET-88355-18 not found. Going to the supplement directory.
TEL-87475-18;TET-86819-17;TET-88355-18 not found. Going to the multiple directory.
SWL-08112-17;SWL-08113-17 not found. Going to the supplement directory.
SWL-12547-18;SWL-12548-18 not found. Going to the supplement directory.
SWL-12547-18;SWL-12548-18 not found. Going to the multiple directory.
SWL-13901-18;SWT-14627-18 not found. Going to the supplement directory.
TEL-77442-17;TET-77790-17 not found. Going to 

### Split the Train Dataframe and Validation Dataframe

In [8]:
train_df = df_unique.iloc[:620, 2:]
val_df = df_unique.iloc[620:, 2:].reset_index(drop=True)
train_df.shape, val_df.shape

((620, 52), (62, 52))

In [9]:
for i, q in enumerate(train_df.columns):
    print(i, q)

0 What is the file number of the case?
1 What was the date of the hearing? [mm/dd/yyyy]
2 What was the date of the decision? [mm/dd/yyyy]
3 Who was the member adjudicating the decision?
4 What was the location of the landlord tenant board?
5 Did the decision state the landlord was represented?
6 Did the decision state the landlord attended the hearing?
7 Did the decision state the tenant was represented?
8 Did the decision state the tenant attended the hearing?
9 Did the decision state the landlord was a not-for-profit landlord (e.g. Toronto Community Housing)?
10 Did the decision state the tenant was collecting a subsidy?
11 What was the outcome of the case?
12 What was the length of the tenancy, or in other words, how long had the tenants lived at the residence in question? 
13 What was the monthly rent?
14 What was the amount of the rental deposit? 
15 If any rent increases occurred, what was the rent after the increase(s)?
16 If any rent increases occurred, when did the rent increa

## Initialize the Tokenizer and the Model

In [169]:
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding_side='left')
# tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModelForCausalLM.from_pretrained("gpt2")

## Preprocessing

In [170]:
stop_words = [
    "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in",
    "into", "is", "it",# "no", "not",
    "of", "on", "or", "such", "that", "the",
    "their", "then", "there", "these", "they", "this", "to", "was", "will", "with",
    "about", "all", "also", "any", "can", "could", "do", "does", "from", "has",
    "have", "how", "however", "i", "if", "may", "might", "my", "need", "our",
    "should", "so", "some", "than", "their", "them", "there", "these", "thing",
    "things", "think", "us", "want", "way", "we", "what", "when", "where", "which",
    "who", "why", "would", "you",
    'canlii'
]

In [171]:
first = [0, 1, 4, 5, 6, 7, 8, 9, 12]
second = []
third = []
fourth = [11,]
fifth = [2, 3, 6, 11]

In [172]:
def prompt(dataframe, q_no, raw_texts):
    input_texts = []
    outputs = []
    long_cases = 0
    
    questions = dataframe.columns
    answers = dataframe.iloc[:,q_no]
    # print(len(raw_texts), len(answers))
    assert len(raw_texts) == len(answers)

    for i in range(len(answers)):
        full_text = raw_texts[i]
        text = full_text[full_text.find('Content:')+len('Content:'):]
        
        if len(text) > 20000:
            # print(len(text))
            text = text[:20000]
            long_cases += 1
        
        text = text.replace('\n', ' ')
        text = text.replace('\xa0', ' ')
        text = text.replace('\t', ' ')
        text = text.replace('   ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ')
        # text
        for word in stop_words:
            text = text.replace(' '+word+' ', ' ')
        
        # There would be overlaps, so in each if statement, 
        # the input and output need to be appended
        if q_no in first:
            text = text[:int(0.2*len(text))]
            input_text = text + '\n' + questions[q_no]
            input_texts.append(input_text)
            output = str(answers[i])
            outputs.append(output)
       
        if q_no in second:
            text = text[int(0.2*len(text)):int(0.4*len(text))]
            input_text = text + '\n' + questions[q_no]
            input_texts.append(input_text)
            output = str(answers[i])
            outputs.append(output)
            
        if q_no in third:
            text = text[int(0.4*len(text)):int(0.6*len(text))]
            input_text = text + '\n' + questions[q_no]
            input_texts.append(input_text)
            output = str(answers[i])
            outputs.append(output)
            
        if q_no in fourth:
            text = text[int(0.6*len(text)):int(0.8*len(text))]
            input_text = text + '\n' + questions[q_no]
            input_texts.append(input_text)
            output = str(answers[i])
            outputs.append(output)
        
        if q_no in fifth:
            text = text[int(0.8*len(text)):]
            input_text = text + '\n' + questions[q_no]
            input_texts.append(input_text)
            output = str(answers[i])
            outputs.append(output)
        
    # print(len(input_texts), len(outputs))
    # print(input_texts[0], outputs[0]) 
    print(long_cases)
    return input_texts, outputs

            
def preprocess(dataframe, q_no, tokenizer, raw_texts):
    input_texts, outputs = prompt(dataframe, q_no, raw_texts)   
        
    # # for AutoTokenizer
    # concat_inputs = tokenizer(
    #     input_texts, #outputs, 
    #     return_token_type_ids=False
    # ) 
    
    # for BertTokenizer
    concat_train = tokenizer(
        input_texts, outputs, 
        return_token_type_ids=False
    )
    
    # val = tokenizer(
    #     input_texts[620:], outputs[620:], 
    #     return_token_type_ids=False
    # )
    
    #  concat_inputs = [tokenizer(
    #     input_text, output, 
    #     return_token_type_ids=False
    # ) for input_text, output in zip(input_texts, outputs)]
    
    return concat_train

In [173]:
train_raw_texts = raw_file_text[:620]
val_raw_texts = raw_file_text[620:]

In [174]:
# q1_val = preprocess(val_df, 1, tokenizer, val_raw_texts)

In [175]:
q1_train = preprocess(train_df, 0, tokenizer, train_raw_texts)
q1_val = preprocess(val_df, 0, tokenizer, val_raw_texts)

19


Token indices sequence length is longer than the specified maximum sequence length for this model (614 > 512). Running this sequence through the model will result in indexing errors


3


In [176]:
# # for AutoTokenizer
# print("Input length:", len(q1_train))
# print(" ")
# print("Input example:\n", tokenizer.decode(q1_train['input_ids'][0]))
# print(" ")
# print("Input ID example:\n", q1_train['input_ids'][0])
# print(" ")
# print("Tokens:\n", [tokenizer.convert_ids_to_tokens(id) for id in q1_train['input_ids'][0]])
# print(" ")
# print("Attention Mask:", q1_train['attention_mask'][0])


In [177]:
# for BertTokenizer
print("Input length:", len(q1_train))
print(" ")
print("Input example:\n", tokenizer.decode(q1_train['input_ids'][0]))
print(" ")
print("Input ID example:\n", q1_train['input_ids'][0])
print(" ")
print("Tokens:\n", [tokenizer.convert_ids_to_tokens(id) for id in q1_train['input_ids'][0]])
print(" ")
print("Attention Mask:", q1_train['attention_mask'][0])


Input length: 2
 
Input example:
 [CLS] Order under Section 69 Residential Tenancies Act, 2006 File Number : CEL - 87788 - 19 M. C. ( the'Landlord') applied order terminate tenancy evict M. K. ( the'Tenant') because he been persistently late paying his rent. The Landlord claimed compensation each day Tenant remained unit after termination date. This application heard Toronto October 16, 2019. The Landlord Tenant attended hearing. The Landlord represented S. K. Also attendance Landlord ’ s property manager, B. A. Determinations : 1. By background, month month tenancy rent due first month amount $ 1398. 79. This tenancy commenced June 1, 2012. 2. The Landlord ’ s L2 application based N8 notice termination served Tenant April 24, 2019 termination date June 30, 2019 alleging Tenant been persistently late paying rent since January 2018. 3. Since N8 notice termination served Tenant, Tenant continued pay rent late months May through October 2 What is the file number of the case? [SEP] CEL - 8

## Create the Dataset

In [178]:
PAD = tokenizer.pad_token_id
SEP = tokenizer.sep_token_id

In [179]:
# for BertTokenizer
class CaseDataset(Dataset):

    def __init__(self, data) -> None:
        super().__init__()
        self.data = data
        self.labels = self._get_label(data['input_ids'])
        # self.is_inference = False
    
    def __len__(self):
        return len(self.data['input_ids'])
    
    def _get_label(self, inputs):
        labels = []
        for inp in inputs:
            sep_idx = inp.index(SEP)
            label = [-100] * len(inp)
            label[sep_idx+1:] = inp[sep_idx+1:]
            labels.append(label)
        # print(len(labels))
        return labels
    
#     def inference(self):
#         self.is_inference = True

#     def train(self):
#         self.is_inference = False
        
#     def is_inference(self):
#         return self.is_inference

    def __getitem__(self, idx):
        input_ids = self.data["input_ids"][idx]
        att_mask = self.data["attention_mask"][idx]
        # print(len(self.labels), idx)
        label = self.labels[idx]
        return {"input_ids": input_ids, "attention_mask":att_mask, "labels":label}
      
        # if not self.is_inference:
        #     return {"input_ids": self.data["input_ids"][idx], "attention_mask": self.data["attention_mask"][idx], "labels": self.labels[idx]}
        # else:
        #     sep_idx = self.data["input_ids"][idx].index(SEP)
        #     input_ids = self.data["input_ids"][idx][:sep_idx+1]
        #     att_mask = self.data["attention_mask"][idx][:sep_idx+1]
        #     # print(len(self.labels), idx)
        #     label = self.labels[idx][sep_idx+1:]
        #     # print(len(input_ids), len(att_mask), len(label))
        #     return {"input_ids": input_ids, "attention_mask":att_mask, "labels":label}


def collate_fn(batch):
    batch_input_ids = [torch.LongTensor(example["input_ids"]) for example in batch]
    batch_att_mask = [torch.LongTensor(example["attention_mask"]) for example in batch]
    batch_label = [torch.LongTensor(example["labels"]) for example in batch]
    
    padded_batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=PAD)
    padded_batch_att_mask = pad_sequence(batch_att_mask, batch_first=True, padding_value=PAD)
    padded_batch_label = pad_sequence(batch_label, batch_first=True, padding_value=-100)
    
    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}

def to_device(data, device):
    new_data = {}
    for k in data:
        new_data[k] = data[k].to(device)
    return new_data

## Training

In [180]:
# Experiment
q1_train_dataset = CaseDataset(q1_train)
q1_train_loader = DataLoader(q1_train_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

q1_val_dataset = CaseDataset(q1_val)
q1_val_loader = DataLoader(q1_val_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)


In [181]:
def train(model:nn.Module, train_loader:DataLoader, optimizer:optim.Optimizer, log_step=50):
    model.train()
    epoch_loss = 0.0
    log_loss = 0.0
    for idx, batch in enumerate(train_loader):
        model.zero_grad()
        batch = to_device(batch, device)
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        log_loss += loss.item()
        if idx % log_step == 0:
            print(f"Train Step: {idx} Loss: {log_loss / log_step}")
            log_loss = 0.0
    return epoch_loss / len(train_loader)
        

@torch.no_grad()
def evaluate(model:nn.Module, eval_loader:DataLoader):
    eval_loss = 0.0
    correct = 0
    total = 0
    model.eval()
    for batch in eval_loader:
        batch = to_device(batch, device)
        output = model(**batch)
        loss = output.loss
        eval_loss += loss.item()
        pred = output.logits.argmax(-1)[..., :-1]
        label = batch["labels"][..., 1:]
        correct += torch.where(label!=-100, pred==label, 0).sum().item()
        total += torch.sum(label != -100).item()
    
    print(total, correct)

    eval_acc = correct / total
    eval_loss = eval_loss / len(eval_loader) 
    return eval_acc, eval_loss


In [182]:
# experiment
epochs = 1
optimizer = optim.Adam(model.parameters(), lr=5e-5)

model.train()

for epoch in range(epochs):
    print(f"Training Epoch {epoch+1}")
    
    train_loss = train(model, q1_train_loader, optimizer)
    print(f"Epoch {epoch+1} Training Loss: {train_loss}")
    
    eval_acc, eval_loss = evaluate(model, q1_val_loader)
    print(f"Epoch {epoch} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")


Training Epoch 1
Train Step: 0 Loss: 0.11970531463623046


KeyboardInterrupt: 

In [None]:
def get_dataloader(df, q_no, tokenizer, raw_texts):
    data = preprocess(df, q_no, tokenizer, raw_texts)
    dataset = CaseDataset(data)
    dataloader = DataLoader(dataset, 
                            batch_size=2, 
                            collate_fn=collate_fn, 
                            shuffle=False)
    return dataloader
    
def train_qs(train_df, val_df, q_no, tokenizer, optimizer):
    train_loader = get_dataloader(train_df, q_no, tokenizer, train_raw_texts)
    val_loader = get_dataloader(val_df, q_no, tokenizer, val_raw_texts)
    
    questions = train_df.columns
    print(questions[q_no])
    
    # train 1 epoch only, given the small data
    train_loss = train(model, train_loader, optimizer)
    print(f"Epoch {epoch} Training Loss: {train_loss}")
    
    eval_acc, eval_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")
    
    print('')

In [None]:
# starting from 1 because the first question has been trained on
for i in range(1, train_df.shape[0]):
    train_qs(train_df, val_df, i, tokenizer, optimizer)

In [None]:
@torch.no_grad()
def answer(model, loader):
    all_preds = []
    all_labels = []
    model.eval()
    for batch in loader:   
        batch = to_device(batch, device)
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        
        # truncated_input = []
        # truncated_attn = []
        # for input_id, attn_mask in zip(input_ids, attention_mask):
        #     # print(input_id.shape)
        #     sep_idx = torch.nonzero(input_id == SEP)[0]
        #     # print(sep_idx)
        #     truncated_input.append(input_id[:sep_idx])
        #     truncated_attn.append(attn_mask[:sep_idx])
        # truncated_input_ids = torch.stack(truncated_input)
        # truncated_attention = torch.stack(truncated_attn)
        # # print(truncated_input_ids.shape)
        
        # pos_ids = batch["position_ids"]
        labels = batch["labels"]
        outputs = model.generate(input_ids=input_ids, #truncated_input_ids, 
                                 attention_mask=attention_mask, 
                                 return_dict_in_generate=True, 
                                 pad_token_id=PAD, #50256, 
                                 max_length=1024, 
                                 top_k=10) 
                                 #stopping_criteria=stop_criteria_list)
        pred_start = torch.nonzero(input_ids==SEP, as_tuple=True)[1][0] + 1
        truncated_outputs = []
        for out in outputs["sequences"]:
            sep_idxs = torch.nonzero(out==SEP, as_tuple=True)[0]
            if len(sep_idxs) == 1:
                end_idx = -1
            else:
                end_idx = sep_idxs[1]
            truncated_outputs.append(out[pred_start:end_idx])
        decode_texts = tokenizer.batch_decode(truncated_outputs)
        gold_texts = tokenizer.batch_decode([l[l != -100][:-1] for l in labels])

        for gold, decode in zip(gold_texts, decode_texts):
            all_labels.append(gold)
            all_preds.append(decode)
        # all_preds = process_sys(all_preds)
    
    return all_preds, all_labels


def accuracy(sys, gold):
    total = 0
    correct = 0
    for s, g in zip(sys, gold):
        if s == g:
            correct += 1
        total += 1
            
    accuracy = correct / total
    return accuracy, correct, total


# def left_pad_sequence(sequence, batch_first, padding_value=0):
#     padded = []
#     max_len = max(len(each) for each in sequence)
#     for each in sequence:
#         if not isinstance(each, torch.LongTensor):
#             each = torch.LongTensor(each)
#         pad = torch.full((max_len-len(each),), fill_value=padding_value,dtype=each.dtype)
#         padded.append(torch.cat([pad, each]))
#     padded = torch.vstack(padded)
#     if not batch_first:
#         padded = padded.permute(1, 0, 2)
#     return padded
        
# def inference_colate_fn(batch):
#     batch_input_ids = [torch.LongTensor(each["input_ids"]) for each in batch]
#     batch_att_mask = [torch.LongTensor(each["attention_mask"]) for each in batch]
#     batch_label = [torch.LongTensor(each["labels"]) for each in batch]
#     batch_position_ids = [torch.arange(len(each["input_ids"]), dtype=torch.long) for each in batch]
    
#     padded_batch_input_ids = left_pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
#     padded_batch_att_mask = left_pad_sequence(batch_att_mask, batch_first=True, padding_value=0)
#     padded_batch_label = pad_sequence(batch_label, batch_first=True, padding_value=-100)
#     # padded_batch_position_ids = left_pad_sequence(batch_position_ids, batch_first=True, padding_value=0)
#     # return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "position_ids":padded_batch_position_ids, "labels": padded_batch_label}   
#     return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}    

# val_dataset.inference()


In [None]:
def preprocess_pred(dataframe, q_no, tokenizer, raw_texts):
    input_texts, outputs = prompt(dataframe, q_no, raw_texts)   
        
    # # for AutoTokenizer
    # concat_inputs = tokenizer(
    #     input_texts, #outputs, 
    #     return_token_type_ids=False
    # ) 
    
    # for BertTokenizer
    test = tokenizer(
        input_texts, 
        return_token_type_ids=False
    )
    
    # val = tokenizer(
    #     input_texts[620:], outputs[620:], 
    #     return_token_type_ids=False
    # )
    
    #  concat_inputs = [tokenizer(
    #     input_text, output, 
    #     return_token_type_ids=False
    # ) for input_text, output in zip(input_texts, outputs)]
    
    return test

def get_test_dataloader(df, q_no, tokenizer, raw_texts):
    data = preprocess_pred(df, q_no, tokenizer, raw_texts)
    dataset = CaseDataset(data)
    dataloader = DataLoader(dataset, 
                            batch_size=32, 
                            collate_fn=collate_fn, 
                            shuffle=False)
    return dataloader

In [None]:
def answer_qs(val_df, q_no, tokenizer):
    loader = get_test_dataloader(val_df, q_no, tokenizer, val_raw_texts)
    
    questions = val_df.columns
    print(questions[q_no])
    
    preds, golds = answer(model, loader)
    acc, correct, total = accuracy(preds, golds)
    acc = round(acc, 5)
    
    print(f"Accuracy for this question is: {acc*100}%")
    print('')
    
    return acc, preds

In [None]:
acc_lst = []
for i in range(0, val_df.shape[0]):
    accuracy, _ = answer_qs(val_df, i, tokenizer)
    acc_lst.append(accuracy)
avg_acc = sum(acc_lst) / len(acc_lst)

In [None]:
torch.save(model.state_dict(), "gpt2_1epoch_law.pt")