## PolyNeo
- Introducing a poly policy for fast adaptation
- Stage: Cambrian
- Version: Inaria
- Update Logs
    - 0713: with DeepPath style rollback at training
    - 0716: new learning paradigm, see memo for details
    - 0724: poly structure to fit fast online adaptation

#### Related Commands
- tensorboard --logdir runs
- nohup jupyter lab > jupyter.log &

In [1]:
import logging 
logging.basicConfig(level=logging.CRITICAL)

In [2]:
import os
import itertools
import copy
import random
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# from tensorboardX import SummaryWriter

use_cuda = torch.cuda.is_available()
print("use_cuda: {}".format(use_cuda))

use_cuda: True


In [3]:
import tyrell.spec as S
from tyrell.decider import Example

# Morpheus Version
from MorpheusInterpreter import *
from ProgramSpace import *

In [4]:
torch.__version__

'1.0.0'

In [5]:
class ListModule(object):
    def __init__(self, module, prefix, *args):
        self.module = module
        self.prefix = prefix
        self.num_module = 0
        for new_module in args:
            self.append(new_module)
    
    def append(self, new_module):
        if not isinstance(new_module, nn.Module):
            raise ValueError('Not a Module')
        else:
            self.module.add_module(self.prefix + str(self.num_module), new_module)
            self.num_module += 1
            
    def __len__(self):
        return self.num_module
    
    def __getitem__(self, i):
        if i<0 or i>=self.num_module:
            raise IndexError('Out of bound')
        return getattr(self.module, self.prefix+str(i))

In [6]:
class ValueEncoder(nn.Module):
    def __init__(self, p_config=None):
        super(ValueEncoder, self).__init__()
        self.config = p_config
        
        self.vocab_size = self.config["val"]["vocab_size"]
        self.embd_dim = self.config["val"]["embd_dim"]
        self.embedding = nn.Embedding(
            self.vocab_size,
            self.embd_dim,
            self.config["val"]["IDX_PAD"],
        )
        
        self.conv = nn.Conv2d(
            in_channels = self.config["val"]["embd_dim"],
            out_channels = self.config["val"]["conv_n_kernels"],
            kernel_size = self.config["val"]["conv_kernel_size"],
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = self.config["val"]["pool_kernel_size"],
            padding = self.config["val"]["IDX_PAD"],
        )
        
        self.fc = nn.Linear(
            self.config["val"]["conv_n_kernels"],
            self.config["embd_dim"],
        )
        
    def forward(self, bp_map):
        # batched maps, (B, map_r, map_c)
        # in this version, every value only contains 1 map
        B = bp_map.shape[0]
        
        # (B, map_r, map_c, val_embd_dim) -> (B, val_embd_dim, map_r, map_c)
        d_embd = self.embedding(bp_map).permute(0,3,1,2)
        
        # (B, n_kernel, map_r, 1)
        d_conv = F.relu(self.conv(d_embd))
        
        # (B, n_kernel)
        d_pool = self.pool(d_conv).view(B,self.config["val"]["conv_n_kernels"])
        
        # (B, embd_dim)
        d_out = F.relu(
            self.fc(d_pool)
        )
        
        return d_out
        

