## Necessary Imports and Setup

In [1]:
#TODO: Add here your imports

from transformers import AutoTokenizer, BartForConditionalGeneration

import os
import nltk
import pandas as pd
import torch
import numpy as np
from jinja2 import Template
import xmltodict
import pickle
from collections import defaultdict

from fuzzywuzzy import fuzz

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

In [2]:
import sys
sys.path.append('/scratch/users/bozyurt20/hpc_run/')
sys.path.append('/scratch/users/bozyurt20/hpc_run/text_injection')
from text_injection.util_research import *

In [3]:
## TODO: Add the number of template options for prompting

num_templates = 23

path_andersen = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Andersen"
path_fanny = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Fanny Fern"
path_annotations = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Annotations"

dir_list_andersen = os.listdir(path_andersen)
dir_list_fanny = os.listdir(path_fanny)
dir_list_annotations = os.listdir(path_annotations)

stop_words = set(stopwords.words("english"))

In [6]:
## TODO: Add Your model and tokenizer

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model_instance = ModelInstance(model, tokenizer, device)

## Example Pipeline

In [9]:
##Â TODO: example pipeline

from transformers import AutoTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")

# Generate Summary
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


'PG&E stated it scheduled the blackouts in response to forecasts for high winds amid'

## Making Predictions

In [7]:
def predictor(model_instance):
    
    model = model_instance.model
    tokenizer = model_instance.tokenizer
    device = model_instance.device
    
    all_result_objects = defaultdict(list)
    
    for item in dir_list_annotations:

        print(item)

        story_no = item[len("Andersen_story"):-len(".txt")]

        f = open(os.path.join(path_andersen, item), 'r') 
        story = f.read()
        f.close()
        
        paragraphs = story.split("\n\n")
        paragraph = paragraphs[0]
        len_title = len(paragraph) + 2   
        
        annotations = all_annotations[item]
        
        for line in annotations:
            
            ind = line[0]
            character = line[1]
            gold_answer = line[2]
            grammatical_number = line[3]

            gold_locations = gold_answer.split("/")
            
            for k in range(1, num_templates+1):
                
                y = line[0]
                x = y - 5120

                if x < len_title:
                    text = story[len_title:y]

                else:
                    x = story[x:y].find(" ") + x
                    text = story[x:y]                
                
                text = text_clean_ending(text)
                text = remove_new_lines(text)                    
                
                ## TODO: Write the predictor lines
                
                prompt, context2 = create_prompt_clipped(tokenizer, k, text, character, grammatical_number, 1024)
                
                inputs = tokenizer([prompt], max_length=1024, return_tensors="pt")
                input_ids = inputs["input_ids"].to(device)
                
                with torch.no_grad():
                    summary_ids = model.generate(input_ids, min_length=0, max_length=20)
                    
                out = tokenizer.decode(summary_ids[0], skip_special_tokens=True)                
                
                ## Ends here
                
                match1, match2 = exactly_or_fuzzily_matched(out, character, gold_locations)
                result_object = ResultObject(prompt, out, ind, character, gold_locations, story_no, k, match1, match2)
                all_result_objects[k].append(result_object)
                
    return all_result_objects
        

In [8]:
all_result_objects = predictor(model_instance)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Andersen_story11.txt


Token indices sequence length is longer than the specified maximum sequence length for this model (1179 > 1024). Running this sequence through the model will result in indexing errors


Andersen_story12.txt
Andersen_story13.txt
Andersen_story15.txt
Andersen_story16.txt
Andersen_story17.txt
Andersen_story18.txt
Andersen_story1.txt
Andersen_story2.txt
Andersen_story3.txt
Andersen_story5.txt
Andersen_story7.txt
Andersen_story8.txt
Andersen_story9.txt
Andersen_story10.txt


In [9]:
## TODO: define a path name for saving the results

with open("bart_large_predictions.txt", "wb") as f:
    pickle.dump(all_result_objects, f)

## Calculating the Accuracy

In [None]:
## TODO: write the path name that has the results

with open("bart_large_predictions.txt", "rb") as f:
    all_result_objects = pickle.load(f)

In [10]:
def accuracy_calculator(all_result_objects):
    
    exact_accuracies = defaultdict(list)
    fuzzy_accuracies = defaultdict(list)
    exact_averages = defaultdict(list)
    fuzzy_averages = defaultdict(list)

    for k in all_result_objects:
        for result_object in all_result_objects[k]:
            """prompt = result_object.prompt
            out = result_object.out
            story_no = result_object.story_no
            k = result_object.k"""
            match1 = result_object.match1
            match2 = result_object.match2
            if match1 == "Yes":
                exact_accuracies[k].append(1)
            else:
                exact_accuracies[k].append(0)
            if match2 == "Yes":
                fuzzy_accuracies[k].append(1)
            else:
                fuzzy_accuracies[k].append(0)
            
        exact_averages[k] = np.mean(np.array(exact_accuracies[k]))
        fuzzy_averages[k] = np.mean(np.array(fuzzy_accuracies[k]))
        
    best_exact_k = np.array(list(exact_averages.values())).argmax() + 1
    best_fuzzy_k = np.array(list(exact_averages.values())).argmax() + 1
    
    return exact_averages, fuzzy_averages, best_exact_k, best_fuzzy_k


In [11]:
exact_averages, fuzzy_averages, best_exact_k, best_fuzzy_k = accuracy_calculator(all_result_objects)

In [13]:
fuzzy_averages

defaultdict(list,
            {1: 0.03614457831325301,
             2: 0.0321285140562249,
             3: 0.04819277108433735,
             4: 0.05622489959839357,
             5: 0.04819277108433735,
             6: 0.04819277108433735,
             7: 0.04819277108433735,
             8: 0.04819277108433735,
             9: 0.05220883534136546,
             10: 0.04417670682730924,
             11: 0.05220883534136546,
             12: 0.03614457831325301,
             13: 0.028112449799196786,
             14: 0.06827309236947791,
             15: 0.04819277108433735,
             16: 0.05220883534136546,
             17: 0.060240963855421686,
             18: 0.060240963855421686,
             19: 0.060240963855421686,
             20: 0.04417670682730924,
             21: 0.060240963855421686,
             22: 0.05622489959839357,
             23: 0.024096385542168676})