In [1]:
from naslib.search_spaces.core.graph import Graph, EdgeData
from naslib.search_spaces.core import primitives as ops
from naslib.search_spaces.core.primitives import AbstractPrimitive
from naslib.optimizers import DARTSOptimizer
from naslib.utils import utils, setup_logger, get_config_from_args, set_seed, log_args
from torch import nn
import torch
from IPython.display import clear_output
import logging
from copy import deepcopy

device: cuda:0
device: cpu
device: cuda:0
device: cuda:0
device: cuda:0
device: cuda:0


In [2]:
class DartsSearchSpaceA(Graph):

    OPTIMIZER_SCOPE = [
        'a_stage_1',
        'a_stage_2'
    ]

    QUERYABLE = False

    def __init__(self):
        super().__init__()

        channels = [(16 * 5 * 5, 120), (120, 84), (84, 10)]
        stages = ['a_stage_1', 'a_stage_2', 'a_stage_3']

        # cell definition
        activation_cell = Graph()
        activation_cell.name = 'activation_cell'
        activation_cell.add_node(1,node_name="input") # input node
        activation_cell.add_node(2,node_name="unarry1") # unarry node 1
        activation_cell.add_node(3,node_name="unarry2") # unarry node 2
        activation_cell.add_node(4,node_name="recombine") # recombination node
        activation_cell.nodes[4]['comb_op']=stack()
        activation_cell.add_node(5,node_name="ouptut") # output
        activation_cell.add_edges_from([(1, 2, EdgeData({"edge_name":"unary 1"}))]) # unary op 1
        activation_cell.add_edges_from([(1, 3, EdgeData({"edge_name":"unary 2"}))]) # unary op2
        activation_cell.add_edges_from([(2, 4, EdgeData({"edge_name":"identity1"}))]) # identity
        activation_cell.add_edges_from([(3, 4, EdgeData({"edge_name":"identity2"}))]) # identity
        activation_cell.add_edges_from([(4, 5, EdgeData({"edge_name":"binary"}))]) # binary op

        # macroarchitecture definition
        self.name = 'makrograph'
        self.add_node(1) # input node
        self.add_node(2) # CNN+Linear output node
        self.add_node(3, subgraph=deepcopy(activation_cell).set_scope('a_stage_1').set_input([2])) # activation node 1
        self.add_node(4, subgraph=deepcopy(activation_cell).set_scope('a_stage_2').set_input([3])) # activation node 2
        self.nodes[3]['subgraph'].name = self.nodes[3]['subgraph'].scope
        self.nodes[4]['subgraph'].name = self.nodes[4]['subgraph'].scope
        self.add_node(5) # output node
        
        self.add_edges_from([(i, i+1, EdgeData()) for i in range(1, 5)])
        self.edges[1, 2].set('op',
            ops.Sequential(
                nn.Conv2d(3, 6, 5),
                nn.MaxPool2d(2),
                nn.Conv2d(6, 16, 5),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(16*5*5,120)
            )) # convolutional edge
        
        self.edges[3,4].set('op',ops.Sequential(nn.Linear(120,84)))
        self.edges[4,5].set('op',ops.Sequential(nn.Linear(84,10), nn.Softmax(dim=1)))

        
        for stage in stages:
            self.update_edges(
                update_func=lambda edge: self._set_ops(edge),
                scope=stage,
                private_edge_data=True,
            )

    def _set_ops(self, edge):
        if "unary" in edge.data["edge_name"]:
            edge.data.set('op', [ops.Identity(), abs_op(), exp(2), exp(3), sin(), cos(), sign()])
        elif "binary" in edge.data["edge_name"]:
            edge.data.set('op', [add(), sub(), mul(), maximum(), minimum()])
        else:
            edge.data.set('op', [ops.Identity()])


