<a href="https://colab.research.google.com/github/erikmcguire/textworld_light/blob/main/TextWorld_LIGHT_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Dependencies and Imports

##### Installs - Restart after running (before imports)

In [None]:
%%shell
# Ubuntu no longer distributes chromium-browser outside of snap
#
# Proposed solution: https://askubuntu.com/questions/1204571/how-to-install-chromium-without-snap

# Add debian buster
cat > /etc/apt/sources.list.d/debian.list <<'EOF'
deb [arch=amd64 signed-by=/usr/share/keyrings/debian-buster.gpg] http://deb.debian.org/debian buster main
deb [arch=amd64 signed-by=/usr/share/keyrings/debian-buster-updates.gpg] http://deb.debian.org/debian buster-updates main
deb [arch=amd64 signed-by=/usr/share/keyrings/debian-security-buster.gpg] http://deb.debian.org/debian-security buster/updates main
EOF

# Add keys
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys DCC9EFBF77E11517
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 648ACFD622F3D138
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 112695A0E562B32A

apt-key export 77E11517 | gpg --dearmour -o /usr/share/keyrings/debian-buster.gpg
apt-key export 22F3D138 | gpg --dearmour -o /usr/share/keyrings/debian-buster-updates.gpg
apt-key export E562B32A | gpg --dearmour -o /usr/share/keyrings/debian-security-buster.gpg

# Prefer debian repo for chromium* packages only
# Note the double-blank lines between entries
cat > /etc/apt/preferences.d/chromium.pref << 'EOF'
Package: *
Pin: release a=eoan
Pin-Priority: 500


Package: *
Pin: origin "deb.debian.org"
Pin-Priority: 300


Package: chromium*
Pin: origin "deb.debian.org"
Pin-Priority: 700
EOF


In [None]:
!apt-get update &> /dev/null
!apt-get install chromium chromium-driver &> /dev/null
!pip3 install selenium &> /dev/null
!pip install urllib3 &> /dev/null
!pip install textworld textworld[vis] &> /dev/null
!pip3 install deepsig &> /dev/null
!pip3 install scipy==1.10.0 &> /dev/null
!pip install pingouin pyyaml==5.4.1 &> /dev/null

##### Imports

In [None]:
import warnings
warnings.filterwarnings(category=UserWarning,
                                            action='ignore')
warnings.filterwarnings(category=DeprecationWarning,
                                            action='ignore')
import pandas as pd
import glob
from os.path import join as pjoin
from collections import OrderedDict, defaultdict, Counter
import textworld
import locale
from textworld import GameMaker, g_rng

from textworld.generator.data import KnowledgeBase
from textworld.logic import GameLogic
from textworld.generator.game import GameOptions, GrammarOptions
from textworld.generator.text_grammar import Grammar
from textworld.generator.text_grammar import Grammar as GrammarO
from textworld.envs.wrappers import Recorder
from textworld.generator.game import Game, World, Quest, Event, EntityInfo

import textworld.gym
from sklearn.model_selection import train_test_split
import shutil
import gym
from more_itertools import powerset

from time import time
import random
import re
import sys
import os
import copy

import ipywidgets as widgets
from ipywidgets import interact, interactive, interactive_output, fixed, interact_manual
from IPython.display import display

from selenium import webdriver
from selenium.webdriver.chrome.options import Options

from deepsig import aso

from google.colab import runtime


import scipy
scipy.__version__
import spacy
from scipy import stats
from scipy.stats import shapiro, kstest
from collections import Counter

import numpy as np
from itertools import product

### Data

In [None]:
DATA_PTH = "/../content/drive/MyDrive/data/light_data/"
dfqq_quest_global_graph = pd.read_pickle(f"{DATA_PTH}dfqq_quest_global_graph.pkl")
vod_df_qs = pd.read_pickle(f"{DATA_PTH}vod_df_qs.pkl")

### Parse objects, affordances, create data files used above

In [None]:
import spacy
import random
import locale

In [None]:
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
!pip install spacy-transformers &> /dev/null

In [None]:
!python3 -m spacy download en_core_web_trf

In [None]:
import spacy_transformers

In [None]:
spacy.prefer_gpu()
nlp = spacy.load('en_core_web_trf')

In [None]:
#@title Utils

vset = {'drink','drop','eat','follow','get','give',
        'go','hit','hug','put','remove','steal',
        'use','wear','wield'}

def get_object(doc):
    # https://subscription.packtpub.com/book/data/9781838987312/2/ch02lvl1sec16/extracting-subjects-and-objects-of-the-sentence
    for token in doc:
        if ("dobj" in token.dep_):
            subtree = list(token.subtree)
            start = subtree[0].i
            end = subtree[-1].i + 1
            return doc[start:end]

def get_o(vo, doc):
    # The 15 Light commands
    vset = ['drink','drop','eat','follow','get','give',
            'go','hit','hug','put','remove','steal',
            'use','wear','wield']
    try:
        nc = next(doc.noun_chunks)
    except:
        pass
    o = None
    for token in doc:
        if str(token) in vset and str(token.dep_) != "ROOT":
            # probably needs det if first verb not root
            vo2 = vo.split()[0] + " a " + " ".join(vo.split()[1:])
            doc = nlp(vo2)
            for token in doc:
                # try def vs indef articles if still does't work
                if str(token) in vset and str(token.dep_) != "ROOT":
                    vo3 = vo.split()[0] + " the " + " ".join(vo.split()[1:])
                    doc = nlp(vo3)
    for token in doc:
        if str(token) in vset and str(token.dep_) != "ROOT":
            # give up and use latter part of original action
            o = " ".join(vo.split()[1:])
        elif str(token) in vset and str(token.dep_) == "ROOT":
            o = get_object(doc)
            o = str(o)
            o = o.replace("the ", "").replace("a ", "")
            try:
                nctr = nc.text.replace("the ", "").replace("a ", "")
                if nctr.split(" ")[0] in vset:
                    nctr = " ".join(nctr.split(" ")[1:])
                if o not in nctr:
                    o = nctr
            except:
                pass
    if o == "None":
        o = " ".join(vo.split()[1:])
    return o

vset = {'drink','drop','eat','follow','get','give',
            'go','hit','hug','put','remove','steal',
            'use','wear','wield'}

def get_vod_qs(df, ky = "questl"):
    """Create counts of commands (action-object pairings) per gender,
    e.g. male: wield+sword (vod_df["M"]["wield"]["sword"])
    occurs 58 times
    """
    vod = {g: {v: dict()
            for v in vset}
       for g in ["M", "F", "N"]}
    # Count
    for g in ["M", "F", "N"]:
        vo_pairs_g = df[df.gender == g]
        for ix in vo_pairs_g.index:
            vol = vo_pairs_g[ky][ix]
            for vo in vol:
                v = vo.split()[0]
                doc = nlp(vo)
                o = get_o(vo, doc)
                if not o.strip():
                    o = " ".join(vo.split()[1:])
                if not o in vod[g][v].keys():
                    vod[g][v][o] = [ix]
                else:
                    vod[g][v][o].append(ix)
    return vod

def get_simplified_qs(vvod_df):
    vod_df = vvod_df.copy(deep=True)
    for g in ["M", "F", "N"]:
        objects = set() # create initial list of single word objects
        for v, od in vod_df[g].items():
            for o, qixs in od.items():
                if len(o.split(" ")) == 1:
                    objects.add(o)
        # replace instances where duplicates hidden by modifiers
        # e.g. as in multiword yet contains object above
        # ex: 'hot tea' can be seen as dupe of 'tea'
        for v, od in vod_df[g].items():
            new_od = copy.deepcopy(od) # can't change od size during loop
            for o, qixs in od.items():
                maybe_add = []
                if len(o.split(" ")) > 1: # ex: 'hot tea'
                    old_o = o
                    o = o.split(" because ")[0]
                    o = o.split(" and ")[0]
                    o = o.split(" from ")[0]
                    o = o.split(" to ")[0]
                    o = o.split(" with ")[0]
                    o = o.split(" in ")[0]
                    o = o.split(" inside ")[0]
                    o = o.split(" into ")[0]
                    o = o.split(" on ")[0]
                    o = o.split(" onto ")[0]
                    for el in o.split(" "): # ex: ['hot', 'tea']
                        if el in objects: # e.g. 'tea' in objects
                            # Assume match can be replaced
                            # e.g. 'hot tea' effectively replaced by 'tea'
                            # by extending 'tea' quest indices and removing 'hot tea'
                            maybe_add.append((el, qixs))
                    if len(maybe_add) > 1:
                        el, qixs = maybe_add[-1] # could go by more common, but use final, assume prev are modifiers
                    if not el in new_od.keys():
                        new_od[el] = [qixs]
                    else:
                        new_od[el].extend(qixs) # increment simple
                    if old_o in new_od.keys():
                        new_od.pop(old_o) # remove multiword
            vod_df[g][v] = new_od
    vod = vod_df.to_dict()
    return vod_df, vod

def get_ix2v_global(df):
    ix2v = {ix: {"pairs": dict()} for ix in dfqq.index}
    o2v = dict()
    for g in ["M", "F", "N"]:
        o2v[g] = dict()
        for v, od in vod_df_qs[g].items():
            for o, _ in od.items():
                if o not in o2v[g].keys():
                    o2v[g][o] = {v}
                else:
                    o2v[g][o].add(v)
    for g in ["M", "F", "N"]:
        for v, od in vod_df_qs[g].items():
            for o, qixs in od.items():
                for ix in qixs:
                    if type(ix) == int:
                        if o not in ix2v[ix]["pairs"].keys():
                            ix2v[ix]["pairs"][o] = o2v[g][o]
                    elif type(ix) == list:
                        while type(ix) == list:
                            ix = ix[0]
                            if o not in ix2v[ix]["pairs"].keys():
                                ix2v[ix]["pairs"][o] = o2v[g][o]
    return ix2v

def get_ix2v(vod_df_qs):
    ix2v = {ix: {"pairs": dict()} for ix in dfqq.index}
    for g in ["M", "F", "N"]:
        for v, od in vod_df_qs[g].items():
            for o, qixs in od.items():
                for ix in qixs:
                    if type(ix) == int:
                        if o not in ix2v[ix]["pairs"].keys():
                            ix2v[ix]["pairs"][o] = {v}
                        else:
                            ix2v[ix]["pairs"][o].add(v)
                    elif type(ix) == list:
                        while type(ix) == list:
                            ix = ix[0]
                            if o not in ix2v[ix]["pairs"].keys():
                                ix2v[ix]["pairs"][o] = {v}
                            else:
                                ix2v[ix]["pairs"][o].add(v)
    return ix2v

In [None]:
#vod_df_qs = pd.DataFrame.from_dict(get_vod_qs(dfqq, "questl"))