In [7]:
class PolyNeo(nn.Module):
    def __init__(self, p_config=None):
        super(PolyNeo, self).__init__()
        self.config = p_config
        
        self.value_encoder = ValueEncoder(self.config)
        self.fn_embedding = nn.Embedding(
            self.config["fn"]["vocab_size"]+1, # +1 for pad
            self.config["embd_dim"],
            padding_idx=self.config["fn"]["vocab_size"], # NOTICE: use +1 as padding
        )
        self.core_policy = nn.Linear(
            self.config["embd_dim"],
            self.config["fn"]["vocab_size"],
        )
        self.poly_policy = nn.Linear(
            self.config["fn"]["vocab_size"] + \
            (self.config["adaptation"]["maxn_step"]-1) * self.config["embd_dim"],
            self.config["fn"]["vocab_size"],
        )
        
    def forward(self, p_mapin, p_mapout, p_fns):
        # p_mapin/p_mapout: (B=1, map_r, map_c)
        # p_fns: (B=1, #f_called)
        # p_poly: int, step/#f_called
        B = p_mapin.shape[0]
        
        v_in = self.value_encoder(p_mapin)
        v_out= self.value_encoder(p_mapout)
        v_delta = v_out-v_in # (B=1, embd_dim)
        v_core = F.relu(
            self.core_policy(
                v_delta,
            )
        ) # (B=1, fn_vocab_size)
        
        v_fns = self.fn_embedding(p_fns) # (B=1, maxn_step, embd_dim)
        vi_fns = v_fns.view(B, -1) # (B=1, maxn_step*embd_dim)
        
        v_po = torch.cat([v_core,vi_fns],dim=1)
        # (B=1, fn_vocab_size+maxn_step*embd_dim)
        
        v_pred = self.poly_policy(
            v_po,
        )
        # (B=1, fn_vocab_size)
        
        # didn't apply any activation
        return F.log_softmax(
            v_pred, dim=1
        )
    
    def pretrain(self, p_mapin, p_mapout):
        # p_mapin/p_mapout: (B=1, map_r, map_c)
        B = p_mapin.shape[0]
        
        v_in = self.value_encoder(p_mapin)
        v_out= self.value_encoder(p_mapout)
        v_delta = v_out-v_in # (B=1, embd_dim)
        v_core = F.log_softmax(
            self.core_policy(
                v_delta,
            ), dim=1
        ) # (B=1, fn_vocab_size)
        
        return v_core

In [8]:
# replace certain node id with certain value
def modify_shell(p_shell, p_id_from, p_id_to):
    d_prod = p_shell[0]
    d_rhs = p_shell[1]
    ld_rhs = [p_id_to if d_rhs[i]==p_id_from else d_rhs[i]
             for i in range(len(d_rhs))]
    return (d_prod, tuple(ld_rhs))

In [9]:
'''
meta-train the agent in a supervised way
epoch -> episode, one attempt with hint
NOTICE: only valid for size 1 training
'''
def Pretrain(p_config, p_spec, p_interpreter, p_model, p_data, p_optim, p_writer):
    print("# Start Pretraining...")
    for d_epoch in range(p_config["pretrain"]["n_epoch"]):
        p_model.train()
        
        epoch_loss_list = []
        batch_loss_list = []
        random.shuffle(p_data)
        train_data = p_data[:p_config["pretrain"]["n_truncated"]]
        
        for d_ind in range(len(train_data)):
            print("\r# epoch:{}, index:{}/{}, avg.loss:{:.2f}".format(
                d_epoch, d_ind, len(train_data),
                sum(epoch_loss_list)/len(epoch_loss_list)
                if len(epoch_loss_list)>0 else 0,
            ),end="")
            d_prog, dstr_example = train_data[d_ind]
            d_example = Example(
                input=[
                    p_interpreter.load_data_into_var(p)
                    for p in dstr_example.input
                ],
                output=p_interpreter.load_data_into_var(
                    dstr_example.output
                )
            )
            
            # initialize a solution
            ps_solution = ProgramSpace(
                p_spec, p_interpreter, d_example.input, d_example.output,
            )
            ps_solution.init_by_prog(d_prog) # this constructs a solution for this problem
            
            # initialize a new ProgramSpace
            ps_current = ProgramSpace(
                p_spec, p_interpreter, d_example.input, d_example.output,
            )
            # then initialize a shell template
            tmp_shell_list = ps_current.get_neighboring_shells()
            tmp_node_to_replace = ps_current.node_dict["ParamNode"][0] # for chain only
            # replace the Param Node id in shells with -1 to make them templates
            template_list = [
                modify_shell(tmp_shell_list[i],tmp_node_to_replace,-1)
                for i in range(len(tmp_shell_list))
            ]
            
            id_current = ps_current.get_strict_frontiers()[0]
            # make current shell list
            current_shell_list = [
                modify_shell(template_list[i],-1,id_current)
                for i in range(len(template_list))
            ]
            
            var_input = ps_current.inputs[0]
            var_output = ps_current.output
            map_input = p_interpreter.camb_get_abs(var_input)
            map_output = p_interpreter.camb_get_abs(var_output)
            # wrap in B=1
            if use_cuda:
                td_input = Variable(torch.tensor([map_input],dtype=torch.long)).cuda()
                td_output = Variable(torch.tensor([map_output],dtype=torch.long)).cuda()
            else:
                td_input = Variable(torch.tensor([map_input],dtype=torch.long))
                td_output = Variable(torch.tensor([map_output],dtype=torch.long))

            # (B=1, fn_vocab_size)
            td_pred = p_model.pretrain(td_input, td_output)
            # directly give the hint / supervised, ps.solution.shell[0] works for 1
            tmp_id = current_shell_list.index(ps_solution.shells[0])
            d_loss = (+1)*(-td_pred[0,tmp_id])
            batch_loss_list.append(
                d_loss, # supervised / always correct with +1 reward
            )
            epoch_loss_list.append(
                d_loss.cpu().data.numpy(),
            )
            
            if len(batch_loss_list)%p_config["pretrain"]["batch_size"]==0 or len(batch_loss_list)==len(train_data):
                # do back-prop.
                if len(batch_loss_list)>0:
                    batch_loss = sum(batch_loss_list)/len(batch_loss_list)
                    p_optim.zero_grad()
                    batch_loss.backward()
                    p_optim.step()
                # after back-prop., clean up
                batch_loss = None
                batch_loss_list = []
                
        print()
    

