In [2]:
#!pip install datasets

In [59]:
#Load the Dataset
from datasets import load_dataset
import pandas as pd
import random
import re
import torch
import os
import torch.nn.functional as F

In [4]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [5]:
# Load the 'us_bills'
dataset = load_dataset("pile-of-law/pile-of-law", "us_bills")

Loading Dataset Infos from /Users/urvimidha/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
Overwrite dataset info from restored data version if exists.
Loading Dataset info from /Users/urvimidha/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
Found cached dataset pile-of-law (/Users/urvimidha/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60)
Loading Dataset info from /Users/urvimidha/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


In [6]:
len(set(dataset['train']['text']))

84362

In [7]:
def clean_text(bills):
    clean_bills = []
    for bill in bills:
      # Remove all newline and tab characters
      text = bill.replace('\n', ' ').replace('\t', ' ')
      # Remove exactly three lowercase letters helps in (removing roman numeral numbering )
      text = re.sub(r'\[([a-z]{1,3})\]', ' ', text, flags=re.IGNORECASE)
      clean_bills.append(text.lower())

    return clean_bills

In [8]:
print(clean_text(dataset['train']['text'][:2]))

['    113 s2875 is: national guard investigations transparency and improvement act of 2014 u.s. senate 2014-09-18 text/xml en pursuant to title 17 section 105 of the united states code, this file is not subject to copyright protection and is in the public domain.      ii   113th congress2d session   s. 2875   in the senate of the united states       september 18, 2014    mr. begich introduced the following bill; which was read twice and referred to the committee on armed services      a bill   to codify in law the establishment and duties of the office of complex administrative     investigations in the national guard bureau, and for other purposes.    1.short titlethis act may be cited as the national guard investigations transparency and improvement act of 2014.2.codification in law of establishment and duties of the office of complex administrative     investigations in the national guard bureau(a)in generalthere is in the office of the chief of the national guard bureau the office 

In [99]:
# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")



In [75]:
def predict_next_word(text, actual_next_word, model=model, tokenizer=tokenizer, top_k=3,):
    """
    Predicts the next word for a given text using GPT-2.
    """
    total_log_prob = 0.0
    input_ids = text

    #### print("Predict next word, input shape: ", input_ids.shape)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #print(device)
    #print("input_ids device: ", input_ids.device)

    with torch.no_grad():
        outputs = model(input_ids.to(device))
        #print("In predict next word Outputs: ", outputs, outputs.logits.shape)
        logits = outputs.logits

#     # Predict the next token (old approach of using only 1 word)
#     predicted_token_id = torch.argmax(logits[:, -1, :], dim=-1).item()
#     predicted_token = tokenizer.decode(predicted_token_id)

    # Predict top k words
    top_k_token_ids = torch.topk(logits[:, -1, :], top_k, dim=-1).indices.squeeze().cpu().tolist()
    top_k_tokens = [tokenizer.decode(token_id) for token_id in top_k_token_ids]
    
    
    #perplexity
    
    log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
    actual_next_word_log_prob = log_probs[0, actual_next_word].item()
    total_log_prob += actual_next_word_log_prob
    total_words= len(input_ids)
    
    # Calculate perplexity
    perplexity = torch.exp(torch.tensor(-total_log_prob / total_words)) if total_words > 0 else float('inf')
    

    return top_k_tokens, perplexity #predicted_token


In [76]:
def predict_every_n_words(bill_text, bill_number, model, interval_min=20, interval_max=40):
    words = tokenizer.encode(bill_text, return_tensors="pt", truncation=True, max_length=512) #bill_text.split()
    n = words.shape[1]
    results = []

    #### print(words.shape)
    #### print(words)

    random.seed(42)

    # Start from the first 20 words and predict every 20-40 words
    i = 15
    while i < n and i < 512:
        text_chunk = words[:,:i] #' '.join(words[:i])
        actual_next_word = words[:,i]
        #print("Current text chunk: ",  tokenizer.decode(text_chunk[0], skip_special_tokens=True))
        predicted_next_word, perplexity = predict_next_word(text_chunk, model=model, actual_next_word=actual_next_word) #can add parameter top_k=k
        english_actual_next_word = tokenizer.decode(actual_next_word, skip_special_tokens=True)
        
        
        #Modification in prediction for top k
        pred_res = False
        for word in predicted_next_word:
            if word.lower() == english_actual_next_word.lower():
                pred_res = True
       
        res = {
            "bill_number": bill_number,
            "input_text_length": i,
            "actual_next_word": english_actual_next_word,
            "predicted_next_word": predicted_next_word,
            "top_k_word_prediction": pred_res, #english_actual_next_word.lower() == predicted_next_word.lower()
            "top_word_prediction": english_actual_next_word.lower() == predicted_next_word[0].lower(),
            #"perplexity": perplexity
            }
        results.append(res)
        #### print("Appending row: ", res)
        # Move to the next interval of 20-40 words
        interval = random.randint(interval_min, interval_max)
        i += interval

    return results


In [94]:
def evaluate_bills(bills, model):
    overall_predictions = []
    total_correct = 0
    total_total = 0
    total_correct_topk=0
    

    for idx, bill_text in enumerate(bills):
        print(f"Processing bill {idx} of length {len(bill_text)} characters...")


        results = predict_every_n_words(bill_text, idx, model=model)
        display(pd.DataFrame(results).drop('perplexity', axis=1))

        overall_predictions.extend(results)

        #top k accuracy for the current bill
        correct_predictions_topk = sum(1 for result in results if result['top_k_word_prediction'])
        bill_accuracy_topk = correct_predictions_topk / len(results) if len(results) > 0 else 0
#         tot_perplexity = 0.0
#         for result in results:
#             tot_perplexity += result['perplexity']
#         perplexity = tot_perplexity/len(results)
        
        #accuracy for the current bill
        correct_predictions = sum(1 for result in results if result['top_word_prediction'])
        bill_accuracy = correct_predictions / len(results) if len(results) > 0 else 0

        print(f"Bill {idx} accuracy: {bill_accuracy * 100:.2f}%\n")
        print(f"Bill {idx} top k accuracy: {bill_accuracy_topk * 100:.2f}%\n")
        #print(f"Bill {idx} perplexity: {perplexity:.2f}\n")
        
        total_correct += correct_predictions
        total_total += len(results)
        total_correct_topk += correct_predictions_topk

    df_overall_predictions = pd.DataFrame(overall_predictions) #dataframe
    
    overall_accuracy = total_correct / total_total if total_total > 0 else 0
    overall_accuracy_topk = total_correct_topk / total_total if total_total > 0 else 0
#    avg_perplexity = df_overall_predictions["perplexity"].mean()
    
    print(f"Overall accuracy for all bills: {overall_accuracy * 100:.2f}%")
    print(f"Overall topk accuracy for all bills: {overall_accuracy_topk * 100:.2f}%")
#    print(f"Overall average perplexity all bills: {avg_perplexity:.2f}")
    return df_overall_predictions

In [100]:
cleaned_validation = clean_text(dataset["validation"]["text"][:15])

In [101]:
df_predictions_bills = evaluate_bills(cleaned_validation, model)

Processing bill 0 of length 6705 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,0,15,act,"[,, ., (]",False,False
1,0,55,the,"[ the, im, no]",True,True
2,0,78,the,"[ the, :, title]",True,True
3,0,98,(,"[,, 2016, 2015]",False,False
4,0,126,penn,"[ ut, the, h]",False,False
5,0,153,of,"[ man, lady, woman]",False,False
6,0,180,of,"[,, of, y]",True,False
7,0,204,m,"[ m, , and]",True,True
8,0,227,.,"[., ,, \n]",True,True
9,0,264,m,"[ m, n, j]",True,True


Bill 0 accuracy: 57.89%

Bill 0 top k accuracy: 63.16%

Processing bill 1 of length 48228 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,1,15,the,"[ fiscal, the, FY]",True,False
1,1,55,the,"[ title, the, Public]",True,False
2,1,78,i,"[ , 116, 117]",False,False
3,1,98,united,"[ United, united, state]",True,True
4,1,126,second,"[ second, third, next]",True,True
5,1,153,,"[________________, , �]",False,False
6,1,180,budget,"[ , S, s]",False,False
7,1,204,for,"[ for, of, and]",True,True
8,1,227,(,"[ , (, 2]",True,False
9,1,264,1,"[1, 01, 2]",True,True


Bill 1 accuracy: 36.84%

Bill 1 top k accuracy: 57.89%

Processing bill 2 of length 4566 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,2,15,and,"[ and, to, ,]",True,True
1,2,55,and,"[., ,, and]",True,False
2,2,78,session,"[ s, reading, amendment]",False,False
3,2,98,,"[ , 110, 111]",True,True
4,2,126,,"[ , congress, 1]",True,True
5,2,153,and,"[ and, ,, ]",True,True
6,2,180,with,"[ in, that, ]",False,False
7,2,204,,"[ , The, (]",True,True
8,2,227,service,"[ education, loan, school]",False,False
9,2,264,the,"[ , terms, practice]",False,False


Bill 2 accuracy: 57.89%

Bill 2 top k accuracy: 68.42%

Processing bill 3 of length 3199 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,3,15,,"[ , of, ,]",True,True
1,3,55,",","[), , ,]",True,False
2,3,78,.,"[., ., ,]",True,True
3,3,98,,"[ , , \n]",True,True
4,3,126,,"[ , m, t]",True,True
5,3,153,,"[ , , e]",True,True
6,3,180,,"[ , , .]",True,True
7,3,204,a,"[ , m, t]",False,False
8,3,227,security,"[ security, safety, insurance]",True,True
9,3,264,,"[ , The, (]",True,True


Bill 3 accuracy: 68.42%

Bill 3 top k accuracy: 78.95%

Processing bill 4 of length 7822 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,4,15,(,"[., of, ,]",False,False
1,4,55,liability,"[ enforcement, remed, compensation]",False,False
2,4,78,trade,"[ environmental, Environmental, act]",False,False
3,4,98,other,"[ inj, to, other]",True,False
4,4,126,5,"[), 4, 3]",False,False
5,4,153,for,"[ to, for, provide]",True,False
6,4,180,the,"[ the, title, this]",True,True
7,4,204,117,"[., ., ]",False,False
8,4,227,j,"[ix, , ____]",False,False
9,4,264,oro,"[oro, or, oros]",True,True


Bill 4 accuracy: 47.37%

Bill 4 top k accuracy: 63.16%

Processing bill 5 of length 5447 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,5,15,of,"[ of, to, and]",True,True
1,5,55,u,"[\n, The, (]",False,False
2,5,78,the,"[ the, title, FIFA]",True,True
3,5,98,,"[\n, 2018, The]",False,False
4,5,126,city,"[ congress, United, same]",False,False
5,5,153,,"[ en, -, ,]",False,False
6,5,180,the,"[ the, to, this]",True,True
7,5,204,me,"[ me, , can]",True,True
8,5,227,most,"[ most, fastest, largest]",True,True
9,5,264,1994,"[ FIFA, first, world]",False,False


Bill 5 accuracy: 36.84%

Bill 5 top k accuracy: 57.89%

Processing bill 6 of length 8875 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,6,15,,"[ the, God, a]",False,False
1,6,55,the,"[ the, no, im]",True,True
2,6,78,.,"[., ., ]",True,True
3,6,98,,"[ , 111, h]",True,True
4,6,126,representatives,"[ , representatives, s]",True,False
5,6,153,a,"[ a, well, an]",True,True
6,6,180,,"[ , , t]",True,True
7,6,204,national,"[ linguistic, ethnic, cultural]",False,False
8,6,227,north,"[ , and, north]",True,False
9,6,264,about,"[ about, with, ]",True,True


Bill 6 accuracy: 57.89%

Bill 6 top k accuracy: 73.68%

Processing bill 7 of length 8487 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,7,15,,"[ , of, ,]",True,True
1,7,55,himself,"[ , ), s]",False,False
2,7,78,,"[ bill, resolution, amendment]",False,False
3,7,98,homeland,"[ the, commerce, ]",False,False
4,7,126,,"[ , , \n]",True,True
5,7,153,,"[ , The, \n]",True,True
6,7,180,.,"[., ., ,]",True,True
7,7,204,(,"[ is, ,, .]",False,False
8,7,227,,"[ , , ________]",True,True
9,7,264,after,"[ paragraph, , (]",False,False


Bill 7 accuracy: 63.16%

Bill 7 top k accuracy: 63.16%

Processing bill 8 of length 16413 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,8,15,,"[ , of, (]",True,True
1,8,55,,"[ , i, o]",True,True
2,8,78,means,"[ , s, t]",False,False
3,8,98,,"[ , , .]",True,True
4,8,126,jurisdiction,"[ jurisdiction, authority, pur]",True,True
5,8,153,of,"[ of, ,, to]",True,True
6,8,180,by,"[,, ., of]",False,False
7,8,204,.,"[st, 2, 1]",False,False
8,8,227,standards,"[ , health, care]",False,False
9,8,264,)(,"[)(, ), ),]",True,True


Bill 8 accuracy: 68.42%

Bill 8 top k accuracy: 78.95%

Processing bill 9 of length 2863 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,9,15,u,"[ives, i, 1]",False,False
1,9,55,,"[ , \n, ]",True,True
2,9,78,r,"[title, , r]",False,False
3,9,98,;,"[ a, the, from]",False,False
4,9,126,short,"[ , 1, The]",False,False
5,9,153,general,"[ the, this, accordance]",False,False
6,9,180,facilities,"[ pipelines, transportation, pipeline]",False,False
7,9,204,any,"[ the, any, ,]",True,False
8,9,227,,"[ , 5, 5]",True,True
9,9,264,environmental,"[ environmental, pipeline, final]",True,True


Bill 9 accuracy: 42.11%

Bill 9 top k accuracy: 47.37%

Processing bill 10 of length 845 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,10,15,res,"[o, , re]",False,False
1,10,55,respect,"[ respect, regard, the]",True,True
2,10,78,6,"[ 6, 7, 5]",True,True
3,10,98,committee,"[ committee, Committee, committees]",True,True
4,10,126,10,"[ 10, 9, 11]",True,True
5,10,153,",","[., and, of]",False,False
6,10,180,for,"[ to, for, so]",True,False


Bill 10 accuracy: 57.14%

Bill 10 top k accuracy: 71.43%

Processing bill 11 of length 3764 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,11,15,u,"[,, (, .]",False,False
1,11,55,public,"[ public, Public, custody]",True,True
2,11,78,the,"[ the, :, title]",True,True
3,11,98,the,"[ the, bill, a]",True,True
4,11,126,child,"[ child, payments, bills]",True,True
5,11,153,collection,"[ bill, collection, enforcement]",True,False
6,11,180,re,"[re, b, me]",True,True
7,11,204,.,"[., s, d]",True,True
8,11,227,,"[ , "", ``]",True,True
9,11,264,payment,"[ child, unpaid, ar]",False,False


Bill 11 accuracy: 63.16%

Bill 11 top k accuracy: 73.68%

Processing bill 12 of length 23184 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,12,15,,"[ , of, (]",True,True
1,12,55,ak,"[., ia, t]",False,False
2,12,78,the,"[ , the, by]",True,False
3,12,98,,"[ , , �]",True,True
4,12,126,ricting,"[ricting, rict, ribut]",True,True
5,12,153,,"[ and, to, ]",True,False
6,12,180,authority,"[ , amendment, error]",False,False
7,12,204,fairness,"[ "", ``, National]",False,False
8,12,227,it,"[ the, , this]",False,False
9,12,264,of,"[ of, ., and]",True,True


Bill 12 accuracy: 42.11%

Bill 12 top k accuracy: 63.16%

Processing bill 13 of length 4446 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,13,15,res,"[ , of, s]",False,False
1,13,55,concurrent,"[ , m, .]",False,False
2,13,78,iet,"[iet, iv, .]",True,True
3,13,98,,"[ The, \n, ]",True,False
4,13,126,;,"[,, and, .]",False,False
5,13,153,result,"[ be, continue, not]",False,False
6,13,180,whereas,"[ , , whereas]",True,False
7,13,204,the,"[ the, all, both]",True,True
8,13,227,accounting,"[ the, its, all]",False,False
9,13,264,11,"[03, 02, 04]",False,False


Bill 13 accuracy: 26.32%

Bill 13 top k accuracy: 57.89%

Processing bill 14 of length 2984 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,14,15,of,"[ i, i, a]",False,False
1,14,55,in,"[ hereby, for, copyrighted]",False,False
2,14,78,house,"[ United, senate, House]",True,False
3,14,98,the,"[ the, committee, and]",True,True
4,14,126,,"[ , 114, 112]",True,True
5,14,153,1,"[ , The, i]",False,False
6,14,180,body,"[ in, a, body]",True,False
7,14,204,agency,"[ agency, officer, ]",True,True
8,14,227,prescribing,"[ requiring, implementing, enforcing]",False,False
9,14,264,3,"[ , 3, 4]",True,False


Bill 14 accuracy: 42.11%

Bill 14 top k accuracy: 68.42%

Overall accuracy for all bills: 50.92%
Overall topk accuracy for all bills: 65.57%


In [36]:
#### Fine Tuning ####

In [37]:
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
from torch.utils.data import DataLoader

In [88]:
cleaned_train = clean_text(dataset["train"]["text"][:60])
os.environ["WANDB_DISABLED"] = "true"

In [89]:
# Custom function to generate chunks from documents
def generate_chunks_from_text(text, min_length=100, max_length=512, tokenizer = tokenizer):
    tokens = tokenizer.encode(text, truncation=False, padding=False)  # Encode without truncation
    chunks = []
    while len(tokens) >= min_length:
        chunk_length = 512 #random.randint(min_length, min(max_length, len(tokens)))  # Random length for each chunk
        chunk = tokens[:chunk_length]
        chunks.append(chunk)
        tokens = tokens[chunk_length:]
    print("Generated n chunks: ", len(chunks))
    return chunks

# Function to manually expand the dataset by adding multiple rows for each document
def split_into_chunks(train_bills, min_length=100, max_length=512, tokenizer=tokenizer):
    all_input_ids = []
    all_labels = []
    all_attention_masks = []  # To store attention masks

    for text in train_bills:
        # Split the document into multiple chunks
        chunks = generate_chunks_from_text(text, min_length, max_length, tokenizer)
        for chunk in chunks:
            all_input_ids.append(chunk)
            all_labels.append(chunk)  # For causal language modeling, labels are the same as input_ids

            # Generate attention mask: 1 for real tokens, 0 for padding (for now, we assume no padding in chunks)
            attention_mask = [1] * len(chunk)
            all_attention_masks.append(attention_mask)

    print("Raw input ids and label ids dims:", len(all_input_ids), " x ", len(all_input_ids[0]), " or ", len(all_input_ids[1]), ' and ', len(all_labels), ' x ', len(all_labels[0]), ' or ', len(all_labels[1]))

    input_ids_padded = tokenizer.pad(
        {"input_ids": all_input_ids},  
        padding='max_length',  
        max_length=max_length, 
        return_tensors="pt"  
    )
    labels_padded = tokenizer.pad(
        {"input_ids": all_labels}, 
        padding='max_length',
        max_length=max_length,
        return_tensors="pt"
    )
    attention_masks_padded = tokenizer.pad(
        {"input_ids": all_attention_masks}, 
        padding='max_length',
        max_length=max_length,
        return_tensors="pt"
    )
    print("In split into chunks, size of all 3 columns: ", input_ids_padded['input_ids'].shape)
    return {
        "input_ids": input_ids_padded["input_ids"],
       # "labels": labels_padded["input_ids"],
       # "attention_mask": attention_masks_padded["input_ids"]  # Add the attention mask
    }

# Function to train the model with explicit parameters
def train_model(clean_tr, clean_val, tokenizer, model, epochs=3, output_dir="./drive/MyDrive/nlp_proj_results"):
    train_bills = clean_tr
    val_bills = clean_val

    tokenizer.pad_token = tokenizer.eos_token

    train_dataset = Dataset.from_dict({"text": train_bills})
    train_dataset = Dataset.from_dict(split_into_chunks(train_bills, min_length=100, max_length=512))
    #print("Train dataset size: ", train_dataset['input_ids'].shape)
    print("Train dataset: ", train_dataset)
    data_collator = DataCollatorForLanguageModeling(
      tokenizer=tokenizer,
      mlm=False  # This is not a masked language model task, so set to False
    )

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        max_grad_norm=1.0,
        logging_dir="./logs",
        logging_steps=100,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()

    model.save_pretrained(output_dir)
    return model

In [90]:
#torch.cuda.empty_cache()
fine_tuned_gpt2 = train_model(cleaned_train, cleaned_validation, tokenizer, model, epochs=10)

I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  3
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  3
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  3
I be
Doubting
Generated n chunks:  5
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  5
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  8
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  4
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  3
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  5
I be
Doubting
Generated n chunks:  3
I be
Doubting
Generated n chunks:  1
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  2
I be
Doubting
Generated n chunks:  13


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In split into chunks, size of all 3 columns:  torch.Size([227, 512])
Train dataset:  Dataset({
    features: ['input_ids'],
    num_rows: 227
})




Step,Training Loss
100,2.1173
200,1.7109


In [102]:
df_predictions_bills_finetuned = evaluate_bills(cleaned_validation, model=fine_tuned_gpt2)

Processing bill 0 of length 6705 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,0,15,act,"[ , act, of]",True,False
1,0,55,the,"[ the, the, an]",True,True
2,0,78,the,"[ the, house, congress]",True,True
3,0,98,(,"[ introduced, (, ,]",True,False
4,0,126,penn,"[ cal, fl, m]",False,False
5,0,153,of,"[,, of, )]",True,False
6,0,180,of,"[ of, ,, son]",True,True
7,0,204,m,"[ m, and, ms]",True,True
8,0,227,.,"[., .., ,]",True,True
9,0,264,m,"[ m, ms, and]",True,True


Bill 0 accuracy: 63.16%

Bill 0 top k accuracy: 78.95%

Processing bill 1 of length 48228 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,1,15,the,"[ the, fiscal, 2018]",True,True
1,1,55,the,"[ the, title, the]",True,True
2,1,78,i,"[ , ii, i]",True,False
3,1,98,united,"[ united, senate, house]",True,True
4,1,126,second,"[ next, bill, second]",True,False
5,1,153,,"[ , for, u]",True,True
6,1,180,budget,"[short, to, def]",False,False
7,1,204,for,"[ for, of, ]",True,True
8,1,227,(,"[ , (, 2]",True,False
9,1,264,1,"[1, 2, 3]",True,True


Bill 1 accuracy: 47.37%

Bill 1 top k accuracy: 68.42%

Processing bill 2 of length 4566 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,2,15,and,"[ and, ,, services]",True,True
1,2,55,and,"[ and, &, ]",True,True
2,2,78,session,"[ session, sessions, hr]",True,True
3,2,98,,"[ , 1, m]",True,True
4,2,126,,"[ , u, (]",True,True
5,2,153,and,"[ and, , care]",True,True
6,2,180,with,"[ that, in, of]",False,False
7,2,204,,"[ , the, such]",True,True
8,2,227,service,"[ education, dental, hygiene]",False,False
9,2,264,the,"[the, there, section]",True,True


Bill 2 accuracy: 78.95%

Bill 2 top k accuracy: 84.21%

Processing bill 3 of length 3199 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,3,15,,"[ , h, s]",True,True
1,3,55,",","[,, and, ]",True,True
2,3,78,.,"[., .., s]",True,True
3,3,98,,"[ , m, and]",True,True
4,3,126,,"[ , legislation, bills]",True,True
5,3,153,,"[ , m, the]",True,True
6,3,180,,"[ , the, i]",True,True
7,3,204,a,"[ , and, �]",False,False
8,3,227,security,"[ security, safety, insurance]",True,True
9,3,264,,"[ , the, 1]",True,True


Bill 3 accuracy: 78.95%

Bill 3 top k accuracy: 78.95%

Processing bill 4 of length 7822 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,4,15,(,"[ to, , ;]",False,False
1,4,55,liability,"[ enforcement, remed, control]",False,False
2,4,78,trade,"[ environmental, disaster, toxic]",False,False
3,4,98,other,"[ other, cert, enforcement]",True,True
4,4,126,5,"[), 4, 3]",False,False
5,4,153,for,"[ for, to, provide]",True,True
6,4,180,the,"[ the, title, this]",True,True
7,4,204,117,"[ , ., 117]",False,False
8,4,227,j,"[ , (, in]",False,False
9,4,264,oro,"[oro, or, o]",True,True


Bill 4 accuracy: 52.63%

Bill 4 top k accuracy: 63.16%

Processing bill 5 of length 5447 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,5,15,of,"[ of, made, to]",True,True
1,5,55,u,"[ , u, h]",True,False
2,5,78,the,"[ the, title, law]",True,True
3,5,98,,"[ , d, This]",True,True
4,5,126,city,"[ house, residence, office]",False,False
5,5,153,,"[th, is, (]",False,False
6,5,180,the,"[ the, can, ]",True,True
7,5,204,me,"[ , united, and]",False,False
8,5,227,most,"[ most, fastest, largest]",True,True
9,5,264,1994,"[ f, 2018, world]",False,False


Bill 5 accuracy: 42.11%

Bill 5 top k accuracy: 63.16%

Processing bill 6 of length 8875 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,6,15,,"[ representatives, the, ]",True,False
1,6,55,the,"[ the, its, ]",True,True
2,6,78,.,"[., 18, 18]",True,True
3,6,98,,"[ , jan, apr]",True,True
4,6,126,representatives,"[ representatives, congress, ]",True,True
5,6,153,a,"[ a, , refugees]",True,True
6,6,180,,"[ , (, u]",True,True
7,6,204,national,"[ ethnic, cultural, linguistic]",False,False
8,6,227,north,"[ whereas, and, while]",False,False
9,6,264,about,"[ for, about, over]",True,False


Bill 6 accuracy: 57.89%

Bill 6 top k accuracy: 73.68%

Processing bill 7 of length 8487 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,7,15,,"[ , h, s]",True,True
1,7,55,himself,"[ m, , himself]",True,False
2,7,78,,"[ , bill, resolution]",True,True
3,7,98,homeland,"[ foreign, banking, armed]",False,False
4,7,126,,"[ , (, �]",True,True
5,7,153,,"[ , 1, i]",True,True
6,7,180,.,"[., .., .(]",True,True
7,7,204,(,"[ (, is, provides]",True,True
8,7,227,,"[ , (, of]",True,True
9,7,264,after,"[ , the, (]",False,False


Bill 7 accuracy: 73.68%

Bill 7 top k accuracy: 78.95%

Processing bill 8 of length 16413 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,8,15,,"[ , h, h]",True,True
1,8,55,,"[ , (, of]",True,True
2,8,78,means,"[ , means, ways]",True,False
3,8,98,,"[ , the, a]",True,True
4,8,126,jurisdiction,"[ , jurisdiction, sense]",True,False
5,8,153,of,"[ of, to, ,]",True,True
6,8,180,by,"[,, of, ;]",False,False
7,8,204,.,"[., st, ]",True,True
8,8,227,standards,"[ , care, patient]",False,False
9,8,264,)(,"[), )(, ),]",True,False


Bill 8 accuracy: 57.89%

Bill 8 top k accuracy: 84.21%

Processing bill 9 of length 2863 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,9,15,u,"[ of, uary, u]",True,False
1,9,55,,"[ , ii, i]",True,True
2,9,78,r,"[ r, res, r]",True,True
3,9,98,;,"[ the, a, an]",False,False
4,9,126,short,"[short, cong, closure]",True,True
5,9,153,general,"[ general, , generals]",True,True
6,9,180,facilities,"[ infrastructure, transportation, water]",False,False
7,9,204,any,"[ the, a, design]",False,False
8,9,227,,"[(, the, ]",True,False
9,9,264,environmental,"[ environmental, final, proposed]",True,True


Bill 9 accuracy: 52.63%

Bill 9 top k accuracy: 63.16%

Processing bill 10 of length 845 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,10,15,res,"[ r, res, r]",True,False
1,10,55,respect,"[ respect, , regard]",True,True
2,10,78,6,"[ 6, (, 7]",True,True
3,10,98,committee,"[ committee, committees, subcommittee]",True,True
4,10,126,10,"[ 10, 9, 8]",True,True
5,10,153,",","[ for, ., and]",False,False
6,10,180,for,"[ to, for, ]",True,False


Bill 10 accuracy: 57.14%

Bill 10 top k accuracy: 85.71%

Processing bill 11 of length 3764 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,11,15,u,"[ u, of, ]",True,True
1,11,55,public,"[ public, interest, improper]",True,True
2,11,78,the,"[ the, congress, ]",True,True
3,11,98,the,"[ the, , a]",True,True
4,11,126,child,"[ child, payments, federal]",True,True
5,11,153,collection,"[ collection, enforcement, ]",True,True
6,11,180,re,"[re, ri, ab]",True,True
7,11,204,.,"[., s, ..]",True,True
8,11,227,,"[ , (, 4]",True,True
9,11,264,payment,"[ child, such, ]",False,False


Bill 11 accuracy: 73.68%

Bill 11 top k accuracy: 78.95%

Processing bill 12 of length 23184 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,12,15,,"[ , h, .]",True,True
1,12,55,ak,"[ol, avis, ale]",False,False
2,12,78,the,"[ the, , ,]",True,True
3,12,98,,"[ , m, the]",True,True
4,12,126,ricting,"[ricting, rict, rav]",True,True
5,12,153,,"[ , and, to]",True,True
6,12,180,authority,"[ violations, violation, ]",False,False
7,12,204,fairness,"[ , state, states]",False,False
8,12,227,it,"[ the, states, ]",False,False
9,12,264,of,"[ of, , or]",True,True


Bill 12 accuracy: 57.89%

Bill 12 top k accuracy: 73.68%

Processing bill 13 of length 4446 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,13,15,res,"[ res, r, ]",True,True
1,13,55,concurrent,"[ , m, and]",False,False
2,13,78,iet,"[iet, ik, iv]",True,True
3,13,98,,"[ , which, m]",True,True
4,13,126,;,"[,, ., and]",False,False
5,13,153,result,"[ be, strengthen, continue]",False,False
6,13,180,whereas,"[ whereas, , and]",True,True
7,13,204,the,"[ the, all, v]",True,True
8,13,227,accounting,"[ the, its, accordance]",False,False
9,13,264,11,"[03, 04, 07]",False,False


Bill 13 accuracy: 52.63%

Bill 13 top k accuracy: 63.16%

Processing bill 14 of length 2984 characters...


Unnamed: 0,bill_number,input_text_length,actual_next_word,predicted_next_word,top_k_word_prediction,top_word_prediction
0,14,15,of,"[ u, no, of]",True,False
1,14,55,in,"[ in, not, for]",True,True
2,14,78,house,"[ house, senate, united]",True,True
3,14,98,the,"[ the, committee, a]",True,True
4,14,126,,"[ officers, agencies, ]",True,False
5,14,153,1,"[ , 1, (]",True,False
6,14,180,body,"[ in, , cameras]",False,False
7,14,204,agency,"[ agency, , department]",True,True
8,14,227,prescribing,"[ superv, implementing, requiring]",False,False
9,14,264,3,"[ , 3, (]",True,False


Bill 14 accuracy: 47.37%

Bill 14 top k accuracy: 78.95%

Overall accuracy for all bills: 59.71%
Overall topk accuracy for all bills: 73.99%
