<a href="https://colab.research.google.com/github/erikmcguire/textworld_light/blob/main/LIGHT_Quest_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Quest Generation

Steps:
1. Set data path for loading/saving files.
2. Install dependencies, import modules.
3. Choose options and run cell to load primary model, tokenizer.
4. Run cell which defines utility functions for processing data.
5.
    * A. Load data for processing if necessary.
    * B. Process data if necessary.
6. LM fine-tuning of quests given LIGHT persona + context prompts.
7. Generation
    * A. Single example generation
    * B. Bulk generation for analysis

### I. Define data path to drive folder

In [None]:
DATA_PTH = "/../content/drive/MyDrive/data/light_data/"

### II. Dependencies, Imports

In [None]:
!pip install transformers sentencepiece datasets &> /dev/null

In [None]:
from transformers import (T5TokenizerFast as T5Tokenizer, T5ForConditionalGeneration,
                          BartTokenizer, BartForConditionalGeneration)
from transformers import (AutoTokenizer, DataCollatorWithPadding, Seq2SeqTrainer,
                          DataCollatorForSeq2Seq, Seq2SeqTrainingArguments,
                          Trainer, TrainingArguments)

In [None]:
from datasets import load_dataset, DatasetDict
from collections import defaultdict
from datasets import load_from_disk

import pandas as pd
import random, json
import numpy as np
import datasets
import torch

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, interactive_output, fixed, interact_manual
from IPython.display import display, clear_output

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
from google.colab import data_table

data_table.enable_dataframe_formatter()

In [None]:
data_table.DataTable.num_rows_per_page = 10

In [None]:
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

In [None]:
# random.randint(1, 10000)

In [None]:
seed = 42
set_seed(seed)

### III. Load model, tokenizer

* ```sel``` - Choose BART vs. T5, base pre-trained or fine-tuned models.
* ```sz``` - Base size or large.
* ```g``` - Choose model fine-tuned all genders, male, female, or gender-neutral.

Choose options and run cell. Relies on models saved in ```light_data``` folder.

