## Purpose

To train a model with implicit language Q-learning.

To keep things fair, sparse zero-one rewards will be used as signal.

## Inputs:

- Offline dataset from the task of interest.

In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['HF_HOME'] = '/teamspace/studios/this_studio/Agents-Course-Final-Project/.cache/'

from omegaconf import OmegaConf
from ilql_utils import train_ilql
from ilql_eval import run_evaluations, compute_cross_run_metrics
import json
from collections import Counter

import yaml

%load_ext autoreload
%autoreload 2

In [17]:
config_path = './configs/ilql/default.yaml'

In [18]:
config = OmegaConf.load(config_path)

In [19]:
config.saving.save_basedir = config.saving.save_basedir.format(task=config.task)
config.data_path = config.data_path.format(task=config.task)

config.saving.save_dir = os.path.join(config.saving.save_basedir,
                            config.run_group_name,
                            config.run_name)

os.makedirs(config.saving.save_dir, exist_ok=True)

In [24]:
# next step -- format 20Q data according to what we expect to take in

In [20]:
# code for generating held out cities
task = 'twenty-questions'

held_out_secrets_path = f'input_data/{task}/held_out_secrets.json'
filtered_train_path = f'input_data/{task}/train_transformed.json'
filtered_eval_path = f'input_data/{task}/eval_transformed.json'

def transform_datapoint(d):
    turns = []
    for line in d['lines']:
        if '? ' not in line:
            return None
        try:
            clauses = line.split('? ')
            q_str = '? '.join(clauses[:-1]) + '?'
            a_str = clauses[-1]
        except:
            print(line)
            raise ValueError()
        
        turns.extend([q_str, a_str])
    new_d = {
        'turns': turns, 
        'secret': d['word'][0] if isinstance(d['word'], list) else d['word'],
        'guessed': d['correct']
    }
    return new_d
    
    

with open(f'./input_data/{task}/train.json', 'r') as f:
    train_data = json.load(f)

with open(f'./input_data/{task}/eval.json', 'r') as f:
    eval_data = json.load(f)


if os.path.exists(held_out_secrets_path):
    with open(f'./input_data/{task}/held_out_secrets.json', 'r') as f:
        held_out_secrets = json.load(f)
else:
    
    words_train = Counter([sorted(d['word'])[0] if isinstance(d['word'], list) else d['word'] for d in train_data])
    words_eval = Counter([sorted(d['word'])[0] if isinstance(d['word'], list) else d['word'] for d in eval_data])
    
    secrets = list(set(words_eval.keys()))
    random.seed(42)
    random.shuffle(secrets)
    held_out_secrets = secrets[:10]

    with open(held_out_secrets_path, 'w') as f:
        json.dump(held_out_secrets, f)

# hold out some cities to test for generalization

if not os.path.exists(filtered_train_path):
    filtered_train_data = [t for t in train_data if t['word'] not in held_out_secrets and t['word'][0] not in held_out_secrets]
    filtered_train_data = [transform_datapoint(d) for d in filtered_train_data]
    
    with open(filtered_train_path, 'w') as f:
        json.dump(filtered_train_data, f)



if not os.path.exists(filtered_eval_path):
    filtered_eval_data = [transform_datapoint(d) for d in eval_data]
    with open(filtered_eval_path, 'w') as f:
        json.dump(filtered_eval_data, f)


with open(filtered_train_path, 'r') as f:
    train_transformed = json.load(f)

with open(filtered_eval_path, 'r') as f:
    eval_transformed = json.load(f)

print('Data prepared!')

Data prepared!


## Training loop

In [21]:
train_ilql(config)

Loading data...


100%|██████████| 100/100 [00:00<00:00, 938.19it/s]


Data loaded!
Loading data...


100%|██████████| 100/100 [00:00<00:00, 990.41it/s]