In [None]:
# ix2v_df = pd.DataFrame.from_dict(get_ix2v(vod_df_qs))
# ix2v_df = ix2v_df.T
ix2v_df = pd.read_pickle(f"{DATA_PTH}ix2v.pkl")

In [None]:
ix2v_df_global = pd.DataFrame.from_dict(get_ix2v_global(vod_df_qs))
ix2v_df_global = ix2v_df_global.T.rename({"pairs": "pairs_global"}, axis=1)

In [None]:
#vod_df_qss, vod_qs = get_simplified_qs(vod_df_qs)
vod_df_qss = pd.read_pickle(f"{DATA_PTH}vod_df_qss.pkl")

In [None]:
merged_df = pd.read_pickle(f"{DATA_PTH}merged_dfo.pkl") # json vers of graph

In [None]:
dfqq_global = dfqq.join(ix2v_df_global)
dfqq_global_graph = dfqq_global.merge(merged_df, how="left", on="quest", suffixes=(None, "_y"))

In [None]:
# dfqq_quest_based = dfqq.join(ix2v_df)

In [None]:
dfqq = pd.read_pickle(f"{DATA_PTH}dfqq_gmn.pkl")

In [None]:
dfqq_quest_global = dfqq.join(ix2v_df).join(ix2v_df_global)

In [None]:
dfqq_quest_global_graph = dfqq_quest_global.merge(merged_df, how="left", on="quest", suffixes=(None, "_y"))

In [None]:
# ix2v_df.T.to_pickle(f"{DATA_PTH}ix2v.pkl")
# ix2v_df_global.to_pickle(f"{DATA_PTH}ix2v_global.pkl")
# vod_df_qs.to_pickle(f"{DATA_PTH}vod_df_qs.pkl")
# vod_df_qss.to_pickle(f"{DATA_PTH}vod_df_qss.pkl")
# dfqq_quest_based.to_pickle(f"{DATA_PTH}dfqq_quest_based.pkl")
# dfqq_quest_global.to_pickle(f"{DATA_PTH}dfqq_quest_global.pkl")
# dfqq_global.to_pickle(f"{DATA_PTH}dfqq_global.pkl")
# dfqq_global_graph.to_pickle(f"{DATA_PTH}dfqq_global_graph.pkl")
# dfqq_quest_global_graph.to_pickle(f"{DATA_PTH}dfqq_quest_global_graph.pkl")
# vod_df_qss.to_pickle(f"{DATA_PTH}vod_df_qs_simplified.pkl") # after running fn to replace multiword on loaded vod_df.pkl

### Logic, grammar for game creation

#### Set custom logic (broken across cells due to length)

In [None]:
#@title twl 1

twl = '''
# container
type c : t {
    predicates {
        open(c);
        closed(c);
        locked(c);

        in(o, c);
    }

    rules {
        lock/c   :: $at(P, r) & $at(c, r) & $in(k, I) & $match(k, c) & closed(c) -> locked(c);
        unlock/c :: $at(P, r) & $at(c, r) & $in(k, I) & $match(k, c) & locked(c) -> closed(c);

        open/c  :: $at(P, r) & $at(c, r) & closed(c) -> open(c);
        close/c :: $at(P, r) & $at(c, r) & open(c) -> closed(c);
    }

    reverse_rules {
        lock/c :: unlock/c;
        open/c :: close/c;
    }

    constraints {
        c1 :: open(c)   & closed(c) -> fail();
        c2 :: open(c)   & locked(c) -> fail();
        c3 :: closed(c) & locked(c) -> fail();
    }

    inform7 {
        type {
            kind :: "container";
            definition :: "containers are openable, lockable and fixed in place. containers are usually closed.";
        }

        predicates {
            open(c) :: "The {c} is open";
            closed(c) :: "The {c} is closed";
            locked(c) :: "The {c} is locked";

            in(o, c) :: "The {o} is in the {c}";
        }

        commands {
            open/c :: "open {c}" :: "opening the {c}";
            close/c :: "close {c}" :: "closing the {c}";

            lock/c :: "lock {c} with {k}" :: "locking the {c} with the {k}";
            unlock/c :: "unlock {c} with {k}" :: "unlocking the {c} with the {k}";
        }
    }
}

# Inventory
type I {
    predicates {
        in(o, I);
    }

    rules {
        inventory :: at(P, r) -> at(P, r);  # Nothing changes.

        take :: $at(P, r) & at(o, r) -> in(o, I);
        drop :: $at(P, r) & in(o, I) -> at(o, r);

        take/c :: $at(P, r) & $at(c, r) & $open(c) & in(o, c) -> in(o, I);
        insert :: $at(P, r) & $at(c, r) & $open(c) & in(o, I) -> in(o, c);

        take/s :: $at(P, r) & $at(s, r) & on(o, s) -> in(o, I);
        put    :: $at(P, r) & $at(s, r) & in(o, I) -> on(o, s);

        examine/I :: in(o, I) -> in(o, I);  # Nothing changes.
        examine/s :: at(P, r) & $at(s, r) & $on(o, s) -> at(P, r);  # Nothing changes.
        examine/c :: at(P, r) & $at(c, r) & $open(c) & $in(o, c) -> at(P, r);  # Nothing changes.
    }

    reverse_rules {
        inventory :: inventory;

        take :: drop;
        take/c :: insert;
        take/s :: put;

        examine/I :: examine/I;
        examine/s :: examine/s;
        examine/c :: examine/c;
    }

    inform7 {
        predicates {
            in(o, I) :: "The player carries the {o}";
        }

        commands {
            take :: "take {o}" :: "taking the {o}";
            drop :: "drop {o}" :: "dropping the {o}";

            take/c :: "take {o} from {c}" :: "removing the {o} from the {c}";
            insert :: "insert {o} into {c}" :: "inserting the {o} into the {c}";

            take/s :: "take {o} from {s}" :: "removing the {o} from the {s}";
            put :: "put {o} on {s}" :: "putting the {o} on the {s}";

            inventory :: "inventory" :: "taking inventory";

            examine/I :: "examine {o}" :: "examining the {o}";
            examine/s :: "examine {o}" :: "examining the {o}";
            examine/c :: "examine {o}" :: "examining the {o}";
        }
    }
}

# supporter
type s : t {
    predicates {
        on(o, s);
    }

    inform7 {
        type {
            kind :: "supporter";
            definition :: "supporters are fixed in place.";
        }

        predicates {
            on(o, s) :: "The {o} is on the {s}";
        }
    }
}



# door
type d : t {
    predicates {
        open(d);
        closed(d);
        locked(d);

        link(r, d, r);
    }

    rules {
        lock/d   :: $at(P, r) & $link(r, d, r') & $link(r', d, r) & $in(k, I) & $match(k, d) & closed(d) -> locked(d);
        unlock/d :: $at(P, r) & $link(r, d, r') & $link(r', d, r) & $in(k, I) & $match(k, d) & locked(d) -> closed(d);

        open/d   :: $at(P, r) & $link(r, d, r') & $link(r', d, r) & closed(d) -> open(d) & free(r, r') & free(r', r);
        close/d  :: $at(P, r) & $link(r, d, r') & $link(r', d, r) & open(d) & free(r, r') & free(r', r) -> closed(d);

        examine/d :: at(P, r) & $link(r, d, r') -> at(P, r);  # Nothing changes.
    }

    reverse_rules {
        lock/d :: unlock/d;
        open/d :: close/d;

        examine/d :: examine/d;
    }

    constraints {
        d1 :: open(d)   & closed(d) -> fail();
        d2 :: open(d)   & locked(d) -> fail();
        d3 :: closed(d) & locked(d) -> fail();

        # A door can't be used to link more than two rooms.
        link1 :: link(r, d, r') & link(r, d, r'') -> fail();
        link2 :: link(r, d, r') & link(r'', d, r''\') -> fail();

        # There's already a door linking two rooms.
        link3 :: link(r, d, r') & link(r, d', r') -> fail();

        # There cannot be more than four doors in a room.
        too_many_doors :: link(r, d1: d, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();

        # There cannot be more than four doors in a room.
        dr1 :: free(r, r1: r) & link(r, d2: d, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
        dr2 :: free(r, r1: r) & free(r, r2: r) & link(r, d3: d, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
        dr3 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & link(r, d4: d, r4: r) & link(r, d5: d, r5: r) -> fail();
        dr4 :: free(r, r1: r) & free(r, r2: r) & free(r, r3: r) & free(r, r4: r) & link(r, d5: d, r5: r) -> fail();

        free1 :: link(r, d, r') & free(r, r') & closed(d) -> fail();
        free2 :: link(r, d, r') & free(r, r') & locked(d) -> fail();
    }

    inform7 {
        type {
            kind :: "door";
            definition :: "door is openable and lockable.";
        }

        predicates {
            open(d) :: "The {d} is open";
            closed(d) :: "The {d} is closed";
            locked(d) :: "The {d} is locked";
            link(r, d, r') :: "";  # No equivalent in Inform7.
        }

        commands {
            open/d :: "open {d}" :: "opening {d}";
            close/d :: "close {d}" :: "closing {d}";

            unlock/d :: "unlock {d} with {k}" :: "unlocking {d} with the {k}";
            lock/d :: "lock {d} with {k}" :: "locking {d} with the {k}";

            examine/d :: "examine {d}" :: "examining {d}";
        }
    }
}

# object
type o : t {
    predicates {
        wearable(o);
        wieldable(o);
        huggable(o);
        hittable(o);
        followable(o);
        worn(o);
        wielded(o);
    }
    constraints {
        obj1 :: in(o, I) & in(o, c) -> fail();
        obj2 :: in(o, I) & on(o, s) -> fail();
        obj3 :: in(o, I) & at(o, r) -> fail();
        obj4 :: in(o, c) & on(o, s) -> fail();
        obj5 :: in(o, c) & at(o, r) -> fail();
        obj6 :: on(o, s) & at(o, r) -> fail();
        obj7 :: at(o, r) & at(o, r') -> fail();
        obj8 :: in(o, c) & in(o, c') -> fail();
        obj9 :: on(o, s) & on(o, s') -> fail();
    }

    rules {
        wear :: $in(o, I) & $wearable(o) -> worn(o);
        wield :: $in(o, I) & $wieldable(o) -> wielded(o);
        hit :: at(P, r) & $hittable(o) & $at(o, r) -> at(P, r);
        hug :: at(P, r) & $huggable(o) & $at(o, r) -> at(P, r);
        follow :: at(P, r) & $followable(o) & $at(o, r) -> at(P, r);
        use :: in(o, I) -> in(o, I);
    }


    inform7 {
        type {
            kind :: "object-like";
            definition :: "object-like is portable. object-like can be huggable. object-like can be hittable. object-like can be followable. object-like can be wearable. object-like can be wieldable. object-like can be wielded.";
        }
        commands {
            wear :: "wear {o}" :: "You put on the {o}";
            wield :: "wield {o}" :: "You swing the {o} around a few times.";
            hit :: "hit {o}" :: "You hit the {o}.";
            hug :: "hug {o}" :: "You hug the {o}.";
            use :: "use {o}" :: "You use the {o}.";
            follow :: "follow {o}" :: "You follow the {o}.";
        }

        predicates {
            wearable(o) :: "The {o} is wearable";
            wieldable(o) :: "The {o} is wieldable";
            worn(o) :: "The {o} is worn";
            wielded(o) :: "The {o} is wielded";
            huggable(o) :: "The {o} is huggable";
            hittable(o) :: "The {o} is hittable";
            followable(o) :: "The {o} is followable";
        }

        code :: """
            Understand "steal [something]" as taking.
            Understand "hug [something]" as kissing.
            Understand "give [something] to [something]" as giving it to.
            Understand "give [something] [something]" as giving it to (with nouns reversed).
            Understand "drop [something worn]" as taking off.
            Wielding is an action applying to one carried thing.
            Understand "wield [something]" as wielding.
            Check an actor wielding:
                if the noun is not wieldable:
                    say "[The noun] is not a weapon!";
                    stop the action.
            After wielding the noun:
                now the noun is wielded;
                say "You swing [the noun] around a few times.";

            Following is an action applying to one thing.
            Understand "follow [something]" as following.
            Check an actor following:
                if the noun is not followable:
                    say "[The noun] can not be followed!";
                    stop the action.
            After following the noun:
                say "You follow [the noun].";

            Using is an action applying to one thing.
            Understand "use [something]" as using.
            After using the noun:
                say "You use [the noun].";

        """;
    }
}

# Player
type P {
    rules {
        look :: at(P, r) -> at(P, r);  # Nothing changes.
    }

    reverse_rules {
        look :: look;
    }

    inform7 {
        commands {
            look :: "look" :: "looking";
        }
    }
}

# room
type r {
    predicates {
        at(P, r);
        at(t, r);

        north_of(r, r);
        west_of(r, r);

        north_of/d(r, d, r);
        west_of/d(r, d, r);

        free(r, r);

        south_of(r, r') = north_of(r', r);
        east_of(r, r') = west_of(r', r);

        south_of/d(r, d, r') = north_of/d(r', d, r);
        east_of/d(r, d, r') = west_of/d(r', d, r);
    }

    rules {
        go/north :: at(P, r) & $north_of(r', r) & $free(r, r') & $free(r', r) -> at(P, r');
        go/south :: at(P, r) & $south_of(r', r) & $free(r, r') & $free(r', r) -> at(P, r');
        go/east  :: at(P, r) & $east_of(r', r) & $free(r, r') & $free(r', r) -> at(P, r');
        go/west  :: at(P, r) & $west_of(r', r) & $free(r, r') & $free(r', r) -> at(P, r');
    }

    reverse_rules {
        go/north :: go/south;
        go/west :: go/east;
    }

    constraints {
        r1 :: at(P, r) & at(P, r') -> fail();
        r2 :: at(s, r) & at(s, r') -> fail();
        r3 :: at(c, r) & at(c, r') -> fail();

        # An exit direction can only lead to one room.
        nav_rr1 :: north_of(r, r') & north_of(r'', r') -> fail();
        nav_rr2 :: south_of(r, r') & south_of(r'', r') -> fail();
        nav_rr3 :: east_of(r, r') & east_of(r'', r') -> fail();
        nav_rr4 :: west_of(r, r') & west_of(r'', r') -> fail();

        # Two rooms can only be connected once with each other.
        nav_rrA :: north_of(r, r') & south_of(r, r') -> fail();
        nav_rrB :: north_of(r, r') & west_of(r, r') -> fail();
        nav_rrC :: north_of(r, r') & east_of(r, r') -> fail();
        nav_rrD :: south_of(r, r') & west_of(r, r') -> fail();
        nav_rrE :: south_of(r, r') & east_of(r, r') -> fail();
        nav_rrF :: west_of(r, r')  & east_of(r, r') -> fail();
    }

    inform7 {
        type {
            kind :: "room";
        }

        predicates {
            at(P, r) :: "The player is in {r}";
            at(t, r) :: "The {t} is in {r}";
            free(r, r') :: "";  # No equivalent in Inform7.

            north_of(r, r') :: "The {r} is mapped north of {r'}";
            south_of(r, r') :: "The {r} is mapped south of {r'}";
            east_of(r, r') :: "The {r} is mapped east of {r'}";
            west_of(r, r') :: "The {r} is mapped west of {r'}";

            north_of/d(r, d, r') :: "South of {r} and north of {r'} is a door called {d}";
            south_of/d(r, d, r') :: "North of {r} and south of {r'} is a door called {d}";
            east_of/d(r, d, r') :: "West of {r} and east of {r'} is a door called {d}";
            west_of/d(r, d, r') :: "East of {r} and west of {r'} is a door called {d}";
        }

        commands {
            go/north :: "go north" :: "going north";
            go/south :: "go south" :: "going south";
            go/east :: "go east" :: "going east";
            go/west :: "go west" :: "going west";
        }
    }
}

# key
type k : o {
    predicates {
        match(k, c);
        match(k, d);
    }

    constraints {
        k1 :: match(k, c) & match(k', c) -> fail();
        k2 :: match(k, c) & match(k, c') -> fail();
        k3 :: match(k, d) & match(k', d) -> fail();
        k4 :: match(k, d) & match(k, d') -> fail();
    }

    inform7 {
        type {
            kind :: "key";
        }

        predicates {
            match(k, c) :: "The matching key of the {c} is the {k}";
            match(k, d) :: "The matching key of the {d} is the {k}";
        }
    }
}

# thing
type t {
    rules {
        examine/t :: at(P, r) & $at(t, r) -> at(P, r);
    }

    reverse_rules {
        examine/t :: examine/t;
    }

    inform7 {
        type {
            kind :: "thing";
        }

        commands {
            examine/t :: "examine {t}" :: "examining the {t}";
        }
    }
}
'''

