In [23]:
import sys
import os
root_dir = '/home/qing/Desktop/Closed-Loop-Learning/HINT/'
os.chdir(root_dir)
from train import *
from data.domain import OPERATORS
import pandas as pd 
from collections import OrderedDict
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

from dataset import HINT_collate

In [13]:
sys.argv = []
args = parse_args()
args.input = 'symbol'
args.model = 'TRAN.relative_universal'
args.resume = f'./models/{args.input}.{args.model}/model_100000.p'
args.nhead = 8
args.enc_layers = 6
args.dec_layers = 1
args.hid_dim = 512
args.max_rel_pos = 15
args

Namespace(batch_size=128, cos_sim_margin=0.2, curriculum='no', dec_layers=1, dropout=0.5, early_stop=None, emb_dim=128, enc_layers=6, epochs=10, epochs_eval=1, fewshot=None, grad_clip=5.0, hid_dim=512, input='symbol', iterations=None, iterations_eval=None, layers=1, lr=0.001, lr_scheduler='constant', main_dataset_ratio=0, max_op_train=None, max_rel_pos=15, model='TRAN.relative_universal', nhead=8, output_attentions=False, output_dir='outputs/', perception_pretrain='data/perception_pretrain/model.pth.tar_78.2_match', pos_emb_type='sin', result_encoding='decimal', resume='./models/symbol.TRAN.relative_universal/model_100000.p', save_model=False, seed=0, train_size=None, wandb='HINT', warmup_steps=100)

In [16]:
train_set = HINT('train', input=args.input, fewshot=args.fewshot, 
                    n_sample=args.train_size, max_op=args.max_op_train,
                    main_dataset_ratio=args.main_dataset_ratio)
val_set = HINT('val', input=args.input, fewshot=args.fewshot)
test_set = HINT('test', input=args.input, fewshot=args.fewshot)
print('train:', len(train_set), 'val:', len(val_set), 'test:', len(test_set))

args.train_set = train_set
args.val_set = val_set
args.test_set = test_set

train: 998000 val: 4698 test: 46620


In [17]:
args.res_enc = ResultEncoding(args.result_encoding)
model = make_model(args)
if args.resume:
    print('Load checkpoint from ' + args.resume)
    ckpt = torch.load(args.resume)
    model.load_state_dict(ckpt['model_state_dict'])
model.to(DEVICE)

n_params = sum(p.numel() for p in model.parameters())
print('Num params:', n_params)

Load checkpoint from ./models/symbol.TRAN.relative_universal/model_100000.p
Num params: 8019469


In [18]:
dataset = args.test_set
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64,
                                            shuffle=False, num_workers=4, collate_fn=HINT_collate)

model.eval() 
res_all = []
res_pred_all = []

expr_all = []
expr_pred_all = []

dep_all = []
dep_pred_all = []

metrics = OrderedDict()

with torch.no_grad():
    for sample in tqdm(dataloader):
        if args.input == 'image':
            src = sample['img_seq']
        elif args.input == 'symbol':
            src = torch.tensor([x for s in sample['sentence'] for x in s])
        res = sample['res']
        if args.result_encoding == 'sin':
            tgt = res.unsqueeze(1)
        else:
            tgt = torch.tensor(args.res_enc.res2seq_batch(res.numpy()))
        expr = sample['expr']
        dep = sample['head']
        src_len = sample['len']

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model(src, tgt[:, :-1], src_len)
        pred = torch.argmax(output, -1).detach().cpu().numpy()
        if args.result_encoding == 'sin':
            res_pred = pred
        else:
            res_pred = args.res_enc.seq2res_batch(pred)
        res_pred_all.append(res_pred)
        res_all.append(res)

        # expr_pred_all.extend(expr_preds)
        expr_all.extend(expr)
        # dep_pred_all.extend(dep_preds)
        dep_all.extend(dep)

res_pred_all = np.concatenate(res_pred_all, axis=0)
res_all = np.concatenate(res_all, axis=0)
result_acc = (res_pred_all == res_all).mean()
metrics['result_acc/avg'] = result_acc

tracked_attrs = ['length', 'symbol', 'digit', 'result', 'eval', 'tree_depth', 'ps_depth', 'max_dep']
for attr in tracked_attrs:
    # print(f"result accuracy by {attr}:")
    attr2ids = getattr(dataloader.dataset, f'{attr}2ids')
    for k, ids in sorted(attr2ids.items()):
        res = res_all[ids]
        res_pred = res_pred_all[ids]
        res_acc = (res == res_pred).mean() if ids else 0.
        k = 'div' if k == '/' else k
        metrics[f'result_acc/{attr}/{k}'] = res_acc
        # print(k, "(%2d%%)"%(100*len(ids)//len(dataloader.dataset)), "%5.2f"%(100 * res_acc))