May compare to second model defined [here](#scrollTo=ujH228SC-gp8).

In [None]:
#@title Primary model, tokenizer selection

sel = "bart" #@param ["bart", "t5", "bart_custom", "t5_custom"]
sz = "base" #@param ["base", "large"]
g = "Male" #@param ["All", "Male", "Female", "Neutral"]
gd = {"Male": "m", "Female": "f", "Neutral": "n", "All": "all"}
g = gd[g]
sd = "7696" #@param [42, 7696, 7304]
cd = {"all": 1640, "m": 205, "f": 180, "n": 1305}
cp = cd[g]
pth = widgets.Text(f"{DATA_PTH}models/quest_model_{g}_{sd}/checkpoint-{cp}").value
display(pth)

if pth or "custom" not in sel:
    mt = {"t5": (f"t5-{sz}", (T5Tokenizer, T5ForConditionalGeneration)),
        "bart": (f"facebook/bart-{sz}", (BartTokenizer, BartForConditionalGeneration)),
        "bart_custom": (f"facebook/bart-{sz}", (BartTokenizer, BartForConditionalGeneration)),
        "t5_custom": (f"t5-{sz}", (T5Tokenizer, T5ForConditionalGeneration))}

    if sel in ["t5", "bart"]:
        tokenizer = mt[sel][1][0].from_pretrained(f"{mt[sel][0]}")
        model = mt[sel][1][1].from_pretrained(f"{mt[sel][0]}")
    else:
        if "t5" in sel:
            tpth = "t5-base"
        elif "bart_" in sel:
            tpth = "facebook/bart-base"
        tokenizer = mt[sel][1][0].from_pretrained(f"{tpth}")
        model = mt[sel][1][1].from_pretrained(pth)

    #tokenizer.pad_token = tokenizer.eos_token

Move to GPU if using for runtime:

In [None]:
model.cuda()
clear_output()

### IV. Run to define functions for data processing.

In [None]:
#@title Functions

def change_gender(x):
    if x["character"].split()[-1].lower() == "men" or "salesmen" in x["character"].lower():
        x["gender"] = "M"
    if "men and women" in x["character"].lower():
        x["gender"] = "N"
    return x

def make_quest(x):
    """Extract actions from timeline,
       re-insert goal action as present action."""
    now = x.goal
    l = pd.DataFrame(x.timeline).action.tolist()
    l.insert(3, now)
    x["quest"] = ", ".join(l).lower() # .split(",")
    return x

def make_questl(x):
    """Extract actions from timeline,
       re-insert goal action as present action.
       Return as list (vs. string list)."""
    now = x.goal
    l = pd.DataFrame(x.timeline).action.tolist()
    l.insert(3, now)
    x["questl"] = [s.lower() for s in l]
    return x

def replace_all(x):
    picker = {"gender-neutral": "N",
              "N": "N",
              "male": "M",
              "M": "M",
              "F": "F",
              "female": "F"}
    x["gender"] = picker[x["gender"]]
    return x

def get_verbs(x):
    """Get main verbs from action list."""
    # If loaded from disk, may need to eval(x["questl"])
    actions = list(filter(lambda x: x, x["questl"]))
    try:
        verbs = list(map(lambda x: x.strip().split()[0], actions))
    except:
        print(actions)
    x["verbs"] = verbs
    return x

def get_goalverb(x):
    x["gverb"] = x["goal"].split()[0]
    return x

def get_qs_flat(x: str) -> list:
    """Get query list to search agency and power."""
    df = dfqq[dfqq.gender == x]
    actions = []
    for ix in range(len(df)):
        verbs = df.verbs.iloc[ix]
        actions.extend(verbs)
    return list(map(lambda x: pluralize(x), actions))

def get_qs(x: str) -> list:
    """Get query list to search agency and power."""
    df = dfqq[dfqq.gender == x]
    actions = []
    for ix in range(len(df)):
        verbs = df.verbs.iloc[ix]
        actions.append(list(map(lambda x: pluralize(x), verbs)))
    return actions

def comb_csm(x):
    """Combine character and short motivation."""
    x["character_sm"] = f'{x["character"]} - {x["short_motivation"]}'
    return x

def comb_clsm(x):
    """Combine char + motivation w/ setting + desc."""
    x["clsm"] = f'{x["character_sm"]} - {x["location"]}'
    return x

def comb_sd(x):
    """Combine setting and description of location."""
    x["setdesc"] = f'{x["setting"]} - {x["description"]}'
    return x

def tok_func(x):
    """Tokenize character prompt, short motivation completion."""
    tok = tokenizer(x["character"], padding=True, truncation=True)
    x["input_ids"] = tok["input_ids"]
    x["attention_mask"] = tok["attention_mask"]
    labels = tokenizer(x['short_motivation'], padding=True, truncation=True)["input_ids"]
    labels = [label if label != tokenizer.pad_token_id else -100 for label in labels]
    x["labels"] = labels
    return x

def tok_func2(x):
    """Tokenize combined prompt, action sequence completion."""
    tok = tokenizer(x["clsm"], padding=True, truncation=True)
    x["input_ids"] = tok["input_ids"]
    x["attention_mask"] = tok["attention_mask"]
    labels = tokenizer(x['quest'], padding=True, truncation=True)["input_ids"]
    labels = [label if label != tokenizer.pad_token_id else -100 for label in labels]
    x["labels"] = labels
    return x

def filt_g(x, g):
    """Tokenize combined prompt, action sequence completion."""
    tok = tokenizer(x["clsm"], padding=True, truncation=True)
    x["input_ids"] = tok["input_ids"]
    x["attention_mask"] = tok["attention_mask"]
    labels = tokenizer(x['quest'], padding=True, truncation=True)["input_ids"]
    labels = [label if label != tokenizer.pad_token_id else -100 for label in labels]
    x["labels"] = labels
    return x

def get_char_name(x):
    """Get names of characters in locations."""
    l = []
    for cid in x["in_characters"]:
        name = dfc[dfc.character_id == cid].name
        l.append(name.values[0])
    x["characters"] = list(set(l))
    return x

def extract_room(x):
    """Assumes start location is after prefix."""
    desc = x["description"].split(".")[0]
    sub = "You are in "
    x["location"] = (desc[len(sub):].title())
    return x

def check_gender(x):
    """Extract crowdsourced gender of character or N/A."""
    words = list(map(lambda w: w.lower(), x["character"].split()))
    for w in words:
        res = gdf.loc[:, ["word", "gender"]][gdf.word == w]
        if len(res) > 0:
            rez = res.gender.values[0]
            x["gender"] = rez
            return x
    if len(res.gender.values) == 0:
        x["gender"] = "NA"
        return x

def check_gender_new(x):
    """Extract crowdsourced gender of character or N/A."""
    words = list(map(lambda w: w.lower(), x["character"].split()))
    for w in words:
        res = persona_id_df.loc[:, ["shortname", "gender"]][persona_id_df.shortname.str.lower() == w]
        if len(res) > 0:
            rez = res.gender.values[0]
            x["gender"] = rez
            return x
    if len(res.gender.values) == 0:
        x["gender"] = "N"
        return x

def get_agencies(x):
    d = dict()
    s = x["quest"]
    try:
        sett = set(map(lambda w: pluralize(w).replace("s,s", "es").replace(",s", "s"),
                   s.split()))
    except AttributeError:
        return None
    for w in sett:
        try:
            d[w] = rdf.query(f"verb == '{w}'").agency.values[0]
        except:
            pass
    x["agencies"] = d
    return x

def get_agencies_i(s):
    try:
        sett = set(map(lambda w: pluralize(w.split()[0]).replace("s,s", "es").replace(",s", "s"), s.split(",")))
    except AttributeError:
        return None

    d = {w: 0 for w in sett}
    dd = {"agency_pos": 1, "agency_equal": 0}
    for w in sett:
        rdf = agency_power_df[agency_power_df.verb.str.contains(w)].fillna("power_equal")
        rdf["agency"] = rdf.agency.apply(lambda x: dd.get(x, -1))
        try:
            d[w] = rdf.query(f"verb == '{w}'").agency.values[0]
        except:
            pass
    return d

### V. Data

#### A. Load original data for processing if not already processed and saved.

###### Load [environment](https://github.com/interactive-fiction-class/interactive-fiction-class-data/tree/master/) data
and create dataframe with locations as combined settings and descriptions, with associated characters within locations.

In [None]:
json_filename = f"{DATA_PTH}light_environment_train.json"
f = open(json_filename)
light_environment = json.load(f)

###### Load [gender](https://aclanthology.org/2020.emnlp-main.656.pdf) for Light

In [None]:
persona_id_df = pd.read_pickle(f"{DATA_PTH}updated_genderation.pkl")

##### Or create by loading gender annotations, combining with persona data in dataframe, save in order to load above:

In [None]:
!cp $DATA_PTH/genderation_bias.tar .
!tar -xf genderation_bias.tar

In [None]:
gender_df = pd.read_csv("gendered_list.tsv", sep="\t")

In [None]:
personas = json.load(open("/../content/data_to_release/light/personas.json",
                          'rb'))['old']

In [None]:
persona_map_by_no = dict()
for gender, lst in personas.items():
    for x in lst:
        x["gender"] = gender
        persona_map_by_no[int(x['char_id'])] = {k: v
                                                for k, v in x.items()
                                                if k not in ["flagged",
                                                             "char_id"]}

In [None]:
persona_id_df = pd.DataFrame.from_dict(persona_map_by_no, orient='index')
persona_id_df = persona_id_df.sort_index(ascending=True)

##### Load quest data

with persona information and quests, combining character and location information with quest action sequences.

In [None]:
dfq = pd.read_json(f"{DATA_PTH}light_quests.jsonl", lines=True)
quest_dataset = load_dataset("json",
                             data_files=f"{DATA_PTH}light_quests.jsonl",
                             split="train")

In [None]:
display(dfq.shape)
dfq.head()

(7486, 8)

Unnamed: 0,character,persona,description,goal,short_motivation,mid_motivation,long_motivation,timeline
0,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"[{'label': '2 hours ago', 'action': 'wear Arro..."
1,The Bedbug,I am a bug that lives in the bed of a small in...,You are in the Bedroom.\nThe bedroom is simple...,get wall,I need to get the wall so that I can crawl clo...,I will be going to the bed of the town baker s...,I am going to find a nice place where I can bu...,"[{'label': '1 hour ago', 'action': 'follow mic..."
2,A Gamekeeper,I am the gamekeeper of the countryside to the ...,"You are in the kitchen.\nNeat, and well kept. ...",get utensils,I need to skin a fox,I want a fox pelt to use in training my huntin...,I must train my dogs so that when the king vis...,"[{'label': '1 hour ago', 'action': 'go kitchen..."
3,The King,I am a King who rules a vast and mighty land. ...,You are in the The room at the top of the towe...,remove diamond ring,I want to give the diamond ring to the knight,I want the knight to take the diamond ring to ...,I hope to lead my country right during this tr...,"[{'label': '15 minutes ago', 'action': 'get Gl..."
4,The Witch,I only mastered one spell in witch school. I c...,You are in the Behind the Servant Quarters.\nD...,drop rock,I need to drop the rock so that I can pick up ...,I am going to look through the filth in the Se...,I am going to perform really well so that I ca...,"[{'label': '4 hours ago', 'action': 'wear unif..."


In [None]:
display(quest_dataset)
pd.DataFrame(quest_dataset[0])

Dataset({
    features: ['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation', 'long_motivation', 'timeline'],
    num_rows: 7486
})

Unnamed: 0,character,persona,description,goal,short_motivation,mid_motivation,long_motivation,timeline
0,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '2 hours ago', 'action': 'wear Arrow '}"
1,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '30 minutes ago', 'action': 'get Coin'}"
2,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '10 minutes ago', 'action': 'go Temp..."
3,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '10 minutes from now', 'action': 'dr..."
4,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '30 minutes from now', 'action': 'fo..."
5,The Empress,I am the ruler of three kingdoms. I am known f...,You are in the Temple main room.\nThe massive ...,give coin to monk,I want to give offering to the monk,I want the monk to offer prayers today for my ...,I hope to conquer a fourth kingdom in the comi...,"{'label': '1 hour from now', 'action': 'go The..."


