In [63]:
from torch.nn import Linear,Sequential,Identity
from typing import Tuple, List, Callable, Union, Optional
from itertools import chain
import torch

class MLP():
    
    def __init__(
                self,
                Ls: List[int],
                add_bias: bool = False,
                nonlinearity: Optional[Callable] = None,
                ):
        
        """Inits MLP."""

        self.Ls = Ls
        self.add_bias = add_bias
        self.weight_sizes  = [(in_size,out_size)
                                for in_size, out_size in zip(Ls[:-1], Ls[1:])]
        if self.add_bias :
            
            self.bias_sizes = [(out_size)
                                for  out_size in Ls[1:]  ]
        
        self.len_params = sum(
            [
                (in_size + 1 * add_bias) * out_size
                for in_size, out_size in zip(Ls[:-1], Ls[1:])
            ]
        )
        
        if nonlinearity is None: 
            self.nonlinearity = Identity
    
    def reset_weights(self,net,params):
        
        start,end = (0,0)
        
        for layer,(in_size,out_size) in zip(net,self.weight_sizes):
            
            start = end
            end   = start  + (in_size * out_size)
            end   = start + in_size * out_size        
            
            weight_params = params[start:end].reshape(out_size,in_size)
            layer.weight.data = weight_params
            
            if self.add_bias : 
                bias_params = params[end: end+ out_size].reshape(out_size)
                end = end + out_size
                layer.bias.data = bias_params
        
        return net 
    
    def build_net(self):
        
        ### initialize deep layers
        layer_list = [ [Linear(in_size,out_size,bias=self.add_bias),self.nonlinearity]
                        for (in_size,out_size) in self.weight_sizes[:-1]]
        
        ### last layer has no nonlinearity
        layer_list.append([Linear(*self.weight_sizes[-1],bias=self.add_bias)])
        
        ### initialize model
        net = Sequential(*chain(*layer_list))
        
        return net

        
    def __call__(self,states,params):
        
        
        ### we initialize network at each call (maybe reset network in future)
        net = self.build_net()
        #############
        net = self.reset_weights(net,params)
        
        return net(states)
    
    
mlp = MLP([8,2],add_bias=True)
params = torch.rand(mlp.len_params)
states = torch.rand(10,5,8)
mlp(states,params).size()

    


torch.Size([10, 5, 2])