## Component/Abstraction Representation Learning using TransE
- Stage: Cambrian
- Version: Charniodiscus

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import copy
import random
import pickle
import numpy as np

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
print("use_cuda: {}".format(use_cuda))

use_cuda: False


In [4]:
# Morpheus Version
from utils_morpheus import *
from ProgramSpace import *

In [5]:
torch.__version__

'1.0.0'

In [6]:
class MorphGenDataset(Dataset):
    def __init__(self, p_spec=None, p_generator=None, p_interpreter=None, p_dps=None):
        self.spec = p_spec
        self.generator = p_generator
        self.interpreter = p_interpreter
        self.dps = p_dps
        
        # construct a shell list/dict so that every function call (shell) has id
        self.shell_list = self.dps.get_neighboring_shells()
        self.shell_dict = {
            self.shell_list[i]:i for i in range(len(self.shell_list))
        }
        
    def __len__(self):
        return 1024
    
    def __getitem__(self, p_ind):
        # basically ignore the p_ind parameter and generate randomly
        # returns a pair of abstraction maps
        # (tmp_in, tmp_out)
        # which is (n_maps, CAMB_NROW, CAMB_NCOL)
        # currently assuming single input single output
        while True:
            tmp_input = self.interpreter.random_table()
            tmp_prog, tmp_example = self.generator.generate(
                fixed_depth=2, # this should be fixed to 2 since we learn size 1
                example=Example(input=[tmp_input], output=None),
            )
            if tmp_prog is not None and tmp_prog.is_apply():
                break
        
        tmp_in = camb_get_abs(tmp_example.input[0])
        tmp_out= camb_get_abs(tmp_example.output)
        
        tmp_func = self.shell_dict[
            (
                self.dps.prod_list.index(tmp_prog.production),
                tuple(
                    [self.dps.node_list.index(tmp_prog.args[i])
                    for i in range(len(tmp_prog.args))]
                ),
            )
        ]
        
        # =========================================================
        # you should also produce a negative sample here and return
        if random.random()<0.5:
            rep_ind = 0 # means replacing input
        else:
            rep_ind = 1 # means replacing output
            
        while True:
            rep_val = self.interpreter.random_table()
            rep_enc = camb_get_abs(rep_val)
            
            # don't be the same
            if rep_ind==0:
                # replace input
                if np.array_equal(rep_enc,tmp_in):
                    continue
                else:
                    # not safe here
                    # different inputs can result in same output
                    # TODO: check and refine later
                    break
            else:
                # replace output
                if np.array_equal(rep_enc,tmp_out):
                    continue
                else:
                    # safe here
                    # same input will not result in different outputs
                    break
        
        if rep_ind==0:
            # replace input
            return (tmp_in, tmp_func, tmp_out,
                    rep_enc, tmp_func, tmp_out)
        else:
            # replace output
            return (tmp_in, tmp_func, tmp_out,
                    tmp_in, tmp_func, rep_enc)
        

In [7]:
class ListModule(object):
    def __init__(self, module, prefix, *args):
        self.module = module
        self.prefix = prefix
        self.num_module = 0
        for new_module in args:
            self.append(new_module)
    
    def append(self, new_module):
        if not isinstance(new_module, nn.Module):
            raise ValueError('Not a Module')
        else:
            self.module.add_module(self.prefix + str(self.num_module), new_module)
            self.num_module += 1
            
    def __len__(self):
        return self.num_module
    
    def __getitem__(self, i):
        if i<0 or i>=self.num_module:
            raise IndexError('Out of bound')
        return getattr(self.module, self.prefix+str(i))

