# Necessary Imports

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader, Sampler
import os
import nltk
import pandas as pd

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from tqdm.auto import tqdm

import numpy as np
from jinja2 import Template
import pickle
import random
from random import shuffle
import torchtext

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

In [3]:
stop_words = set(stopwords.words("english"))

num_templates = 23

path_andersen = "/kuacc/users/bozyurt20/ChildrenStories/Andersen"
path_fanny = "/kuacc/users/bozyurt20/ChildrenStories/Fanny Fern"
path_annotations = "/kuacc/users/bozyurt20/ChildrenStories/Annotations"

dir_list_andersen = os.listdir(path_andersen)
dir_list_fanny = os.listdir(path_fanny)
dir_list_annotations = os.listdir(path_annotations)

def text_clean_ending(example_text):
    example_text = example_text.rstrip(", ;-\n")
    if example_text[-1] != ".":
        example_text += "."
    return example_text

def remove_new_lines(text):
    paragraphs = text.split("\n\n")
    new_paragraphs = []
    for paragraph in paragraphs:
        new_paragraphs.append(paragraph.replace("\n", " "))
    new_text = "\n".join(new_paragraphs)
    return new_text

# Load Tokenizer

In [2]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [3]:
path = "litbank/original/105_persuasion.txt"

In [4]:
with open(path, "r") as f:
    book = f.read()

In [19]:
tokens = tokenizer.encode(book)[:3890]
tokenizer.decode(tokens)


'Chapter 1 Sir Walter Elliot, of Kellynch Hall, in Somersetshire, was a man who, for his own amusement, never took up any book but the Baronetage; there he found occupation for an idle hour, and consolation in a distressed one; there his faculties were roused into admiration and respect, by contemplating the limited remnant of the earliest patents; there any unwelcome sensations, arising from domestic affairs changed naturally into pity and contempt as he turned over the almost endless creations of the last century; and there, if every other leaf were powerless, he could read his own history with an interest which never failed. This was the page at which the favourite volume always opened: "ELLIOT OF KELLYNCH HALL. "Walter Elliot, born March 1, 1760, married, July 15, 1784, Elizabeth, daughter of James Stevenson, Esq. of South Park, in the county of Gloucester, by which lady (who died 1800) he has issue Elizabeth, born June 1, 1785; Anne, born August 9, 1787; a still-born son, November

# Prompt Creating Function

