In [119]:
import torch
import torch.nn as nn

# define a class extends nn.Module
class myLinearNet(nn.Module):
    def __init__(self, in_feature, hidden, out_feature):
        super().__init__()
        self.linear1 = nn.Linear(in_feature, hidden)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden, out_feature)
    
    def forward(self, X):
        X = self.relu(self.linear1(X))
        return self.linear2(X)
    
net = myLinearNet(3, 2, 1)
print(net)

myLinearNet(
  (linear1): Linear(in_features=3, out_features=2, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=2, out_features=1, bias=True)
)


In [89]:
# nn.Sequential is a special type of nn.Module
class mySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, arg in enumerate(args):
            name = str(idx)
            self._modules[name] = arg
            #same usage as self.__modules
            #self.add_module(name, arg)
    
    def forward(self, X):
        for layer in self._modules.values:
            X = layer(X)
        return X

    def __getitem__(self, idx):
        return self._modules[str(idx)]

sq1 = mySequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2,1))
sq2 = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2,1))

print(sq1, sq2)
print(sq1[0], sq2[0])

mySequential(
  (0): Linear(in_features=3, out_features=2, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2, out_features=1, bias=True)
) Sequential(
  (0): Linear(in_features=3, out_features=2, bias=True)
  (1): ReLU()
  (2): Linear(in_features=2, out_features=1, bias=True)
)
Linear(in_features=3, out_features=2, bias=True) Linear(in_features=3, out_features=2, bias=True)


In [130]:
# nested nn.Module
class seq_in_net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Sequential(nn.Linear(4,5), nn.ReLU(), nn.Linear(5,2))
        self.relu = nn.ReLU()
        self.linear2 = mySequential(nn.Linear(2,1))
    
    def forward(self, X):
        return self.linear2(self.relu(self.linear1(X)))
    
nest1 = seq_in_net()
nest2 = nn.Sequential(nest1, nn.Linear(1,1))

print(nest1)
print(nest2)

seq_in_net(
  (linear1): Sequential(
    (0): Linear(in_features=4, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=2, bias=True)
  )
  (relu): ReLU()
  (linear2): mySequential(
    (0): Linear(in_features=2, out_features=1, bias=True)
  )
)
Sequential(
  (0): seq_in_net(
    (linear1): Sequential(
      (0): Linear(in_features=4, out_features=5, bias=True)
      (1): ReLU()
      (2): Linear(in_features=5, out_features=2, bias=True)
    )
    (relu): ReLU()
    (linear2): mySequential(
      (0): Linear(in_features=2, out_features=1, bias=True)
    )
  )
  (1): Linear(in_features=1, out_features=1, bias=True)
)


In [116]:
# nested nn.Module in function
def block1():
    return nn.Sequential(nn.Linear(2,2), nn.ReLU())

def block2():
    li = nn.Sequential()
    for i in range(4):
        li.add_module(f'block {i}', block1())
    return li

block_net = block2()
nested_net = nn.Sequential(nn.Linear(2,2), block_net)
print(nested_net)

Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Sequential(
    (block 0): Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU()
    )
    (block 1): Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU()
    )
    (block 2): Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU()
    )
    (block 3): Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU()
    )
  )
)


In [113]:
# parameters() is a parameter generator, with only name
# named_parameters() is a parameter generator, with name and weight
print("=====parameters,named_parameters ======")
print( [param.shape for param in sq2.parameters()] )
print( [(name, param.shape) for name, param in sq2.named_parameters()] )

# state_dict() return OrderedDict
print("===== state_dict ======")
print("layer2 state_dict:", sq2[2].state_dict())

# weight data can be accessed by net.weight.data or net.state_dict()['weight']
# net.weight return a nn.Parameter
print("===== weight_data ======")
print("layer2 weight:", sq2[2].weight)
print("layer2 weight data:",sq2[2].weight.data)
print("layer2 weight data:", sq2[2].state_dict()['weight'])
print("layer2 weight data:", sq2.state_dict()['2.weight'])

[torch.Size([2, 3]), torch.Size([2]), torch.Size([1, 2]), torch.Size([1])]
[('0.weight', torch.Size([2, 3])), ('0.bias', torch.Size([2])), ('2.weight', torch.Size([1, 2])), ('2.bias', torch.Size([1]))]
layer2 state_dict: OrderedDict([('weight', tensor([[0.2316, 0.6217]])), ('bias', tensor([-0.2668]))])
layer2 weight: Parameter containing:
tensor([[0.2316, 0.6217]], requires_grad=True)
layer2 weight data: tensor([[0.2316, 0.6217]])
layer2 weight data: tensor([[0.2316, 0.6217]])
layer2 weight data: tensor([[0.2316, 0.6217]])


In [118]:
# parameter share
share = nn.Linear(3,3)
net_share = nn.Sequential(nn.Linear(3,3), nn.ReLU(), share, share)
print(net_share[0].weight.data == net_share[2].weight.data)
print(net_share[2].weight.data == net_share[3].weight.data)

tensor([[False, False, False],
        [False, False, False],
        [False, False, False]])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [127]:
# self define nn.Linear module
class myLinear(nn.Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.weight = nn.Parameter( torch.randn((in_f, out_f)) )
        self.bias = nn.Parameter( torch.randn(out_f) )
    
    def forward(self, x):
        return torch.matmul(x, self.weight.data) + self.bias.data

l1 = myLinear(3,3)
l2 = nn.Linear(3,3,dtype=torch.float32)
l1.weight

Parameter containing:
tensor([[ 0.4585,  0.6741,  1.1821],
        [ 0.4694, -0.7802,  0.2946],
        [-0.0970,  0.4984,  1.0811]], requires_grad=True)

In [138]:
x = torch.tensor([1,2])
y = torch.tensor([3,4])

print("========save tensor===========")
torch.save(x, 'x_file')
xx = torch.load('x_file')
print(x == xx)

print("====save list of tensor=======")
torch.save([x, y], 'x_file')
xx, yy = torch.load('x_file')
print(x == xx, y == yy)

print("========save model============")
torch.save(nest1.state_dict(), 'x_file')
nest1_load = seq_in_net()
nest1_load.load_state_dict(torch.load('x_file'))
print(nest1_load.linear1[0].weight.data == nest1.linear1[0].weight.data)

tensor([True, True])
tensor([True, True]) tensor([True, True])
tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])
