## TransNeo
- AlphaNeo using pre-trained TransE embeddings
- Stage: Cambrian
- Version: Charniodiscus
#### Related Commands
- tensorboard --logdir runs
- nohup jupyter lab > jupyter.log &

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import copy
import random
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [3]:
import tyrell.spec as S
from tyrell.interpreter import Interpreter, PostOrderInterpreter, GeneralError, InterpreterError
from tyrell.enumerator import Enumerator, SmtEnumerator, RandomEnumerator, DesignatedEnumerator, RandomEnumeratorS, ExhaustiveEnumerator
from tyrell.decider import Example, ExampleConstraintPruningDecider, ExampleDecider, TestDecider
from tyrell.synthesizer import Synthesizer
from tyrell.logger import get_logger
from sexpdata import Symbol
from tyrell import dsl as D
from typing import Callable, NamedTuple, List, Any

In [4]:
# Morpheus Version
from utils_morpheus import *
from ProgramSpace import *

In [5]:
torch.__version__

'1.0.0'

In [6]:
DBG_PS = None

In [7]:
class ValueEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValueEncoder, self).__init__()
        self.config = p_config
        
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv2d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
        
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"],
        )
        
    def forward(self, bp_map):
        # batched maps, (B, map_r, map_c)
        # in this version, every value only contains 1 map
        B = bp_map.shape[0]
        
        # (B, map_r, map_c, val_embd_dim) -> (B, val_embd_dim, map_r, map_c)
        d_embd = self.embedding(bp_map).permute(0,3,1,2)
        
        # (B, n_kernel, map_r, 1)
        d_conv = F.relu(self.conv(d_embd))
        
        # (B, n_kernel)
        d_pool = self.pool(d_conv).view(B,self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        d_out = torch.sigmoid(
            self.fc(d_pool)
        )
        
        return d_out
        

In [8]:
class MorphTransE(nn.Module):
    def __init__(self, p_config=None):
        super(MorphTransE, self).__init__()
        self.config = p_config
        
        self.value_encoder = ValueEncoder(p_config=p_config)
        
        self.fn_vocab_size = self.config["fn"]["vocab_size"]
        self.embd_dim = self.config["embd_dim"]
        
        self.fn_embedding = nn.Embedding(
            self.fn_vocab_size,
            self.embd_dim,
        )
        
        nn.init.xavier_uniform_(self.fn_embedding.weight.data)
        self.fn_embedding.weight.data = F.normalize(
            self.fn_embedding.weight.data, p=2, dim=1,
        )
# ----> skip the forward part since we don't need it <---- #

In [9]:
class TransNeo(nn.Module):
    def __init__(self, p_config=None):
        super(TransNeo, self).__init__()
        self.config = p_config
        
        # predict a fixed number of shells
        self.policy = nn.Linear(
            self.config["embd_dim"],
            self.config["fn"]["vocab_size"],
        )
        
        # deeper
#         self.policy0 = nn.Linear(
#             self.config["embd_dim"],
#             2048,
#         )
#         self.policy1 = nn.Linear(
#             2048,
#             self.config["fn"]["vocab_size"],
#         )
        
        self.mte = MorphTransE(p_config=p_config)
        # then load the parameters
        self.mte.load_state_dict(torch.load(self.config["mte_path"]))
        
    def forward(self, p_mapin, p_mapout):
        # p_mapin/p_mapout: (B, map_r, map_c)
        v_in = self.mte.value_encoder(p_mapin).detach() # (B, embd_dim)
        v_out= self.mte.value_encoder(p_mapout).detach() # (B, embd_dim)
        v_delta = v_out - v_in
        tmp_out = torch.log_softmax(
            self.policy(v_delta),dim=1
        )
#         tmp_out = torch.log_softmax(
#             self.policy1(
#                 F.relu(
#                     self.policy0(
#                         v_delta
#                     )
#                 )
#             ),dim=1
#         )
        
        return tmp_out

In [10]:
def get_new_program(p_config, p_interpreter, p_generator):
    # initialize a program first
    while True:
        p_input = p_interpreter.random_table()
        p_prog, p_example = p_generator.generate(
            fixed_depth=p_config["fixed_depth"],
            example=Example(input=[p_input], output=None),
        )
        # make sure at least one function call
        if p_prog is not None and p_prog.is_apply():
            break
    return p_prog, p_example

# replace certain node id with certain value
def modify_shell(p_shell, p_id_from, p_id_to):
    d_prod = p_shell[0]
    d_rhs = p_shell[1]
    ld_rhs = [p_id_to if d_rhs[i]==p_id_from else d_rhs[i]
             for i in range(len(d_rhs))]
    return (d_prod, tuple(ld_rhs))    

In [11]:
def TransNeoBeam(p_config, p_spec, p_interpreter, p_model, p_exp, verbose=False):
    p_model.eval()
    p_example = p_exp
    is_solved = False
    
    # in every new attempt, initialize a new Program Space
    ps_current = ProgramSpace(
        p_spec, p_interpreter, p_example.input, p_example.output,
    )
    
    seed_list = [
        (0.0, ps_current, True) # (score, ps, is_alive)
    ]
    beam_list = []
    dead_list = []
    
    # start the beam search
    for d_step in range(p_config["max_steps"]):
        if verbose:
            print("# Analyzing step#{}".format(d_step))
        
        for sd_score, sd_ps, sd_sts in seed_list:
            
            if not sd_sts:
                # already dead, append it then move on to the next
                dead_list.append(
                    (sd_score, sd_ps, sd_sts)
                )
                continue
            
            # initialized customized shell template
            sd_shell_list = sd_ps.get_neighboring_shells()
            sd_node_to_replace = sd_ps.node_dict["ParamNode"][0] # for chain only
            # replace the Param Node id in shells with -1 to make them templates
            sd_template_list = [
                modify_shell(sd_shell_list[i],sd_node_to_replace,-1)
                for i in range(len(sd_shell_list))
            ]
            
            # ### assume chain execution, so only 1 possible returns
            # ### at d_step=0, this should be input[0]
            id_current = sd_ps.get_frontiers()[0]
            var_current = sd_ps.node_list[id_current].ps_data # need the real var name in r env
            var_output = sd_ps.output

            map_current = camb_get_abs(var_current)
            map_output = camb_get_abs(var_output)

            # make current shell list
            current_shell_list = [
                modify_shell(sd_template_list[i],-1,id_current)
                for i in range(len(sd_template_list))
            ]

            # wrap in B=1
            if use_cuda:
                td_current = Variable(torch.tensor([map_current],dtype=torch.long)).cuda()
                td_output = Variable(torch.tensor([map_output],dtype=torch.long)).cuda()
            else:
                td_current = Variable(torch.tensor([map_current],dtype=torch.long))
                td_output = Variable(torch.tensor([map_output],dtype=torch.long))

            # (B=1, fn_vocab_size)
            # fn_vocab_size
            td_pred = p_model(td_current, td_output)
            d_pred = td_pred.data.flatten().cpu().numpy()
            asorted_pred = np.argsort(d_pred)[::-1]
            
            # pick and add the top beam_size actions
            for i in range(p_config["test"]["beam_size"]):
                d_ind = asorted_pred[i]
                tmp_ps = copy.deepcopy(sd_ps)
                tmp_status = tmp_ps.add_neighboring_shell(
                    sd_shell_list[d_ind]
                )
                # check if successful
                if tmp_status:
                    # succeed
                    beam_list.append(
                        (sd_score+d_pred[d_ind], tmp_ps, True)
                    )
                else:
                    # fail
                    beam_list.append(
                        (sd_score+d_pred[d_ind], tmp_ps, False)
                    )
        
        # <END_SEED_LIST>
        # sort the beam list and keep the top seeded ones
        sorted_beam_list = sorted(beam_list, key=lambda p:p[0], reverse=True)
        seed_list = sorted_beam_list[:p_config["test"]["seed_size"]] if len(seed_list)>=p_config["test"]["seed_size"] else sorted_beam_list
        beam_list = []

    # <END_STEP>
    # filter out the beam list once again
    for sd_score, sd_ps, sd_sts in seed_list:
        if not sd_sts:
            # already dead, append it then move on to the next
            dead_list.append(
                (sd_score, sd_ps, sd_sts)
            )
            continue
        else:
            beam_list.append(
                (sd_score, sd_ps, sd_sts)
            )
    seed_list = beam_list
    
    return seed_list, dead_list

In [12]:
m_interpreter = MorpheusInterpreter()
m_spec = S.parse_file('./example/camb3.tyrell')
m_generator = MorpheusGenerator(
    spec=m_spec,
    interpreter=m_interpreter,
    sfn=m_interpreter.sanity_check,
)
m_ps = ProgramSpace(
    m_spec, m_interpreter, [None], None,
)

m_config = {
    "val":{
        "vocab_size": len(CAMB_LIST),
        "embd_dim": 16, # embedding dim of CAMB abstract token
        "conv_n_kernels": 512,
        "conv_kernel_size": (1,CAMB_NCOL), 
        "pool_kernel_size": (CAMB_NROW,1), 
        "IDX_PAD": 0,
    },
    "fn":{
        "vocab_size": len(m_ps.get_neighboring_shells())
    },
    "embd_dim": 128,
    "mte_path": "./saved_models/0712CAMB_TransE_camb3_ep{}.pt".format(150),
    "agent_path": "./saved_models/0712CAMB_RL2_camb3_ep{}.pt".format(4300),
    "batch_size": 8, # number of steps between every back-prop by n_attempt
    "fixed_depth": 3, # means size of 3
    "n_episode": 100000000,
    "max_attempts": 100, # max number of attepmts in every episode
    "max_steps": 2, # max number of function calls
    # "exploration_rate": 0.1, # fixed exp rate
    # "exploration_rate": lambda x:0.9-0.8*(min(1,x/2500)), # from 0.9 to 0.1
    "exploration_rate": lambda xep,xat:(0.9-0.8*(min(1,xep/1000)))*(1-xat/100), # from 0.9 to 0.1
    "decay_rate": 0.9,
    # "hint_rate": lambda x:1 if x<1000 and x%2==0 else 0, # first 100 episodes have hints
    # "hint_rate": lambda x:1 if x<10000 else 0, # first 100 episodes have hints
    "hint_rate": lambda x:0,
    "test":{
        "beam_size":50, # how many candidates to expand in every step
        "seed_size":20, # how many candidates are there in every final list
    }
}

trans_neo = TransNeo(p_config=m_config)
trans_neo.load_state_dict(torch.load(m_config["agent_path"]))
if use_cuda:
    trans_neo = trans_neo.cuda()

In [13]:
m_config

{'val': {'vocab_size': 150,
  'embd_dim': 16,
  'conv_n_kernels': 512,
  'conv_kernel_size': (1, 15),
  'pool_kernel_size': (15, 1),
  'IDX_PAD': 0},
 'fn': {'vocab_size': 120},
 'embd_dim': 128,
 'mte_path': './saved_models/0712CAMB_TransE_camb3_ep150.pt',
 'agent_path': './saved_models/0712CAMB_RL2_camb3_ep4300.pt',
 'batch_size': 8,
 'fixed_depth': 3,
 'n_episode': 100000000,
 'max_attempts': 100,
 'max_steps': 2,
 'exploration_rate': <function __main__.<lambda>(xep, xat)>,
 'decay_rate': 0.9,
 'hint_rate': <function __main__.<lambda>(x)>,
 'test': {'beam_size': 50, 'seed_size': 20}}

In [14]:
# TransNeoTrainer(m_config, m_spec, m_interpreter, m_generator, trans_neo, optimizer, writer)

In [22]:
m_prog, m_example = get_new_program(
    m_config, m_interpreter, m_generator,
)
alist, dlist = TransNeoBeam(m_config, m_spec, m_interpreter, trans_neo, m_example, verbose=True)
# visualize
print("====INPUT====")
print(robjects.r(m_example.input[0]))
print("====OUTPUT====")
print(robjects.r(m_example.output))
print("====Original Program====")
print(m_prog)
print("=====Beam Solutions=====")
for i in range(len(alist)):
    is_solved = None
    if alist[i][1].check_eq() is not None:
        # solved
        is_solved = True
    else:
        is_solved = False
    tmp_node_id = alist[i][1].get_frontiers()[0]
    print("#{}({},{}): {}".format(
        i+1, is_solved, alist[i][0], alist[i][1].node_list[tmp_node_id]
        # get_frontiers only works for chain execution
    ))
print("=====Dead Solutions=====")
for i in range(len(dlist)):
    tmp_node_id = dlist[i][1].get_frontiers()[0]
    print("#{}(,{}): {}".format(
        i+1, dlist[i][0], dlist[i][1].node_list[tmp_node_id]
        # get_frontiers only works for chain execution
    ))

# Analyzing step#0
# Analyzing step#1
====INPUT====
  OCOL1 OCOL2  OCOL3 OCOL4  OCOL5
1    29    36  61.59    50  65.15
2    29    96  72.77    94  21.47
3    75    85 -42.04    56 -26.65
4    29   -66 -66.84    56 -20.60
5    29     5  27.93    56  49.19

====OUTPUT====
  OCOL1  OCOL3  OCOL5 -66  5 COL179001 COL179002 85 96
1    29 -66.84 -20.60  56 NA      <NA>      <NA> NA NA
2    29  27.93  49.19  NA 56      <NA>      <NA> NA NA
3    29  61.59  65.15  NA NA        50      <NA> NA NA
4    29  72.77  21.47  NA NA      <NA>      <NA> NA 94
5    75 -42.04 -26.65  NA NA      <NA>      <NA> 56 NA

====Original Program====
separate(spread(@param0, 2, 4), 6)
=====Beam Solutions=====
#1(False,-9.442216873168945): spread(@param0, 1, 2)
#2(False,-10.522241115570068): gather(@param0, ['1'])
#3(False,-11.354947090148926): select(@param0, ['4', '5'])
#4(False,-11.543549060821533): gather(@param0, ['1', '2'])
#5(False,-11.753877639770508): gather(@param0, ['1', '2'])
#6(False,-12.008302688598633)

In [26]:
alist[7][1].shells

[(3, (69, 1, 2)), (1, (69, 37))]

In [27]:
alist[7][1].check_eq()

70

In [28]:
alist[7][1].node_list[70]

ApplyNode(spread, [ParamNode(0), AtomNode(2), AtomNode(3)])

In [30]:
alist[7][1].node_list[70].ps_data

'RET_DF11614'

In [31]:
print(robjects.r("RET_DF11614"))

  OCOL1 OCOL4  OCOL5    -66     5    36     85    96
1    29    50  65.15     NA    NA 61.59     NA    NA
2    29    56 -20.60 -66.84    NA    NA     NA    NA
3    29    56  49.19     NA 27.93    NA     NA    NA
4    29    94  21.47     NA    NA    NA     NA 72.77
5    75    56 -26.65     NA    NA    NA -42.04    NA



In [33]:
print(robjects.r("RET_DF11599"))

  OCOL1  OCOL3  OCOL5 -66  5 COL179001 COL179002 85 96
1    29 -66.84 -20.60  56 NA      <NA>      <NA> NA NA
2    29  27.93  49.19  NA 56      <NA>      <NA> NA NA
3    29  61.59  65.15  NA NA        50      <NA> NA NA
4    29  72.77  21.47  NA NA      <NA>      <NA> NA 94
5    75 -42.04 -26.65  NA NA      <NA>      <NA> 56 NA



In [50]:
print(robjects.r("""
tmp1 <- sapply(RET_DF11614, as.character)
tmp2 <- sapply(RET_DF11599, as.character)
compare(tmp1,tmp2, ignoreOrder=TRUE)
"""))

TRUE
  sorted



In [48]:
print(robjects.r("tmp1"))

     OCOL1 OCOL4 OCOL5    -66      5       36      85       96     
[1,] "29"  "50"  "65.15"  NA       NA      "61.59" NA       NA     
[2,] "29"  "56"  "-20.6"  "-66.84" NA      NA      NA       NA     
[3,] "29"  "56"  "49.19"  NA       "27.93" NA      NA       NA     
[4,] "29"  "94"  "21.47"  NA       NA      NA      NA       "72.77"
[5,] "75"  "56"  "-26.65" NA       NA      NA      "-42.04" NA     



In [49]:
print(robjects.r("tmp2"))

     OCOL1 OCOL3    OCOL5    -66  5    COL179001 COL179002 85   96  
[1,] "29"  "-66.84" "-20.6"  "56" NA   NA        NA        NA   NA  
[2,] "29"  "27.93"  "49.19"  NA   "56" NA        NA        NA   NA  
[3,] "29"  "61.59"  "65.15"  NA   NA   "50"      NA        NA   NA  
[4,] "29"  "72.77"  "21.47"  NA   NA   NA        NA        NA   "94"
[5,] "75"  "-42.04" "-26.65" NA   NA   NA        NA        "56" NA  



In [34]:
m_interpreter.equal("RET_DF11614","RET_DF11599")

True

In [32]:
alist[7][1].output

'RET_DF11599'