In [5]:
def create_prompt_clipped(version, context, character, grammatical_number, max_no_tokens=512):
    
    if grammatical_number == 'singular':
        to_be = 'is'
    elif grammatical_number == 'plural':
        to_be = 'are'
    
    if version in [1, 2, 9, 10, 11, 12, 13, 20, 21, 22]:
        question = "Where " + to_be + " " + character + "?"
    elif version in [4, 5, 7, 8, 15, 16, 18, 19]:
        question = "where " + character + " " + to_be + "."
    elif version in [3, 14]:
        question = "where " + character + " " + to_be + "?"
    elif version in [6, 17]:
        question = "where " + to_be + " " + character + "?"
        
    if version == 1 or version == 12:
        intro = "Answer the question depending on the context."
    elif version == 2 or version == 13:
        intro = "What is the answer?"
    elif version == 3 or version == 14:
        intro = "Can you tell me "
    elif version == 4 or version == 15:
        intro = "Please tell me "
    elif version == 5 or version == 16:
        intro = "Tell me "
    elif version == 6 or version == 17:
        intro = "From the passage, "
    elif version == 7 or version == 18:
        intro = "I want to know "
    elif version == 8 or version == 19:
        intro = "I want to ask "
    elif version == 9 or version == 20:
        intro = "What is the answer to: "
    elif version == 10 or version == 21:
        intro = "Find the answer to: "
    elif version == 11 or version == 22:
        intro = "Answer: "     
    
    if version in [1, 2]:
        oo = 0
        tm = Template("""{{ intro }}
Context: {{context}};
Question: {{question}};
Answer: """)        
        prompt = tm.render(intro=intro, context=context, question=question)
        
        while len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            diff = len(tokenizer.encode(prompt)) - max_no_tokens
            context = context[diff:]
            oo += 1
            if oo > 4:
                context = context[1:]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = tm.render(intro=intro, context=context, question=question)
        
    elif version in [3, 4, 5, 6, 7, 8, 9, 10, 11]:
        oo = 0
        tm = Template("{{context}} {{intro}}{{question}}")
        prompt = tm.render(intro=intro, context=context, question=question)
        while len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            diff = len(tokenizer.encode(prompt)) - max_no_tokens
            context = context[diff:]            
            oo += 1
            if oo > 4:
                context = context[1:]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = tm.render(intro=intro, context=context, question=question)
        
        
    elif version in [12, 13]:
        oo = 0
        tm = Template("""{{ intro }}
Context: {{context}};
Question: {{question}};
If you can't find the answer, please respond "unanswerable".
Answer: """)
        prompt = tm.render(intro=intro, context=context, question=question)
        while len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            diff = len(tokenizer.encode(prompt)) - max_no_tokens
            context = context[diff:]
            oo += 1
            if oo > 4:
                context = context[1:]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = tm.render(intro=intro, context=context, question=question)
        
    elif version in [14, 15, 16, 17, 18, 19, 20, 21, 22]:
        oo = 0
        tm = Template('{{context}} {{intro}}{{question}} If you can\'t find the answer, please respond "unanswerable"."')
        prompt = tm.render(intro=intro, context=context, question=question)    
        while len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            diff = len(tokenizer.encode(prompt)) - max_no_tokens
            context = context[diff:]
            oo += 1
            if oo > 4:
                context = context[1:]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = tm.render(intro=intro, context=context, question=question)
            
    elif version == 23:
        oo = 0
        prompt = "Where " + to_be + " " + character + " in the following text: " + context + " Answer: "
        while len(tokenizer.encode(prompt)) > max_no_tokens:
            context = tokenizer.encode(context)
            diff = len(tokenizer.encode(prompt)) - max_no_tokens
            context = context[diff:]
            oo += 1
            if oo > 4:
                context = context[1:]
            context = tokenizer.decode(context, skip_special_tokens=True)
            prompt = "Where " + to_be + " " + character + " in the following text: " + context + " Answer: "
        
    return prompt, context

# Get the Annotaions

In [6]:
all_annotations = {}

for item in dir_list_annotations:
    
    f = open(os.path.join(path_annotations, item), 'r')
    annotations = pd.read_csv(f, sep="\t")
    annotations = annotations.values
    f.close()
    
    all_annotations[item] = annotations

# Create the Dataset Pickle

In [None]:
dataset = {}
    
for k in range(1, num_templates+1):
    dataset[k] = []
    
i = 0

for item in dir_list_annotations:
        
    print(item)        

    f = open(os.path.join(path_andersen, item), 'r') 
    story = f.read()
    f.close()

    annotations = all_annotations[item]

    paragraphs = story.split("\n\n")
    paragraph = paragraphs[0]
    len_title = len(paragraph) + 2        

    for line in annotations:

        character = line[1]
        gold_answer = line[2]
        grammatical_number = line[3]

        for k in range(1, 24):

            data_point = {}

            y = line[0]
            x = y - 5120

            if x < len_title:
                text = story[len_title:y]

            else:
                x = story[x:y].find(" ") + x
                text = story[x:y]

            text = text_clean_ending(text)
            text = remove_new_lines(text)

            prompt, context2 = create_prompt_clipped(k, text, character, grammatical_number, 512)

            data_point["prompt"] = prompt
            data_point["gold_locations"] = gold_answer
            data_point["id"] = i

            dataset[k].append(data_point)
            
            i += 1

            

Token indices sequence length is longer than the specified maximum sequence length for this model (1244 > 512). Running this sequence through the model will result in indexing errors


Andersen_story11.txt
Andersen_story12.txt
Andersen_story13.txt
Andersen_story15.txt
Andersen_story16.txt
Andersen_story17.txt
Andersen_story18.txt
Andersen_story1.txt
Andersen_story2.txt
Andersen_story3.txt
Andersen_story5.txt
Andersen_story7.txt
Andersen_story8.txt
Andersen_story9.txt
Andersen_story10.txt


# Save the Dataset in a Pickle

In [None]:
with open("T5_tokenizer_dataset.txt", "wb") as f:
    pickle.dump(dataset, f)