In [22]:
from torch.nn import Linear,Sequential,Identity
from typing import Tuple, List, Callable, Union, Optional
from copy import deepcopy
from itertools import chain
import torch

class MLP(torch.nn.Module):
    
    """MLP that is differentiable w.r.t to parameters
    """

    def __init__(
                self,
                Ls: List[int],
                add_bias: bool = False,
                nonlinearity: Optional[Callable] = None,
                ):
        
        """Inits MLP with the provided weights 
        Note the MLP can support batches of weights """
        
        super(MLP, self).__init__()
        
        #print(f'MyMLP received params with shape',params.shape)
            
        weight_sizes  = [(in_size,out_size)
                                for in_size, out_size in zip(Ls[:-1], Ls[1:])]
        n_layers = len(weight_sizes)
        
        len_params = sum(
            [
                (in_size + 1 * add_bias) * out_size
                for in_size, out_size in zip(Ls[:-1], Ls[1:])
            ]
        )
        
        if nonlinearity is None:
            nonlinearity = torch.nn.ReLU()
        
        self.__dict__.update(locals())
        
    
    def create_weights(self,params):
        
        weights,biases = [],[]
                
        start,end = (0,0)
        
        for (in_size,out_size) in self.weight_sizes:
            
            start = deepcopy(end)
            end   = deepcopy(start)  + (in_size * out_size)
            weight = params[...,start:end].reshape(*params.shape[:-1],out_size,in_size)
            
            if self.add_bias:
                
                bias = params[...,end:end+out_size].reshape(*params.shape[:-1],out_size)
                end = deepcopy(end) + out_size
            
            else :
                
                bias = torch.zeros(*params.shape[:-1],out_size)
                
            weights.append(weight) ## add transpose or dim error
            biases.append(bias)
                
            #print(f'weight.shape {weight.shape} bias.shape {bias.shape}')
        
        return weights,biases
        
        
    def forward(self,params,states):
        
        weights,biases = self.create_weights(params)
        outputs = states.T
        
        for i,(w,b) in enumerate(
                            zip(weights,biases)
                            ):
            
            w_tmp = w @ outputs
            b_tmp = b.unsqueeze(-1).expand_as(w_tmp)
            outputs =  w_tmp + b_tmp
            
            ## no nonlinearity for last layer
            if i != self.n_layers :
                outputs = self.nonlinearity(outputs)
                
            #print(f'forward : output_tmp.shape{output_tmp.shape} b_tmp.shape{b_tmp.shape}')
            #print(f'forward: output.shape {output.shape}')
        return outputs
    
    
mlp = MLP([4,2],add_bias=True)
params = torch.rand(10,mlp.len_params)
states = torch.rand(1000,4)
mlp(params,states).size()

torch.Size([10, 2, 1000])

In [25]:
mlp = MLP([4,2],add_bias=True)

print(mlp.len_params)
params = torch.rand(mlp.len_params,requires_grad=True) 
states = torch.rand(100,4)
outputs = torch.sum(mlp(params,states))
gradient = torch.autograd.grad(outputs=outputs, inputs=params)
assert (gradient[0].shape == params.shape)

10


In [26]:
gradient

(tensor([ 51.0128,  53.7022,  47.2828,  47.0763,  51.0128,  53.7022,  47.2828,
          47.0763, 100.0000, 100.0000]),)