## Purpose

To train a model with implicit language Q-learning.

To keep things fair, sparse zero-one rewards will be used as signal.

## Inputs:

- Offline dataset from the task of interest.

In [11]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['HF_HOME'] = '/home/seobrien/.cache/'

from omegaconf import OmegaConf
from ilql_utils import train_ilql
import json
from collections import Counter

import yaml

%load_ext autoreload
%autoreload 2

In [77]:
config_path = './configs/ilql/default.yaml'

In [78]:
config = OmegaConf.load(config_path)

In [79]:
config.saving.save_basedir = config.saving.save_basedir.format(task=config.task)
config.data_path = config.data_path.format(task=config.task)

config.saving.save_dir = os.path.join(config.saving.save_basedir,
                            config.run_group_name,
                            config.run_name)

os.makedirs(config.saving.save_dir, exist_ok=True)

In [80]:
# next step -- format 20Q data according to what we expect to take in

In [81]:
# code for generating held out cities
task = 'twenty-questions'

held_out_secrets_path = f'input_data/{task}/held_out_secrets.json'
filtered_train_path = f'input_data/{task}/train_transformed.json'
filtered_eval_path = f'input_data/{task}/eval_transformed.json'

def transform_datapoint(d):
    turns = []
    for line in d['lines']:
        if '? ' not in line:
            return None
        try:
            clauses = line.split('? ')
            q_str = '? '.join(clauses[:-1]) + '?'
            a_str = clauses[-1]
        except:
            print(line)
            raise ValueError()
        
        turns.extend([q_str, a_str])
    new_d = {
        'turns': turns, 
        'secret': d['word'][0] if isinstance(d['word'], list) else d['word'],
        'guessed': d['correct']
    }
    return new_d
    
    

with open(f'./input_data/{task}/train.json', 'r') as f:
    train_data = json.load(f)

with open(f'./input_data/{task}/eval.json', 'r') as f:
    eval_data = json.load(f)


if os.path.exists(held_out_secrets_path):
    with open(f'./input_data/{task}/held_out_secrets.json', 'r') as f:
        held_out_secrets = json.load(f)
else:
    
    words_train = Counter([sorted(d['word'])[0] if isinstance(d['word'], list) else d['word'] for d in train_data])
    words_eval = Counter([sorted(d['word'])[0] if isinstance(d['word'], list) else d['word'] for d in eval_data])
    
    secrets = list(set(words_eval.keys()))
    random.seed(42)
    random.shuffle(secrets)
    held_out_secrets = secrets[:10]

    with open(held_out_secrets_path, 'w') as f:
        json.dump(held_out_secrets, f)

# hold out some cities to test for generalization

if not os.path.exists(filtered_train_path):
    filtered_train_data = [t for t in train_data if t['word'] not in held_out_secrets and t['word'][0] not in held_out_secrets]
    filtered_train_data = [transform_datapoint(d) for d in filtered_train_data]
    
    with open(filtered_train_path, 'w') as f:
        json.dump(filtered_train_data, f)



if not os.path.exists(filtered_eval_path):
    filtered_eval_data = [transform_datapoint(d) for d in eval_data]
    with open(filtered_eval_path, 'w') as f:
        json.dump(filtered_eval_data, f)


with open(filtered_train_path, 'r') as f:
    train_transformed = json.load(f)

with open(filtered_eval_path, 'r') as f:
    eval_transformed = json.load(f)

print('Data prepared!')

Data prepared!


## Training loop

In [82]:
train_ilql(config)

Loading data...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2776.46it/s]


Data loaded!
Loading data...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2888.94it/s]


Data loaded!
TRAIN DATA LENGTH: 1000
VAL DATA LENGTH: 100
##################################################
EXAMPLE TRAJECTORY:
##################################################
('user: Is the object an animal?\n'
 'assistant: No.\n'
 'user: Is the object man-made?\n'
 'assistant: No.\n'
 'user: Is the object a mineral?\n'
 'assistant: No.\n'
 'user: Is the object a plant?\n'
 'assistant: Yes.\n'
 'user: Is the object a tree?\n'
 'assistant: No.\n'
 'user: Is the object a fruit?\n'
 'assistant: No.\n'
 'user: Is the object a flower?\n'
 'assistant: No.\n'
 'user: Is the object a vegetable?\n'
 'assistant: Yes.\n'
 'user: Is the object a root vegetable?\n'
 'assistant: No.\n'
 'user: Is the object an ornamental vegetable?\n'
 'assistant: No.\n'
 'user: Is the object an edible berry?\n'
 'assistant: No.\n'
 'user: Is the object a leafy green vegetable?\n'
 'assistant: Yes.\n'
 'user: Is the object spinach?\n'
 'assistant: No.\n'
 'user: Is the object kale?\n'
 'assistant: No.\n'
 'user



MODEL LOADED


Loss: 18052.7148:   0%|▎                                                                                    | 4/1000 [00:00<00:28, 35.01it/s]

##################################################
EXAMPLE TRAJECTORY:
('user: Is it larger than a breadbox?\n'
 'assistant: No.\n'
 'user: Is it smaller than a tennis ball?\n'
 'assistant: Yes.\n'
 'user: Is it smaller than a marble?\n'
 'assistant: No.\n'
 'user: Is it larger than a walnut?\n'
 'assistant: No.\n'
 'user: Is it smaller than a pea?\n'
 'assistant: No.\n'
 'user: Is it smaller than a penny?\n'
 'assistant: No.\n'
 'user: Is it smaller than a tennis ball ball?\n'
 'assistant: Yes.\n'
 'user: Is it smaller than a coin?\n'
 'assistant: No.\n'
 'user: Is it a tennis ball ball?\n'
 'assistant: No.\n'
 'user: Is it smaller than a marble?\n'
 'assistant: No.\n'
 'user: Is it smaller than a wooden chopstick?\n'
 'assistant: No.\n'
 'user: Is it smaller than a wooden pencil?\n'
 'assistant: No.\n'
 'user: Is it smaller than a wooden chopstick handle?\n'
 'assistant: Yes.\n'
 'user: Is it a wooden pen?\n'
 'assistant: No.\n'
 'user: Is it smaller than a wooden utensil?\n'
 'assis

Loss: 23239.4062: 100%|██████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:27<00:00, 36.48it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 61.64it/s]
Loss: 13603.7998: 100%|██████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:27<00:00, 36.47it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 62.75it/s]
Loss: 8735.9316: 100%|███████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:27<00:00, 36.40it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 61.80it/s]
Loss: 21545.1797: 100%|██████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:27<00:00, 36.23it/s]
100%|█

Checkpoint saved at ./checkpoints/ilql/twenty-questions/DEBUG-GROUP/DEBUG-RUN-NAME/final_checkpoint


In [83]:
tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')

In [84]:
tokenizer.encode('\n')

[198]