Data loaded!
TRAIN DATA LENGTH: 100
VAL DATA LENGTH: 100
##################################################
EXAMPLE TRAJECTORY:
##################################################
('user: Is the object an animal?\n'
 'assistant: No.\n'
 'user: Is the object man-made?\n'
 'assistant: No.\n'
 'user: Is the object a mineral?\n'
 'assistant: No.\n'
 'user: Is the object a plant?\n'
 'assistant: Yes.\n'
 'user: Is the object a tree?\n'
 'assistant: No.\n'
 'user: Is the object a fruit?\n'
 'assistant: No.\n'
 'user: Is the object a flower?\n'
 'assistant: No.\n'
 'user: Is the object a vegetable?\n'
 'assistant: Yes.\n'
 'user: Is the object a root vegetable?\n'
 'assistant: No.\n'
 'user: Is the object an ornamental vegetable?\n'
 'assistant: No.\n'
 'user: Is the object an edible berry?\n'
 'assistant: No.\n'
 'user: Is the object a leafy green vegetable?\n'
 'assistant: Yes.\n'
 'user: Is the object spinach?\n'
 'assistant: No.\n'
 'user: Is the object kale?\n'
 'assistant: No.\n'
 'user:



MODEL LOADED


  0%|          | 0/100 [00:00<?, ?it/s]

##################################################
EXAMPLE TRAJECTORY:
('user: Is it an animal?\n'
 'assistant: No.\n'
 'user: Is it an inanimate object?\n'
 'assistant: Yes.\n'
 'user: Is it man-made?\n'
 'assistant: Yes.\n'
 'user: Is it a tool?\n'
 'assistant: No.\n'
 'user: Is it a piece of technology?\n'
 'assistant: No.\n'
 'user: Is it a piece of furniture?\n'
 'assistant: No.\n'
 'user: Is it an item of clothing?\n'
 'assistant: No.\n'
 'user: Is it a kitchen appliance?\n'
 'assistant: No.\n'
 'user: Is it a household item?\n'
 'assistant: No.\n'
 'user: Is it a decoration?\n'
 'assistant: No.\n'
 'user: Is it a toy?\n'
 'assistant: No.\n'
 'user: Is it something used in the outdoors?\n'
 'assistant: Yes.\n'
 'user: Is it a recreational vehicle?\n'
 'assistant: No.\n'
 'user: Is it a sporting equipment?\n'
 'assistant: Yes.\n'
 'user: Is it a bat?\n'
 'assistant: No.\n'
 'user: Is it a glove?\n'
 'assistant: No.\n'
 'user: Is it a ball?\n'
 'assistant: Yes.\n'
 'user: Is it a s

Loss: 241.5545: 100%|██████████| 100/100 [01:28<00:00,  1.13it/s]
100%|██████████| 100/100 [00:46<00:00,  2.14it/s]
Loss: 255.5099: 100%|██████████| 100/100 [01:29<00:00,  1.11it/s]
100%|██████████| 100/100 [00:46<00:00,  2.14it/s]
Loss: 228.4615: 100%|██████████| 100/100 [01:29<00:00,  1.11it/s]
100%|██████████| 100/100 [00:46<00:00,  2.14it/s]
Loss: 237.5196: 100%|██████████| 100/100 [01:29<00:00,  1.11it/s]
100%|██████████| 100/100 [00:46<00:00,  2.14it/s]


Checkpoint saved at ./checkpoints/twenty-questions/ilql/DEBUG-GROUP/gpt2-xl_42_100/final_checkpoint


0,1
cql_loss,▃▃▃▂▂▂█▄▇▂▃▃▆▂▂▃▃▂▃▃▅▇▃▂█▂▂▃▃▃▁▃▄▄▃▃▂▃▂▂
cql_loss_val,█▃▁▂
expectile_loss,▄▂▄▇▃▃▃▁▁█▂▃▄▃▄▂▂▃▃▂▄▂▃▂▂▄▁▅▅▂▃▂▃▂▃▄▃▃▃▃
expectile_loss_val,█▂▁▁
global_step,▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
loss,▅▂▄▄▅▄▄▄▆▃▄▁▃▄▆▆▃▅▇▂▆▂▅▅▃▄█▇▅▃▃▃▁▃▂▄▃▃▃▃
loss_val,█▂▁▁
lr,██████▇▇▇▇▇▇▇▆▆▆▆▅▅▄▄▄▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁
mc_loss,▄▃▃▃▃▄▄▂▅▆▂▆▃▂▁▂▁▂▁▁▆▁▂█▁▁▇▁▇▂▂▁█▁▇▁▁▁█▂
mc_loss_val,█▂▁▁

0,1
cql_loss,0.94141
cql_loss_val,0.83375
expectile_loss,54.51488
expectile_loss_val,53.72224
global_step,400.0
loss,237.51958
loss_val,233.8264
lr,0.0
mc_loss,0.82812
mc_loss_val,0.65648


In [24]:
run_evaluations(config, filtered_eval_path)

Start Evaluating: gpt2_42_100
ILQL CHECKPOINT LOADED


100%|██████████| 1000/1000 [00:59<00:00, 16.77it/s]


Evaluated: gpt2_42_100
Start Evaluating: gpt2-medium_42_100
ILQL CHECKPOINT LOADED


100%|██████████| 1000/1000 [02:07<00:00,  7.86it/s]


Evaluated: gpt2-medium_42_100
Start Evaluating: gpt2-large_42_100
ILQL CHECKPOINT LOADED


100%|██████████| 1000/1000 [04:04<00:00,  4.09it/s]


Evaluated: gpt2-large_42_100
Start Evaluating: gpt2-xl_42_100
ILQL CHECKPOINT LOADED


100%|██████████| 1000/1000 [07:40<00:00,  2.17it/s]

Evaluated: gpt2-xl_42_100





In [25]:
compute_cross_run_metrics("./evaluation_results")

{'variance_by_turn': {0: 0.7506887267056481,
  1: 3.0827336093902113,
  2: 3.991752503770039,
  3: 4.541182183367169,
  4: 4.726416334320138,
  5: 4.598426892908739,
  6: 4.513711400745505,
  7: 4.253859160601736,
  8: 4.049204411483638,
  9: 3.900609631599493,
  10: 3.7932281815596096,
  11: 3.7300817876877868,
  12: 3.6663655050961603,
  13: 3.631174634382329,
  14: 3.5915843489278894,
  15: 3.5693184446837556,
  16: 3.585712945475949,
  17: 3.546538530571067,
  18: 3.545031387282631},
 'agreement_by_turn': {0: 0.3866666666666631,
  1: 0.3546666666666644,
  2: 0.34084836339345187,
  3: 0.3991228070175414,
  4: 0.4617597292724178,
  5: 0.4789311408016436,
  6: 0.48813263525305356,
  7: 0.4930851063829785,
  8: 0.4956568946796959,
  9: 0.496111111111111,
  10: 0.4979023646071701,
  11: 0.49664694280078886,
  12: 0.4975698663426487,
  13: 0.49584889995848896,
  14: 0.49724342663273957,
  15: 0.49583333333333324,
  16: 0.49303683737646,
  17: 0.4944751381215468,
  18: 0.49285714285714277