In [3]:
class exp(AbstractPrimitive):
    def __init__(self,power):
        super().__init__(locals())
        self.power=power
    
    def forward(self,x, edge_data):
        return torch.pow(x,self.power)
    def get_embedded_ops(self):
        return None

class sin(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    
    def forward(self,x, edge_data):
        return torch.sin(x)
    def get_embedded_ops(self):
        return None
    
class cos(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    
    def forward(self,x, edge_data):
        return torch.cos(x)
    def get_embedded_ops(self):
        return None

class abs_op(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    
    def forward(self,x, edge_data):
        return torch.abs(x)
    def get_embedded_ops(self):
        return None

class sign(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return x*-1
    def get_embedded_ops(self):
        return None

In [4]:
class add(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return torch.add(x[0],x[1])
    def get_embedded_ops(self):
        return None

class sub(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return torch.sub(x[0],x[1])
    def get_embedded_ops(self):
        return None

class mul(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return torch.mul(x[0],x[1])
    def get_embedded_ops(self):
        return None
class maximum(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return torch.maximum(x[0],x[1])
    def get_embedded_ops(self):
        return None
class minimum(AbstractPrimitive):
    def __init__(self):
        super().__init__(locals())
    def forward(self,x, edge_data):
        return torch.minimum(x[0],x[1])
    def get_embedded_ops(self):
        return None

In [5]:
class stack():
    def __init__(self):
        pass
    def __call__(self, tensors, edges_data=None):
        return torch.stack(tensors)

In [6]:
search_space = DartsSearchSpaceA() 

Update function could not be verified. Be cautious with the setting of `private_edge_data` in `update_edges()`
Update function could not be verified. Be cautious with the setting of `private_edge_data` in `update_edges()`
Update function could not be verified. Be cautious with the setting of `private_edge_data` in `update_edges()`


In [7]:
search_space.nodes[3]["subgraph"][1]

AtlasView({2: private: <{'_final': False, 'op': [Identity(), abs_op(), exp(), exp(), sin(), cos(), sign()], 'edge_name': 'unary 1'}>, shared: <{'_deleted': False}>, 3: private: <{'_final': False, 'op': [Identity(), abs_op(), exp(), exp(), sin(), cos(), sign()], 'edge_name': 'unary 2'}>, shared: <{'_deleted': False}>})

In [8]:
config = utils.get_config_from_args(config_type='nas')
config.optimizer = 'darts'

In [9]:
optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

Update function could not be verified. Be cautious with the setting of `private_edge_data` in `update_edges()`
Update function could not be verified. Be cautious with the setting of `private_edge_data` in `update_edges()`


In [10]:
optimizer.graph.nodes

NodeView((1, 2, 3, 4, 5))

In [11]:
optimizer.graph.nodes[3]['subgraph'].edges[1,2]

private: <{'_final': False, 'op': DARTSMixedOp(
  (primitive-0): Identity()
  (primitive-1): abs_op()
  (primitive-2): exp()
  (primitive-3): exp()
  (primitive-4): sin()
  (primitive-5): cos()
  (primitive-6): sign()
), 'edge_name': 'unary 1'}>, shared: <{'_deleted': False, 'alpha': Parameter containing:
tensor([ 0.0004, -0.0006,  0.0009, -0.0019, -0.0016,  0.0011,  0.0006],
       requires_grad=True)}>

In [12]:
optimizer.architectural_weights

ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 7]
    (1): Parameter containing: [torch.FloatTensor of size 7]
    (2): Parameter containing: [torch.FloatTensor of size 1]
    (3): Parameter containing: [torch.FloatTensor of size 1]
    (4): Parameter containing: [torch.FloatTensor of size 5]
    (5): Parameter containing: [torch.FloatTensor of size 7]
    (6): Parameter containing: [torch.FloatTensor of size 7]
    (7): Parameter containing: [torch.FloatTensor of size 1]
    (8): Parameter containing: [torch.FloatTensor of size 1]
    (9): Parameter containing: [torch.FloatTensor of size 5]
)