In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import tqdm
from torch.autograd import Variable

In [None]:
#Class that defines the search space
class MLPSearchSpace(object):
  def __init__(self):
    self.vocab = self.vocab_dict()
  def vocab_dict(self):
    #could be modified to contain cnn or more complicate structure
    nodes = [8,16,32,64,128,256,512,1024]
    act_funcs = ['relu','elu','tanh']

    layer_params = []
    layer_id = []

    for i in range(len(nodes)):
      for j in range(len(act_funcs)):
        layer_params.append((nodes[i],act_funcs[j]))
        layer_id.append(len(act_funcs)*i+j)
    vocab = dict(zip(layer_id,layer_params))
    return vocab
  #randomly sample architecture from the search space. The default number of sampled 
  #architecture is 10
  def random_sample_architecture(self,architecture_num = 10):
    search_space_size = len(self.vocab)
    architecture_space = []
    for i in range(architecture_num):
      layer_num = random.randint(3, 10)
      sequence = [random.randint(0, search_space_size-1) for i in range(layer_num)]
      architecture_space.append(sequence)
    return architecture_space

  #Encode architecture(list) into numerical sequence
  def encode_sequence(self, sequence):
    keys = list(self.vocab.keys())
    values = list(self.vocab.values())
    encoded_sequence = []
    for value in sequence:
      encoded_sequence.append(keys[values.index(value)])
    return encoded_sequence
  
  #Decode numerical sequence into architecture(list)
  def decode_sequence(self, sequence):
    keys = list(self.vocab.keys())
    values = list(self.vocab.values())
    decoded_sequence = []
    for key in sequence:
      decoded_sequence.append(values[keys.index(key)])
    return decoded_sequence

In [None]:
#class to generate model based on the given architecture
class MLPGenerator(MLPSearchSpace):
  def __init__(self):
    super().__init__()
    self.mlp_one_shot = True
    self.mlp_optimizer = 'Adam'
    self.mlp_lr = 1e-4
    self.mlp_loss_func = 'mse'
  
  #create model based on the sequence. User can define the input size
  def create_model(self,sequence,mlp_input_shape = 3):
    layer_configs = self.decode_sequence(sequence)
    #layer_configs = sequence
    layers = nn.Sequential()
    previous_layer_num = 0
    for i,layer_conf in enumerate(layer_configs):
      if i == 0:
        layers.append(nn.Linear(mlp_input_shape,layer_conf[0]))
        if layer_configs[i][1] == 'relu':
          layers.append(nn.ReLU())
        elif layer_configs[i][1] == 'elu':
          layers.append(nn.ELU())
        else:
          layers.append(nn.Tanh())
        previous_layer_num = layer_conf[0]
      else:
        layers.append(nn.Linear(previous_layer_num,layer_conf[0]))
        previous_layer_num = layer_conf[0]
        if layer_configs[i][1] == 'relu':
          layers.append(nn.ReLU())
        elif layer_configs[i][1] == 'elu':
          layers.append(nn.ELU())
        else:
          layers.append(nn.Tanh())
    layers.append(nn.Linear(previous_layer_num,2))
    return nn.Sequential(*layers)

  #Train the PINN model
  def train_PINN(self,model,X_pinn = X_pinn, X_semigroup = X_semigroup, 
                 X_smooth = X_smooth,T = T):
    pinn_model = PINN(X_pinn, X_semigroup, X_smooth, T, model)
    early_loss = pinn_model.train()
    return early_loss

  #low-fidelity training for one genertaed_sequence
  def train_model(self,model,train_dataloader):
    #decoded_sequence = self.decode_sequence(encoded_sequence)
    #model = self.create_model(encoded_sequence,3)
    #one shot training -- with only one epoch
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.float()
      optimizer.zero_grad()
      outputs = model(inputs)
      loss = loss_function(outputs, targets)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
    return running_loss


  #low-fidelity training for all generate_sequence
  #TODO: modify to adapt to semigroup PINN
  def low_fidelity_evaluation(self,train_dataloader,sample_space):
    if not sample_space:
      search_space = self.random_sample_architecture()
    else:
      search_space = sample_space
    model_train_log = {}
    model_eval_log = {}
    architecture_history = {}
    for i,encoded_sequence in enumerate(search_space):
      print("training architecture: ", i)
      decoded_sequence = self.decode_sequence(encoded_sequence)
      model = self.create_model(encoded_sequence,3)
      #low fidelity training -- with only one epoch
      loss_function = nn.MSELoss()
      optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
      model.train()
      running_loss = 0.0
      for i, data in enumerate(train_dataloader, 0):
        # Get and prepare inputs
        inputs, targets = data
        inputs, targets = inputs.float(), targets.float()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 199:
          print(f"batch {i + 1}] loss: {running_loss / 200:.3f}")
          running_loss = 0.0

      print("training loss: ",running_loss)
      model_train_log[f'{decoded_sequence}'] = running_loss
      architecture_history[f'{encoded_sequence}'] = running_loss
    return model_train_log,architecture_history

