## Component/Abstraction Representation Learning using TransE
- Stage: Cambrian
- Version: Charniodiscus

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import time
import copy
import random
import pickle
import numpy as np

os.environ["CUDA_LAUNCH_BLOCKING"]="1"

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
# use_cuda = False
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [4]:
from sklearn.metrics.pairwise import pairwise_distances

In [5]:
torch.__version__

'1.0.0'

In [6]:
class AlgoDataset(Dataset):
    def __init__(self, p_len):
        self.len = p_len
        self.max_num = 32
        self.min_num = -32
        self.vals = [i for i in range(self.min_num,self.max_num+1)]
        self.fns = [
            
#             lambda x:self.rectify(x+2),
#             lambda x:self.rectify(x+3),
#             lambda x:self.rectify(x+4),
#             lambda x:self.rectify(x+5),
#             lambda x:self.rectify(x+6),
#             lambda x:self.rectify(x-2),
#             lambda x:self.rectify(x-3),
#             lambda x:self.rectify(x-4),
#             lambda x:self.rectify(x-5),
#             lambda x:self.rectify(x-6),
            
            lambda x:self.rectify(x+1),
            lambda x:self.rectify(x-1),
            lambda x:self.rectify(x+2),
            lambda x:self.rectify(x-2),
            lambda x:self.rectify(x+3),
            lambda x:self.rectify(x-3),
            lambda x:self.rectify(x+4),
            lambda x:self.rectify(x-4),
            lambda x:self.rectify(x+5),
            lambda x:self.rectify(x-5),
            
            lambda x:self.rectify(x+6),
            lambda x:self.rectify(x-6),
        ]
        self.val2id = {
            self.vals[i]:i for i in range(len(self.vals))
        }
        
    def rectify(self,x):
#         dx = int(x)
#         if dx>self.max_num:
#             return self.max_num
#         elif dx<self.min_num:
#             return self.min_num
#         else:
#             return dx
        return int(x)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, p_ind):
        while True:
            tmp_in = random.choice(range(len(self.vals)))
            tmp_fn = random.choice(range(len(self.fns)))
            tmp_out = self.fns[tmp_fn](self.vals[tmp_in]) # still a value
            if tmp_out>self.max_num or tmp_out<self.min_num:
                continue
            else:
                tmp_out = self.val2id[tmp_out]
                break
                
        # in this version, just replace the function called
        while True:
            rep_fn = random.choice(range(len(self.fns)))
            rep_in = tmp_in
            rep_out = tmp_out
            if self.vals[rep_out]==self.fns[rep_fn](self.vals[rep_in]):
                continue
            else:
                break

        return (tmp_in, tmp_fn, tmp_out,
                rep_in, rep_fn, rep_out,)
        

In [7]:
class MarginLoss(nn.Module):
    def __init__(self):
        super(MarginLoss, self).__init__()

    def forward(self, pos, neg, margin):
        if use_cuda:
            zero_tensor = torch.tensor(pos.size(),dtype=torch.float).cuda()
        else:
            zero_tensor = torch.tensor(pos.size(),dtype=torch.float)
        zero_tensor.zero_()
        zero_tensor = Variable(zero_tensor)
        return torch.sum(torch.max(pos - neg + margin, zero_tensor))

In [8]:
def norm_loss(embeddings, dim=1):
    norm = torch.sum(embeddings ** 2, dim=dim, keepdim=True)
    if use_cuda:
        return torch.sum(
            torch.max(
                norm - Variable(torch.tensor([1.0],dtype=torch.float)).cuda(), 
                Variable(torch.tensor([0.0],dtype=torch.float)).cuda(),
            )
        )
    else:
        return torch.sum(
            torch.max(
                norm - Variable(torch.tensor([1.0],dtype=torch.float)), 
                Variable(torch.tensor([0.0],dtype=torch.float)),
            )
        )

In [9]:
class MorphTransE(nn.Module):
    def __init__(self, p_config=None):
        super(MorphTransE, self).__init__()
        self.config = p_config
        
        self.val_vocab_size = self.config["val"]["vocab_size"]
        self.fn_vocab_size = self.config["fn"]["vocab_size"]
        self.embd_dim = self.config["embd_dim"]
        
        self.val_embedding = nn.Embedding(
            self.val_vocab_size,
            self.embd_dim,
#             max_norm=1,
        )
        self.fn_embedding = nn.Embedding(
            self.fn_vocab_size,
            self.embd_dim,
#             max_norm=1,
        )
        