100%|█████████████████████████████████████████| 729/729 [00:22<00:00, 31.70it/s]


In [76]:
model.model.output_attentions = True
sample = dataset[3000]
print(sample['expr'], sample['res'])
sample = HINT_collate([sample])
if args.input == 'image':
    src = sample['img_seq']
elif args.input == 'symbol':
    src = torch.tensor([x for s in sample['sentence'] for x in s])
res = sample['res']
if args.result_encoding == 'sin':
    tgt = res.unsqueeze(1)
else:
    tgt = torch.tensor(args.res_enc.res2seq_batch(res.numpy()))
expr = sample['expr']
dep = sample['head']
src_len = sample['len']

src = src.to(DEVICE)
tgt = tgt.to(DEVICE)

output, attentions = model(src, tgt[:, :-1], src_len)
pred = torch.argmax(output, -1).detach().cpu().numpy()
if args.result_encoding == 'sin':
    res_pred = pred
else:
    res_pred = args.res_enc.seq2res_batch(pred)
print('pred: ', res_pred[0])

(0/8+4)*6 24
pred:  24


In [84]:
from bertviz import model_view, head_view
# utils.logging.set_verbosity_error()  # Remove line to see warnings

encoder_text = ['<SOS>'] + list(expr[0]) + ['<EOS>']
decoder_text = [args.res_enc.vocab[x] for x in tgt[0, :-1]]

encoder_attentions, decoder_attentions, cross_attentions = attentions
# model view require enc_layers == dec_layers, so we just duplicate
if len(decoder_attentions) < len(encoder_attentions):
    decoder_attentions = decoder_attentions * len(encoder_attentions)
    cross_attentions = cross_attentions * len(encoder_attentions)

head_view(
    encoder_attention=encoder_attentions,
    decoder_attention=decoder_attentions,
    cross_attention=cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text,
)

model_view(
    encoder_attention=encoder_attentions,
    decoder_attention=decoder_attentions,
    cross_attention=cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text,
)
    

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
html_head_view = head_view(
    encoder_attention=encoder_attentions,
    decoder_attention=decoder_attentions,
    cross_attention=cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text,
    html_action='return'
)

html_model_view = model_view(
    encoder_attention=encoder_attentions,
    decoder_attention=decoder_attentions,
    cross_attention=cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text,
    html_action='return'
)

with open("results/vis_head_view.html", 'w') as file:
    file.write(html_head_view.data)
    
with open("results/vis_model_view.html", 'w') as file:
    file.write(html_model_view.data)

In [8]:
def compute_max_dep(heads):
    return max([0] + [abs(i-h) for i, h in enumerate(heads) if h != -1])

from functools import lru_cache
def compute_tree_depth(head):
    @lru_cache
    def depth(i):
        """The depth of node i."""
        if head[i] == -1:
            return 1
        return depth(head[i]) + 1

    return max(depth(i) for i in range(len(head)))

lps = '('
rps = ')'
def compute_ps_depth(expr):
    depth = 0
    max_depth = 0
    for x in expr:
        if x == lps:
            c = 1
        elif x == rps:
            c = -1
        else:
            c = 0
        depth += c
        if depth > max_depth:
            max_depth = depth
    return max_depth

def compute_n_op(expr):
    return len([1 for x in expr if x in OPERATORS])

tracked_attrs = ['eval', 'length', 'tree_depth', 'ps_depth', 'max_dep', 'n_op', 'result']
attr2data = {}
for attr in tracked_attrs:
    data = []
    for sample in dataset:
        if attr == 'length':
            d = len(sample['expr'])
        elif attr == 'result':
            d = sample['res']
        elif attr == 'tree_depth':
            d = compute_tree_depth(sample['head'])
        elif attr == 'max_dep':
            d = compute_max_dep(sample['head'])
        elif attr == 'ps_depth':
            d = compute_ps_depth(sample['expr'])
        elif attr == 'n_op':
            d = compute_n_op(sample['expr'])
        elif attr == 'eval':
            d = sample['eval']
        else:
            assert False
        data.append(d)
    attr2data[attr] = data
attr2data['pred'] = res_pred_all == res_all

In [9]:
df = pd.DataFrame(attr2data)
df

Unnamed: 0,eval,length,tree_depth,ps_depth,max_dep,n_op,result,pred
0,I,1,1,0,0,0,7,True
1,I,1,1,0,0,0,1,True
2,I,1,1,0,0,0,0,True
3,I,1,1,0,0,0,6,True
4,I,1,1,0,0,0,4,True
...,...,...,...,...,...,...,...,...
46615,LL,67,10,6,44,20,3,False
46616,LL,61,8,4,30,20,1638,False
46617,LL,61,9,4,47,20,1,True
46618,LL,55,9,2,28,20,5,False