In [10]:
'''
meta-test an agent, directly run into testing / online adaptation
'''
def Adaptation(p_config, p_spec, p_interpreter, p_generator, p_model, p_optim, p_writer):
    print("# Start Adaptation...")
    
    n_solved = 0 # track the number of solved problem
    n_attempt_list = [] # track the number of attempts in every episode
    
    for d_episode in range(p_config["adaptation"]["n_episode"]):
        
        # retrieve the given meta-trained model for testing
        test_model = copy.deepcopy(p_model)
        test_model.train()
        
        # if doing random meta-testing
        # then randomly generate a program for testing
        ps_solution = p_generator.get_new_chain_program(
            p_config["adaptation"]["fixed_depth"],
        )
        # print("# Problem: {}".format(str(ps_solution.node_list[-1])))
        
        is_solved = False
        for d_attempt in range(p_config["adaptation"]["maxn_attempt"]):
            
            attempt_reward = None
            
            # in every new attempt, initialize a new Program Space
            ps_current = ProgramSpace(
                p_spec, p_interpreter, ps_solution.inputs, ps_solution.output,
            )
            # then initialize a shell template
            tmp_shell_list = ps_current.get_neighboring_shells()
            tmp_node_to_replace = ps_current.node_dict["ParamNode"][0] # for chain only
            # replace the Param Node id in shells with -1 to make them templates
            template_list = [
                modify_shell(tmp_shell_list[i],tmp_node_to_replace,-1)
                for i in range(len(tmp_shell_list))
            ]
            
            n_dead = 0
            n_sanity = 0
            d_step = 0
            selected_neurons = []
            selected_functions = [
                p_config["fn"]["vocab_size"]
                for _ in range(p_config["adaptation"]["maxn_step"]-1)
            ]
            
            var_input = ps_current.inputs[0]
            var_output = ps_current.output
            map_input = p_interpreter.camb_get_abs(var_input)
            map_output = p_interpreter.camb_get_abs(var_output)
            # wrap in B=1
            if use_cuda:
                td_input = Variable(torch.tensor([map_input],dtype=torch.long)).cuda()
                td_output = Variable(torch.tensor([map_output],dtype=torch.long)).cuda()
            else:
                td_input = Variable(torch.tensor([map_input],dtype=torch.long))
                td_output = Variable(torch.tensor([map_output],dtype=torch.long))
            
            while d_step<p_config["adaptation"]["maxn_step"]:
                
                # print the training progress
                print("\r# AC/EP:{}/{}, AT:{}, SP:{}, ND:{}, NS:{}, avg.attempt:{:.2f}".format(
                    n_solved, d_episode, d_attempt, d_step,
                    n_dead, n_sanity,
                    sum(n_attempt_list)/len(n_attempt_list) if len(n_attempt_list)>0 else -1,
                ),end="")
                
                # ### assume chain execution, so only 1 possible returns
                # ### at d_step=0, this should be input[0]
                id_current = ps_current.get_strict_frontiers()[0]
                # make current shell list
                current_shell_list = [
                    modify_shell(template_list[i],-1,id_current)
                    for i in range(len(template_list))
                ]
                
                # wrap in B=1
                if use_cuda:
                    td_fns = Variable(torch.tensor([selected_functions],dtype=torch.long)).cuda()
                else:
                    td_fns = Variable(torch.tensor([selected_functions],dtype=torch.long))
                    
                # (B=1, fn_vocab_size)
                td_pred = test_model(td_input, td_output, td_fns)
                
                # no hints
                if random.random()<=p_config["adaptation"]["exploration_rate"]:
                    # exploration
                    tmp_id = random.choice(range(len(current_shell_list)))
                else:
                    # exploitation
                    tmp_id = torch.multinomial(td_pred.exp().flatten(), 1).cpu().flatten().numpy()[0]
                
                # update ps_current
                ps_backup = ps_current.make_copy() # supports undo for failed sanity_check
                update_status = ps_current.add_neighboring_shell(
                    current_shell_list[tmp_id]
                )
                
                if update_status:
                    
                    # succeed
                    if ps_current.check_eq() is not None:
                        # and solved!
                        d_step += 1
                        is_solved = True
                        n_solved += 1
                        attempt_reward = 1. # useless, but still attach it
                        break
                    else:
                        # do sanity check
                        check_current = p_interpreter.sanity_check(ps_current)
                        if check_current[0]:
                            d_step += 1
                            selected_neurons.append(td_pred[0,tmp_id])
                            selected_functions.append(tmp_id)
                            selected_functions = selected_functions[1:]
                            attempt_reward = -1. # this is temporal, may change later
                            continue
                        else:
                            # Inaria: fail the sanity check, bp immediately and restart the same step
                            dead_loss = (-1.)*td_pred[0,tmp_id]
                            p_optim.zero_grad()
                            dead_loss.backward()
                            p_optim.step()
                            dead_loss = None
                            
                            # between this, undo the ProgramSpace
                            ps_current = ps_backup
                            
                            n_sanity += 1
                            if n_sanity>=p_config["adaptation"]["maxn_sanity"]:
                                # reach the limit, perhaps all are dead, restart the attempt
                                attempt_reward = -0.5
                                break
                            else:
                                continue
                            
                else:
                    # Inaria: fail, back prop **immediately** and restart the **same step**
                    dead_loss = (-1.)*td_pred[0,tmp_id]
                    p_optim.zero_grad()
                    dead_loss.backward()
                    p_optim.step()
                    dead_loss = None
                    
                    n_dead += 1
                    if n_dead>=p_config["adaptation"]["maxn_dead"]:
                        # reach the limit, perhaps all are dead, restart the attempt
                        attempt_reward = -0.5
                        break
                        # and then still need to deal with bp of previous choices
                    else:
                        continue
            
            # <END_FOR_STEP>
            
            
            if is_solved:
                # already solved in the last attempt, stop
                n_attempt_list.append(d_attempt)
                # print("Solution: {}".format(ps_current.node_list[-1]))
                break
            
            if len(selected_neurons)>0:
                attempt_loss = 0.
                # compute the loss (sequential selected)
                for i in range(len(selected_neurons)):
                    d_decay = p_config["adaptation"]["decay_rate"]**(len(selected_neurons)-1-i)
                    attempt_loss += d_decay*attempt_reward*(-selected_neurons[i]) 
                p_optim.zero_grad()
                attempt_loss.backward()
                p_optim.step()
                attepmt_loss = None
            # else: do nothing
                
        # <END_FOR_ATTEMPT>     
            
    # <END_FOR_EPISODE>
    