#         print("val-before:")
#         print(self.val_embedding.weight.data)
        
        nn.init.xavier_uniform_(self.val_embedding.weight.data)
        self.val_embedding.weight.data = F.normalize(
            self.val_embedding.weight.data, p=2, dim=1,
        )
        
#         print("val-after:")
#         print(self.val_embedding.weight.data)
#         input("PAUSE--0")
        
        nn.init.xavier_uniform_(self.fn_embedding.weight.data)
        self.fn_embedding.weight.data = F.normalize(
            self.fn_embedding.weight.data, p=2, dim=1,
        )
        
        
    def forward(self, batch_triplets):
        # print("batch_triplets:{}".format(batch_triplets))
        
        v_in = self.val_embedding(batch_triplets[0])
        v_fn = self.fn_embedding(batch_triplets[1])
        v_out = self.val_embedding(batch_triplets[2])
        
        r_in = self.val_embedding(batch_triplets[3])
        r_fn = self.fn_embedding(batch_triplets[4])
        r_out = self.val_embedding(batch_triplets[5])
        
        pos_score = torch.sum((v_in + v_fn - v_out) ** 2, 1)
        neg_score = torch.sum((r_in + r_fn - r_out) ** 2, 1)
        # pos_score = torch.sum(torch.abs(v_in + v_fn - v_out), 1)
        # neg_score = torch.sum(torch.abs(r_in + r_fn - r_out), 1)
        
        # (B, 1), ..
        return (pos_score, neg_score)
    
    def dbg_infer_fn(self, ios):
        # ios: (
        #   (B, n_maps, map_r, map_c), --> input
        #   (B, n_maps, map_r, map_c), --> output
        # )
        B = ios[0].shape[0]
        
        v_in = self.val_embedding(ios[0]) # (B, embd_dim)
        v_out = self.val_embedding(ios[1]) # (B, embd_dim)
        
        est_fn = v_out-v_in # (B, embd_dim)
        # print(est_fn.shape)
        # print(self.fn_embedding.weight.data.shape)
        # input("PAUSE")
        
        nest_fn = est_fn.data.cpu().numpy()
        nest_embd = self.fn_embedding.weight.data.cpu().numpy()
        dist = pairwise_distances(nest_fn,nest_embd,metric='euclidean')
            
        return dist
    
    def infer_fn(self, ios):
        # ios: (
        #   (B, n_maps, map_r, map_c), --> input
        #   (B, n_maps, map_r, map_c), --> output
        # )
        B = ios[0].shape[0]
        
        v_in = self.val_embedding(ios[0]) # (B, embd_dim)
        v_out = self.val_embedding(ios[1]) # (B, embd_dim)
        
        est_fn = v_out-v_in # (B, embd_dim)
        # print(est_fn.shape)
        # print(self.fn_embedding.weight.data.shape)
        # input("PAUSE")
        
        dlist = []
        for i in range(B):
            sublist = []
            for j in range(self.fn_vocab_size):
                sublist.append(
                    torch.dist(est_fn[i,:], self.fn_embedding.weight.data[j,:]) # (1,)
                )
            dlist.append(
                torch.tensor([sublist])
            )
        
        ret_dlist = torch.cat(dlist,dim=0)
        # print(ret_dlist.shape)
        # input("PAUSE____")
        
        return ret_dlist
    
    def infer_output(self, ifs):
        # input/function-s
        B = ifs[0].shape[0]
        
        v_in = self.val_embedding(ifs[0])
        v_fn = self.fn_embedding(ifs[1])
        
        est_out = v_in + v_fn # (B, embd_dim)
        
        dlist = []
        for i in range(B):
            sublist = []
            for j in range(self.val_vocab_size):
                sublist.append(
                    torch.dist(est_out[i,:], self.val_embedding.weight.data[j,:]) # (1,)
                )
            dlist.append(
                torch.tensor([sublist])
            )
            
        ret_dlist = torch.cat(dlist,dim=0)
        return ret_dlist
        
        

In [10]:
'''
use the same ld in trainer if sampling randomly every time
'''
def MTETester(p_model, p_ld_data):
    
