In [1]:
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable

# _concat

In [2]:
def _concat(xs):
    """Concats the arrays

    Args:
        xs (array): the 2d array to be passed

    Returns:
        array: concated array
    """
    return torch.cat([x.view(-1) for x in xs])

In [3]:
a = torch.tensor([[[[1],[2],[3]], [[4], [5], [6]]], [[[2],[4],[6]], [[8], [10], [12]]]])
b = torch.tensor([1, 2])

In [4]:
_concat([a,b])

tensor([ 1,  2,  3,  4,  5,  6,  2,  4,  6,  8, 10, 12,  1,  2])

# Architect

In [5]:
class Architect(object):
    """Constructs the model

    Parameters:
      network_momentum(float):  network momentum
      network_weight_decay(float): network weight decay
      model(Network): Network archtecture with cells
      optimise(optimiser): Adam / SGD
    """

    def __init__(self, model, args):
        """Initialises the architecture

        Args:
            model (Network): Network archtecture with cells
            args (dict): cli args
        """
        self.network_momentum = args.momentum
        self.network_weight_decay = args.weight_decay
        self.model = model
        self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
                                          lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)

    def _compute_unrolled_model(self, input, target, eta, network_optimizer):

        loss = self.model._loss(input, target)
        theta = _concat(self.model.parameters()).data
        try:
            moment = _concat(network_optimizer.state[v]['momentum_buffer']
                             for v in self.model.parameters()).mul_(self.network_momentum)
        except:
            moment = torch.zeros_like(theta)
        dtheta = _concat(torch.autograd.grad(
            loss, self.model.parameters())).data + self.network_weight_decay*theta
        unrolled_model = self._construct_model_from_theta(
            theta.sub(eta, moment+dtheta))
        return unrolled_model

    def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
        """Computer a step for gradient descend

        Args:
            input_train (tensor): a train of input
            target_train (tensor): a train of targets
            input_valid (tensor): a train of validation
            target_valid (tensor): a train of validation targets
            # TODO: complete eta
            eta (tensor): eta
            network_optimizer (optimiser): network optimiser for network
            unrolled (bool): True if training we need unrolled
        """
        self.optimizer.zero_grad()
        if unrolled:
            self._backward_step_unrolled(
                input_train, target_train, input_valid, target_valid, eta, network_optimizer)
        else:
            self._backward_step(input_valid, target_valid)
        self.optimizer.step()

    def _backward_step(self, input_valid, target_valid):
        """Backward step for validation

        Args:
            input_train (tensor): a train of input
            target_train (tensor): a train of targets
        """
        loss = self.model._loss(input_valid, target_valid)
        loss.backward()

    def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
        """Backward step for training

        Args:
            input_train (tensor): a train of input
            target_train (tensor): a train of targets
            input_valid (tensor): a train of validation
            target_valid (tensor): a train of validation targets
            eta (tensor): eta
            network_optimizer (optimiser): network optimiser for network
        """
        unrolled_model = self._compute_unrolled_model(
            input_train, target_train, eta, network_optimizer)
        unrolled_loss = unrolled_model._loss(input_valid, target_valid)

        unrolled_loss.backward()
        dalpha = [v.grad for v in unrolled_model.arch_parameters()]
        vector = [v.grad.data for v in unrolled_model.parameters()]
            
        implicit_grads = self._hessian_vector_product(
            vector, input_train, target_train)

        for g, ig in zip(dalpha, implicit_grads):
            g.data.sub_(eta, ig.data)

        for v, g in zip(self.model.arch_parameters(), dalpha):
            if v.grad is None:
                v.grad = Variable(g.data)
            else:
                v.grad.data.copy_(g.data)

    def _construct_model_from_theta(self, theta):
        """Construct and Update model from theta

        Args:
            theta (tensor): thetas

        Returns:
            Network: return new model
        """
        model_new = self.model.new()
        model_dict = self.model.state_dict()

        params, offset = {}, 0
        for k, v in self.model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset: offset+v_length].view(v.size())
            offset += v_length
        
        print(param)
        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        return model_new

    def _hessian_vector_product(self, vector, input, target, r=1e-2):
        R = r / _concat(vector).norm()
        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)
        loss = self.model._loss(input, target)
        grads_p = torch.autograd.grad(loss, self.model.arch_parameters())

        for p, v in zip(self.model.parameters(), vector):
            p.data.sub_(2*R, v)
        loss = self.model._loss(input, target)
        grads_n = torch.autograd.grad(loss, self.model.arch_parameters())

        for p, v in zip(self.model.parameters(), vector):
            p.data.add_(R, v)

        return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]


## Testing

In [6]:
from model_search import Network

In [7]:
criterion = nn.BCELoss()
model = Network(3, 2, 3, criterion)
args = {
    "momentum": 0.9,
    "weight_decay": 3e-4,
    "arch_learning_rate": 3e-4,
    "arch_weight_decay": 1e-3
}

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [8]:
arc = Architect(model, Struct(**args))

In [None]:
for k, v in model.named_parameters():
    v_length = np.prod(v.size())
    print(v.size())    

In [12]:
for p in model.parameters():
    if p.shape == torch.randn(14, 8).shape:
        print(p)