In [None]:
#@title twl 2
twl += '''
# food
type f : o {
    predicates {
        edible(f);
		drinkable(f);
        eaten(f);
    }

    rules {
        eat :: in(f, I) & edible(f) -> eaten(f);
		drink :: in(f, I) & drinkable(f) -> eaten(f);
    }

    constraints {
        eaten1 :: eaten(f) & in(f, I) -> fail();
        eaten2 :: eaten(f) & in(f, c) -> fail();
        eaten3 :: eaten(f) & on(f, s) -> fail();
        eaten4 :: eaten(f) & at(f, r) -> fail();
    }

    inform7 {
        type {
            kind :: "food";
            definition :: "food can be edible. food can be drinkable. food can be eaten.";
        }

        predicates {
            edible(f) :: "The {f} is edible";
			drinkable(f) :: "The {f} is drinkable";
            eaten(f) :: "The {f} is eaten";
        }

        commands {
            eat :: "eat {f}" :: "eating the {f}";
			drink :: "drink {f}" :: "drinking the {f}";
        }

		code :: """
            [Drinking liquid]
            The block drinking rule is not listed in any rulebook.
            Report an actor drinking carried thing (this is the report drinking rule):
                if the actor is the player:
                    say "You drink [the noun]. Not bad.";
                otherwise:
                    say "[The person asked] just drank [the noun].".
            After drinking the noun:
                now the noun is eaten;
            After eating the noun:
                now the noun is eaten;
        """;

    }
}
'''

In [None]:
#@title Update grammar files with props from new logic
write_grammar = False #@param {'type': 'boolean'}
if write_grammar:
    path = "/../content/grammars/"
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(len(vars(grammar)["grammar_files"])):
        with open(vars(grammar)["grammar_files"][i]) as f:
            with open(f"{path}{options.grammar.theme}{i}.twg", "w") as f2:
                lines = f.readlines()
                for line in lines:
                    if line.startswith("##actions"):
                        #f2.write(line)
                        f2.write("wear:wear (o).\n")
                        f2.write("steal:steal (o).\n")
                        f2.write("wield:wield (o).\n")
                        f2.write("drink:drink (o).\n")
                        f2.write("eat:eat (o).\n")
                        f2.write("hit:hit (o).\n")
                        f2.write("hug:hug (o).\n")
                        f2.write("use:use (o).\n")
                        f2.write("follow:follow (o).\n")
                    #else:
                    #    f2.write(line)
# %cp /../content/grammars/ /../content/drive/MyDrive/data/light_data/grammars/ -r

print("Done.")

In [None]:
logic = GameLogic.parse(twl)
options = GameOptions()
options.kb = KnowledgeBase(logic, f"{DATA_PTH}grammars/")
rngs = options.rngs
rng_quest = rngs['quest']

### Create games

In [None]:
#@title ### Make games current v3 affordance-based
M = GameMaker(options)
min_quest_len = 5 #@param {'type': 'integer'}
max_quest_len = 5 #@param {'type': 'integer'}

ix = 101 #@param {type:"slider", min:0, max:7500, step:1}
maxix = 0 #@param {'type': 'integer'}

g = "F" #@param ["M", "F"]
rewards = "balanced" #@param ["sparse", "balanced", "dense"]
gen_bulk = False #@param {'type': 'boolean'}
restrict_len = False #@param {'type': 'boolean'}
save = False #@param {'type': 'boolean'}
global_based = False #@param {'type': 'boolean'}
quest_based = True #@param {'type': 'boolean'}
alt_distractors = True #@param {'type': 'boolean'}
merged_dff = dfqq_quest_global_graph
qt = "w" #@param ["w", "f"]
merged_dff = merged_dff[dfqq_quest_global_graph.gender == g]
merged_dff = merged_dff.reset_index(inplace=False)
post_d = {"wear": "worn", "wield": "wielded",
          "eat": "eaten", "drink": "eaten"}
supporters = ["huggable", "hittable", "followable", "goable"]
if qt == "w":
    va, vb = "wield", "wear"
elif qt == "f":
    va, vb = "eat", "drink"
pset = {'drink': "drinkable",'drop': "o",'eat': "edible",
        'follow': "followable", 'get': "o", 'give': "o",
        'go': "goable", 'hit': "hittable", 'hug': "huggable",
        'put': "o", 'remove': "o", 'steal': "o",
        'use': "o", 'wear': "wearable", 'wield': "wieldable"}
container_d = {"wield": ['wooden chest', 'metal locker'],
                'wear': ['large closet', 'antique wardrobe'],
                'drink': ['icebox', 'cooler'],
                'eat': ['large pantry', 'cabinet']}
# Use pairs if wield/wear in quest, otherwise use global, else skip

if maxix == 0:
    maxix = len(merged_dff)
rnge = range(maxix) if gen_bulk else [ix]
if gen_bulk:
    dct = {"walkthrough": [None for _ in rnge],
        "objective": [None for _ in rnge],
        "quests": [None for _ in rnge]}