#     # ### OUTPUT prediction ### #
#     rank_list = []
#     for batch_idx, bts in enumerate(p_ld_data):
#         p_model.eval()
#         B = bts[0].shape[0]
        
#         if use_cuda:
#             td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
#         else:
#             td_bts = [Variable(bts[i]) for i in range(len(bts))]
        
#         # feed the true input and output, and infer the function
#         # return the function scores
#         # (B, fn_vocab_size)
        
#         d_scores = p_model.infer_output(
#             (td_bts[0],td_bts[1])
#         )
#         sorted_scores = torch.argsort(d_scores,dim=1).cpu().numpy()
#         for i in range(B):
#             rank_list.append(
#                 sorted_scores[i,:].tolist().index(bts[2][i])
#                 # sorted_scores[i,bts[2][i]].item()
#             )
            
#         print("\r# TEST/OUTPUT B:{}".format(batch_idx),end="")
        
#     print()
#     print("# TEST/OUTPUT avg.rank:{:.2f}, b:{:.2f}, w:{:.2f}, 50p:{:.2f}, 75p:{:.2f}, 90p:{:.2f}".format(
#         sum(rank_list)/len(rank_list), 
#         min(rank_list), 
#         max(rank_list), 
#         np.percentile(rank_list,50),
#         np.percentile(rank_list,75),
#         np.percentile(rank_list,90)
#     ))
    
    
    # ### FUNCTION prediction ### #
    rank_list = []
    for batch_idx, bts in enumerate(p_ld_data):
        p_model.eval()
        B = bts[0].shape[0]
        
        if use_cuda:
            td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
        else:
            td_bts = [Variable(bts[i]) for i in range(len(bts))]
        
        # feed the true input and output, and infer the function
        # return the function scores
        # (B, fn_vocab_size)
        
        d_scores = p_model.infer_fn(
            (td_bts[0],td_bts[2])
        )
        
#         d_scores = p_model.dbg_infer_fn(
#             (td_bts[0],td_bts[2])
#         )
        
        sorted_scores = torch.argsort(d_scores,dim=1).cpu().numpy()
#         sorted_scores = np.argsort(d_scores,axis=1)
        # print(sorted_scores.shape)
        # print(sorted_scores)
        # input("PAUSE")
        for i in range(B):
            rank_list.append(
                sorted_scores[i,:].tolist().index(bts[1][i])
                # sorted_scores[i,bts[2][i]].item()
            )
            
        print("\r# TEST/FUNCTION B:{}".format(batch_idx),end="")
        
    print()
    print("# TEST/FUNCTION avg.rank:{:.2f}, b:{:.2f}, w:{:.2f}, 50p:{:.2f}, 75p:{:.2f}, 90p:{:.2f}".format(
        sum(rank_list)/len(rank_list), 
        min(rank_list), 
        max(rank_list), 
        np.percentile(rank_list,50),
        np.percentile(rank_list,75),
        np.percentile(rank_list,90)
    ))
        

In [11]:
def MTETrainer(p_nep, p_model, p_ld_train, p_ld_test, p_optim, p_lossfn):
    MTETester(p_model, p_ld_test)
    for d_ep in range(p_nep):
        train_loss_list = []
        for batch_idx, bts in enumerate(p_ld_train):
            p_model.train()
            if use_cuda:
                td_bts = [Variable(bts[i]).cuda() for i in range(len(bts))]
            else:
                td_bts = [Variable(bts[i]) for i in range(len(bts))]
            
            d_scores = p_model(td_bts) # (pos,neg)
            p_optim.zero_grad()
            if use_cuda:
                margin = Variable(torch.tensor([1],dtype=torch.float)).cuda()
            else:
                margin = Variable(torch.tensor([1],dtype=torch.float))
            d_loss = p_lossfn(d_scores[0],d_scores[1],margin)
            d_loss += norm_loss(p_model.val_embedding(td_bts[0]))
            d_loss += norm_loss(p_model.fn_embedding(td_bts[1]))
            d_loss += norm_loss(p_model.val_embedding(td_bts[2]))
            d_loss += norm_loss(p_model.val_embedding(td_bts[3]))
            d_loss += norm_loss(p_model.fn_embedding(td_bts[4]))
            d_loss += norm_loss(p_model.val_embedding(td_bts[5]))
            train_loss_list.append(d_loss)
            d_loss.backward()
            p_optim.step()
            
            print("\r# TRAIN EP{}, B:{}, L:{:.4f}, AvgL:{:.4f}".format(
                d_ep, batch_idx, d_loss,
                float(sum(train_loss_list))/float(len(train_loss_list))
            ),end="")
        print()
        if d_ep%10==0:
            MTETester(p_model, p_ld_test)
            