In [11]:
m_interpreter = MorpheusInterpreter()
m_spec = S.parse_file('./example/camb3.tyrell')
m_generator = MorpheusGenerator(
    spec=m_spec,
    interpreter=m_interpreter,
)

# dumb variable to help infer the shells
m_ps = ProgramSpace(
    m_spec, m_interpreter, [None], None,
)

m_config = {
    # ==== TransE Setting ==== #
    "val":{
        "vocab_size": len(m_interpreter.CAMB_LIST),
        "embd_dim": 16, # embedding dim of CAMB abstract token
        "conv_n_kernels": 512,
        "conv_kernel_size": (1,m_interpreter.CAMB_NCOL), 
        "pool_kernel_size": (m_interpreter.CAMB_NROW,1), 
        "IDX_PAD": 0,
    },
    "fn":{
        "vocab_size": len(m_ps.get_neighboring_shells())
    },
    "embd_dim": 128,
    "pretrain":{
        "n_epoch": 10,
        "batch_size": 4, # how many indices
        "data_path": "./0716MDsize1.pkl",
        "n_truncated": 1000,
    },
    "adaptation":{
        "n_episode": 100000,
        "fixed_depth": 2,
        "maxn_attempt": 100,
        "maxn_step": 1, # program size
        "maxn_dead": 50,
        "maxn_sanity": 50,
        "exploration_rate": 0,
        "decay_rate": 0.9,
    },
}