else:
    dct = {"walkthrough": [None for _ in range(ix+1)],
        "objective": [None for _ in range(ix+1)],
        "quests": [None for _ in range(ix+1)]}
if rnge == [ix] and ix >= len(merged_dff):
    rnge = [len(merged_dff)-1]
for ix in rnge:
    pairs = merged_dff.iloc[ix].pairs
    global_pairs = merged_dff.iloc[ix].pairs_global
    if quest_based:
        pairs_d = pairs
    elif global_based:
        pairs_d = global_pairs
    va_avail = any([va in vs for _, vs in pairs_d.items()]) # any wield cmds
    vb_avail = any([vb in vs for _, vs in pairs_d.items()]) # '' wear ''
    if not va_avail and not vb_avail:
        continue
    walkthrough = []
    quests = []

    if ix >= len(merged_dff):
        ix = len(merged_dff) - 1
        print(f"Max index of {len(merged_dff)-1} exceeded. Set index to {ix}.")
    else:
        va_list = []
        vb_list = []
        rooms = []
        room_descs = dict()
        M = GameMaker(options)

        if type(merged_dff.iloc[ix].graph_json) == float:
            continue
        rix = 0
        for node, d in merged_dff.iloc[ix].graph_json["nodes"].items():
            if d["room"]: # hacky replacements to prevent errors
                room = M.new_room(name = d["name"].replace("outside", "ouitside")) #, desc=d["desc"])
                room_descs[f"r_{rix}"] = d["desc"].replace("\"", "\'").replace("[", "").replace("]", "")
                rix += 1
                rooms.append(room)
        M.set_player(rooms[0])

        chest = M.new(type='c', name=container_d[va][0])
        locker = M.new(type='c', name=container_d[va][1])
        rooms[0].add(chest, locker)
        locker.add_property("closed")
        chest.add_property("closed")

        if rng_quest.rand() > 0.5:
            lholder = chest
        else:
            lholder = locker

        closet = M.new(type='c', name=container_d[vb][0])
        wardrobe = M.new(type='c', name=container_d[vb][1])
        rooms[0].add(closet, wardrobe)
        wardrobe.add_property("closed")
        closet.add_property("closed")

        for container in M.findall(type="c"):
            container.add_property("closed")
        if rng_quest.rand() > 0.5:
            holder = closet
        else:
            holder = wardrobe
        tsd = dict()
        for o, vs in pairs_d.items():
            if va in vs or vb in vs:
                if va in vs:
                    if o in tsd.keys():
                        tsd[o].add(pset[va])
                    else:
                        tsd[o] = {pset[va]}
                if vb in vs:
                    if o in tsd.keys():
                        tsd[o].add(pset[vb])
                    else:
                        tsd[o] = {pset[vb]}
        for o, ts in tsd.items():
            old_o = o # may modify later but still want to look up w/ orig
            if any([t in supporters for t in ts]): # fixed in place
                t = "s" # will change to "o" if target affordance(s)
                if ts == {'goable'}: # e.g. don't treat a castle as portable
                    continue
            else:
                if "edible" in ts or "drinkable" in ts:
                    t = "f"
                else:
                    t = "o"
            properties = [] # reset to prevent properties carrying over to supporters
            if not o[-1].isalpha():
                o = o[:-1]

            if pset[va] in ts and pset[vb] in ts: # choose 1 based on freq
                # to avoid conflicts
                if len(vod_df_qs[g][va][old_o]) > len(vod_df_qs[g][vb][old_o]):
                    properties = [pset[va]]
                elif len(vod_df_qs[g][va][old_o]) < len(vod_df_qs[g][vb][old_o]):
                    properties = [pset[vb]]
                else:
                    if rng_quest.rand() > 0.5:
                        properties = [pset[va]]
                    else:
                        properties = [pset[vb]]
            elif pset[va] in ts: # e.g. only has wieldable/wearable
                properties = [pset[va]]
                if not va in ["eat", "drink"]:
                    t = "o" # e.g., wine can be wielded or consumed, prioritize target, esp. since consuming destroys
                else:
                    t = "f"
            elif pset[vb] in ts:
                properties = [pset[vb]]
                if not vb in ["eat", "drink"]:
                    t = "o"
                else:
                    t = "f"
            else: # free to add non-target w/o fear of container/quest conflicts
                properties = [t for t in ts if len(t) > 1]
            o2 = M.new(type=t, name=o)
            if properties:
                for p in properties:
                    o2.add_property(p)
            o2.infos.desc = o # default description is name - use affordances?

            for _, d in merged_dff.iloc[ix].graph_json["nodes"].items():
                if o in d["name"].lower(): # use original description if available
                    o2.infos.desc = d["desc"].replace("\"", "\'").strip()
                    break
            in_holder = False
            in_lholder = False
            if o2.has_property(pset[va]):
                va_list.append(o2)
                if not o2 in lholder.content:
                    lholder.add(o2)
                    in_lholder = True
                    if alt_distractors:
                        distractor = random.choice(list(vod_df_qs[g][vb].keys()))
                        while distractor in vod_df_qs[g][va]:
                            distractor = random.choice(list(vod_df_qs[g][vb].keys()))
                        distractor = M.new(type=lholder.content[0].type, name=distractor)
                        distractor.add_property(pset[va])
                        lholder.add(distractor)
            elif o2.has_property(pset[vb]):
                vb_list.append(o2)
                if not o2 in holder.content:
                    holder.add(o2)
                    in_holder = True
                    if alt_distractors:
                        distractor = random.choice(list(vod_df_qs[g][va].keys()))
                        while distractor in vod_df_qs[g][vb]:
                            distractor = random.choice(list(vod_df_qs[g][va].keys()))
                        distractor = M.new(type=holder.content[0].type, name=distractor)
                        distractor.add_property(pset[vb])
                        holder.add(distractor)

        # create walkthrough, quest
        if len(holder.content) > 0:
            if vb_list:
                walkthrough.append(f"open {holder.name}")
                if rewards == "dense":
                    quests.append(
                        Quest(win_events=[
                            Event(conditions={M.new_fact("open", holder)})
                        ])
                    )
        if len(lholder.content) > 0:
            if va_list:
                walkthrough.append(f"open {lholder.name}")
                if rewards == "dense":
                    quests.append(
                        Quest(win_events=[
                            Event(conditions={M.new_fact("open", lholder)})
                        ])
                    )
        # balanced, dense only reward higher-level, not conditioning affordances
        if vb_list:
            for w in vb_list:
                if w in holder.content:
                    walkthrough.append(f"take {w.name} from {holder.name}")
                    if rewards in ["balanced", "dense"]:
                        quests.append(
                            Quest(win_events=[
                                Event(conditions={M.new_fact("in", w,
                                                            M.inventory)})
                            ])
                        )

                    walkthrough.append(f"{vb} {w.name}")
                    quests.append(
                        Quest(win_events=[
                            Event(conditions={M.new_fact(f"{post_d[vb]}", w)})
                        ])
                    )
        if va_list:
            for w in va_list:
                if w in lholder.content:
                    walkthrough.append(f"take {w.name} from {lholder.name}")
                    if rewards in ["balanced", "dense"]:
                        quests.append(
                            Quest(win_events=[
                                Event(conditions={M.new_fact("in", w, M.inventory)})
                            ])
                        )

                    walkthrough.append(f"{va} {w.name}")
                    quests.append(
                        Quest(win_events=[
                            Event(conditions={M.new_fact(f"{post_d[va]}", w)})
                        ])
                    )

        if restrict_len: # e.g. only quests of length 5 in min=max=5
            cond = len(quests) in range(min_quest_len, max_quest_len+1)
        else:
            cond = len(quests) > 0
        if cond: # no quests available
            M.quests = quests
            if gen_bulk and ix > 1 and ix % 50 == 0: # log
                print(ix, walkthrough)
            if not gen_bulk:
                print(ix, g, walkthrough)
        else:
            continue

        if not rewards in ["balanced", "dense"]: # buggy on 'take'
            M.set_walkthrough(walkthrough)

        game = M.build()
        for rid, desc in room_descs.items():
            game.infos[rid].desc = desc + "\n" + game.infos[rid].desc
        game.metadata["walkthrough"] = walkthrough
        game.objective = ""
        if va_list:
            for w in va_list:
                game.objective += f"Find and {va} {w.name}. "
        if vb_list:
            for w in vb_list:
                game.objective += f"Find and {vb} {w.name}. "
        if save:
            if global_based:
                ssuffix = "_gb"
            elif quest_based:
                ssuffix = "_qb"
            ssuffix += "_nod"
            pth = f"/../content/{g.lower()}_quests_{qt}_{rewards}{ssuffix}/"
            if not os.path.exists(pth):
                os.makedirs(pth)
            try:
                M.compile(f"{pth}/test_{ix}")
            except:
                pass

        dct["walkthrough"][ix] = walkthrough
        dct["quests"][ix] = quests
        dct["objective"][ix] = game.objective


In [None]:
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"

locale.getpreferredencoding = getpreferredencoding

In [None]:
# !zip -r /../content/m_quests_f_balanced_qb_nod.zip /../content/m_quests_f_balanced_qb_nod/ &> /dev/null
# !zip -r /../content/f_quests_f_balanced_qb_nod.zip /../content/f_quests_f_balanced_qb_nod/ &> /dev/null

In [None]:
# %cp /../content/m_quests_f_balanced_qb_nod.zip {{DATA_PTH}}
# %cp /../content/f_quests_f_balanced_qb_nod.zip {{DATA_PTH}}

In [None]:
M.test()

In [None]:
#@title Load/save walkthroughs

gender = "f" #@param ["m", "f"]
basis = "_qb" #@param ["", "_qb", "_gb"]
distractors = "_nod" #@param ["", "_nod", "_d", "_df"]
mode = "w" #@param ["w", "f"]

save_walkthroughs = False #@param {'type': 'boolean'}
load_walkthroughs = True #@param {'type': 'boolean'}

if save_walkthroughs:
    pd.DataFrame(dct).to_pickle(f"{DATA_PTH}quests/{gender}_quests_{mode}{basis}{distractors}_walkthroughs.pkl")

if load_walkthroughs:
    m_walkthroughs = pd.read_pickle(f"{DATA_PTH}quests/m_quests_{mode}{basis}{distractors}_walkthroughs.pkl")
    m_walkthroughs = m_walkthroughs.dropna(0)
    f_walkthroughs = pd.read_pickle(f"{DATA_PTH}quests/f_quests_{mode}{basis}{distractors}_walkthroughs.pkl")
    f_walkthroughs = f_walkthroughs.dropna(0)

### Agents

In [None]:
#@title Definitions

seed = 42 #@param {'type': 'integer'}

def statistic(x, y, axis):
    """Test statistic for permutation tests."""
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)

