## Component/Abstraction Representation Learning using TransE
- Stage: Cambrian
- Version: Charniodiscus

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import time
import copy
import random
import pickle
import numpy as np

os.environ["CUDA_LAUNCH_BLOCKING"]="1"

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
# use_cuda = False
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [4]:
# Morpheus Version
from utils_morpheus import *
from ProgramSpace import *

In [5]:
torch.__version__

'1.0.0'

In [6]:
class MorphGenDataset(Dataset):
    def __init__(self, p_spec=None, p_generator=None, p_interpreter=None, p_dps=None, p_len=None):
        self.spec = p_spec
        self.generator = p_generator
        self.interpreter = p_interpreter
        self.dps = p_dps
        self.len = p_len
        
        # construct a shell list/dict so that every function call (shell) has id
        self.shell_list = self.dps.get_neighboring_shells()
        self.shell_dict = {
            self.shell_list[i]:i for i in range(len(self.shell_list))
        }
        
    def __len__(self):
        return self.len

    def __getitem__(self, p_ind):
        # basically ignore the p_ind parameter and generate randomly
        # returns a pair of abstraction maps
        # (tmp_in, tmp_out)
        # which is (n_maps, CAMB_NROW, CAMB_NCOL)
        # currently assuming single input single output
        
        while True:
            tmp_input = self.interpreter.random_table()
            tmp_prog, tmp_example = self.generator.generate(
                fixed_depth=2, # this should be fixed to 2 since we learn size 1
                example=Example(input=[tmp_input], output=None),
            )
            if tmp_prog is not None and tmp_prog.is_apply():
                break

        tmp_in = camb_get_abs(tmp_example.input[0])
        tmp_out= camb_get_abs(tmp_example.output)

        tmp_func = self.shell_dict[
            (
                self.dps.prod_list.index(tmp_prog.production),
                tuple(
                    [self.dps.node_list.index(tmp_prog.args[i])
                    for i in range(len(tmp_prog.args))]
                ),
            )
        ]

        # =========================================================
        # you should also produce a negative sample here and return

        while True:
            rep_in = tmp_in
            rep_out = tmp_out
            rep_func = random.choice(range(len(self.shell_list)))
            if rep_func==tmp_func:
                continue
            else:
                # in fact, not safe
                # should verify that rep_func(rep_in)!=rep_out
                # but never mind here
                break
        
        return (tmp_in, tmp_func, tmp_out,
                rep_in, rep_func, rep_out)

In [7]:
class MarginLoss(nn.Module):
    def __init__(self):
        super(MarginLoss, self).__init__()

    def forward(self, pos, neg, margin):
        if use_cuda:
            zero_tensor = torch.tensor(pos.size(),dtype=torch.float).cuda()
        else:
            zero_tensor = torch.tensor(pos.size(),dtype=torch.float)
        zero_tensor.zero_()
        zero_tensor = Variable(zero_tensor)
        return torch.sum(torch.max(pos - neg + margin, zero_tensor))

In [8]:
def norm_loss(embeddings, dim=1):
    norm = torch.sum(embeddings ** 2, dim=dim, keepdim=True)
    if use_cuda:
        return torch.sum(
            torch.max(
                norm - Variable(torch.tensor([1.0],dtype=torch.float)).cuda(), 
                Variable(torch.tensor([0.0],dtype=torch.float)).cuda(),
            )
        )
    else:
        return torch.sum(
            torch.max(
                norm - Variable(torch.tensor([1.0],dtype=torch.float)), 
                Variable(torch.tensor([0.0],dtype=torch.float)),
            )
        )

