In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from memformers.models.membart import MemBartForConditionalGeneration

import os
import nltk
import pandas as pd
import torch
import numpy as np
from jinja2 import Template
import pickle
from collections import defaultdict

from fuzzywuzzy import fuzz

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from util_research import *

import matplotlib.pyplot as plt
from math import log

import pdb

In [2]:
path_andersen = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Andersen"
path_fanny = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Fanny Fern"
path_annotations = "/kuacc/users/bozyurt20/hpc_run/ChildrenStories/Annotations"

dir_list_andersen = os.listdir(path_andersen)
dir_list_fanny = os.listdir(path_fanny)
dir_list_annotations = os.listdir(path_annotations)

In [3]:
all_annotations = {}

for item in dir_list_annotations:
    
    f = open(os.path.join(path_annotations, item), 'r')
    annotations = pd.read_csv(f, sep="\t")
    annotations = annotations.values
    f.close()
    
    all_annotations[item] = annotations

In [4]:
story_info = {}
for story in all_annotations:
    story_info[story] = {}
    for character_list in all_annotations[story]:
        character_name = character_list[1]
        story_info[story].setdefault(character_name, [])
        story_info[story][character_name].append((character_list[0], character_list[2], character_list[3]))
        story_info[story][character_name].sort(key=lambda x: x[0])

In [6]:
all_annotations["Andersen_story3.txt"]

array([[178, 'the prince', 'all over the world', 'singular'],
       [453, 'the prince',
        "his palace/the prince's palace/the palace/at home/home/his home/his castle/the castle",
        'singular'],
       [797, 'the king',
        'the palace/outside/out/the castle/in his castle', 'singular'],
       [850, 'the princess',
        'the palace/the castle/outside the door/outside/out', 'singular'],
       [1079, 'the queen', 'the palace/the castle/at home', 'singular'],
       [1173, 'the queen', 'the bedroom/the palace/the castle/at home',
        'singular']], dtype=object)

In [7]:
story_info["Andersen_story3.txt"]

{'the prince': [(178, 'all over the world', 'singular'),
  (453,
   "his palace/the prince's palace/the palace/at home/home/his home/his castle/the castle",
   'singular')],
 'the king': [(797,
   'the palace/outside/out/the castle/in his castle',
   'singular')],
 'the princess': [(850,
   'the palace/the castle/outside the door/outside/out',
   'singular')],
 'the queen': [(1079, 'the palace/the castle/at home', 'singular'),
  (1173, 'the bedroom/the palace/the castle/at home', 'singular')]}

In [5]:
class MemoryEnhanced():
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.model = self.model.to(device)
        self.device = device
        
class MemoryEnhancedResults():
    def __init__(self, prompt, model_out, gold_locations, character, k, story_no, query_point, exact_match, fuzzy_match):
        self.prompt = prompt
        self.model_out = model_out
        self.gold_locations = gold_locations
        self.character = character
        self.k = k
        self.story_no = story_no
        self.query_point = query_point
        self.exact_match = exact_match
        self.fuzzy_match = fuzzy_match

In [10]:
def get_memory_states(memory_enhanced_instance, all_prompt_drafts):
    
    model = memory_enhanced_instance.model
    tokenizer = memory_enhanced_instance.tokenizer
    device = memory_enhanced_instance.device
    
    all_memory_states = []
    memory_states = model.construct_memory(batch_size=1)
    all_memory_states.append(memory_states)
    
    for prompt_draft in all_prompt_drafts:
    
        input_ids = torch.LongTensor([tokenizer.encode(prompt_draft)])
        input_ids = input_ids.to(device)
        encoder_outputs = model.model.encoder(input_ids=input_ids, memory_states=memory_states, attention_mask=None)
        memory_states = encoder_outputs.memory_states
        all_memory_states.append(memory_states)
        
    return all_memory_states       
        

