## TransNeo
- AlphaNeo using pre-trained TransE embeddings
- Stage: Cambrian
- Version: Charniodiscus
- **0713: with DeepPath style rollback at training**
#### Related Commands
- tensorboard --logdir runs
- nohup jupyter lab > jupyter.log &

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import copy
import random
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [3]:
import tyrell.spec as S
from tyrell.interpreter import Interpreter, PostOrderInterpreter, GeneralError, InterpreterError
from tyrell.enumerator import Enumerator, SmtEnumerator, RandomEnumerator, DesignatedEnumerator, RandomEnumeratorS, ExhaustiveEnumerator
from tyrell.decider import Example, ExampleConstraintPruningDecider, ExampleDecider, TestDecider
from tyrell.synthesizer import Synthesizer
from tyrell.logger import get_logger
from sexpdata import Symbol
from tyrell import dsl as D
from typing import Callable, NamedTuple, List, Any

In [4]:
# Morpheus Version
from utils_morpheus import *
from ProgramSpace import *

In [5]:
torch.__version__

'1.0.0'

In [6]:
DBG_PS = None

In [7]:
class ValueEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValueEncoder, self).__init__()
        self.config = p_config
        
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv2d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
        
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"],
        )
        
    def forward(self, bp_map):
        # batched maps, (B, map_r, map_c)
        # in this version, every value only contains 1 map
        B = bp_map.shape[0]
        
        # (B, map_r, map_c, val_embd_dim) -> (B, val_embd_dim, map_r, map_c)
        d_embd = self.embedding(bp_map).permute(0,3,1,2)
        
        # (B, n_kernel, map_r, 1)
        d_conv = F.relu(self.conv(d_embd))
        
        # (B, n_kernel)
        d_pool = self.pool(d_conv).view(B,self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        d_out = torch.sigmoid(
            self.fc(d_pool)
        )
        
        return d_out
        

In [8]:
class MorphTransE(nn.Module):
    def __init__(self, p_config=None):
        super(MorphTransE, self).__init__()
        self.config = p_config
        
        self.value_encoder = ValueEncoder(p_config=p_config)
        
        self.fn_vocab_size = self.config["fn"]["vocab_size"]
        self.embd_dim = self.config["embd_dim"]
        
        self.fn_embedding = nn.Embedding(
            self.fn_vocab_size,
            self.embd_dim,
        )
        
        nn.init.xavier_uniform_(self.fn_embedding.weight.data)
        self.fn_embedding.weight.data = F.normalize(
            self.fn_embedding.weight.data, p=2, dim=1,
        )
# ----> skip the forward part since we don't need it <---- #

In [9]:
class TransNeo(nn.Module):
    def __init__(self, p_config=None):
        super(TransNeo, self).__init__()
        self.config = p_config
        
        # predict a fixed number of shells
        self.policy = nn.Linear(
            self.config["embd_dim"],
            self.config["fn"]["vocab_size"],
        )
        
        # deeper
#         self.policy0 = nn.Linear(
#             self.config["embd_dim"],
#             2048,
#         )
#         self.policy1 = nn.Linear(
#             2048,
#             self.config["fn"]["vocab_size"],
#         )
        
        self.mte = MorphTransE(p_config=p_config)
        # then load the parameters
        # self.mte.load_state_dict(torch.load(self.config["mte_path"]))
        
    def forward(self, p_mapin, p_mapout):
        # p_mapin/p_mapout: (B, map_r, map_c)
        v_in = self.mte.value_encoder(p_mapin) # (B, embd_dim)
        v_out= self.mte.value_encoder(p_mapout) # (B, embd_dim)
        v_delta = v_out - v_in
        tmp_out = torch.log_softmax(
            self.policy(v_delta),dim=1
        )
#         tmp_out = torch.log_softmax(
#             self.policy1(
#                 F.relu(
#                     self.policy0(
#                         v_delta
#                     )
#                 )
#             ),dim=1
#         )
        
        return tmp_out

In [10]:
def get_new_program(p_config, p_interpreter, p_generator):
    # initialize a program first
    while True:
        p_input = p_interpreter.random_table()
        p_prog, p_example = p_generator.generate(
            fixed_depth=p_config["fixed_depth"],
            example=Example(input=[p_input], output=None),
        )
        # make sure at least one function call
        if p_prog is not None and p_prog.is_apply():
            break
    return p_prog, p_example

# replace certain node id with certain value
def modify_shell(p_shell, p_id_from, p_id_to):
    d_prod = p_shell[0]
    d_rhs = p_shell[1]
    ld_rhs = [p_id_to if d_rhs[i]==p_id_from else d_rhs[i]
             for i in range(len(d_rhs))]
    return (d_prod, tuple(ld_rhs))

def TransNeoTrainer(p_config, p_spec, p_interpreter, p_generator, p_model, p_optim, p_writer):
    global DBG_PS
    nth_attempt = 0 # tell whether to back-prop or not
    batch_lossA_list = []
    batch_lossD_list = []
    
    n_solved = 0 # track the number of solved problem
    n_attempt_list = [] # track the number of attempts in every episode
    
    selected_neurons = []
    dead_neurons = [] # DeepPath: store node with execution error
    
    # in every episode, generate a new program(maze/ps) to learn
#     p_prog, p_example = get_new_program(
#         p_config, p_interpreter, p_generator,
#     )
    
    for d_episode in range(p_config["n_episode"]):
        p_model.train()
        
        # in every episode, generate a new program(maze/ps) to learn
        p_prog, p_example = get_new_program(
            p_config, p_interpreter, p_generator,
        )
        
        ps_solution = ProgramSpace(
            p_spec, p_interpreter, p_example.input, p_example.output,
        )
        ps_solution.init_by_prog(p_prog) # this constructs a solution for this problem
#         print("\n#### Episode Program Shells: {}".format(ps_solution.shells))
        DBG_PS = ps_solution
        
        is_solved = False
        
        for d_attempt in range(p_config["max_attempts"]):
            if is_solved:
                # already solved in the last attempt, stop
                break
            
            nth_attempt += 1
            attempt_reward = None
            
            # in every new attempt, initialize a new Program Space
            ps_current = ProgramSpace(
                p_spec, p_interpreter, p_example.input, p_example.output,
            )
            # then initialize a shell template
            tmp_shell_list = ps_current.get_neighboring_shells()
            tmp_node_to_replace = ps_current.node_dict["ParamNode"][0] # for chain only
            # replace the Param Node id in shells with -1 to make them templates
            template_list = [
                modify_shell(tmp_shell_list[i],tmp_node_to_replace,-1)
                for i in range(len(tmp_shell_list))
            ]
                
            d_step = 0
            # for d_step in range(p_config["max_steps"]):
            while d_step<p_config["max_steps"]:
                
                # print the training progress
                print("\r# AC/EP:{}/{}, AT:{}, SP:{}, DN:{}, avg.attempt:{:.2f}, er:{:.2f}".format(
                    n_solved, d_episode, d_attempt, d_step, 
                    len(dead_neurons),
                    sum(n_attempt_list)/len(n_attempt_list) if len(n_attempt_list)>0 else -1,
                    p_config["exploration_rate"](d_episode,d_attempt),
                ),end="")
                
                # ### assume chain execution, so only 1 possible returns
                # ### at d_step=0, this should be input[0]
                # print("frontiers:{}".format(ps_current.get_frontiers()))
                # input("PAUSE")
                id_current = ps_current.get_frontiers()[0]
                var_current = ps_current.node_list[id_current].ps_data # need the real var name in r env
                var_output = p_example.output
                
                map_current = camb_get_abs(var_current)
                map_output = camb_get_abs(var_output)
                
                # make current shell list
                current_shell_list = [
                    modify_shell(template_list[i],-1,id_current)
                    for i in range(len(template_list))
                ]
                
                # wrap in B=1
                if use_cuda:
                    td_current = Variable(torch.tensor([map_current],dtype=torch.long)).cuda()
                    td_output = Variable(torch.tensor([map_output],dtype=torch.long)).cuda()
                else:
                    td_current = Variable(torch.tensor([map_current],dtype=torch.long))
                    td_output = Variable(torch.tensor([map_output],dtype=torch.long))
                    
                # (B=1, fn_vocab_size)
                # fn_vocab_size
                td_pred = p_model(td_current, td_output)
                
                if random.random()<=p_config["hint_rate"](d_episode):
                    # print("hint!")
                    # give some hints
                    tmp_id = current_shell_list.index(ps_solution.shells[d_step])
                else:
                    # no hints
                    if random.random()<=p_config["exploration_rate"](d_episode,d_attempt):
                        # exploration
                        tmp_id = random.choice(range(len(current_shell_list)))
                    else:
                        # exploitation
                        tmp_id = torch.multinomial(td_pred.exp().flatten(), 1).cpu().flatten().numpy()[0]
                
                # print("shell to add:{}".format(current_shell_list[tmp_id]))
                # input("PAUSE-2")
                
                # update ps_current
                update_status = ps_current.add_neighboring_shell(
                    current_shell_list[tmp_id]
                )
                
                if update_status:
                    # record selected neuron
                    selected_neurons.append(td_pred[0,tmp_id])
                    d_step += 1
                    
                    # succeed
                    if ps_current.check_eq() is not None:
#                     if str(ps_current.node_list[-1])==str(ps_solution.node_list[-1]):
#                         print("# adding: {}, succeeded&correct".format(current_shell_list[tmp_id]))
                        # and solved!
                        is_solved = True
                        n_solved += 1
                        attempt_reward = 1.0
                        break
#                     else: 
                        # not yet solved, just move to next step
#                         print("# adding: {}, succeeded&wrong".format(current_shell_list[tmp_id]))
                else:
#                     print("# adding: {}, failed".format(current_shell_list[tmp_id]))
                    # DeepPath: fail, add to dead list
                    dead_neurons.append(td_pred[0,tmp_id])
                    # break
            
            # <END_FOR_STEP>
            
            # check the attempt_reward
            if attempt_reward is None:
                # means either failure in execution or exceeding max_step
                attempt_reward = -1.
            
            # compute the loss (sequential selected)
            # attempt_loss = 0.
            for i in range(len(selected_neurons)):
                d_decay = p_config["decay_rate"]**(len(selected_neurons)-1-i)
                # attempt_loss += d_decay*attempt_reward*(-selected_neurons[i])    
                batch_lossA_list.append(
                    d_decay*attempt_reward*(-selected_neurons[i]) 
                )
            # batch_loss += attempt_loss
            
            # compute the loss (dead neurons)
            # dead_loss = 0.
            for i in range(len(dead_neurons)):
                # dead_loss += (-1.)*(-dead_neurons[i])
                batch_lossD_list.append(
                    (-1.)*(-dead_neurons[i])
                )
            # batch_loss += dead_loss
            
            if is_solved or nth_attempt>=p_config["batch_size"]:
                # directly do the back-prop
                if len(batch_lossD_list)>0:
                    batch_lossD = sum(batch_lossD_list)/len(batch_lossD_list)
                    p_optim.zero_grad()
                    batch_lossD.backward()
                    p_optim.step()
            
                batch_lossA = sum(batch_lossA_list)/len(batch_lossA_list)
                p_optim.zero_grad()
                batch_lossA.backward()
                p_optim.step()
                
#                 print("\n##$$## Back-Prop!!! Loss:{}".format(batch_loss))
            
                nth_attempt = 0
                # batch_loss = 0.
                batch_lossA_list = []
                batch_lossD_list = []
                selected_neurons = []
                dead_neurons = []
                
        # <END_FOR_ATTEMPT>     
        
        # after all the attempts
        n_attempt_list.append(d_attempt)
        if writer is not None:
            writer.add_scalar(
                'avg.attempt',
                sum(n_attempt_list)/len(n_attempt_list) if len(n_attempt_list)>0 else 0,
                len(n_attempt_list),
            )
        
#         if d_episode%100==0:
#             # save the model
#             torch.save(
#                 p_model.state_dict(),
#                 "./saved_models/0713CAMB_RL2_camb3_ep{}.pt".format(d_episode)
#             )
            
    # <END_FOR_EPISODE>
    

In [11]:
m_interpreter = MorpheusInterpreter()
m_spec = S.parse_file('./example/camb3.tyrell')
m_generator = MorpheusGenerator(
    spec=m_spec,
    interpreter=m_interpreter,
    sfn=m_interpreter.sanity_check,
)
m_ps = ProgramSpace(
    m_spec, m_interpreter, [None], None,
)

m_config = {
    "val":{
        "vocab_size": len(CAMB_LIST),
        "embd_dim": 16, # embedding dim of CAMB abstract token
        "conv_n_kernels": 512,
        "conv_kernel_size": (1,CAMB_NCOL), 
        "pool_kernel_size": (CAMB_NROW,1), 
        "IDX_PAD": 0,
    },
    "fn":{
        "vocab_size": len(m_ps.get_neighboring_shells())
    },
    "embd_dim": 128,
    "mte_path": "./saved_models/0712CAMB_TransE_camb3_ep{}.pt".format(150),
    "batch_size": 1, # number of steps between every back-prop by n_attempt
    "fixed_depth": 2, # means size of 3
    "n_episode": 100000000,
    "max_attempts": 1, # max number of attepmts in every episode
    "max_steps": 1, # max number of function calls
    # "exploration_rate": 0.1, # fixed exp rate
    # "exploration_rate": lambda x:0.9-0.8*(min(1,x/2500)), # from 0.9 to 0.1
    # "exploration_rate": lambda xep,xat:(0.9-0.8*(min(1,xep/300)))*(1-xat/5), # from 0.9 to 0.1
    "exploration_rate": lambda xep,xat:0.1,
    "decay_rate": 0.9,
    # "hint_rate": lambda x:1 if x<1000 and x%2==0 else 0, # first 100 episodes have hints
    "hint_rate": lambda x:1 if x<10000 else 0, # first 100 episodes have hints
    # "hint_rate": lambda x:0,
}

trans_neo = TransNeo(p_config=m_config)
if use_cuda:
    trans_neo = trans_neo.cuda()
optimizer = torch.optim.Adam(list(trans_neo.parameters()))
# writer = SummaryWriter("runs/0713CAMB_RL2_camb3")
writer = None

In [12]:
trans_neo.mte.fn_embedding.weight

Parameter containing:
tensor([[ 0.1104, -0.0382,  0.1183,  ..., -0.0477,  0.1137, -0.1407],
        [ 0.1155, -0.1479,  0.0249,  ..., -0.0798,  0.0866,  0.0500],
        [ 0.0953, -0.0905, -0.0041,  ...,  0.0773, -0.0427, -0.0332],
        ...,
        [ 0.0135, -0.0911,  0.1300,  ...,  0.0313, -0.0339, -0.1344],
        [ 0.0006,  0.0730, -0.1305,  ..., -0.0735,  0.0657,  0.1509],
        [-0.0037,  0.0926,  0.0345,  ...,  0.0767,  0.1287, -0.0668]],
       device='cuda:0', requires_grad=True)

In [13]:
m_config

{'val': {'vocab_size': 150,
  'embd_dim': 16,
  'conv_n_kernels': 512,
  'conv_kernel_size': (1, 15),
  'pool_kernel_size': (15, 1),
  'IDX_PAD': 0},
 'fn': {'vocab_size': 120},
 'embd_dim': 128,
 'mte_path': './saved_models/0712CAMB_TransE_camb3_ep150.pt',
 'batch_size': 1,
 'fixed_depth': 2,
 'n_episode': 100000000,
 'max_attempts': 1,
 'max_steps': 1,
 'exploration_rate': <function __main__.<lambda>(xep, xat)>,
 'decay_rate': 0.9,
 'hint_rate': <function __main__.<lambda>(x)>}

In [14]:
TransNeoTrainer(m_config, m_spec, m_interpreter, m_generator, trans_neo, optimizer, writer)

# AC/EP:11684/20434, AT:0, SP:0, DN:15, avg.attempt:0.00, er:0.10

KeyboardInterrupt: 