In [8]:
import torch
from torch import nn

class MLPS(nn.Sequential):
    def __init__(self, input_size, hidden_sizes, output_size, activation="tanh", flatten=False, bias=True):
        super(MLPS, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.activation = activation
        if output_size is not None:
            self.output_size = output_size
        else:
            self.output_size = 1

        # Set activation function
        if activation == "relu":
            act = nn.ReLU
        elif activation == "tanh":
            act = nn.Tanh
        else:
            raise ValueError('invalid activation')

        if flatten:
            self.add_module('flatten', nn.Flatten())

        if len(hidden_sizes) == 0:
            # Linear Model
            self.add_module('lin_layer', nn.Linear(self.input_size, self.output_size, bias=bias))
        else:
            # MLP
            in_outs = zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)
            for i, (in_size, out_size) in enumerate(in_outs):
                self.add_module(f'layer{i+1}', nn.Linear(in_size, out_size, bias=bias))
                self.add_module(f'{activation}{i+1}', act())
            self.add_module('out_layer', nn.Linear(hidden_sizes[-1], self.output_size, bias=bias))


input_dim = 9
input_start = 1
hidden_sizes = [100,100]

map_net = MLPS(input_size=input_dim-input_start, hidden_sizes=hidden_sizes, output_size=1, activation='tanh', flatten=False, bias=True)


print(f'len = {len(list(map_net.children()))}')
first_layers = list(map_net.children())[:-1]
ck = torch.nn.Sequential(*first_layers)

llp = list(ck.parameters())[-1].shape[0]

len = 5


In [15]:
## Set up neural network.
width = 50

class oneLayerMLP(torch.nn.Module):
    def __init__(self, width):
        super().__init__()
        self.output_size = 1
        self.net = torch.nn.Sequential(
            torch.nn.Linear(1, width),
            torch.nn.SiLU(),
            torch.nn.Linear(width, 1)
        )

    def forward(self, x):
        output = self.net(x)
        return output
    

## MSE Model
map_net = oneLayerMLP(width)

print(f'len = {len(list(map_net.children()))}') # if len of children == 1, then the child is a Sequential container, and we need to index it.

child_list = list(map_net.children())
if len(child_list) > 1:
    first_layers = child_list[:-1]
elif len(child_list) == 1:
    first_layers = child_list[0][:-1]
ck = torch.nn.Sequential(*first_layers)

llp = list(ck.parameters())[-1].shape[0]
print(f'final layer width : {llp}')

len = 1
final layer width : 50


In [13]:
list(map_net.children())[0][:-1]

Sequential(
  (0): Linear(in_features=1, out_features=50, bias=True)
  (1): SiLU()
)

In [38]:
len(final_parameters) == 2

True

In [4]:
import models as model
import torch
from torch import nn
from functorch import make_functional, make_functional_with_buffers

class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10, bias=False),
        )

    def forward(self,x):
        return self.net(x)

In [14]:
512 * 4* 10

20480

In [None]:
network = model.LeNet5()
params = tuple(network.parameters())
child_list = list(network.children())
if len(child_list) > 1:
    child_list = child_list
elif len(child_list) == 1:
    child_list = child_list[0]
first_layers = child_list[:-1]
final_layer = child_list[-1]
llp = list(final_layer.parameters())[0].size().numel()
print(llp)
print(f'final_layer = {list(final_layer.parameters())[0].shape}')
print(f'num_output = {list(final_layer.parameters())[0].shape[0]}')

840
final_layer = torch.Size([10, 84])
num_output = torch.Size([10, 84])