def get_res(a, b, alpha=0.01, alt='greater', seed=42, num_its=1000):
    res = scipy.stats.permutation_test(data=[a, b],
                                            statistic=statistic,
                                            n_resamples=num_its,
                                            permutation_type='samples',
                                            alternative=alt,
                                            random_state=seed)
    mean_diff = res.statistic
    pval = res.pvalue
    d = pg.compute_effsize(a, b, eftype='cohen')
    if alt == 'greater':
        cles = pg.compute_effsize(a, b, eftype='cles')
    elif alt == 'less':
        cles = pg.compute_effsize(b, a, eftype='cles')
    else:
        cles = 0

    return {"diff": mean_diff, "pval": pval, "sig": pval < alpha, "d": d, "cles": cles, 'alt': alt}

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

def get_split_files(pth, split_indices):
    return [f"{pth}test_{ix}.ulx"
            for ix in split_indices
            if os.path.isfile(f"{pth}test_{ix}.ulx")]

def get_split_indices(pth, g, split):
    with open(f"{pth}{g}_{split}_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
        all_g = f.read().splitlines()
    return all_g

def get_combined(m_ulx, f_ulx):
    indicesA = list(set(map(lambda g: re.search(r"test_[0-9]+\.", g)[0].replace(".", "_m"), m_ulx)))
    indicesB = list(set(map(lambda g: re.search(r"test_[0-9]+\.", g)[0].replace(".", "_f"), f_ulx)))
    if len(indicesA) >= len(indicesB):
        alli = indicesA[:len(indicesB)] + indicesB
    else:
        alli = indicesA + indicesB[:len(indicesA)]
    all = []
    for i in alli:
        if "_m" in i:
            for g in m_ulx:
                if re.search(r"test_[0-9]+\.", g)[0].replace(".", "_m") == i:
                    all.append(g)
        elif "_f" in i:
            for g in f_ulx:
                if re.search(r"test_[0-9]+\.", g)[0].replace(".", "_f") == i:
                    all.append(g)
    return all

def getpreferredencoding(do_setlocale = True):
    return "UTF-8"

locale.getpreferredencoding = getpreferredencoding

# modified via https://github.com/microsoft/TextWorld/blob/main/notebooks/Building%20a%20simple%20agent.ipynb
def play(agent, path, gamefiles = None, max_step=100, nb_episodes=10, verbose=True, seed=42):
    # For reproducibility
    set_seed(seed)
    infos_to_request = agent.infos_to_request
    infos_to_request.max_score = True  # Needed to normalize the scores.
    if not gamefiles and not type(path) == list:
        gamefiles = [path]
    elif not gamefiles:
        gamefiles = path
    if not type(gamefiles) == list or gamefiles[0] == path:
        if os.path.isdir(path): # *.z8
            gamefiles = glob(os.path.join(path, "*.ulx")) # should find .json on its own for metadata
    else:
        gamefiles = [g for g in gamefiles if g.endswith(".ulx")]
        print(f"Using provided list of gamefiles e.g., {gamefiles[0]}.")
    env_id = textworld.gym.register_games(gamefiles,
                                          request_infos=infos_to_request,
                                          max_episode_steps=max_step)
    env = gym.make(env_id)  # Create a Gym environment to play the text game.
    if verbose:
        if os.path.isdir(path):
            print(os.path.dirname(path), end="")
        else:
            print(os.path.basename(path), end="")

    # Collect some statistics: nb_steps, final reward.
    avg_moves, avg_scores, avg_norm_scores = [], [], []
    trajs = []
    for no_episode in range(nb_episodes):
        obs, infos = env.reset()  # Start new episode.
        score = 0
        done = False
        nb_moves = 0
        traj = []
        while not done:
            command = agent.act(obs, score, done, infos)
            traj.append(command)
            obs, score, done, infos = env.step(command)
            nb_moves += 1

        agent.act(obs, score, done, infos)  # Let the agent know the game is done.

        if verbose:
            print(".", end="")
        avg_moves.append(nb_moves)
        avg_scores.append(score)
        avg_norm_scores.append(score / infos["max_score"])
        trajs.append(traj)
    env.close()
    if verbose:
        if os.path.isdir(path):
            msg = "  \tavg. steps: {:5.1f}; avg. normalized score: {:4.1f} / {}."
            print(msg.format(np.mean(avg_moves), np.mean(avg_norm_scores), 1))
        else:
            msg = "  \tavg. steps: {:5.1f}; avg. score: {:4.1f} / {}."
            print(msg.format(np.mean(avg_moves), np.mean(avg_scores), infos["max_score"]))
    return avg_moves, avg_norm_scores, trajs


#### Agent (Neural)
- [via](https://github.com/microsoft/TextWorld/blob/main/notebooks/Building%20a%20simple%20agent.ipynb)

In [None]:
#@title #### Classes, Imports
# modified via: https://github.com/microsoft/TextWorld/blob/main/notebooks/Building%20a%20simple%20agent.ipynb
import re
from typing import List, Mapping, Any, Optional
from collections import defaultdict
warnings.filterwarnings(category=UserWarning,
                                            action='ignore')
warnings.filterwarnings(category=DeprecationWarning,
                                            action='ignore')
import numpy as np

import textworld
import textworld.gym
from textworld import EnvInfos
import torchtext
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CommandScorer(nn.Module):
    def __init__(self, train_seed, input_size, hidden_size):
        super(CommandScorer, self).__init__()
        set_seed(train_seed)  # For reproducibility
        self.embedding    = nn.Embedding(input_size, hidden_size, device=device)
        self.encoder_gru  = nn.GRU(hidden_size, hidden_size, device=device)
        self.cmd_encoder_gru  = nn.GRU(hidden_size, hidden_size, device=device)
        self.state_gru    = nn.GRU(hidden_size, hidden_size, device=device)
        self.hidden_size  = hidden_size
        self.state_hidden = torch.zeros(1, 1, hidden_size, device=device)
        self.critic       = nn.Linear(hidden_size, 1, device=device)
        self.att_cmd      = nn.Linear(hidden_size * 2, 1, device=device)

    def forward(self, obs, commands, **kwargs):
        input_length = obs.size(0)
        batch_size = obs.size(1)
        nb_cmds = commands.size(1)

        embedded = self.embedding(obs)
        encoder_output, encoder_hidden = self.encoder_gru(embedded)
        state_output, state_hidden = self.state_gru(encoder_hidden, self.state_hidden)
        self.state_hidden = state_hidden
        value = self.critic(state_output)

        # Attention network over the commands.
        cmds_embedding = self.embedding.forward(commands)
        _, cmds_encoding_last_states = self.cmd_encoder_gru.forward(cmds_embedding)  # 1 x cmds x hidden

        # Same observed state for all commands.
        cmd_selector_input = torch.stack([state_hidden] * nb_cmds, 2)  # 1 x batch x cmds x hidden

        # Same command choices for the whole batch.
        cmds_encoding_last_states = torch.stack([cmds_encoding_last_states] * batch_size, 1)  # 1 x batch x cmds x hidden

        # Concatenate the observed state and command encodings.
        cmd_selector_input = torch.cat([cmd_selector_input, cmds_encoding_last_states], dim=-1)

        # Compute one score per command.
        scores = F.relu(self.att_cmd(cmd_selector_input)).squeeze(-1)  # 1 x Batch x cmds

        probs = F.softmax(scores, dim=2)  # 1 x Batch x cmds
        index = probs[0].multinomial(num_samples=1).unsqueeze(0) # 1 x batch x indx
        return scores, index, value

    def reset_hidden(self, batch_size):
        self.state_hidden = torch.zeros(1, batch_size, self.hidden_size, device=device)

class NeuralAgent:
    """ Simple Neural Agent for playing TextWorld games. """
    MAX_VOCAB_SIZE = 1000
    UPDATE_FREQUENCY = 10
    LOG_FREQUENCY = 1000
    GAMMA = 0.9

    def __init__(self, train_seed=42) -> None:
        self._initialized = False
        self._epsiode_has_started = False
        self.id2word = ["<PAD>", "<UNK>"]
        self.word2id = {w: i for i, w in enumerate(self.id2word)}

        self.model = CommandScorer(train_seed=train_seed, input_size=self.MAX_VOCAB_SIZE, hidden_size=128)
        self.optimizer = optim.Adam(self.model.parameters(), 0.00003)

        self.mode = "train"
        self.results = []

    def train(self):
        self.model.train(True)
        self.mode = "train"
        self.stats = {"max": defaultdict(list), "mean": defaultdict(list)}
        self.transitions = []
        self.model.reset_hidden(1)
        self.last_score = 0
        self.no_train_step = 0

    def test(self):
        self.mode = "test"
        self.model.reset_hidden(1)
        self.model.eval()

    @property
    def infos_to_request(self) -> EnvInfos:
        return EnvInfos(description=True, inventory=True, admissible_commands=True,
                        won=True, lost=True)

    def _get_word_id(self, word):
        if word not in self.word2id:
            if len(self.word2id) >= self.MAX_VOCAB_SIZE:
                return self.word2id["<UNK>"]

            self.id2word.append(word)
            self.word2id[word] = len(self.word2id)

        return self.word2id[word]

    def _tokenize(self, text):
        # Simple tokenizer: strip out all non-alphabetic characters.
        text = re.sub("[^a-zA-Z0-9\- ]", " ", text)
        word_ids = list(map(self._get_word_id, text.split()))
        return word_ids

    def _process(self, texts):
        texts = list(map(self._tokenize, texts))
        max_len = max(len(l) for l in texts)
        padded = np.ones((len(texts), max_len)) * self.word2id["<PAD>"]

        for i, text in enumerate(texts):
            padded[i, :len(text)] = text

        padded_tensor = torch.from_numpy(padded).type(torch.long).to(device)
        padded_tensor = padded_tensor.permute(1, 0) # Batch x Seq => Seq x Batch
        return padded_tensor

    def _discount_rewards(self, last_values):
        returns, advantages = [], []
        R = last_values.data
        for t in reversed(range(len(self.transitions))):
            rewards, _, _, values = self.transitions[t]
            R = rewards + self.GAMMA * R
            adv = R - values
            returns.append(R)
            advantages.append(adv)

        return returns[::-1], advantages[::-1]

    def act(self, obs: str, score: int, done: bool, infos: Mapping[str, Any]) -> Optional[str]:

        # Build agent's observation: feedback + look + inventory.
        input_ = "{}\n{}\n{}".format(obs, infos["description"], infos["inventory"])

        # Tokenize and pad the input and the commands to chose from.
        input_tensor = self._process([input_])
        commands_tensor = self._process(infos["admissible_commands"])

        # Get our next action and value prediction.
        outputs, indexes, values = self.model(input_tensor, commands_tensor)
        action = infos["admissible_commands"][indexes[0]]

        if self.mode == "test":
            if done:
                self.model.reset_hidden(1)
            return action

        self.no_train_step += 1

        if self.transitions:
            reward = score - self.last_score  # Reward is the gain/loss in score.
            self.last_score = score
            if infos["won"]:
                reward += 100
            if infos["lost"]:
                reward -= 100

            self.transitions[-1][0] = reward  # Update reward information.

        self.stats["max"]["score"].append(score)
        if self.no_train_step % self.UPDATE_FREQUENCY == 0:
            # Update model
            returns, advantages = self._discount_rewards(values)

            loss = 0
            for transition, ret, advantage in zip(self.transitions, returns, advantages):
                reward, indexes_, outputs_, values_ = transition

                advantage        = advantage.detach() # Block gradients flow here.
                probs            = F.softmax(outputs_, dim=2)
                log_probs        = torch.log(probs)
                log_action_probs = log_probs.gather(2, indexes_)
                policy_loss      = (-log_action_probs * advantage).sum()
                value_loss       = (.5 * (values_ - ret) ** 2.).sum()
                entropy     = (-probs * log_probs).sum()
                loss += policy_loss + 0.5 * value_loss - 0.1 * entropy

                self.stats["mean"]["reward"].append(reward)
                self.stats["mean"]["policy"].append(policy_loss.item())
                self.stats["mean"]["value"].append(value_loss.item())
                self.stats["mean"]["entropy"].append(entropy.item())
                self.stats["mean"]["confidence"].append(torch.exp(log_action_probs).item())

            if self.no_train_step % self.LOG_FREQUENCY == 0:
                msg = "{:6d}. ".format(self.no_train_step)
                msg += "  ".join("{}: {: 3.3f}".format(k, np.mean(v)) for k, v in self.stats["mean"].items())
                msg += "  " + "  ".join("{}: {:2d}".format(k, np.max(v)) for k, v in self.stats["max"].items())
                msg += "  vocab: {:3d}".format(len(self.id2word))
                print(msg)
                self.results.append(self.stats)
                self.stats = {"max": defaultdict(list), "mean": defaultdict(list)}

            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 40)
            self.optimizer.step()
            self.optimizer.zero_grad()

            self.transitions = []
            self.model.reset_hidden(1)
        else:
            # Keep information about transitions for Truncated Backpropagation Through Time.
            self.transitions.append([None, indexes, outputs, values])  # Reward will be set on the next call

        if done:
            self.last_score = 0  # Will be starting a new episode. Reset the last score.

        return action

### Create splits

In [None]:
#@title Important folder paths before split, train, eval

basis = "_qb" #@param ["", "_qb", "_gb"]
qtype = "w" #@param ["w", "f"]
rewards = "balanced" #@param ["sparse", "balanced", "dense"]
distractors = "_nod" #@param ["", "_nod", "_d", "_df"]
local_folder_pths = True #@param {'type': 'boolean'}
if local_folder_pths:
    fprefix = "/../content/"
else:
    fprefix = f"{DATA_PTH}quests/"


m_train_folder = f"{fprefix}m_train_{qtype}_{rewards}{basis}{distractors}/"
m_dev_folder = f"{fprefix}m_dev_{qtype}_{rewards}{basis}{distractors}/"
m_test_folder = f"{fprefix}m_test_{qtype}_{rewards}{basis}{distractors}/"
f_train_folder = f"{fprefix}f_train_{qtype}_{rewards}{basis}{distractors}/"
f_dev_folder = f"{fprefix}f_dev_{qtype}_{rewards}{basis}{distractors}/"
f_test_folder = f"{fprefix}f_test_{qtype}_{rewards}{basis}{distractors}/"
all_dev_folder = f"{fprefix}all_dev{basis}{distractors}/"
all_test_folder = f"{fprefix}all_test{basis}{distractors}/"
m_checkpoints_folder = f"{DATA_PTH}m_checkpoints/"
f_checkpoints_folder = f"{DATA_PTH}f_checkpoints/"
m_local_pth = f"/../content/m_quests_{qtype}_{rewards}{basis}{distractors}/"
f_local_pth = f"/../content/f_quests_{qtype}_{rewards}{basis}{distractors}/"

In [None]:
%cp {{DATA_PTH}}m_quests_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.zip /../content/
%cp {{DATA_PTH}}f_quests_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.zip /../content/

In [None]:
m_quest_zip = f"m_quests_{qtype}_{rewards}{basis}{distractors}.zip"
f_quest_zip = f"f_quests_{qtype}_{rewards}{basis}{distractors}.zip"
!unzip "$m_quest_zip" -d /../ &> /dev/null
!unzip "$f_quest_zip" -d /../ &> /dev/null

In [None]:
#@title Create splits for each gender

for g in ["m", "f"]:
    g_local_pth = m_local_pth if g == "m" else f_local_pth

    games = glob(os.path.join(g_local_pth, "*.json")) + \
            glob(os.path.join(g_local_pth, "*.ulx")) + \
            glob(os.path.join(g_local_pth, "*.ni"))
    game_idxs = list(set(list(map(lambda g: re.search(r"test_[0-9]+",
                                            g)[0].split("_")[-1],
                            games))))
    train, test = train_test_split(game_idxs, shuffle = True,
                                   random_state = 42, test_size=0.3)
    dev, test = train_test_split(test, shuffle = True,
                                 random_state = 42, test_size=0.50)
    if g == "m":
        m_train, m_dev, m_test = train, dev, test
    elif g == "f":
        f_train, f_dev, f_test = train, dev, test


display(len(f_train), len(f_dev), len(f_test),
        len(m_train), len(m_dev), len(m_test))

In [None]:
#@title Write gamefiles to text file and save to drive
save_to_drive = False #@param {'type': 'boolean'}

for g in ["m", "f"]:
    g_train = m_train if g == "m" else f_train
    g_dev = m_dev if g == "m" else f_dev
    g_test = m_test if g == "m" else f_test

    with open(f"{g}_train_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
        for i in g_train:
            f.write(f"{i}\n")
    with open(f"{g}_dev_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
        for i in g_dev:
            f.write(f"{i}\n")
    with open(f"{g}_test_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
        for i in g_test:
            f.write(f"{i}\n")
if save_to_drive:
    %cp m_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}
    %cp m_dev_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}
    %cp m_test_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}
    %cp f_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}
    %cp f_dev_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}
    %cp f_test_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt {{DATA_PTH}}

