## TransNeo/AlphaNeo Ranker
- tell which pair of tables are functionally/logically closer on execution chain
- Stage: Cambrian
- Version: Spriggina

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import copy
import random
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [3]:
import tyrell.spec as S
from tyrell.decider import Example

# Morpheus Version
from MorpheusInterpreter import *
from ProgramSpace import *

In [4]:
torch.__version__

'1.0.0'

In [5]:
class RankerDataset(Dataset):
    def __init__(self, p_config=None, p_dataset=None, p_interpreter=None):
        self.interpreter = p_interpreter
        self.dataset = p_dataset
        self.config = p_config
        
        # flatten the dataset in the form [str_example]
        # we don't care about what the prog is
        self.str_examples = [
            self.dataset[dkey][i][1]
            for dkey in self.dataset.keys()
            for i in range(len(self.dataset[dkey]))
        ]
        self.n_exp = len(self.str_examples)
        self.n_neg = self.config["ranker"]["n_neg"]
        self.n_row = self.config["ranker"]["n_row"]
        self.n_col = self.config["ranker"]["n_col"]
        
        # == only works for chain: 1 input, 1 output ==
        # then compute the abstraction map of the variables
        print("# Parsing Dataset...")
        self.LMTX = np.full(
            (self.n_exp, self.n_row, self.n_col),
            self.interpreter.CAMB_DICT["<PAD>"],
            dtype=int,
        )
        self.RMTX = np.full(
            (self.n_exp, self.n_row, self.n_col),
            self.interpreter.CAMB_DICT["<PAD>"],
            dtype=int,
        )
        for i in range(self.n_exp):
            print("\r## in:{}/{}".format(i,self.n_exp),end="")
            d_example = self.str_examples[i]
            var_input = self.interpreter.load_data_into_var(d_example.input[0]) ## == chain == ##
            var_output= self.interpreter.load_data_into_var(d_example.output)
            map_input = self.interpreter.camb_get_abs(var_input)
            map_output= self.interpreter.camb_get_abs(var_output)
            # store into matrices
            self.LMTX[i,:,:] = map_input
            self.RMTX[i,:,:] = map_output
        print()
        
    def __len__(self):
        return self.n_exp
    
    '''
    should always use batch_size=1 so as to ensure the ratio of negative examples
    '''
    def __getitem__(self, p_ind):
        # generate random tables as negative examples
        tmpWMTX = np.full(
            (self.n_neg, self.n_row, self.n_col),
            self.interpreter.CAMB_DICT["<PAD>"],
            dtype=int,
        )
        for i in range(self.n_neg):
            var_table = self.interpreter.random_table()
            map_table = self.interpreter.camb_get_abs(var_table)
            tmpWMTX[i,:,:] = map_table
        
        # (left, right, negatives)
        return (
            self.LMTX[p_ind,:,:], # (map_r, map_c)
            self.RMTX[p_ind,:,:], # (map_r, map_c)
            tmpWMTX, # (n_neg, map_r, map_c)
        )