In [120]:
df_filter = df[df['eval'] != 'I']
df_filter.groupby(['length', 'n_op', 'result']).get_group((35, 15, 10)).groupby(['max_dep', 'tree_depth'])['pred'].mean()

KeyError: (35, 15, 10)

In [7]:
error_ids = [i for i, (x, y) in enumerate(zip(res_all, res_pred_all)) if x != y ]

In [21]:
symbol_images_dir = root_dir + 'data/symbol_images/'
def render_img(img_paths):
    images = [Image.open(symbol_images_dir + x) for x in img_paths]
    widths, heights = zip(*(i.size for i in images))

    total_width = sum(widths)
    max_height = max(heights)

    new_im = Image.new('L', (total_width, max_height))

    x_offset = 0
    for im in images:
        new_im.paste(im, (x_offset,0))
        x_offset += im.size[0]
    return new_im

def show_sample(sample, show_image=True):
    print(sample['expr'], len(sample['expr']))
    if show_image:
        img = render_img(sample['img_paths'])
        display(img)

In [25]:
for i in range(len(dataset)):
    sample = dataset.dataset[i]
#     if sample['eval'] not in ['LL']: continue
    if sample['res'] < 100: continue
    print(sample['expr'], len(sample['expr']), max(sample['res_all']))
    print(res_all[i], res_pred_all[i])
    input()

8*(8*9) 7 576
576 64


 


6*(9+9) 7 108
108 98


 


6*(9+8) 7 102
102 92


 


9*(2*7) 7 126
126 36


 


8*8*2 5 128
128 64


 


(9+5)*8 7 112
112 96


 


8*7*6 5 336
336 56


 


9*4*9 5 324
324 36


 


9*(5*6) 7 270
270 90


 


8*3*6 5 144
144 84


 


8*(6+8) 7 112
112 96


 


7*5*7 5 245
245 35


 


7*5*8 5 280
280 70


 


(8+9)*9 7 153
153 63


 


7*(3*5) 7 105
105 95


 


2*6*9 5 108
108 78


 


9*4*6 5 216
216 96


 


5*8*9 5 360
360 90


 


7*2*8 5 112
112 96


 


9*8*8 5 576
576 64


 


6*(6*8) 7 288
288 48


 


6*(5*5) 7 150
150 90


 


6*7*8 5 336
336 96


 


5*7*9 5 315
315 95


 


(6+9)*8 7 120
120 90


 


6*4*8 5 192
192 96


 


(8+9)*8 7 136
136 96


 


(6+9)*7 7 105
105 95


 


9*(5*3) 7 135
135 15


 


9*(9+9) 7 162
162 72


 


(6+7)*8 7 104
104 94


 


6*9*6 5 324
324 54


 


5*9*8 5 360
360 80


 


8*7*7 5 392
392 56


 


7*(6*8) 7 336
336 56


 


4*(9*4) 7 144
144 36


 


8*(7+8) 7 120
120 90


 


4*(4*7) 7 112
112 96


 


4*4*7 5 112
112 96


 


7*9*7 5 441
441 63


 


6*(6*7) 7 252
252 42


 


9*(4*4) 7 144
144 54


 


7*4*5 5 140
140 100


 


9*7*8 5 504
504 64


 


9*(6*9) 7 486
486 54


 


3*9*9 5 243
243 81


 


7*(7*9) 7 441
441 63


 


9*4*7 5 252
252 72


 


6*(9*6) 7 324
324 54


 


6*(9*3) 7 162
162 72


 


8*(7*7) 7 392
392 96


 


8*(9*5) 7 360
360 90


 


(9+9)*9 7 162
162 72


 


9*(7*5) 7 315
315 35


 


7*5*3 5 105
105 95


 


4*5*6 5 120
120 90


 


8*(8+9) 7 136
136 56


 


4*9*6 5 216
216 96


 


3*8*9 5 216
216 36


 


6*4*7 5 168
168 98


 


8*(6+9) 7 120
120 90


 


9*(4*6) 7 216
216 36


 


9*(6*8) 7 432
432 72


 


3*9*8 5 216
216 64


 


6*(7*8) 7 336
336 96


 


6*(5*7) 7 210
210 90


 


8*(7*4) 7 224
224 56


 


8*8*7 5 448
448 56


 


8*(9+8) 7 136
136 56


 


8*(8+5) 7 104
104 94


 


4*8*8 5 256
256 96


 


(8+4)*9 7 108
108 72


 


9*(7*4) 7 252
252 72


 


9*(6*5) 7 270
270 90


KeyboardInterrupt: Interrupted by user