### Train

In [None]:
# train_seeds = [random.randint(1, 10000) for _ in range(10)]
train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]

Copy train indices lists to this runtime to consult when obtaining file lists (which will reference unzipped files in local folders after copying zips from drive to here).

In [None]:
m_train_txt = f"{DATA_PTH}m_train_{qtype}_{rewards}{basis}{distractors}.txt"
f_train_txt = f"{DATA_PTH}f_train_{qtype}_{rewards}{basis}{distractors}.txt"
all_train_txt = f"{DATA_PTH}all_train_{qtype}_{rewards}{basis}{distractors}.txt"

In [None]:
%cp {{m_train_txt}} /../content
%cp {{f_train_txt}} /../content
%cp {{all_train_txt}} /../content

In [None]:
#@title Train male, female neural agents w/ multiple initial seeds
ckpt = 100 # param {'type': 'integer'}
k = 5 #@param {'type': 'integer'} # multiplier of training_games[:n] for nb. episodes
continue_from_chkpt = False #@param {'type': 'boolean'}
save = False #@param {'type': 'boolean'}
save_res = False #@param {'type': 'boolean'}
limit_n = 50 #@param {'type': 'integer'}
use_limit = False #@param {'type': 'boolean'}
use_all = True #@param {'type': 'boolean'}
create_train_ulxs = True #@param {'type': 'boolean'}
prefix = "/../content/"

if create_train_ulxs:
    try:
        m_train_ulx = get_split_files(m_local_pth,
                                    get_split_indices("/../content/",
                                                        "m", "train"))
        f_train_ulx = get_split_files(f_local_pth,
                                    get_split_indices("/../content/",
                                                        "f", "train"))
    except FileNotFoundError as e:
        split_pref = f"{DATA_PTH}"
        m_train_ulx = get_split_files(m_local_pth,
                                    get_split_indices(split_pref,
                                                        "m", "train"))
        f_train_ulx = get_split_files(f_local_pth,
                                    get_split_indices(split_pref,
                                                        "f", "train"))
