In [9]:
import sys
sys.path.append('../')
from utils.notebook_imports import *
from tactic_and_execution.thought_action_observation_data_structure import *
from tactic_and_execution.traj_utils import *
from collections import Counter
from traj_runner import TrajRunner

In [11]:
runner = TrajRunner('../data/tactics')

# In this notebook, we show how to make trainable seqs for fine-tuning different versions of Tiger-8B 

## first load the data

In [4]:
# download the json files from HF repo if you haven't
strain = jload('../data/standalone_train.json')
htrain = jload('../data/hybrid_train.json')

## the key function we use is `make_perfect_traj` which makes the perfect traj from the raw `sample['gt_traj']` and also return the inds of those exps that are correct in the raw traj so that we can make mask out those incorrect ones for imperfect traj training

### to visualize the process let's sample a subset

In [7]:
subset = list(np.random.choice(strain, 10, replace=False))
# visualize the traj status
for e in subset:
    print(Trajectory.from_json(e['gt_traj']).show_traj_status())

[('Plan', 'exec ok'), ('Write program', 'exec wrong'), ('Revise code', 'exec ok'), ('Aggregate and answer', 'exec wrong'), ('Aggregate and answer', 'exec solved')]
[('Plan', 'exec ok'), ('Build FOL model', 'exec wrong'), ('Revise code', 'exec wrong'), ('Revise code', 'exec ok'), ('Aggregate and answer', 'exec solved')]
[('Plan', 'exec ok'), ('Build graph model', 'exec ok'), ('Aggregate and answer', 'exec solved')]
[('Build math model', 'exec ok'), ('Aggregate and answer', 'exec wrong'), ('Aggregate and answer', 'exec solved')]
[('Plan', 'exec ok'), ('Write program', 'exec wrong'), ('Revise code', 'exec ok'), ('Reason', 'exec ok'), ('Aggregate and answer', 'exec wrong'), ('Aggregate and answer', 'exec wrong'), ('Aggregate and answer', 'exec solved')]
[('Plan', 'exec ok'), ('Build graph model', 'exec ok'), ('Revise code', 'exec ok'), ('Aggregate and answer', 'exec solved')]
[('Plan', 'exec ok'), ('Build FOL model', 'exec ok'), ('Revise code', 'exec ok'), ('Aggregate and answer', 'exec wr

## you see some of the above trajs while solving the problem at the end, make mistakes in the middle, and we proposed two way handling it:
- PJ: we make a prefect traj out of the raw ones by replacing the content (action) of the bad ones with the later good ones
- IPJ: we keep the same traj but mask out the seq labels of the bad ones 

## to do so we use function `make_perfect_traj`

In [17]:
for ind, e in enumerate(nbtqdm(strain)):
    pj, trainable_inds, log_ls = make_perfect_traj(e, 'gt_traj', runner)
    if pj is None:
        continue
    e['pj'] = pj
    e['trainable_inds'] = trainable_inds
    e['log_ls'] = log_ls

  0%|          | 0/4303 [00:00<?, ?it/s]

## now let's look at the newly created perfect trajs

In [19]:
for e in subset:
    print(e['log_ls'])

[(0, 'Plan', 'exec ok'), ('1/2', 'Write program/Revise code', 'exec wrong/exec ok'), ('3/4', 'Aggregate and answer/Aggregate and answer', 'exec wrong/exec solved')]
[(0, 'Plan', 'exec ok'), ('1/3', 'Build FOL model/Revise code', 'exec wrong/exec ok'), (4, 'Aggregate and answer', 'exec solved')]
[(0, 'Plan', 'exec ok'), (1, 'Build graph model', 'exec ok'), (2, 'Aggregate and answer', 'exec solved')]
[(0, 'Build math model', 'exec ok'), ('1/2', 'Aggregate and answer/Aggregate and answer', 'exec wrong/exec solved')]
[(0, 'Plan', 'exec ok'), ('1/2', 'Write program/Revise code', 'exec wrong/exec ok'), (3, 'Reason', 'exec ok'), ('4/6', 'Aggregate and answer/Aggregate and answer', 'exec wrong/exec solved')]
[(0, 'Plan', 'exec ok'), ('1/2', 'Build graph model/Revise code', 'exec ok/exec ok'), (3, 'Aggregate and answer', 'exec solved')]
[(0, 'Plan', 'exec ok'), ('1/2', 'Build FOL model/Revise code', 'exec ok/exec ok'), ('3/4', 'Aggregate and answer/Aggregate and answer', 'exec wrong/exec solved

## you see some of the exps are merged

## now we can make the training seq using `traj2seq_simple`

In [20]:
pj_data, ipj_data, rpj_data = [], [], []

for sample in strain:
    if 'pj' not in sample:
        continue
        
    pj_data.append({
        'input': sample['input'],
        'label': sample['label'],
        'src': sample['src'],
        'dataset': sample['dataset'],
        'orig_data': sample['orig_data'],
        'traj': traj2seq_simple(
            sample['pj'], 
            no_exp_ok=False, 
            # there are lines starting with keyword "HINTS" that have contents injected during data generation to guide the generation
            # here, we need to remove those lines since the model should not learn to predict these
            remove_hints=True, 
            add_resp_head=False, 
            trainable_inds=None
        )
    })
    
    ipj_data.append({
        'input': sample['input'],
        'label': sample['label'],
        'src': sample['src'],
        'dataset': sample['dataset'],
        'orig_data': sample['orig_data'],
        'traj': traj2seq_simple(
            Trajectory.from_json(sample['gt_traj']), 
            no_exp_ok=False, 
            remove_hints=True, 
            add_resp_head=False,
            # recall the trainable_inds we get by calling make_prefect_traj, this is where we use it to mask out the incorrect ones to
            # make imperfect traj trainable
            trainable_inds=sample['trainable_inds']
        )
    })

# also gather the trainable seqs of the routing trajs
for sample in htrain:
    rpj_data.append({
        'input': sample['input'],
        'label': sample['label'],
        'src': sample['src'],
        'dataset': sample['dataset'],
        'traj': traj2seq_simple(
            Trajectory.from_json(sample['hyb_syn_traj']), 
            no_exp_ok=False, 
            remove_hints=True, 
            add_resp_head=False, 
            trainable_inds=None,
            obs_trainable=False
        )
    })

In [21]:
print(len(pj_data), len(ipj_data), len(rpj_data))

4294 4294 987


In [22]:
# save the data for PJ training
jdump(pj_data, '../data/pj_data.json')
# for IPJ training we use PJ+IPJ data
jdump(pj_data+ipj_data, '../data/ipj_data.json')
# for router training we use PJ+RPJ data
jdump(pj_data+rpj_data, '../data/rpj_data.json')