In [9]:
class ValueEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValueEncoder, self).__init__()
        self.config = p_config
        
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv2d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
        
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"],
        )
        
    def forward(self, bp_map):
        # batched maps, (B, map_r, map_c)
        # in this version, every value only contains 1 map
        B = bp_map.shape[0]
        
        # (B, map_r, map_c, val_embd_dim) -> (B, val_embd_dim, map_r, map_c)
        d_embd = self.embedding(bp_map).permute(0,3,1,2)
        
        # (B, n_kernel, map_r, 1)
        d_conv = F.relu(self.conv(d_embd))
        
        # (B, n_kernel)
        d_pool = self.pool(d_conv).view(B,self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        d_out = torch.sigmoid(
            self.fc(d_pool)
        )
        
        return d_out
        

In [10]:
class MorphTransE(nn.Module):
    def __init__(self, p_config=None):
        super(MorphTransE, self).__init__()
        self.config = p_config
        
        self.value_encoder = ValueEncoder(p_config=p_config)
        
        self.fn_vocab_size = self.config["fn"]["vocab_size"]
        self.embd_dim = self.config["embd_dim"]
        
        self.fn_embedding = nn.Embedding(
            self.fn_vocab_size,
            self.embd_dim,
        )
        
        nn.init.xavier_uniform_(self.fn_embedding.weight.data)
        self.fn_embedding.weight.data = F.normalize(
            self.fn_embedding.weight.data, p=2, dim=1,
        )
        
        
    def forward(self, batch_triplets):
        # print("batch_triplets:{}".format(batch_triplets))
        
        # v_in = self.val_embedding(batch_triplets[0])
        v_in = self.value_encoder(batch_triplets[0])
        v_fn = self.fn_embedding(batch_triplets[1])
        # v_out = self.val_embedding(batch_triplets[2])
        v_out = self.value_encoder(batch_triplets[2])
        
        # r_in = self.val_embedding(batch_triplets[3])
        r_in = self.value_encoder(batch_triplets[3])
        r_fn = self.fn_embedding(batch_triplets[4])
        # r_out = self.val_embedding(batch_triplets[5])
        r_out = self.value_encoder(batch_triplets[5])
        
        pos_score = torch.sum((v_in + v_fn - v_out) ** 2, 1)
        neg_score = torch.sum((r_in + r_fn - r_out) ** 2, 1)
        
        # (B, 1), ..
        return (pos_score, neg_score)
    
    def infer_fn(self, ios):
        # ios: (
        #   (B, n_maps, map_r, map_c), --> input
        #   (B, n_maps, map_r, map_c), --> output
        # )
        B = ios[0].shape[0]
        
        # v_in = self.val_embedding(ios[0]) # (B, embd_dim)
        # v_out = self.val_embedding(ios[1]) # (B, embd_dim)
        v_in = self.value_encoder(ios[0])
        v_out = self.value_encoder(ios[1])
        
        est_fn = v_out-v_in # (B, embd_dim)
        # print(est_fn.shape)
        # print(self.fn_embedding.weight.data.shape)
        # input("PAUSE")
        
        dlist = []
        for i in range(B):
            sublist = []
            for j in range(self.fn_vocab_size):
                sublist.append(
                    torch.dist(est_fn[i,:], self.fn_embedding.weight.data[j,:]) # (1,)
                )
            dlist.append(
                torch.tensor([sublist])
            )
        
        ret_dlist = torch.cat(dlist,dim=0)
        
        return ret_dlist

In [11]:
'''
use the same ld in trainer if sampling randomly every time
'''
def MTETester(p_model, p_ld_data):
    
    # ### FUNCTION prediction ### #
    rank_list = []
    for batch_idx, bts in enumerate(p_ld_data):
        p_model.eval()
        B = bts[0].shape[0]
        
        if use_cuda:
            td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
        else:
            td_bts = [Variable(bts[i]) for i in range(len(bts))]
        
        # feed the true input and output, and infer the function
        # return the function scores
        # (B, fn_vocab_size)
        
        d_scores = p_model.infer_fn(
            (td_bts[0],td_bts[2])
        )
        
        sorted_scores = torch.argsort(d_scores,dim=1).cpu().numpy()
        for i in range(B):
            rank_list.append(
                sorted_scores[i,:].tolist().index(bts[1][i])
            )
            
        print("\r# TEST/FUNCTION B:{}".format(batch_idx),end="")
        
    print()
    print("# TEST/FUNCTION avg.rank:{:.2f}, b:{:.2f}, w:{:.2f}, 50p:{:.2f}, 75p:{:.2f}, 90p:{:.2f}".format(
        sum(rank_list)/len(rank_list), 
        min(rank_list), 
        max(rank_list), 
        np.percentile(rank_list,50),
        np.percentile(rank_list,75),
        np.percentile(rank_list,90)
    ))
        

In [12]:
def MTETrainer(p_nep, p_model, p_ld_train, p_ld_test, p_optim, p_lossfn):
    MTETester(p_model, p_ld_test)
    for d_ep in range(p_nep):
        train_loss_list = []
        for batch_idx, bts in enumerate(p_ld_train):
            p_model.train()
            if use_cuda:
                td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
            else:
                td_bts = [Variable(bts[i]) for i in range(len(bts))]
            
            d_scores = p_model(td_bts) # (pos,neg)
            p_optim.zero_grad()
            if use_cuda:
                margin = Variable(torch.tensor([1],dtype=torch.float)).cuda()
            else:
                margin = Variable(torch.tensor([1],dtype=torch.float))
            d_loss = p_lossfn(d_scores[0],d_scores[1],margin)
            d_loss += norm_loss(p_model.value_encoder(td_bts[0]))
            d_loss += norm_loss(p_model.fn_embedding(td_bts[1]))
            d_loss += norm_loss(p_model.value_encoder(td_bts[2]))
            d_loss += norm_loss(p_model.value_encoder(td_bts[3]))
            d_loss += norm_loss(p_model.fn_embedding(td_bts[4]))
            d_loss += norm_loss(p_model.value_encoder(td_bts[5]))
            train_loss_list.append(d_loss)
            d_loss.backward()
            p_optim.step()
            
            print("\r# TRAIN EP{}, B:{}, L:{:.4f}, AvgL:{:.4f}".format(
                d_ep, batch_idx, d_loss,
                float(sum(train_loss_list))/float(len(train_loss_list))
            ),end="")
        print()
        if d_ep%10==0:
            MTETester(p_model, p_ld_test)
            # and also save the model
#             torch.save(
#                 p_model.state_dict(),
#                 "./saved_models/0709CAMB_TransE_camb3_std_ep{}.pt".format(d_ep)
#             )
            

In [13]:
m_interpreter = MorpheusInterpreter()
# m_spec = S.parse_file('./example/set_select.tyrell')
# m_spec = S.parse_file('./example/set_unite.tyrell')
# m_spec = S.parse_file('./example/camb2.tyrell')
m_spec = S.parse_file('./example/camb3.tyrell')
m_generator = MorpheusGenerator(
    spec=m_spec,
    interpreter=m_interpreter,
    sfn=m_interpreter.sanity_check,
)
# dumb Program Space
m_dps = ProgramSpace(
    m_spec, m_interpreter, [None], None
)


dt_mg_train = MorphGenDataset(p_spec=m_spec, p_generator=m_generator, p_interpreter=m_interpreter, p_dps=m_dps, p_len=1024)
ld_mg_train = DataLoader(dataset=dt_mg_train, batch_size=8, shuffle=False)

dt_mg_test = MorphGenDataset(p_spec=m_spec, p_generator=m_generator, p_interpreter=m_interpreter, p_dps=m_dps, p_len=512)
ld_mg_test = DataLoader(dataset=dt_mg_test, batch_size=8, shuffle=False)

m_config = {
    "val":{
        "vocab_size": len(CAMB_LIST),
        "embd_dim": 16, # embedding dim of CAMB abstract token
        "conv_n_kernels": 512,
        "conv_kernel_size": (1,CAMB_NCOL), 
        "pool_kernel_size": (CAMB_NROW,1), 
        "IDX_PAD": 0,
    },
    "fn":{
        "vocab_size": len(dt_mg_train.shell_list)
    },
    "embd_dim":128,
}

mte = MorphTransE(p_config=m_config)
m_loss = MarginLoss()
if use_cuda:
    mte = mte.cuda()
    m_loss = m_loss.cuda()
optimizer = torch.optim.Adam(mte.parameters())


In [14]:
m_config

{'val': {'vocab_size': 150,
  'embd_dim': 16,
  'conv_n_kernels': 512,
  'conv_kernel_size': (1, 15),
  'pool_kernel_size': (15, 1),
  'IDX_PAD': 0},
 'fn': {'vocab_size': 120},
 'embd_dim': 128}

In [15]:
MTETrainer(1000000, mte, ld_mg_train, ld_mg_test, optimizer, m_loss)

# TEST/FUNCTION B:63
# TEST/FUNCTION avg.rank:58.16, b:0.00, w:119.00, 50p:58.00, 75p:85.00, 90p:104.00
# TRAIN EP0, B:127, L:9.2735, AvgL:43.4766
# TEST/FUNCTION B:63
# TEST/FUNCTION avg.rank:22.94, b:0.00, w:118.00, 50p:13.00, 75p:35.00, 90p:57.90
# TRAIN EP1, B:127, L:4.6447, AvgL:5.9019
# TRAIN EP2, B:127, L:3.7905, AvgL:5.2777
# TRAIN EP3, B:127, L:9.0010, AvgL:5.2318
# TRAIN EP4, B:127, L:4.4380, AvgL:5.0906
# TRAIN EP5, B:127, L:3.3485, AvgL:4.5793
# TRAIN EP6, B:127, L:4.4360, AvgL:4.6429
# TRAIN EP7, B:127, L:3.3629, AvgL:4.5348
# TRAIN EP8, B:127, L:4.9942, AvgL:4.4090
# TRAIN EP9, B:127, L:7.0046, AvgL:4.2847
# TRAIN EP10, B:127, L:4.8360, AvgL:4.2373
# TEST/FUNCTION B:63
# TEST/FUNCTION avg.rank:20.10, b:0.00, w:113.00, 50p:10.00, 75p:32.00, 90p:49.90
# TRAIN EP11, B:127, L:5.4690, AvgL:4.0416
# TRAIN EP12, B:127, L:6.7340, AvgL:4.3124
# TRAIN EP13, B:127, L:5.3311, AvgL:4.2229
# TRAIN EP14, B:127, L:2.3463, AvgL:3.9646
# TRAIN EP15, B:127, L:4.4872, AvgL:4.3407
# TRAIN EP1

KeyboardInterrupt: 