else:
    try:
        f_train_ulx[0]
    except:
        with open(f"{DATA_PTH}m_train_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
            m_train_ulx = f.read().splitlines()
        with open(f"{DATA_PTH}train_res/f_train_qb_nod_ulx.txt", "r") as f:
            # f"{DATA_PTH}f_train_{qtype}_{rewards}{basis}{distractors}.txt", "r")
            f_train_ulx = f.read().splitlines()

if continue_from_chkpt:
    assert(n - ckpt > 0)

train_mode = "w_balanced_qb_nod" #@param ["w_balanced", "wear_balanced_qb_nod", "wield_balanced_qb_nod", "w_balanced_qb_nod", "w_balanced_qb_d", "w_balanced_gb_nod", "w_balanced_gb_d", "w_balanced_gb_df", "f_balanced_qb_nod"]
if "wear" in train_mode:
    m_train_files = get_split_files(m_local_pth, m_train_wear_ixs)
    f_train_files = get_split_files(f_local_pth, f_train_wear_ixs)
elif "wield" in train_mode:
    m_train_files = get_split_files(m_local_pth, m_train_wield_ixs)
    f_train_files = get_split_files(f_local_pth, f_train_wield_ixs)
else:
    m_train_files = m_train_ulx
    f_train_files = f_train_ulx

if use_limit:
    n = min(limit_n, min(len(m_train_files), len(f_train_files)))
else:
    n = min(len(m_train_files), len(f_train_files))

single_seed  = False #@param {'type': 'boolean'}
seed_from_list = "858" #@param train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
seed_from_list = int(seed_from_list)
seed_from_list = [seed_from_list]

g_results = []

for ggg in ["M", "F"]:
    results = []
    if continue_from_chkpt:
        pth = f"{prefix}{ggg.lower()}_training_games_{ckpt}_{n}{basis}{distractors}/"
    else:
        pth = f"{prefix}{ggg.lower()}_training_games_{train_mode}_{n}{basis}{distractors}/"
    if not use_all:
        if ggg == "M":
            training_games = m_train_files
        else:
            training_games = f_train_files
        training_games = training_games[:n]
    else:
        ggg = "All"
        # For parity with m/f, will want to cap at largest (m or f)
        # For all, use half of each to reach cap
        # e.g. cap at 240 b/c one gender only has 240,
        # use 120 m, 120 f
        j = n//2
        training_games = m_train_ulx[:j] + f_train_ulx[:j]
    seed_range = train_seeds if not single_seed else seed_from_list
    overall_res = {"m": {ts: [] for ts in seed_range},
                "f": {ts: [] for ts in seed_range}}
    for train_seed in seed_range: # train_seeds:
        overall_res[ggg.lower()][train_seed] = []
        train_seed = int(train_seed)
        set_seed(train_seed)
        agent = NeuralAgent(train_seed=train_seed)
        if continue_from_chkpt:
            print(f"Loading from checkpoint trained on first {ckpt} games.")
            # agent.model.load_state_dict(torch.load(f'{m_checkpoints_folder}{ggg.lower()}_agent_{train_mode}_{train_seed}_{ckpt}_{k}'))
        print(f"Training seed {train_seed} on {n} {ggg} games")
        agent.train()
        starttime = time()
        nb_episodes = n * k
        print(nb_episodes)
        avg_moves, avg_norm_scores, trajs = play(agent, path=pth, gamefiles=training_games, nb_episodes=nb_episodes,
             verbose=True, seed=train_seed) # Each game will be seen 5 times.
        results.append(agent.results)
        print("Trained in {:.2f} secs".format(time() - starttime))
        if save:
            spth = f'{DATA_PTH}{ggg.lower()}_checkpoints/'
            if not os.path.exists(spth):
                os.makedirs(spth)
            torch.save(agent, f"{spth}{ggg.lower()}_agent_{train_mode}_{train_seed}_{n}_{k}.pt")
            torch.save(agent.model.state_dict(), f"{spth}{ggg.lower()}_agent_{train_mode}_{train_seed}_{n}_{k}")
            res_df = pd.DataFrame(agent.results).T
            res_df.to_pickle(f"{DATA_PTH}{ggg.lower()}_train_{train_mode}_res_{n}_{k}_{train_seed}.pkl")
            print(f"Saved {ggg} checkpoint with seed {train_seed}")
        if save_res:
            print("Saving results.")
            overall_res[ggg.lower()][train_seed] = [avg_moves, avg_norm_scores, trajs]
            pd.DataFrame(overall_res).to_pickle(f"{DATA_PTH}{ggg.lower()}_train_{train_mode}_{train_seed}_{n}_{k}_res.pkl")
    g_results.append(results)
    if use_all:
        break

In [None]:
# overall_res = pd.read_pickle(f"{DATA_PTH}train_res/f_train_w_balanced_qb_nod_6673_240_5_res.pkl")

### Combine genders for dev/test splits

In [None]:
try:
    m_dev_ulx = get_split_files(m_local_pth, get_split_indices("/../content/", "m", "dev"))
    f_dev_ulx = get_split_files(f_local_pth, get_split_indices("/../content/", "f", "dev"))
    m_test_ulx = get_split_files(m_local_pth, get_split_indices("/../content/", "m", "test"))
    f_test_ulx = get_split_files(f_local_pth, get_split_indices("/../content/", "f", "test"))
except FileNotFoundError as e:
    #print(e)
    split_pref = f"{DATA_PTH}"
    m_dev_ulx = get_split_files(m_local_pth, get_split_indices(f"{split_pref}", "m", "dev"))
    f_dev_ulx = get_split_files(f_local_pth, get_split_indices(f"{split_pref}", "f", "dev"))
    m_test_ulx = get_split_files(m_local_pth, get_split_indices(f"{split_pref}", "m", "test"))
    f_test_ulx = get_split_files(f_local_pth, get_split_indices(f"{split_pref}", "f", "test"))

In [None]:
try:
    m_train_ulx = get_split_files(m_local_pth,
                                get_split_indices("/../content/",
                                                    "m", "train"))
    f_train_ulx = get_split_files(f_local_pth,
                                get_split_indices("/../content/",
                                                    "f", "train"))
except FileNotFoundError as e:
    print(e)
    split_pref = f"{DATA_PTH}"
    try:
        m_train_ulx = get_split_files(m_local_pth,
                                    get_split_indices(split_pref,
                                                        "m", "train"))
        f_train_ulx = get_split_files(f_local_pth,
                                    get_split_indices(split_pref,
                                                        "f", "train"))
    except:
        raise Exception

In [None]:
all_dev = get_combined(m_dev_ulx, f_dev_ulx)
all_test = get_combined(m_test_ulx, f_test_ulx)

In [None]:
all_train = get_combined(m_train_ulx, f_train_ulx)

In [None]:
#@title save all eval data

with open(f"all_dev_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
    f.writelines(line + '\n' for line in all_dev)
with open(f"all_test_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
    f.writelines(line + '\n' for line in all_test)

In [None]:
#@title save all train data

with open(f"all_train_{qtype}_{rewards}{basis}{distractors}.txt", "w") as f:
    f.writelines(line + '\n' for line in all_train)

In [None]:
with open(f"f_train_{qtype}_{rewards}{basis}{distractors}_ulx.txt", "w") as f:
    f.writelines(line + '\n' for line in f_train_ulx)

In [None]:
%cp /../content/f_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}_ulx.txt /../content/drive/MyDrive/data/light_data/f_train{{basis}}{{distractors}}_ulx.txt

In [None]:
#@title load all train data

try:
    with open(f"{DATA_PTH}all_train_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
        all_train = f.read().splitlines()
except:
    %cp {{DATA_PTH}}all_train_{qtype}_{rewards}{{basis}}{{distractors}}.txt /../content/all_train_{qtype}_{rewards}{{basis}}{{distractors}}.txt
    with open(f"{DATA_PTH}all_train_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
        all_train = f.read().splitlines()

In [None]:
#@title load all eval data

with open(f"{DATA_PTH}all_dev_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
    all_dev = f.read().splitlines()

with open(f"{DATA_PTH}all_test_{qtype}_{rewards}{basis}{distractors}.txt", "r") as f:
    all_test = f.read().splitlines()

In [None]:
%cp /../content/m_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt /../content/drive/MyDrive/data/light_data/m_train{{basis}}{{distractors}}.txt
%cp /../content/f_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt /../content/drive/MyDrive/data/light_data/f_train{{basis}}{{distractors}}.txt

In [None]:
%cp /../content/all_dev_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt /../content/drive/MyDrive/data/light_data/all_dev_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt
%cp /../content/all_test_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt /../content/drive/MyDrive/data/light_data/all_test_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt

In [None]:
%cp /../content/all_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt /../content/drive/MyDrive/data/light_data/all_train_{{qtype}}_{{rewards}}{{basis}}{{distractors}}.txt

### Evaluate

In [None]:
#@title Evaluate male, female on combined or swapped data

k = 10 #@param {'type': 'integer'}
# game_idx = 105 #@param {type:"slider", min:0, max:1000, step:1}
load_seed = 858 #@param train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
load_k = 5 #@param
load_n = 240 #@param
local = False #@param {'type': 'boolean'}
local_eval = False #@param {'type': 'boolean'}
bulk = True #@param {'type': 'boolean'}
save = True #@param {'type': 'boolean'}
eval_all = True #@param {'type': 'boolean'}
ext = "ulx" #@param ["ulx"]
split = "dev" #@param ["dev_m", "dev_f", "test_m", "test_f", "dev", "test"]
single_seed  = False #@param {'type': 'boolean'}
seed_from_list = "4386" #@param train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
seed_from_list = [int(seed_from_list)]
load_seed = int(load_seed)
train_seeds = train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
if not bulk:
    if eval_all:
        if local:
            agent_all = torch.load(f'/../content/all_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
        else:
            agent_all = torch.load(f'{DATA_PTH}all_checkpoints/all_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
        agent_all.test()
    if local:
        agent = torch.load(f'/../content/m_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
        agentB = torch.load(f'/../content/f_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
    else:
        agent = torch.load(f'{DATA_PTH}m_checkpoints/m_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
        agentB = torch.load(f'{DATA_PTH}f_checkpoints/f_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)

        agent.test()
        agentB.test()

    if "dev" in split:
        if split == "dev_m":
            dtpth = m_dev_ulx
        elif split == "dev_f":
            dtpth = f_dev_ulx
        else:
            dtpth = all_dev
    else:
        if split == "test_m":
            dtpth = m_test_ulx
        elif split == "test_f":
            dtpth = f_test_ulx
        else:
            dtpth = all_test

    if type(dtpth) == str: # dtpth is folder location
        if local_eval:
            dtpth = all_dev
            dtpth = dtpth.replace(f"{DATA_PTH}quests", "/../content")
        games = glob(os.path.join(dtpth, f"*.{ext}"))
    else: # dtpth already a list of games
        games = dtpth
    if not "m" in split and not "f" in split:
        indices = list(set(list(map(lambda g: "_".join(re.search(r"test_[0-9]+_[mf]", g)[0].split("_")[1:]), games))))
    else:
        indices = list(set(list(map(lambda g: re.search(r"test_[0-9]+", g)[0].split("_")[-1], games))))

    a_steps = []
    a_scores = []
    b_steps = []
    b_scores = []
    a_trajs = []
    b_trajs = []
    all_steps = []
    all_scores = []
    all_trajs = []

    for ix, game in enumerate(games):
        if "_m." in game:
            print("M", end=" ")
            game = game.replace("_m.", ".")
            if "all_dev" in game:
                game = game.replace("all_dev", f"m_dev_{qtype}_{rewards}")
            elif "all_test" in game:
                game = game.replace("all_test", f"m_test_{qtype}_{rewards}")
        elif "_f." in game:
            print("F", end=" ")
            game = game.replace("_f.", ".")
            if "all_dev" in game:
                game = game.replace("all_dev", f"f_dev_{qtype}_{rewards}")
            elif "all_test" in game:
                game = game.replace("all_test", f"f_test_{qtype}_{rewards}")
        if not eval_all:
            avg_moves_neuralA, avg_norm_scores_neuralA, trajsA = play(agent, game,
                                                                nb_episodes = k,
                                                                seed = 123)
            avg_moves_neuralB, avg_norm_scores_neuralB, trajsB = play(agentB, game,
                                                            nb_episodes = k,
                                                            seed = 123)
            a_steps.append(np.mean(avg_moves_neuralA))
            a_scores.append(np.mean(avg_norm_scores_neuralA))
            a_trajs.append(trajsA)
            b_steps.append(np.mean(avg_moves_neuralB))
            b_scores.append(np.mean(avg_norm_scores_neuralB))
            b_trajs.append(trajsB)
        else:
            avg_moves_neural_all, avg_norm_scores_neural_all, trajs_all = play(agent_all, game,
                                                            nb_episodes = k,
                                                            seed = 123)
            all_steps.append(np.mean(avg_moves_neural_all))
            all_scores.append(np.mean(avg_norm_scores_neural_all))
            all_trajs.append(trajs_all)
else:
    load_split = split
    print(load_split)
    all_res_d = {"dev": dict(), "test": dict()}
    for split in ["dev", "test"]:
        print(split)
        all_res_d[split] = {k: dict() for k in train_seeds}
        seed_range = train_seeds if not single_seed else seed_from_list
        for load_seed in seed_range:
            print(load_seed)
            if not eval_all:
                all_res_d[split][load_seed]["a_steps"] = []
                all_res_d[split][load_seed]["a_trajs"] = []
                all_res_d[split][load_seed]["a_scores"] = []
                all_res_d[split][load_seed]["b_steps"] = []
                all_res_d[split][load_seed]["b_trajs"] = []
                all_res_d[split][load_seed]["b_scores"] = []
            else:
                all_res_d[split][load_seed]["all_steps"] = []
                all_res_d[split][load_seed]["all_trajs"] = []
                all_res_d[split][load_seed]["all_scores"] = []
            if eval_all:
                if local:
                    agent_all = torch.load(f'/../content/all_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                else:
                    agent_all = torch.load(f'{DATA_PTH}all_checkpoints/all_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                print("Loaded agent.")
                agent_all.test()
            else:
                if local:
                    agent = torch.load(f'/../content/m_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                    agentB = torch.load(f'/../content/f_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                else:
                    agent = torch.load(f'{m_checkpoints_folder}m_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                    agentB = torch.load(f'{f_checkpoints_folder}f_agent_{qtype}_{rewards}{basis}{distractors}_{load_seed}_{load_n}_{load_k}.pt', map_location=device)
                print("Loaded agents.")
                agent.test()
                agentB.test()

            if "dev" in split:
                if "_m" in load_split:
                    dtpth = m_dev_ulx
                elif "_f" in load_split:
                    dtpth = f_dev_ulx
                else:
                    dtpth = all_dev
            else:
                if "_m" in load_split:
                    dtpth = m_test_ulx
                elif "_f" in load_split:
                    dtpth = f_test_ulx
                else:
                    dtpth = all_test

            if type(dtpth) == str: # dtpth is folder location
                games = glob(os.path.join(dtpth, f"*.ulx"))
            else: # dtpth already a list of games
                games = dtpth

            indices = list(set(list(map(lambda g: re.search(r"test_[0-9]+", g)[0].split("_")[-1], games))))
            all_res_d[split][load_seed]["indices"] = indices

            for ix, game in enumerate(games):
                if "_m." in game:
                    game = game.replace("_m.", ".")
                    if "all_dev" in game:
                        game = game.replace("all_dev", f"m_dev_{qtype}_{rewards}")
                    elif "all_test" in game:
                        game = game.replace("all_test", f"m_test_{qtype}_{rewards}")
                elif "_f" in game:
                    game = game.replace("_f.", ".")
                    if "all_dev" in game:
                        game = game.replace("all_dev", f"f_dev_{qtype}_{rewards}")
                    elif "all_test" in game:
                        game = game.replace("all_test", f"f_test_{qtype}_{rewards}")
                if not eval_all:
                    avg_moves_neuralA, avg_norm_scores_neuralA, trajsA = play(agent, game,
                                                                        nb_episodes = k,
                                                                        seed = 123)
                    avg_moves_neuralB, avg_norm_scores_neuralB, trajsB = play(agentB, game,
                                                                    nb_episodes = k,
                                                                    seed = 123)
                    all_res_d[split][load_seed]["a_steps"].append(np.mean(avg_moves_neuralA))
                    all_res_d[split][load_seed]["a_scores"].append(np.mean(avg_norm_scores_neuralA))
                    all_res_d[split][load_seed]["a_trajs"].append(trajsA)
                    all_res_d[split][load_seed]["b_steps"].append(np.mean(avg_moves_neuralB))
                    all_res_d[split][load_seed]["b_scores"].append(np.mean(avg_norm_scores_neuralB))
                    all_res_d[split][load_seed]["b_trajs"].append(trajsB)
                else:
                    avg_moves_neural_all, avg_norm_scores_neural_all, trajs_all = play(agent_all, game,
                                                                    nb_episodes = k,
                                                                    seed = 123)
                    all_res_d[split][load_seed]["all_steps"].append(np.mean(avg_moves_neural_all))
                    all_res_d[split][load_seed]["all_scores"].append(np.mean(avg_norm_scores_neural_all))
                    all_res_d[split][load_seed]["all_trajs"].append(trajs_all)
            if save:
                if not eval_all:
                    a_df = pd.DataFrame(zip(all_res_d[split][load_seed]["indices"], all_res_d[split][load_seed]["a_steps"], all_res_d[split][load_seed]["a_scores"], all_res_d[split][load_seed]["a_trajs"]))
                    a_df.columns = ["game_idx", "avg_steps", "avg_scores", "trajs"]
                    b_df = pd.DataFrame(zip(all_res_d[split][load_seed]["indices"], all_res_d[split][load_seed]["b_steps"], all_res_d[split][load_seed]["b_scores"], all_res_d[split][load_seed]["b_trajs"]))
                    b_df.columns = ["game_idx", "avg_steps", "avg_scores", "trajs"]
                    if "_m" in load_split:
                        a_df.to_csv(f"/../content/drive/MyDrive/data/light_data/m_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}_m.csv", index=False)
                        b_df.to_csv(f"/../content/drive/MyDrive/data/light_data/f_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}_m.csv", index=False)
                    elif "_f" in load_split:
                        a_df.to_csv(f"/../content/drive/MyDrive/data/light_data/m_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}_f.csv", index=False)
                        b_df.to_csv(f"/../content/drive/MyDrive/data/light_data/f_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}_f.csv", index=False)
                    else:
                        a_df.to_csv(f"/../content/drive/MyDrive/data/light_data/m_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}.csv", index=False)
                        b_df.to_csv(f"/../content/drive/MyDrive/data/light_data/f_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}.csv", index=False)
                    print(f"Saved male and female {split} results for seed {load_seed}.")
                else:
                    all_df = pd.DataFrame(zip(all_res_d[split][load_seed]["indices"],
                                              all_res_d[split][load_seed]["all_steps"],
                                              all_res_d[split][load_seed]["all_scores"],
                                              all_res_d[split][load_seed]["all_trajs"]))
                    all_df.columns = ["game_idx", "avg_steps", "avg_scores", "trajs"]
                    all_df.to_csv(f"/../content/drive/MyDrive/data/light_data/all_{qtype}_{rewards}{basis}{distractors}_{split}_{load_seed}_{load_n}_{load_k}_{k}.csv", index=False)
                    print("Saved all.")
print("Done.")

In [None]:
#@title Load eval results
test_k = 10 #@param
bulk_load = False #@param {'type': 'boolean'}
switch_m = False #@param {'type': 'boolean'}
switch_f = False #@param {'type': 'boolean'}
train_seed = "1429" #@param train_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
train_seed = int(train_seed)
train_k = 5 #@param
train_n = 240 #@param
split = "test" #@param ["dev", "test"]
eval_t = "" #@param
if not bulk_load:
    a_pth = f"{DATA_PTH}m_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
    b_pth = f"{DATA_PTH}f_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
    if switch_m:
        a_pth += "_sm"
        b_pth += "_sm"
    elif switch_f:
        a_pth += "_sf"
        b_pth += "_sf"
    a_test_df = pd.read_csv(f"{a_pth}.csv")
    b_test_df = pd.read_csv(f"{b_pth}.csv")
else:
    all_eval_d = dict()
    for split in ["dev", "test"]:
        all_eval_d[split] = {"m": dict(), "f": dict()}
        eval_seeds = [858, 4386, 1429, 6673, 4368, 7131, 1719, 834, 2968, 7897]
        if eval_t == "orig":
            eval_seeds = [858, 4386, 1429, 6673, 4368]
        for train_seed in eval_seeds:
            if eval_t == "orig":
                a_pth = f"{DATA_PTH}eval_results/m_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
                b_pth = f"{DATA_PTH}eval_results/f_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
            else:
                a_pth = f"{DATA_PTH}m_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
                b_pth = f"{DATA_PTH}f_{qtype}_{rewards}{basis}{distractors}_{split}_{train_seed}_{train_n}_{train_k}_{test_k}"
            if eval_t == "orig":
                a_pth += "_old"
                b_pth += "_old"
            if switch_m:
                a_pth += "_sm"
                b_pth += "_sm"
            elif switch_f:
                a_pth += "_sf"
                b_pth += "_sf"
            try:
                all_eval_d[split]["m"][train_seed] = pd.read_csv(f"{a_pth}.csv")
                all_eval_d[split]["f"][train_seed] = pd.read_csv(f"{b_pth}.csv")
            except FileNotFoundError:
                if eval_t == "orig":
                    a_pth = a_pth.replace("_old", "_orig")
                    b_pth = b_pth.replace("_old", "_orig")

### Significance Tests

- [ASO documentation](https://deep-significance.readthedocs.io/en/latest/#id3)

In [None]:
!pip install deepsig

In [None]:
import scipy
import pandas as pd
import numpy as np
from itertools import product
from deepsig import aso

In [None]:
seed = 42

In [None]:
axis = 0 #@param {'type': 'integer'} # avg seed or per seed
all_dev_m_scores = pd.DataFrame(all_eval_d["dev"]).m.apply(lambda x: x["avg_scores"]).mean(axis)
all_dev_f_scores = pd.DataFrame(all_eval_d["dev"]).f.apply(lambda x: x["avg_scores"]).mean(axis)
all_dev_m_steps = pd.DataFrame(all_eval_d["dev"]).m.apply(lambda x: x["avg_steps"]).mean(axis)
all_dev_f_steps = pd.DataFrame(all_eval_d["dev"]).f.apply(lambda x: x["avg_steps"]).mean(axis)
all_test_m_scores = pd.DataFrame(all_eval_d["test"]).m.apply(lambda x: x["avg_scores"]).mean(axis)
all_test_f_scores = pd.DataFrame(all_eval_d["test"]).f.apply(lambda x: x["avg_scores"]).mean(axis)
all_test_m_steps = pd.DataFrame(all_eval_d["test"]).m.apply(lambda x: x["avg_steps"]).mean(axis)
all_test_f_steps = pd.DataFrame(all_eval_d["test"]).f.apply(lambda x: x["avg_steps"]).mean(axis)


In [None]:
# Avg. game results (scores and steps)
display(all_dev_m_scores.mean(), all_dev_f_scores.mean(),
        all_test_m_scores.mean(), all_test_f_scores.mean(),
        all_dev_m_steps.mean(), all_dev_f_steps.mean(),
        all_test_m_steps.mean(), all_test_f_steps.mean())

0.9942244224422442

0.9481683168316831

0.9946782178217822

0.9513201320132014

17.360099009900992

30.85168316831683

16.44584158415842

29.22316831683168

In [None]:
# ASO for avg seed or per seed setting

aso(all_dev_m_steps * -1, all_dev_f_steps * -1, seed=42) # lower for steps
# aso(all_test_m_steps * -1, all_test_f_steps * -1, seed=42)
#aso(all_dev_m_scores, all_dev_f_scores, seed=42) # higher for scores
#aso(all_test_m_scores, all_test_f_scores, seed=42) # higher for scores