#### B. Process LIGHT data

###### a. Customize and add columns with functions defined in [4](#scrollTo=IDV5U1wJGGZz).

In [None]:
dfqq = dfq.apply(make_quest, axis=1)
dfqq = dfqq.apply(make_questl, axis=1)
dfqq = dfqq.apply(get_verbs, axis=1)
dfqq = dfqq.apply(comb_csm, axis=1)
dfqq = dfqq.apply(extract_room, axis=1)
dfqq = dfqq.apply(comb_clsm, axis=1)
dfqq = dfqq.apply(check_gender_new, axis=1)
dfqq = dfqq.apply(replace_all, axis=1)
dfqq = dfqq.apply(get_goalverb, axis=1)
dfqq = dfqq.apply(get_agencies, axis=1)

##### b. Save if necessary:

In [None]:
# dfqq.to_csv(f"{DATA_PTH}dfqq_new.csv", index=False)

##### c. Combine with quest data

In [None]:
quest_datasett = quest_dataset.add_column("quest", dfqq.quest)
quest_datasett = quest_datasett.add_column("character_sm", dfqq.character_sm)
quest_datasett = quest_datasett.add_column("location", dfqq.location)
quest_datasett = quest_datasett.add_column("clsm", dfqq.clsm)
quest_datasett = quest_datasett.add_column("gender", dfqq.gender)
quest_datasett = quest_datasett.add_column("verbs", dfqq.verbs)
quest_datasett = quest_datasett.add_column("gverb", dfqq.gverb)

