In [None]:
import sys
import re
import torch
import torchtext
import os
import random
from pprint import pprint
from collections import Counter
import pandas as pd

sys.path.insert(0, '../preprocess/')
sys.path.insert(0, '../../coarse2fine.git/src')

from sketch_generation import Sketch
from tree import SketchRepresentation
import table
import tree

In [None]:
UNK_WORD = '<unk>'
UNK = 0
PAD_WORD = '<blank>'
PAD = 1
BOS_WORD = '<s>'
BOS = 2
EOS_WORD = '</s>'
EOS = 3
SKP_WORD = '<sk>'
SKP = 4
RIG_WORD = '<]>'
RIG = 5
LFT_WORD = '<[>'
LFT = 6
SPECIAL_TOKEN_LIST = [UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD, SKP_WORD, RIG_WORD, LFT_WORD]

## Explore TableDataset fields

In [None]:
def print_field(data, field):
    print(field)
    try:
        print(" ".join(getattr(data, field)))
    except:
        print("N/A")
    print("-"*64)

In [None]:
dataset = conala

file = '../../coarse2fine.git/data_model/%s/train.pt' % dataset
fields = sorted(list(table.IO.TableDataset.get_fields().keys()))

data = torch.load(file)
data.fields = table.IO.TableDataset.get_fields()

print(len(data))

In [None]:
i = random.randint(0, len(data.examples))

# for f in fields:
#     print_field(data.examples[i], f)

ex = [data.examples[i] for i in range(len(data.examples)) if any(['FUNC#%d' % j in data.examples[i].lay for j in range(10)])]
i = random.randint(0, len(ex))

print(i)
print_field(ex[i], 'src')
print_field(ex[i], 'tgt_loss')
print_field(ex[i], 'lay')

### Their sketch

In [None]:
file = '../../coarse2fine.git/data_model/%s/train.json' % dataset
df = pd.read_json(file, lines=True)
df = [df.iloc[i] for i in range(len(df)) if True]

In [None]:
i = random.randint(0, len(df))
# i = 4654

print(i)
print("src     ", df[i]['src'])
print("token   ", df[i]['token'])

t = tree.SketchRepresentation((df[i]['token'], df[i]['type']))

lay = t.layout(add_skip=False)
lay_skip = t.layout(add_skip=True)
tgt = t.target()

print("lay     ", lay)
print("lay_skip", lay_skip)
print("tgt     ", tgt)

print("\n", "-"*64, "\n")

r_list = []
src_set = set(df[i]['src'])

for tk_tgt in t.target():
    if tk_tgt in src_set:
        print("ok\t", tk_tgt)
        r_list.append(tk_tgt)
    else:
        print("unk\t", tk_tgt)
        r_list.append(UNK_WORD)
        
cttgt = table.IO.TableDataset._read_annotated_file(None,None,df, 'copy_to_tgt', None)
src= table.IO.TableDataset._read_annotated_file(None,None,df, 'src', None)
list(cttgt) == list(src)

### Vocab

In [None]:
vocab = []

src_field = torchtext.data.Field(pad_token=PAD_WORD, include_lengths=True)

counter = Counter()
sources = [getattr(data, 'src')]

for data in sources:
    for x in data:
        counter.update(x)

counter.most_common(20)
# print(vocab)

In [None]:
def get_parent_index(tk_list):
    stack = [0]
    r_list = []
    
    for i, tk in enumerate(tk_list):
        print(i, tk)
        r_list.append(stack[-1])
        
        if tk.startswith('('):
            # +1: because the parent of the top level is 0
            stack.append(i + 1)
        elif tk == ')':
            stack.pop()
            
    # for EOS (</s>)
    r_list.append(0)
    return r_list


# get_parent_index('x = func ( a + func2 ( x + y ) )'.split())

In [None]:
def get_lay_index(lay_skip):
    # with a <s> token at the first position
    r_list = [0]
    k = 0
    for tk in lay_skip:
        if tk in (SKP_WORD, RIG_WORD):
            r_list.append(0)
        else:
            r_list.append(k)
            k += 1
    return r_list

get_lay_index('NAME = FUNC#1 ( NAME )'.split())

In [None]:
file = '../../coarse2fine.git/data_model/%s/vocab.pt' % dataset
fields = table.IO.TableDataset.load_fields(torch.load(file))

len(sorted(list(fields['tgt'].vocab.stoi.keys())))

## Model

In [None]:
train_data = torch.load('../../coarse2fine.git/data_model/%s/train.pt' % dataset)
fields = table.IO.TableDataset.load_fields(torch.load('../../coarse2fine.git/data_model/%s/vocab.pt' % dataset))
train_data.fields = dict([(k, f) for (k, f) in fields.items() if k in train_data.examples[0].__dict__])
train_iter = table.IO.OrderedIterator(train_data, 1, device=-1, repeat=False)

fields = train_data.fields
_i = random.randint(0, len(train_data))

for i, batch in enumerate(train_iter):
    if i == _i: break
        
print(batch.indices)
        
print('> src')
for i in batch.src[0].data:
    print(fields['src'].vocab.itos[i[0]], end=" ")
print('\n')

print('> lay')
for i in batch.lay[0].data:
    print(fields['lay'].vocab.itos[i[0]], end=" ")
print('\n')

# print(batch.copy_to_ext)

attr = sorted([x for x in dir(batch) if x[:2] != '__' and x not in ['src', 'lay']])

for a in attr:
    print('>', a)
    if hasattr(getattr(batch, a), 'data') and a in fields.keys() and hasattr(fields[a], 'vocab'):
        for i in getattr(batch, a).data:
            print(fields[a].vocab.itos[i[0]], end=" ")
    print('\n')

In [None]:
# MAKE SURE FUNC# APPEARS ONLY IN LAY VOCABS
for a in attr:
    if a in fields.keys() and hasattr(fields[a], 'vocab'):
        x = fields[a].vocab.stoi.keys()
        print(a)
        for i in range(10):
            if a in ['lay', 'lay_e']:
                assert 'FUNC#%d' % i in x, str(i)
            else:
                assert 'FUNC#%d' % i not in x, str(i)
                
# pred   <s> t1 t2 t2 t5 </s>
# target <s> t1 t2 t3 t4 </s>
# mask    1   0  0  0  0  1
# p=t     1   1  1  0  0  1

In [None]:
mask = torch.ByteTensor([1,0,0,0,0,0,1])
pred =     torch.Tensor([0,1,2,3,3,4,4])
goal =     torch.Tensor([0,1,1,2,3,4,5])

x = pred.eq(goal).masked_select(mask.ne(1))
y = mask.ne(1).sum()

print(x)
# 1 1 1 0 1 1 0