# load the size 1 supervised data
with open(m_config["pretrain"]["data_path"],"rb") as f:
    dt_data = pickle.load(f)
m_data = [
    dt_data[dkey][i]
    for dkey in dt_data.keys()
    for i in range(len(dt_data[dkey]))
]
print("# Total Meta-Train Data: {}".format(len(m_data)))

poly_neo = PolyNeo(p_config=m_config)
if use_cuda:
    poly_neo = poly_neo.cuda()
optimizer = torch.optim.Adam(list(poly_neo.parameters()))

# writer = SummaryWriter("runs/0713CAMB_RL2_camb3")
writer = None

# Total Meta-Train Data: 77038


In [12]:
m_config

{'val': {'vocab_size': 150,
  'embd_dim': 16,
  'conv_n_kernels': 512,
  'conv_kernel_size': (1, 15),
  'pool_kernel_size': (15, 1),
  'IDX_PAD': 0},
 'fn': {'vocab_size': 120},
 'embd_dim': 128,
 'pretrain': {'n_epoch': 10,
  'batch_size': 4,
  'data_path': './0716MDsize1.pkl',
  'n_truncated': 1000},
 'adaptation': {'n_episode': 100000,
  'fixed_depth': 2,
  'maxn_attempt': 100,
  'maxn_step': 1,
  'maxn_dead': 50,
  'maxn_sanity': 50,
  'exploration_rate': 0,
  'decay_rate': 0.9}}

In [13]:
Pretrain(m_config, m_spec, m_interpreter, poly_neo, m_data, optimizer, writer)

# Start Pretraining...
# epoch:0, index:999/1000, avg.loss:4.37
# epoch:1, index:999/1000, avg.loss:3.72
# epoch:2, index:999/1000, avg.loss:3.28
# epoch:3, index:999/1000, avg.loss:3.10
# epoch:4, index:999/1000, avg.loss:2.98
# epoch:5, index:999/1000, avg.loss:2.83
# epoch:6, index:999/1000, avg.loss:2.81
# epoch:7, index:999/1000, avg.loss:2.73
# epoch:8, index:999/1000, avg.loss:2.64
# epoch:9, index:999/1000, avg.loss:2.65


In [14]:
Adaptation(m_config, m_spec, m_interpreter, m_generator, poly_neo, optimizer, writer)

# Start Adaptation...
# AC/EP:14/17, AT:25, SP:0, ND:1, NS:1, avg.attempt:33.07

KeyboardInterrupt: 