In [None]:
#Create the controller model based a on 1-layer LSTM
class Controller_Model(nn.Module):
  def __init__(self,input_size,hidden_size,num_class,num_layers = 1):
    super().__init__()
    self.lstm = nn.LSTM(input_size,hidden_size,num_layers)
    self.sequence_generator = nn.Sequential(
        nn.Linear(hidden_size,256),
        nn.ReLU(),
        nn.Linear(256,num_class),
        nn.Softmax()
    )
  def forward(self,input):
    output,_ = self.lstm(input)
    generated_sequence = self.sequence_generator(output)
    return generated_sequence

#Class to define the controller
class Controller(MLPSearchSpace):
  def __init__(self,max_architecture_length):
    super().__init__()
    self.max_length = max_architecture_length
    self.controller_classes = len(self.vocab)
    self.sequence_data = []
  #create the controller model
  def control_model(self):
    model = Controller_Model(self.max_length,10,num_class = self.controller_classes)
    return model

  #train the controller model
  def train_control_model(self,model,x_data,y_data,loss_func,controller_training_epoch):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    #print("Training Controller Model")
    for epoch in range(controller_training_epoch):
      print("Training Controller Model in epoch: ",epoch)
      optimizer.zero_grad()
      outputs = model(x_data)
      loss = loss_func(y_data,outputs)
      loss.sum().backward()
      optimizer.step()
  #process for controller to generate architecture sequence by the given length
  def sample_architrecture_sequences(self,model,number_of_samples,
                                     by_prob = True):
    samples = []
    print("Generate Architecture Samples...")
    print('--------------------------------')

    while len(samples) < number_of_samples:
      seed = []
      #number of layers we want in the architecture
      while len(seed) < self.max_length:
        seed_pad = np.pad(seed,pad_width=(0,self.max_length-len(seed)))
        sequence = torch.Tensor(seed_pad)
        sequence = sequence.reshape(1,len(sequence))
        logit = model(sequence)
        if by_prob == True:
          #sample next architecture code by the probabilities returned by controller model
          proba = logit.detach().numpy()[0]
          next = np.random.choice(list(self.vocab.keys()),size=1,p=proba).item()
        else:
          next = torch.argmax(logit).detach().item()
        seed.append(next)
      #only record architecture that has not been generated before
      if seed not in self.sequence_data:
        samples.append(seed)
        self.sequence_data.append(seed)
    return samples