In [8]:
class ValEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValEncoder, self).__init__()
        self.config = p_config
        
        # first you need to define different abstraction tokens before using them
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv3d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        self.pool = nn.MaxPool3d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
            
        # temporarily use a single fully connected layer, as a test
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"], # this one is different from abs_embd_dim
        )
        
    def forward(self, bp_maps):
        # bp_maps: batch of maps
        # in a numpy type, (B, n_maps, map_r, map_c)
        # n_maps ordered as: a-map, b-map, ...
        # prepared by the dataset and should support batch operation
        # iterate through every map, inside every map, do batchly
        
        B = bp_maps.shape[0]
        
        # (B, n_maps, map_r, map_c, abs_embd_dim)
        m_embd = self.embedding(bp_maps)
        
        # (B, abs_embd_dim, n_maps, map_r, map_c)
        # fit the conv shape
        m_embd = m_embd.permute(0,4,1,2,3)
        
        # (B, conv_n_kernels, n_maps, map_r, 1)
        m_conv = self.conv(m_embd)
        
        # (B, conv_n_kernels, 1, 1, 1) -> (B, conv_n_kernels)
        # use relu here to filter out irrelevant info
        m_pool = F.relu(
            self.pool(m_conv)
        ).view(B, self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        # use sigmoid here to preserve negative info
        m_out = F.sigmoid(self.fc(m_pool))
        
        return m_out

In [9]:
class MorphTransE(nn.Module):
    def __init__(self, p_config=None):
        super(MorphTransE, self).__init__()
        self.config = p_config
        
        # store the function embedding here
        self.vocab_size = self.config["func"]["vocab_size"]
        self.embd_dim = self.config["embd_dim"]
        
        self.function_embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            # no definition of padding index
        )
        self.value_encoder = ValEncoder(p_config=p_config)
        
    def forward(self, batch_triplets):
        # batch_triplets: (
        #   (B, n_maps, map_r, map_c),  --> in
        #   (B, ),
        #   (B, n_maps, map_r, map_c),  --> out
        # )
        v_in = self.value_encoder(batch_triplets[0]) # (B, embd_dim)
        v_out = self.value_encoder(batch_triplets[2]) # (B, embd_dim)
        v_func = self.function_embedding(batch_triplets[1]) # (B, embd_dim)
        
        vr_in = self.value_encoder(batch_triplets[3])
        vr_out = self.value_encoder(batch_triplets[5])
        vr_func = self.function_embedding(batch_triplets[4])
        
        nv_in = F.normalize(v_in,2,1)
        nv_out = F.normalize(v_out,2,1)
        nvr_in = F.normalize(vr_in,2,1)
        nvr_out = F.normalize(nvr_out,2,1)
        
        # p=None, 2-norm
        pos_score = torch.norm(nv_in+v_func-nv_out, dim=1)
        neg_score = torch.norm(nvr_in+vr_func-nvr_out, dim=1)
        
        # (B, 1), ..
        return (pos_score, neg_score)

In [10]:
def MTETrainer(p_nep, p_model, p_ld_data, p_optim, p_lossfn):
    for d_ep in range(p_nep):
        train_loss_list = []
        for batch_idx, bts in enumerate(p_ld_data):
            print("YES?")
            p_model.train()
            if use_cuda:
                td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
            else:
                td_bts = [Variable(bts[i]) for i in range(len(bts))]
            
            d_scores = p_model(td_bts) # (pos,neg)
            p_optim.zero_grad()
            y = torch.tensor([-1], dtype=torch.float)
            if use_cuda:
                y = y.cuda()
            d_loss = p_lossfn(d_scores[0],d_scores[1],y)
            train_loss_list.append(d_loss)
            d_loss.backward()
            p_optim.step()
            
            print("\r# EP{}, B:{}, L:{:.4f}, AvgL:{:.4f}".format(
                d_ep, batch_idx, d_loss,
                sum(train_loss_list)/len(train_loss_list)
            ),end="")
        print()
            

In [11]:
m_interpreter = MorpheusInterpreter()
m_spec = S.parse_file('./example/set_select.tyrell')
m_generator = MorpheusGenerator(
    spec=m_spec,
    interpreter=m_interpreter,
    sfn=m_interpreter.sanity_check,
)
# dumb Program Space
m_dps = ProgramSpace(
    m_spec, m_interpreter, None, None
)
m_config = {
    "val":{
        "vocab_size": len(CAMB_LIST),
        "embd_dim": 10, # embedding dim of CAMB abstract token
        "conv_n_kernels": 200,
        "conv_kernel_size": (1,1,CAMB_NCOL), # (k_map, k_row, k_col)
        "pool_kernel_size": (2,CAMB_NROW,1), # (k_map, k_row, k_col)
        "IDX_PAD": 0,
    },
    "func":{
        "vocab_size": len(m_dps.prod_list),
    },
    "embd_dim":128,
    "margin":0.01,
}

dt_mg = MorphGenDataset(p_spec=m_spec, p_generator=m_generator, p_interpreter=m_interpreter, p_dps=m_dps)
ld_mg = DataLoader(dataset=dt_mg, batch_size=4, shuffle=False)

mte = MorphTransE(p_config=m_config)
m_loss = nn.MarginRankingLoss(margin=m_config["margin"])
if use_cuda:
    mte = mte.cuda()
    m_loss = m_loss.cuda()
optimizer = torch.optim.Adam(mte.parameters())
    
# writer = SummaryWriter("runs/0704CAMB_TransE_select")


In [None]:
m_config

{'embd_dim': 128,
 'func': {'vocab_size': 72},
 'margin': 0.01,
 'val': {'IDX_PAD': 0,
  'conv_kernel_size': (1, 1, 15),
  'conv_n_kernels': 200,
  'embd_dim': 10,
  'pool_kernel_size': (2, 15, 1),
  'vocab_size': 5}}

In [None]:
MTETrainer(1000000, mte, ld_mg, optimizer, m_loss)

In [None]:
# for i in ld_mg:
#     print(i)
#     break