In [6]:
class ValueEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValueEncoder, self).__init__()
        self.config = p_config
        
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv2d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
        
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"],
        )
        
    def forward(self, bp_map):
        # batched maps, (B, map_r, map_c)
        # in this version, every value only contains 1 map
        B = bp_map.shape[0]
        
        # (B, map_r, map_c, val_embd_dim) -> (B, val_embd_dim, map_r, map_c)
        d_embd = self.embedding(bp_map).permute(0,3,1,2)
        
        # (B, n_kernel, map_r, 1)
        d_conv = F.relu(self.conv(d_embd))
        
        # (B, n_kernel)
        d_pool = self.pool(d_conv).view(B,self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        d_out = torch.sigmoid(
            self.fc(d_pool)
        )
        
        return d_out
    

In [7]:
class Ranker(nn.Module):
    def __init__(self, p_config=None):
        super(Ranker, self).__init__()
        self.config = p_config
        self.value_encoder = ValueEncoder(p_config=p_config)
        self.fc0 = nn.Linear(
            self.config["embd_dim"]*2,
            2048,
        )
        self.fc1 = nn.Linear(
            2048,
            2,
        )
        
    def forward(self, pL, pR, pN):
        # pL/pR: (B=1, map_r, map_c)
        # pN: (B=1, n_neg, map_r, map_c)
        vL = self.value_encoder(pL) # (B, embd_dim)
        vR = self.value_encoder(pR) # (B, embd_dim)
        vN = self.value_encoder(
            pN.view(
                self.config["ranker"]["n_neg"],
                self.config["ranker"]["n_row"],
                self.config["ranker"]["n_col"],
            ),
        ) # (B=n_neg, embd_dim)
        
        vLL = vL.expand(
            self.config["ranker"]["n_neg"]+1,
            self.config["embd_dim"],
        ) # (n_neg+1, embd_dim)
        vRN = torch.cat(
            [vR,vN],
            dim=0
        ) # (n_neg+1, embd_dim)
        
        vII = torch.cat([vLL,vRN],dim=1) # (n_neg+1, embd_dim * 2)
        
        # == Notice: don't do any activation at the last layer ==
        vOO = self.fc1(
            F.relu(
                self.fc0(
                    vII
                )
            )
        )# (n_neg+1, 2)
        
        return vOO
        

In [8]:
def RankerTrainer(p_config, p_model, p_ld_train_data, p_optim, p_lossfn):
    for d_ep in range(p_config["ranker"]["n_ep"]):
        epoch_loss_list = []
        
        for batch_idx, (d_left, d_right, d_negative) in enumerate(p_ld_train_data):
            p_model.train()
            
            if use_cuda:
                td_left = Variable(d_left).cuda() # (B=1, map_r, map_c)
                td_right = Variable(d_right).cuda() # (B=1, map_r, map_c)
                td_negative = Variable(d_negative).cuda() # (B=1, n_neg, map_r, map_c)
                td_label = Variable(torch.tensor(
                    [1]+[0 for _ in range(p_config["ranker"]["n_neg"])]
                )).cuda()
            else:
                td_left = Variable(d_left) # (B=1, map_r, map_c)
                td_right = Variable(d_right) # (B=1, map_r, map_c)
                td_negative = Variable(d_negative) # (B=1, n_neg, map_r, map_c)
                td_label = Variable(torch.tensor(
                    [1]+[0 for _ in range(p_config["ranker"]["n_neg"])]
                ))
                
            # (n_neg+1, 2)
            d_output = p_model(td_left, td_right, td_negative)
            p_optim.zero_grad()
            d_loss = p_lossfn(
                F.log_softmax(d_output, dim=1),
                td_label,
            )
            epoch_loss_list.append(d_loss.cpu().data.numpy())
            d_loss.backward()
            p_optim.step()
            
            print("\r# EP:{}, B:{}, ep.loss:{:.2f}".format(
                d_ep, batch_idx, sum(epoch_loss_list),
            ),end="")
        
        # end of epoch print a new line
        print()

In [9]:
m_interpreter = MorpheusInterpreter()
m_spec = S.parse_file('./example/camb3.tyrell')

m_config = {
    "val":{
        "vocab_size": len(m_interpreter.CAMB_LIST),
        "embd_dim": 16, # embedding dim of CAMB abstract token
        "conv_n_kernels": 512,
        "conv_kernel_size": (1,m_interpreter.CAMB_NCOL), 
        "pool_kernel_size": (m_interpreter.CAMB_NROW,1), 
        "IDX_PAD": 0,
    },
    "embd_dim": 128,
    "ranker":{
        "data_path": "./0716MDsize1.pkl",
        "n_row": m_interpreter.CAMB_NROW,
        "n_col": m_interpreter.CAMB_NCOL,
        "n_neg": 10,
        "n_ep": 1000000,
    },
}



# load the data and dataset
with open(m_config["ranker"]["data_path"],"rb") as f:
    m_data = pickle.load(f)
# re-distribute the dataset
m_train_data = {}
m_test_data = {}
for dkey in m_data.keys():
    tmp_one = int(len(m_data[dkey])/10)
    tmp_list = list(range(len(m_data[dkey])))
    random.shuffle(tmp_list)
    # tmp_tr = [m_data[dkey][i] for i in tmp_list[:9*tmp_one]]
    # tmp_te = [m_data[dkey][i] for i in tmp_list[9*tmp_one:]]
    tmp_tr = [m_data[dkey][i] for i in tmp_list[:20]]
    tmp_te = [m_data[dkey][i] for i in tmp_list[20:40]]
    m_train_data[dkey] = tmp_tr
    m_test_data[dkey] = tmp_te
    
dt_train = RankerDataset(
    p_config=m_config, 
    p_dataset=m_train_data, 
    p_interpreter=m_interpreter,
)
dt_test  = RankerDataset(
    p_config=m_config,
    p_dataset=m_test_data,
    p_interpreter=m_interpreter,
)
ld_train = DataLoader(dataset=dt_train, batch_size=1, shuffle=True)
ld_test  = DataLoader(dataset=dt_test, batch_size=1, shuffle=True)



m_ranker = Ranker(p_config=m_config)
if use_cuda:
    m_ranker = m_ranker.cuda()
optimizer = torch.optim.Adam(list(m_ranker.parameters()))
lossfn = nn.NLLLoss()

# Parsing Dataset...
## in:1565/1566
# Parsing Dataset...
## in:1551/1552


In [10]:
RankerTrainer(m_config, m_ranker, ld_train, optimizer, lossfn)

# EP:0, B:1565, ep.loss:338.93
# EP:1, B:1565, ep.loss:216.35
# EP:2, B:1565, ep.loss:178.31
# EP:3, B:1565, ep.loss:154.35
# EP:4, B:1565, ep.loss:125.87
# EP:5, B:1565, ep.loss:95.03
# EP:6, B:1565, ep.loss:75.62
# EP:7, B:1565, ep.loss:59.18
# EP:8, B:1565, ep.loss:45.34
# EP:9, B:1565, ep.loss:45.58
# EP:10, B:1565, ep.loss:42.52
# EP:11, B:1565, ep.loss:38.45
# EP:12, B:1565, ep.loss:32.17
# EP:13, B:1565, ep.loss:34.59
# EP:14, B:1565, ep.loss:29.85
# EP:15, B:1565, ep.loss:27.81
# EP:16, B:1565, ep.loss:23.49
# EP:17, B:1565, ep.loss:26.33
# EP:18, B:1565, ep.loss:23.83
# EP:19, B:1565, ep.loss:19.23
# EP:20, B:1565, ep.loss:21.37
# EP:21, B:1565, ep.loss:18.68
# EP:22, B:1565, ep.loss:20.90
# EP:23, B:1565, ep.loss:19.03
# EP:24, B:1565, ep.loss:18.06
# EP:25, B:1565, ep.loss:21.64
# EP:26, B:1565, ep.loss:14.43
# EP:27, B:1565, ep.loss:14.80
# EP:28, B:1565, ep.loss:15.83
# EP:29, B:1565, ep.loss:17.69
# EP:30, B:1565, ep.loss:15.79
# EP:31, B:1565, ep.loss:11.73
# EP:32, B:15

KeyboardInterrupt: 