###### d. Split and tokenize

Split into training, validation, test sets and tokenize.

In [None]:
train_devtest = quest_datasett.train_test_split(shuffle = True,
                                                seed = seed, test_size=0.3)
qdev_test = train_devtest['test'].train_test_split(shuffle = True,
                                                   seed = seed, test_size=0.50)
display(train_devtest, qdev_test)
qtrain_dev_test_dataset = DatasetDict({
    'train': train_devtest['train'],
    'test': qdev_test['test'],
    'dev': qdev_test['train']})

In [None]:
# qtrain_dev_test_dataset.save_to_disk(f"{DATA_PTH}light_dataset")

In [None]:
tokenized_datasets = qtrain_dev_test_dataset.map(tok_func2, batched=True,
                            remove_columns=['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest', 'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])

### VI. Training

##### a. Load and tokenize processed splits after saving:

In [None]:
qtrain_dev_test_dataset = load_from_disk(f"{DATA_PTH}light_dataset")
tokenized_datasets = qtrain_dev_test_dataset.map(tok_func2, batched=True,
                            remove_columns=['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest', 'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])

In [None]:
mdataset = qtrain_dev_test_dataset.filter(lambda x: x["gender"].startswith("M"))
fdataset = qtrain_dev_test_dataset.filter(lambda x: x["gender"].startswith("F"))
ndataset = qtrain_dev_test_dataset.filter(lambda x: x["gender"].startswith("N"))

In [None]:
mdataset.save_to_disk(f"{DATA_PTH}light_mdataset")
fdataset.save_to_disk(f"{DATA_PTH}light_fdataset")
ndataset.save_to_disk(f"{DATA_PTH}light_ndataset")

In [None]:
fdataset = datasets.concatenate_datasets([fdataset["train"], fdataset["dev"], fdataset["test"]])

In [None]:
#tokenized_datasets.save_to_disk(f"{DATA_PTH}light_tokenized_dataset")
tokenized_datasets = load_from_disk(f"{DATA_PTH}light_tokenized_dataset")

In [None]:
#tokenized_mdatasets.save_to_disk(f"{DATA_PTH}light_tokenized_mdataset")
#tokenized_fdatasets.save_to_disk(f"{DATA_PTH}light_tokenized_fdataset")
#tokenized_ndatasets.save_to_disk(f"{DATA_PTH}light_tokenized_ndataset")

In [None]:
tokenized_mdatasets = load_from_disk(f"{DATA_PTH}light_tokenized_mdataset")
tokenized_fdatasets = load_from_disk(f"{DATA_PTH}light_tokenized_fdataset")
tokenized_ndatasets = load_from_disk(f"{DATA_PTH}light_tokenized_ndataset")

In [None]:
tokenized_mdatasets = mdataset.map(tok_func2, batched=True,
                            remove_columns=['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest', 'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])
tokenized_fdatasets = fdataset.map(tok_func2, batched=True,
                            remove_columns=['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest', 'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])
tokenized_ndatasets = ndataset.map(tok_func2, batched=True,
                            remove_columns=['character', 'persona', 'description', 'goal', 'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest', 'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])

In [None]:
shuffle_sd = 42 #@param {'type': 'integer'}
mdataset_train_subset = tokenized_mdatasets["train"].shuffle(seed=shuffle_sd).select(range(len(tokenized_fdatasets)))
ndataset_train_subset = tokenized_ndatasets["train"].shuffle(seed=shuffle_sd).select(range(len(tokenized_fdatasets)))
all_train_subset = tokenized_datasets["train"].shuffle(seed=shuffle_sd).select(range(len(tokenized_fdatasets)))

In [None]:
mdataset_train_subset["input_ids"]

In [None]:
#mdataset_train_subset.save_to_disk(f"{DATA_PTH}light_tokenized_mdataset_tok_train_subset")
#ndataset_train_subset.save_to_disk(f"{DATA_PTH}light_tokenized_ndataset_tok_train_subset")
#all_train_subset.save_to_disk(f"{DATA_PTH}light_tokenized_all_train_subset")

In [None]:
mdataset_train_subset = load_from_disk(f"{DATA_PTH}light_tokenized_mdataset_tok_train_subset")
ndataset_train_subset = load_from_disk(f"{DATA_PTH}light_tokenized_ndataset_tok_train_subset")
all_train_subset = load_from_disk(f"{DATA_PTH}light_tokenized_all_train_subset")

In [None]:
mdataset_train_subset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 570
})

In [None]:
ndataset_train_subset = load_from_disk(f"{DATA_PTH}light_ndataset")

In [None]:
display(
    qtrain_dev_test_dataset,
    tokenized_datasets,
    tokenized_mdatasets,
    tokenized_fdatasets,
    tokenized_ndatasets
)

##### b. Define/subclass trainer, trainer args, dataset collator, instantiate objects.

In [None]:
class LightTrainer(Seq2SeqTrainer):
        def compute_loss(self, model, inputs, return_outputs=False):
            """How the loss is computed by Trainer."""
            outputs = model(input_ids=inputs["input_ids"],
                            labels=inputs["labels"],
                            attention_mask=inputs["attention_mask"],
                            output_hidden_states=True,
                            output_attentions=True,
                            return_dict=True)
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            return (loss, outputs) if return_outputs else loss

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
training_args = Seq2SeqTrainingArguments(
                        output_dir="light-trainer",
                        do_train=True,
                        run_name=f'light_{seed}',
                        do_eval=True,
                        report_to='all',
                        evaluation_strategy="epoch",
                        logging_strategy="epoch",
                        save_strategy="epoch",
                        per_device_train_batch_size=16,
                        per_device_eval_batch_size=16,
                        learning_rate=5e-5,
                        num_train_epochs=5,
                        seed=seed,
                        logging_dir="./log")

trainer = LightTrainer(
    model,
    training_args,
    train_dataset=all_train_subset,
    eval_dataset=tokenized_datasets["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer
)

##### c. Train, save

In [None]:
trainer.train()

In [None]:
%cp -r light-trainer $DATA_PTH/models/quest_model_all_subset_{{seed}}/

In [None]:
%cp -r log $DATA_PTH/models/quest_model_all_subset_{{seed}}/

In [None]:
#trainer.save_model("light-checkpoints")

#### Inference with trainer

In [None]:
trainer.evaluate(eval_dataset=trainer.eval_dataset)

In [None]:
trainer.evaluate(eval_dataset=tokenized_datasets["test"])

### VII. Generation

In [None]:
model.eval()

#### A. Single example generation

##### a. Load data for prompts

In [None]:
dfqq = pd.read_pickle(f"{DATA_PTH}dfqq_gmn.pkl")
qtrain_dev_test_dataset = load_from_disk(f"{DATA_PTH}light_dataset")

In [None]:
# dfqq.drop(["timeline", "character_sm", "clsm", "gverb", "verbs", "agencies",
  #          "locations", "characters", "foods", "objects", "questl"], axis=1)

##### b. Get example prompt

In [None]:
#@title ###### i. Select prompt
dchar_idx, ids, char_idx, csm, character, short_motivation, location, qst = None, None, None, None, None, None, None, None
def f(ix=559):
    global dchar_idx, char_idx, csm, character, short_motivation, location, ids, qst
    char_idx= ix
    character = qtrain_dev_test_dataset["train"][ix]['character']
    csm = qtrain_dev_test_dataset["train"][ix]['character_sm']
    gender = qtrain_dev_test_dataset["train"][ix]['gender']
    short_motivation = qtrain_dev_test_dataset["train"][ix]['short_motivation']
    qst = qtrain_dev_test_dataset["train"][ix]['quest']
    location = qtrain_dev_test_dataset["train"][ix]["location"]
    input_context_vector = character + " - " + location
    input_context_vector = csm + " - " + location
    dchar_idx = dfqq[dfqq.character_sm.str.contains(csm)].index.tolist()[0]
    target_seq = short_motivation
    target_seq = qst
    inputs = tokenizer([input_context_vector], padding=True, truncation=True, return_tensors="pt")
    targets = tokenizer([target_seq], padding=True, truncation=True, return_tensors="pt")

    ids = inputs["input_ids"].cuda()
    res = tokenizer.decode(ids[0], skip_special_tokens=True)
    print(f"Gender: {gender}")
    print(res)

slider = interact(f, idx=widgets.IntSlider(value=0, min=0, max=7485,
                        continuous_update=True))

interactive(children=(IntSlider(value=559, description='ix', max=1677, min=-559), Output()), _dom_classes=('wi…

###### ii. Explore:

In [None]:
dfqq.loc[dchar_idx:dchar_idx].loc[:, [i for i in dfqq.columns if "motiv" in i]].values

In [None]:
print("\n".join(dfqq.loc[dchar_idx:dchar_idx].clsm.tolist()[0].split(", ")))
print("\n")
vrbs = pd.DataFrame(dfqq.loc[dchar_idx]).T.verbs.tolist()[0]
ll = pd.DataFrame(dfqq.loc[dchar_idx]).T.questl.tolist()[0]
try:
    print("\n".join(eval(ll)))
except:
    print("\n".join(ll))

In [None]:
for c in dfqq.loc[dchar_idx:dchar_idx].columns:
    if "motiv" in c:
        print(dfqq.loc[dchar_idx:dchar_idx, c].values[0])

In [None]:
for ix in dfqq[dfqq.character_sm.str.contains(csm)].index.tolist():
    print(dfqq.loc[ix:ix].loc[:, [i for i in dfqq.columns if "motiv" in i]].values)

##### c. Generate from example prompt

###### a. Sampling

In [None]:
# Make sure _custom model loaded else will return prompt
with torch.no_grad():
    genids = model.generate(ids, repetition_penalty=1.4, temperature=0.75, num_beams=1, # else beam_sample()
                            top_p=0.25, top_k=30, do_sample=True, min_length=0, max_length=100)
res = tokenizer.batch_decode(genids, skip_special_tokens=True)
print("\n".join(res[0].split(", ")))

go valley of doom
follow guest
hit visitor
get bag from ground
put bag on trunk
go goblin lair
eat creature's body


###### b. Beam search

In [None]:
with torch.no_grad():
    genids = model.generate(
        ids,
        max_length=100,
        num_beams=5,
        early_stopping=True
    )
res = tokenizer.batch_decode(genids, skip_special_tokens=True)
print("\n".join(res[0].split(", ")))

go valley of doom
follow guest
hit guest
get guest's possessions
go goblin lair
put guest's belongings on goblin lair


###### c. Scratch cell

In [None]:
ids = tokenizer.encode(["King - Castle - <mask> sword"],
              padding=True, truncation=True, return_tensors="pt").cuda()
genids = model.generate(
    ids,
    max_length=1000,
    do_sample=True,
)
res = tokenizer.batch_decode(genids, skip_special_tokens=True)
print("\n".join(res[0].split(", ")))

This code would extract encoder states for representation analysis:

In [None]:
encoder_last_hidden_state = model(ids, output_hidden_states=True).encoder_last_hidden_state
encoder_hidden_states = model(ids, output_hidden_states=True).encoder_hidden_states
decoder_hidden_states = model(ids, output_hidden_states=True).decoder_hidden_states

##### d. Generate with model B

* ```sel``` - Choose BART vs. T5, base pre-trained or fine-tuned models.
* ```sz``` - Base size or large.
* ```g``` - Choose model fine-tuned all genders, male, female, or gender-neutral.

Choose options and run cell for comparison to [another model](#scrollTo=ZXuD31Bc7XjH). Relies on models saved in ```light_data``` folder.

In [None]:
#@title Load comparison model, tokenizer

selB = "bart_custom" #@param ["bart", "t5", "bart_custom", "t5_custom"]
szB = "base" #@param ["base", "large"]
gB = "All" #@param ["All", "Male", "Female", "Neutral"]
gdB = {"Male": "m", "Female": "f", "Neutral": "n", "All": "all"}
gB = gdB[gB]
cdB = {"all": 1640, "m": 205, "f": 180, "n": 1305}
cpB = cdB[gB]
pthB = widgets.Text(f"{DATA_PTH}models/quest_model_{gB}/light-trainer_{gB}/checkpoint-{cpB}").value
display(pthB)

if pthB or "custom" not in selB:
    mtB = {"t5": (f"t5-{szB}", (T5Tokenizer, T5ForConditionalGeneration)),
        "bart": (f"facebook/bart-{szB}", (BartTokenizer, BartForConditionalGeneration)),
        "bart_custom": (f"facebook/bart-{szB}", (BartTokenizer, BartForConditionalGeneration)),
        "t5_custom": (f"t5-{szB}", (T5Tokenizer, T5ForConditionalGeneration))}

    if selB in ["t5", "bart"]:
        tokenizerB = mtB[selB][1][0].from_pretrained(f"{mtB[selB][0]}")
        modelB = mtB[selB][1][1].from_pretrained(f"{mtB[selB][0]}")
    else:
        if "t5" in selB:
            tpthB = "t5-base"
        elif "bart_" in selB:
            tpthB = "facebook/bart-base"
        tokenizerB = mtB[selB][1][0].from_pretrained(f"{tpthB}")
        modelB = mtB[selB][1][1].from_pretrained(pthB)

    #tokenizerB.pad_token = tokenizerB.eos_token

Move to GPU if using for runtime:

In [None]:
modelB.cuda()
clear_output()

###### a. Sampling

In [None]:
# Make sure _custom model loaded else will return prompt
genidsB = modelB.generate(ids, repetition_penalty=1.4, temperature=0.75, num_beams=1,
                        top_p=0.25, top_k=30, do_sample=True, min_length=0, max_length=100)
resB = tokenizerB.batch_decode(genidsB, skip_special_tokens=True)
print("\n".join(resB[0].split(", ")))

###### b. Beam search

In [None]:
genidsB = modelB.generate(
    ids,
    max_length=100,
    num_beams=5,
    early_stopping=True
)
resB = tokenizerB.batch_decode(genidsB, skip_special_tokens=True)
print("\n".join(res[0].split(", ")))

This code would extract encoder states for representation analysis:

In [None]:
encoder_last_hidden_stateB = modelB(ids, output_hidden_states=True).encoder_last_hidden_state
encoder_hidden_statesB = modelB(ids, output_hidden_states=True).encoder_hidden_states
decoder_hidden_statesB = modelB(ids, output_hidden_states=True).decoder_hidden_states

#### B. Bulk generate quests

Steps:
1. Load datasets
2. Load data collator
3. For each seed, model type, split:
    * a.  run with sampling, beam
    * b.  save results
    

##### 1. Load splits, tokenize

In [None]:
qtrain_dev_test_dataset = load_from_disk(f"{DATA_PTH}light_dataset")

In [None]:
tokenized_datasets = load_from_disk(f"{DATA_PTH}light_tokenized_dataset")

In [None]:
tokenized_datasets = qtrain_dev_test_dataset.map(tok_func2, batched=True, load_from_cache_file=False,
                            remove_columns=['character', 'persona', 'description', 'goal',
                                            'short_motivation', 'mid_motivation',
                                            'long_motivation', 'timeline', 'quest',
                                            'character_sm', 'location',
                                            'clsm', 'gender', 'verbs', 'gverb'])

##### 2. Load collator for data loader iteration of prompts

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

##### 3. Generate, merge, save

In [None]:
#@title ###### 3a., 3d. Generate quests given prompts for split, method
save_file = "false" #@param ["true", "false"]
for gen_sd in [42, 7304, 7696]:
    set_seed(gen_sd)
    save_file = "true"
    sel = "bart_custom"
    sz = "base"
    cd = {"all": 1640, "m": 205, "f": 180, "n": 1305,
          "m_subset": 180, "n_subset": 180, "all_subset": 180}
    print(f"Set seed to {gen_sd}.")
    for g in cd.keys():
        cp = cd[g.lower()]
        pth = f"{DATA_PTH}models/quest_model_{g.lower()}_{gen_sd}/checkpoint-{cp}"
        mt = {"bart_custom": (f"facebook/bart-{sz}", (BartTokenizer, BartForConditionalGeneration))}
        tpth = "facebook/bart-base"
        tokenizer = mt[sel][1][0].from_pretrained(f"{tpth}")
        model = mt[sel][1][1].from_pretrained(pth)
        model.cuda()
        model.eval()
        print("Loaded model.")
        for split in ["train", "dev", "test"]:
            dff_orig = pd.DataFrame.from_dict(dict(zip(qtrain_dev_test_dataset[split]["clsm"],
                                                    qtrain_dev_test_dataset[split]["quest"])),
                                            orient="index")
            dff_orig.reset_index(level=0, inplace=True)
            dff_orig.columns = ["prompt", "quest_orig"]

            data_loader = torch.utils.data.DataLoader(tokenized_datasets[split],
                                                    collate_fn=data_collator,
                                                    batch_size=16, shuffle=False)
            dff = pd.DataFrame()
            for cnt, batch in enumerate(data_loader):
                ids = batch["input_ids"].cuda()
                prompts = tokenizer.batch_decode(ids, skip_special_tokens=True)
                with torch.no_grad():
                    genids = model.generate(ids, do_sample=True,
                                            repetition_penalty=1.4,
                                            num_beams=1,
                                            temperature=0.75,
                                            top_p=0.25, top_k=30,
                                            min_length=0, max_length=100)
                    genidsb = model.generate(
                        ids,
                        do_sample=False,
                        max_length=100,
                        num_beams=5,
                        early_stopping=True
                    )
                res = tokenizer.batch_decode(genids,
                                             skip_special_tokens=True)
                res2 = tokenizer.batch_decode(genidsb,
                                              skip_special_tokens=True)
                d = dict(zip(prompts, zip(res, res2)))
                df = pd.DataFrame.from_dict(d, orient="index")
                df.reset_index(level=0, inplace=True)
                df.columns = ["prompt", "quest_samp", "quest_beam"]
                dff = pd.concat([dff, df], axis=0, ignore_index=True)
            dff = pd.merge(dff, dff_orig, on="prompt")
            if save_file == "true":
                dff.to_pickle(f"{DATA_PTH}generated_{split}_merged_{g.lower()}_{gen_sd}.pkl")

In [None]:
#@title Load merged continuations
sed = "7696" #@param [42, 7304, 7696]
split = "test" #@param ["train", "dev", "test"]
gend = "n_subset" #@param ["all", "m", "f", "n", "m_subset", "n_subset", "all_subset"]
dff = pd.read_pickle(f"{DATA_PTH}generated_{split}_merged_{gend}_{sed}.pkl")
dff