In [7]:
def predictor(memory_enhanced_instance, write_in_file=False, out_path_prefix="Text_injection_MEMBART_"):

    model = memory_enhanced_instance.model
    tokenizer = memory_enhanced_instance.tokenizer
    device = memory_enhanced_instance.device
    all_result_objects = defaultdict(list)

    version = 0

    for item in dir_list_annotations:

        print(item)

        story_no = item[len("Andersen_story"):-len(".txt")]

        f = open(os.path.join(path_andersen, item), 'r') 
        story = f.read()
        f.close()

        characters = story_info[item].keys()
        story = remove_new_lines(story)

        big_prompt = "Story:\n" + story
        big_prompt_tokens = tokenizer.encode(big_prompt)
        toks_per_segment = 512
        no_segments = len(big_prompt_tokens) // toks_per_segment

        all_tokens = [ big_prompt_tokens[512*i : 512*(i+1)] for i in range(no_segments) ]

        all_tokens.append(big_prompt_tokens[512*(no_segments): ])

        all_prompt_drafts = [ tokenizer.decode(x, skip_special_tokens=True) for x in all_tokens ]
        
        all_memory_states = get_memory_states(memory_enhanced_instance, all_prompt_drafts)

        x = 0
        story_lens = []
        for prompt_draft in all_prompt_drafts:
            if prompt_draft[:len("Story:\n")] == "Story:\n":
                x += len(prompt_draft) - len("Story:\n")
            else:
                x += len(prompt_draft)
            story_lens.append(x)

        for i, prompt_draft in enumerate(all_prompt_drafts):
            
            print(i)

            for character in characters:
                
                print(character)

                tuples = story_info[item][character]

                grammatical_number = tuples[0][2]
                if i != 0:
                    pos = story_lens[i-1]
                else:
                    pos = 0

                prompt, pos_last = create_membart_prompt(tokenizer, version, prompt_draft, character, grammatical_number, max_no_tokens=512)

                pos += pos_last

                check = False

                for num_tupl, tupl in enumerate(tuples):

                    if pos < tupl[0]:
                        break
                    gold_location = tupl[1]
                    check = True

                if check:
                    input_ids = torch.LongTensor([tokenizer.encode(prompt)])
                    encoder_outputs = model.model.encoder(input_ids=input_ids, memory_states=all_memory_states[i], attention_mask=None)
                    #memory_states = encoder_outputs.memory_states
                    
                    outputs = model.generate(
                        encoder_outputs=encoder_outputs,
                        decoder_start_token_id=tokenizer.bos_token_id,
                        max_length=64,
                        num_beams=1,
                        do_sample=False,
                        return_dict_in_generate=True,
                    )

                    out = tokenizer.decode(outputs.sequences[0])
                    match1, match2 = exactly_or_fuzzily_matched(out, gold_locations)
                    result_object =  MemoryEnhancedResults(prompt, out, gold_locations, character, version, story_no, pos, match1, match2)
                    all_result_objects[version].append(result_object)                 

    return all_result_objects

In [8]:
model = MemBartForConditionalGeneration.from_pretrained("qywu/membart-large")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

memory_enhanced_instance = MemoryEnhanced(model, tokenizer, device)

In [None]:
all_result_objects = predictor(memory_enhanced_instance, write_in_file=False, out_path_prefix="Text_injection_MEMBART_")

Andersen_story11.txt
here
here
here2
here2
here2
here2
here2


In [None]:
def accuracy_calculator(all_result_objects):
    
    exact_accuracies = defaultdict(list)
    fuzzy_accuracies = defaultdict(list)
    exact_averages = defaultdict(list)
    fuzzy_averages = defaultdict(list)

    for k in all_result_objects:
        for result_object in all_result_objects[k]:
            """prompt = result_object.prompt
            out = result_object.out
            story_no = result_object.story_no
            k = result_object.k"""
            match1 = result_object.match1
            match2 = result_object.match2
            if match1 == "Yes":
                exact_accuracies[k].append(1)
            else:
                exact_accuracies[k].append(0)
            if match2 == "Yes":
                fuzzy_accuracies[k].append(1)
            else:
                fuzzy_accuracies[k].append(0)
            
        exact_averages[k] = np.mean(np.array(exact_accuracies[k]))
        fuzzy_averages[k] = np.mean(np.array(fuzzy_accuracies[k]))
        
    best_exact_k = np.array(list(exact_averages.values())).argmax() + 1
    best_fuzzy_k = np.array(list(exact_averages.values())).argmax() + 1
    
    return exact_averages, fuzzy_averages, best_exact_k, best_fuzzy_k


