In [1]:
import sys
import os
root_dir = '/home/qingli/Desktop/Closed-Loop-Learning/HINT/'
os.chdir(root_dir)
from train import *
from data.domain import OPERATORS
import pandas as pd 
from collections import OrderedDict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_theme()
sns.set_context("notebook", font_scale=1.5)
sns.color_palette('colorblind')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.argv = []
args = parse_args()
args.input = 'symbol'
args.model = 'TRAN.relative_universal'
args.resume = f'./models/{args.input}.{args.model}/model_100000.p'
args.nhead = 8
args.enc_layers = 6
args.dec_layers = 1
args.hid_dim = 512
args.max_rel_pos = 15
args

Namespace(batch_size=128, cos_sim_margin=0.2, curriculum='no', dec_layers=1, dropout=0.5, early_stop=None, emb_dim=128, enc_layers=6, epochs=10, epochs_eval=1, fewshot=None, grad_clip=5.0, hid_dim=512, input='symbol', iterations=None, iterations_eval=None, layers=1, lr=0.001, lr_scheduler='constant', main_dataset_ratio=0, max_op_train=None, max_rel_pos=15, model='TRAN.relative_universal', nhead=8, output_attentions=False, output_dir='outputs/', perception_pretrain='data/perception_pretrain/model.pth.tar_78.2_match', pos_emb_type='sin', result_encoding='decimal', resume='./models/symbol.TRAN.relative_universal/model_100000.p', save_model=False, seed=0, train_size=None, wandb='HINT', warmup_steps=100)

In [3]:
train_set = HINT('train', input=args.input, fewshot=args.fewshot, 
                    n_sample=args.train_size, max_op=args.max_op_train,
                    main_dataset_ratio=args.main_dataset_ratio)
val_set = HINT('val', input=args.input, fewshot=args.fewshot)
test_set = HINT('test', input=args.input, fewshot=args.fewshot)
print('train:', len(train_set), 'val:', len(val_set), 'test:', len(test_set))

args.train_set = train_set
args.val_set = val_set
args.test_set = test_set

train: 998000 val: 4698 test: 46620


In [4]:
args.res_enc = ResultEncoding(args.result_encoding)
model = make_model(args)
if args.resume:
    print('Load checkpoint from ' + args.resume)
    ckpt = torch.load(args.resume)
    model.load_state_dict(ckpt['model_state_dict'])
model.to(DEVICE)

n_params = sum(p.numel() for p in model.parameters())
print('Num params:', n_params)

Load checkpoint from ./models/symbol.TRAN.relative_universal/model_100000.p
Num params: 8019469


In [5]:
dataset = args.test_set
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,
                                            shuffle=False, num_workers=4, collate_fn=HINT_collate)

model.eval() 
res_all = []
res_pred_all = []

expr_all = []
expr_pred_all = []

dep_all = []
dep_pred_all = []

metrics = OrderedDict()

with torch.no_grad():
    for sample in tqdm(dataloader):
        if args.input == 'image':
            src = sample['img_seq']
        elif args.input == 'symbol':
            src = torch.tensor([x for s in sample['sentence'] for x in s])
        res = sample['res']
        if args.result_encoding == 'sin':
            tgt = res.unsqueeze(1)
        else:
            tgt = torch.tensor(args.res_enc.res2seq_batch(res.numpy()))
        expr = sample['expr']
        dep = sample['head']
        src_len = sample['len']

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model(src, tgt[:, :-1], src_len)
        pred = torch.argmax(output, -1).detach().cpu().numpy()
        if args.result_encoding == 'sin':
            res_pred = pred
        else:
            res_pred = args.res_enc.seq2res_batch(pred)
        res_pred_all.append(res_pred)
        res_all.append(res)

        # expr_pred_all.extend(expr_preds)
        expr_all.extend(expr)
        # dep_pred_all.extend(dep_preds)
        dep_all.extend(dep)

res_pred_all = np.concatenate(res_pred_all, axis=0)
res_all = np.concatenate(res_all, axis=0)
result_acc = (res_pred_all == res_all).mean()
metrics['result_acc/avg'] = result_acc