In [12]:

dt_mg_train = AlgoDataset(p_len=980000)
ld_mg_train = DataLoader(dataset=dt_mg_train, batch_size=9800, shuffle=False)

dt_mg_test = AlgoDataset(p_len=10000)
ld_mg_test = DataLoader(dataset=dt_mg_test, batch_size=8, shuffle=False)

m_config = {
    "val":{
        "vocab_size":len(dt_mg_train.vals),
    },
    "fn":{
        "vocab_size":len(dt_mg_train.fns),
    },
    "embd_dim":128,
}

mte = MorphTransE(p_config=m_config)
# m_loss = nn.MarginRankingLoss(margin=m_config["margin"])
m_loss = MarginLoss()
if use_cuda:
    mte = mte.cuda()
    m_loss = m_loss.cuda()
optimizer = torch.optim.Adam(mte.parameters())


In [13]:
m_config

{'val': {'vocab_size': 65}, 'fn': {'vocab_size': 12}, 'embd_dim': 128}

In [14]:
for i in range(10):
    print(dt_mg_train[i])

(33, 8, 38, 33, 10, 38)
(10, 10, 16, 10, 6, 16)
(22, 8, 27, 22, 4, 27)
(36, 0, 37, 36, 11, 37)
(43, 0, 44, 43, 4, 44)
(44, 1, 43, 44, 5, 43)
(47, 8, 52, 47, 3, 52)
(23, 4, 26, 23, 11, 26)
(0, 6, 4, 0, 2, 4)
(18, 1, 17, 18, 9, 17)


In [15]:
MTETrainer(1000000, mte, ld_mg_train, ld_mg_test, optimizer, m_loss)

# TEST/FUNCTION B:1249
# TEST/FUNCTION avg.rank:5.47, b:0.00, w:11.00, 50p:6.00, 75p:8.00, 90p:10.00
# TRAIN EP0, B:99, L:8954.5430, AvgL:9398.4231
# TEST/FUNCTION B:1249
# TEST/FUNCTION avg.rank:4.88, b:0.00, w:11.00, 50p:4.00, 75p:8.00, 90p:10.00
# TRAIN EP1, B:99, L:8257.6318, AvgL:8631.0800
# TRAIN EP2, B:99, L:7806.0615, AvgL:8023.9056
# TRAIN EP3, B:99, L:7611.5474, AvgL:7667.4606
# TRAIN EP4, B:99, L:7471.4717, AvgL:7521.9869
# TRAIN EP5, B:99, L:7438.4160, AvgL:7459.1431
# TRAIN EP6, B:99, L:7478.4141, AvgL:7420.6763
# TRAIN EP7, B:99, L:7347.0479, AvgL:7408.3438
# TRAIN EP8, B:99, L:7370.2432, AvgL:7393.7238
# TRAIN EP9, B:99, L:7383.8872, AvgL:7390.6769
# TRAIN EP10, B:99, L:7415.7095, AvgL:7396.5363
# TEST/FUNCTION B:1249
# TEST/FUNCTION avg.rank:2.90, b:0.00, w:11.00, 50p:3.00, 75p:4.00, 90p:5.00
# TRAIN EP11, B:99, L:7353.1948, AvgL:7391.9294
# TRAIN EP12, B:99, L:7374.3008, AvgL:7391.2081
# TRAIN EP13, B:99, L:7399.2085, AvgL:7384.8881
# TRAIN EP14, B:99, L:7420.8350, Avg

KeyboardInterrupt: 

In [None]:
dt_mg_train.write_to_file()

In [None]:
mte.val_embedding.weight[0,:]

In [None]:
mte.val_embedding.weight[1,:]

In [None]:
mte.fn_embedding.weight[0,:]

In [None]:
mte.val_embedding.weight[0,:] + mte.fn_embedding.weight[0,:] - mte.val_embedding.weight[1,:]