In [None]:
#
class MLPNAS(Controller):
  def __init__(self,max_architecture_length):
    super().__init__(max_architecture_length)
    self.model_generator = MLPGenerator()
    self.controller_model = self.control_model()
    self.controller_sampling_epochs = 3
    self.controller_loss_alpha = 0.9
    self.samples_per_controller_epoch = 10
    self.controller_training_epoch = 3
    self.data = []
  #record the architecture evaluation history
  def append_model_metrics(self,sequence,history):
    self.data.append([sequence,history])

  #use REINFORCE to train controller
  def prepare_controller_data(self,sequences):
    def to_categorical(y, num_classes):
      return np.eye(num_classes, dtype='uint8')[y]
    x = torch.Tensor(sequences)[:,:-1].reshape(len(sequences),1,self.max_length-1)
    padding = torch.zeros([len(sequences),1,1])
    x_padded = torch.concat([x,padding],dim=2)
    y = torch.Tensor(sequences)[:,-1].type(torch.int64)
    y = to_categorical(y,self.controller_classes)

    val_loss_target = [item[1] for item in self.data[-self.samples_per_controller_epoch:]]
    return x_padded,y,val_loss_target

  #Get the discounted reward for the REINFORCE algorithm
  def get_discounted_reward(self,rewards):
    discounted_r = np.zeros_like(rewards,dtype=np.float32)
    for t in range(len(rewards)):
      running_add = 0
      exp = 0
      for r in rewards[t:]:
        running_add += self.controller_loss_alpha**exp*r
        exp += 1
      discounted_r[t] = running_add
    discounted_r = (discounted_r - discounted_r.mean())/discounted_r.std()
    return discounted_r

  #compute the custom loss from the expected reward 
  def custom_loss(self,target,output):
    reward = np.array([item[1] for item in self.data[-self.samples_per_controller_epoch:]]).reshape(self.samples_per_controller_epoch, 1)
    discounted_reward = self.get_discounted_reward(reward)
    output = torch.Tensor(output.detach().numpy())
    loss = -torch.log(output)*discounted_reward[:,None]
    loss = Variable(loss, requires_grad=True)
    return loss

  #Search architecture for the baseline model(using dataset and dataloader)
  def search(self,input_size,train_dataloader,REINFORCE = False):
    for controller_epoch in range(self.controller_sampling_epochs):
      sequences = self.sample_architrecture_sequences(self.controller_model,number_of_samples=self.samples_per_controller_epoch)
      print("Evaluating architectures in controller_sampling_epoch: ",controller_epoch)
      for i,sequence in enumerate(sequences):
        print("  training architecture: ",i)
        #train and log architecture
        decoded_sequence = self.model_generator.decode_sequence(sequence)
        model = self.model_generator.create_model(sequence,mlp_input_shape = input_size)
        history = self.model_generator.train_model(model,train_dataloader)
        self.append_model_metrics(decoded_sequence,history)

      if REINFORCE == True:
        if controller_epoch != self.controller_sampling_epochs - 1:
          print("Training Controller...")
          #train controller
          x,y,val_acc_target = self.prepare_controller_data(sequences)
          self.train_control_model(self.controller_model,x,y,self.custom_loss,self.controller_training_epoch)
      
    return self.data

  #Search architecture for the PINN model(does not use dataloader)
  def search_PINN(self,input_size,REINFORCE = False):
    for controller_epoch in range(self.controller_sampling_epochs):
      sequences = self.sample_architrecture_sequences(self.controller_model,number_of_samples=self.samples_per_controller_epoch)
      print("Evaluating architectures in controller_sampling_epoch: ",controller_epoch)
      for i,sequence in enumerate(sequences):
        print("  training architecture: ",i)
        #train and log architecture
        decoded_sequence = self.model_generator.decode_sequence(sequence)
        model = self.model_generator.create_model(sequence,mlp_input_shape = input_size)
        history = self.model_generator.train_PINN(model)

        self.append_model_metrics(decoded_sequence,history)

      if REINFORCE == True:
        if controller_epoch != self.controller_sampling_epochs - 1:
          print("Training Controller...")
          #train controller
          x,y,val_acc_target = self.prepare_controller_data(sequences)
          self.train_control_model(self.controller_model,x,y,self.custom_loss,self.controller_training_epoch)
      
    return self.data