tracked_attrs = ['length', 'symbol', 'digit', 'result', 'eval', 'tree_depth', 'ps_depth', 'max_dep']
for attr in tracked_attrs:
    # print(f"result accuracy by {attr}:")
    attr2ids = getattr(dataloader.dataset, f'{attr}2ids')
    for k, ids in sorted(attr2ids.items()):
        res = res_all[ids]
        res_pred = res_pred_all[ids]
        res_acc = (res == res_pred).mean() if ids else 0.
        k = 'div' if k == '/' else k
        metrics[f'result_acc/{attr}/{k}'] = res_acc
        # print(k, "(%2d%%)"%(100*len(ids)//len(dataloader.dataset)), "%5.2f"%(100 * res_acc))

100%|█████████████████████████████████████████| 729/729 [00:09<00:00, 74.57it/s]


In [74]:
def compute_max_dep(heads):
    return max([0] + [abs(i-h) for i, h in enumerate(heads) if h != -1])

from functools import lru_cache
def compute_tree_depth(head):
    @lru_cache
    def depth(i):
        """The depth of node i."""
        if head[i] == -1:
            return 1
        return depth(head[i]) + 1

    return max(depth(i) for i in range(len(head)))

lps = '('
rps = ')'
def compute_ps_depth(expr):
    depth = 0
    max_depth = 0
    for x in expr:
        if x == lps:
            c = 1
        elif x == rps:
            c = -1
        else:
            c = 0
        depth += c
        if depth > max_depth:
            max_depth = depth
    return max_depth

def compute_n_op(expr):
    return len([1 for x in expr if x in OPERATORS])

tracked_attrs = ['expr', 'eval', 'length', 'tree_depth', 'ps_depth', 'max_dep', 'n_op', 'result', 'max_itm']
attr2data = {}
attr2data['pred'] = res_pred_all == res_all

In [75]:
for attr in tracked_attrs:
    if attr in attr2data: continue
    data = []
    for sample in dataset:
        if attr == 'length':
            d = len(sample['expr'])
        elif attr == 'expr':
            d = sample['expr']
        elif attr == 'result':
            d = sample['res']
        elif attr == 'max_itm':
            tmp = sample['res_all'][:]
            tmp.remove(sample['res'])
            d = max(tmp) if tmp else 0
        elif attr == 'tree_depth':
            d = compute_tree_depth(sample['head'])
        elif attr == 'max_dep':
            d = compute_max_dep(sample['head'])
        elif attr == 'ps_depth':
            d = compute_ps_depth(sample['expr'])
        elif attr == 'n_op':
            d = compute_n_op(sample['expr'])
        elif attr == 'eval':
            d = sample['eval']
        else:
            assert False
        data.append(d)
    attr2data[attr] = data

In [76]:
df = pd.DataFrame(attr2data)
df