In [None]:
exact_averages, fuzzy_averages, best_exact_k, best_fuzzy_k = accuracy_calculator(all_result_objects)

In [13]:
x = 1 
y = 2
breakpoint()
z = 4

In [26]:
version = 0
all_result_objects = {}
all_result_objects[version] = []

for item in dir_list_annotations:

    print(item)

    story_no = item[len("Andersen_story"):-len(".txt")]

    f = open(os.path.join(path_andersen, item), 'r') 
    story = f.read()
    f.close()

    characters = story_info[item].keys()
    story = remove_new_lines(story)

    big_prompt = "Story:\n" + story
    big_prompt_tokens = tokenizer.encode(big_prompt)
    toks_per_segment = 512
    no_segments = len(big_prompt_tokens) // toks_per_segment

    all_tokens = [ big_prompt_tokens[512*i : 512*(i+1)] for i in range(no_segments) ]

    all_tokens.append(big_prompt_tokens[512*(no_segments): ])

    all_prompt_drafts = [ tokenizer.decode(x, skip_special_tokens=True) for x in all_tokens ]

    all_memory_states = get_memory_states(memory_enhanced_instance, all_prompt_drafts)
    
    x = 0
    story_lens = []
    for prompt_draft in all_prompt_drafts:
        if prompt_draft[:len("Story:\n")] == "Story:\n":
            x += len(prompt_draft) - len("Story:\n")
        else:
            x += len(prompt_draft)
        story_lens.append(x)

    for i, prompt_draft in enumerate(all_prompt_drafts):

        print(i)

        for character in characters:

            print(character)

            tuples = story_info[item][character]

            grammatical_number = tuples[0][2]
            if i != 0:
                pos = story_lens[i-1]
            else:
                pos = 0
                
            print(pos)

            prompt, pos_last = create_membart_prompt(tokenizer, version, prompt_draft, character, grammatical_number, max_no_tokens=512)

            pos += pos_last
            
            print(pos)

            check = False

            for num_tupl, tupl in enumerate(tuples):

                if pos < tupl[0]:
                    break
                gold_locations = tupl[1]
                check = True

            if check:
                input_ids = torch.LongTensor([tokenizer.encode(prompt)])
                input_ids = input_ids.to(device)
                encoder_outputs = model.model.encoder(input_ids=input_ids, memory_states=all_memory_states[i], attention_mask=None)
                #memory_states = encoder_outputs.memory_states

                outputs = model.generate(
                    encoder_outputs=encoder_outputs,
                    decoder_start_token_id=tokenizer.bos_token_id,
                    max_length=64,
                    num_beams=1,
                    do_sample=False,
                    return_dict_in_generate=True,
                )

                out = tokenizer.decode(outputs.sequences[0])
                match1, match2 = exactly_or_fuzzily_matched(out, gold_locations)
                result_object =  MemoryEnhancedResults(prompt, out, gold_locations, character, version, story_no, pos, match1, match2)
                all_result_objects[version].append(result_object)                 



    

Andersen_story11.txt
0
two old snails
0
2082
the ant
0
2087
the gnats
0
2085
Miss Snail
0
2085
1
two old snails
2138
4024
the ant
2138
4036
the gnats
2138
4032
Miss Snail
2138
4032
2
two old snails
4097
5901
the ant
4097
5906
the gnats
4097
5902
Miss Snail
4097
5902
3
two old snails
5972
6902
the ant
5972
6902
the gnats
5972
6902
Miss Snail
5972
6902
Andersen_story12.txt
0
the poor old man
0
1968
the mother
0
1980
the little child
0
1975
Death
0
1985
the woman in black clothes
0
1962
the thorn-bush
0
1968
the old woman
0
1975
1
the poor old man
2007
3893
the mother
2007
3905
the little child
2007
3899
Death
2007
3906
the woman in black clothes
2007
3891
the thorn-bush
2007
3893
the old woman
2007
3899
2
the poor old man
3983
5811
the mother
3983
5821
the little child
3983
5816
Death
3983
5826
the woman in black clothes
3983
5809
the thorn-bush
3983
5811
the old woman
3983
5816
3
the poor old man
5908
7860
the mother
5908
7867
the little child
5908
7863
Death
5908
7875
the woman in blac

