In [None]:
import gym
import numpy as np

## NasBench101

In [None]:
from nasbench import api
dataset = api.NASBench("/scratch2/sem22hs2/nasbench_full.tfrecord")

In [None]:
def objective_function(adjacency_mat,labeling, budget=108):
    labeling = ['input'] + list(labeling) + ['output']
    model_spec = api.ModelSpec(adjacency_mat, labeling)
    try:
        data = dataset.query(model_spec, epochs=budget)
    except api.OutOfDomainError:
        # self.record_invalid(adjacency_mat, labeling, 1, 1, 0)
        return 0, 0

    # self.record_valid(adjacency_mat, labeling, data, model_spec)
    return data["validation_accuracy"], data["training_time"]

In [None]:
# Todo: sample architecture
# Todo: Check architecture validity. Neg reward if not.

class NasBench101(gym.Env):
    metadata = {"render_modes": [], "render_fps": 1}
    def __init__(self, v=7, e=9, ops=['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3'], step_max=1000, 
                 render_mode: Optional[str] = None):
        assert render_mode is None # or render_mode in self.metadata["render_modes"]
        # Environment definition
        self.max_edges = e
        self.vertices = v
        self.ops = ops

        # Current state
        self.adjacency_mat = np.zeros([v,v])
        self.labeling = (v-2)*[ops[0]] # Initialize op for all layers that are not input or output layer
        
        # Helper
        self.idx_upper = np.triu_indices(v) # Indices of upper triangular matrix

        self.num_step = 0
        self.step_max =  step_max
        num_indecies_triu = len(self.idx_upper[0])
        self.observation_space = spaces.Dict(
            {
                "adjacency_mat": spaces.MultiBinary(num_indecies_triu),
                "labels": spaces.MultiDiscrete(np.array((v-2)*[len(ops)])),
            }
        )

        
    def step(self, action):
        e=self.max_edges
        v=self.vertices
        n = (v*(v+1)/2) # Number of indices in upper triag. part of matrix
        if action < n:
            # Todo: Check this changes matrix at right place
            iu = self.idx_upper
            self.adjacency_mat[iu[0][action],iu[1][action]] = not self.adjacency_mat[iu[0][action],iu[1][action]]
        else:
            o=len(self.ops)
            action = action - n
            [label_row, op] = np.unravel_index(action,[v, o])
            self.labeling[label_row] = op

        y, c = objective_function(self.adjacency_mat, self.labeling)
        reward = y
        if self.step == self.step_max:
            done = 1
        else: done = 0

        observation = None
        info = None
        return observation, reward, done, info

    def reset(self):
        self.adjacency_mat[self.idx_upper] = np.random.randint(0,2,len(self.idx_upper))
        self.labeling = np.random.randint(0,3,len(self.labeling))