Unnamed: 0,pred,expr,eval,length,tree_depth,ps_depth,max_dep,n_op,result,max_itm
0,True,7,I,1,1,0,0,0,7,0
1,True,1,I,1,1,0,0,0,1,0
2,True,0,I,1,1,0,0,0,0,0
3,True,6,I,1,1,0,0,0,6,0
4,True,4,I,1,1,0,0,0,4,0
...,...,...,...,...,...,...,...,...,...,...
46615,False,(4+1)/8/(4/((2+4*(8*2))*(3+(9+3)-(2-(6-(8+2)*4...,LL,67,10,6,44,20,3,858
46616,False,(1*(6/5)/(8*2*((4-8/5)/(0+2)))+(9+4/(3*4))*9)*...,LL,61,8,4,30,20,1638,91
46617,True,(4-(0+1-(1+4)-(9-(8+2-8))-4)*(5/3*(5*6/(6/4+9)...,LL,61,9,4,47,20,1,2160
46618,False,6*4*7/((7+8-3+7)*(8/7)/(3/5+1*0))+(2-5-5)+((4-...,LL,55,9,2,28,20,5,168


In [54]:
df_filter = df[(df['eval'] == 'SL')]
print(df_filter['pred'].mean())

df_filter = df[(df['eval'] == 'SL') & (df['result'] <= 100)]
print(df_filter['pred'].mean())

df_filter = df[(df['eval'] == 'SL') & (df['result'] > 100)]
print(df_filter['pred'].mean())

0.11423611111111111
0.669606512890095
0.0


In [64]:
df_filter = df[(df['eval'] == 'SL') & (df['result'] <= 100) & (df['pred'] == True)]
display(df_filter)
display(df_filter.sample(20))

df_filter = df[(df['eval'] == 'SL') & (df['result'] <= 100) & (df['pred'] == False)]
display(df_filter.sample(20))

Unnamed: 0,expr,eval,length,tree_depth,ps_depth,max_dep,n_op,result,max_res,pred
4638,1/(8*(4*8)),SL,11,4,2,6,3,1,256,True
4657,1/(3*8*9),SL,9,4,1,5,3,1,216,True
4661,3/(5*4*8),SL,9,4,1,5,3,1,160,True
4713,3/(8*(7+7)),SL,11,4,2,6,3,1,112,True
4714,7/(7*(3*9)),SL,11,4,2,6,3,1,189,True
...,...,...,...,...,...,...,...,...,...,...
26583,(9/(2*8-8)-6*7*(7/3))*7+8*6,SL,27,7,2,11,10,48,126,True
26585,5*5*9*(0*8)*3+(4/(8+8)+2*6),SL,27,6,2,9,10,13,225,True
26599,((9-0)/(3*5/9)+1*6)/(8*3*7*4),SL,29,6,2,14,10,1,672,True
26600,5/(6*(9*4-(5-6+6)*(8/5-(4+3)))),SL,31,7,4,26,10,1,216,True


Unnamed: 0,expr,eval,length,tree_depth,ps_depth,max_dep,n_op,result,max_res,pred
23195,0+(4/1+9/9)/((8*4-7)*(1*5)),SL,27,6,2,10,9,1,125,True
8168,9/(8*(3*5)/3),SL,13,5,2,9,4,1,120,True
17249,7*6-(3+5/(7*(8*5)))/6,SL,21,7,3,16,7,41,280,True
26059,9+(9-0-1+5*((0-0)*(7*(8*2)-3))),SL,31,8,4,22,10,17,112,True
23586,1*9+3-4/5/5*0/(7*9*3),SL,21,6,1,8,9,12,189,True
5383,7*4*4/8,SL,7,4,0,2,3,14,112,True
20273,(1+7*6)/((6+5)*2/(1/5)*7),SL,25,6,2,15,8,1,154,True
19635,5+1+(0-(0+9*7*(2+8/1))),SL,23,7,3,16,8,6,630,True
17563,((5+7)/(5*(4+3)*3)+1)/6,SL,23,7,3,18,7,1,105,True
20600,5*(3/2*8/((4+8)*9-6)/4),SL,23,7,3,19,8,5,108,True


Unnamed: 0,expr,eval,length,tree_depth,ps_depth,max_dep,n_op,result,max_res,pred
14572,5*(8*6)/4/(0+0+1),SL,17,5,1,6,6,60,240,False
17146,(4+6+(6+9)*7/4*5)/8,SL,19,7,2,13,7,19,145,False
26406,(9-0)*(9+2+4*4*3)/((9-4+8)*6),SL,29,6,2,12,10,7,531,False
7700,7*(7*3)/(2*6),SL,13,4,1,6,4,13,147,False
23254,(4+(0-0)+0/2)*((4+7)*4)/(6*7),SL,29,6,2,10,9,5,176,False
4823,9*(7*3)/4,SL,9,4,1,6,3,48,189,False
17131,5+6*(4*5)/3+9/(8/4),SL,19,6,1,10,7,50,120,False
17438,4*((6+7)*4+5)/(9/7)/5,SL,21,7,2,12,7,23,228,False
26569,5*9*3/(9*4-1/2-(8-8)*(3-0)),SL,27,5,2,12,10,4,135,False
14472,5*8*(2+(8-4))/(2*2),SL,19,5,2,10,6,60,240,False