In [29]:
for result_object in all_result_objects[version]:
    print(result_object.model_out)
    

<s>.
The old white snails were the first persons of distinction in the world, that they knew; the forest was planted for their sake, and the manor-house was there that they might be boiled and laid
In the story above, the current location of two old snails is the manor</s>
<s>. “Or the burdocks have grown up over it, so that they cannot come out. There need not, however, be any haste about that; but you are always in such a tremendous hurry, and the little one is beginning to be the same. Has he not been creeping up that stalk these</s>
<s>.
And so they went and fetched little Miss Snail. It was a whole week before she arrived; but therein was just the very best of it, for one could thus see that she was of the same species.
And then the marriage was celebrated. Six earth-worms shone as well as</s>
<s>.
And so they went and fetched little Miss Snail. It was a whole week before she arrived; but therein was just the very best of it, for one could thus see that she was of the same species

In [None]:
self.prompt = prompt
        self.model_out = model_out
        self.gold_locations = gold_locations
        self.character = character
        self.k = k
        self.story_no = story_no
        self.query_point = query_point
        self.exact_match = exact_match
        self.fuzzy_match = fuzzy_match

In [21]:
def create_membart_prompt(tokenizer, version, context, character, grammatical_number, max_no_tokens=512):

    prompt = context + "\nIn the story above, the current location of " + character + " is"
    pos = len(context)
    
    if len(tokenizer.encode(prompt)) > max_no_tokens:
    
        context = tokenizer.encode(context)
        diff = len(tokenizer.encode(prompt)) - max_no_tokens
        context = context[:-diff]
        context = tokenizer.decode(context, skip_special_tokens=True)
        prompt = context + "\nIn the story above, the current location of " + character + " is"
        pos = len(context)
        
        if len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            context = context[:-2]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = context + "\nIn the story above, the current location of " + character + " is"
            pos = len(context)

    return prompt, pos

In [13]:
print(all_prompt_drafts[0])

Story:
THE HAPPY FAMILY
Really, the largest green leaf in this country is a dock-leaf; if one holds it before one, it is like a whole apron, and if one holds it over one's head in rainy weather, it is almost as good as an umbrella, for it is so immensely large. The burdock never grows alone, but where there grows one there always grow several: it is a great delight, and all this delightfulness is snails' food. The great white snails which persons of quality in former times made fricassees of, ate, and said, “Hem, hem! how delicious!” for they thought it tasted so delicate--lived on dock-leaves, and therefore burdock seeds were sown.
Now, there was an old manor-house, where they no longer ate snails, they were quite extinct; but the burdocks were not extinct, they grew and grew all over the walks and all the beds; they could not get the mastery over them--it was a whole forest of burdocks. Here and there stood an apple and a plum-tree, or else one never would have thought that it was a 

In [23]:
print(story[:2138])

THE HAPPY FAMILY
Really, the largest green leaf in this country is a dock-leaf; if one holds it before one, it is like a whole apron, and if one holds it over one's head in rainy weather, it is almost as good as an umbrella, for it is so immensely large. The burdock never grows alone, but where there grows one there always grow several: it is a great delight, and all this delightfulness is snails' food. The great white snails which persons of quality in former times made fricassees of, ate, and said, “Hem, hem! how delicious!” for they thought it tasted so delicate--lived on dock-leaves, and therefore burdock seeds were sown.
Now, there was an old manor-house, where they no longer ate snails, they were quite extinct; but the burdocks were not extinct, they grew and grew all over the walks and all the beds; they could not get the mastery over them--it was a whole forest of burdocks. Here and there stood an apple and a plum-tree, or else one never would have thought that it was a garden;

In [28]:
story[:2138] == all_prompt_drafts[0][7:]

True