In [2]:
import torch
import torch.nn.functional as F
from torch import nn


class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y)
Y.mean()

tensor([[-2.5748e-01, -1.9514e-02,  2.2997e-01, -5.2434e-02, -1.2737e-01,
          8.0027e-01,  3.9899e-01,  3.8181e-01, -5.4731e-01,  1.9903e-01,
          2.5995e-01, -1.6573e-01, -4.8794e-01, -2.3635e-01,  3.4394e-01,
          6.1264e-02,  2.6811e-02, -1.0170e-01,  1.6806e-01, -3.5956e-01,
         -3.2735e-01, -3.6488e-01, -2.7775e-02, -6.5512e-01,  4.5217e-02,
         -7.4495e-02, -2.6951e-01, -2.6565e-01, -3.2221e-01, -4.6450e-02,
          1.2108e-01,  1.8254e-01, -6.3297e-01,  9.3186e-02,  3.3132e-01,
          2.6284e-01,  2.6009e-01, -7.1419e-01,  2.1833e-01,  5.7977e-02,
          2.4287e-01, -1.2702e-01,  4.0568e-01,  9.7700e-02, -1.5514e-01,
          5.3977e-01, -3.8786e-01, -2.0130e-01, -2.6092e-01, -1.3869e-02,
          3.2287e-01,  2.2958e-02, -1.4447e-01,  6.0348e-01,  1.0386e-01,
         -1.7108e-01, -6.2856e-01,  4.9079e-01,  8.4828e-02,  3.7007e-01,
          3.1681e-01,  1.3684e-01,  6.2040e-03,  4.8857e-01, -2.0883e-01,
          3.7843e-02, -3.6009e-01, -4.

tensor(-7.4506e-09, grad_fn=<MeanBackward0>)

In [3]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)
linear = MyLinear(5, 3)
linear.weight
linear(torch.rand(2, 5))
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[0.0000],
        [2.7123]])

In [4]:
class Linear5_4_1(nn.Module):
    def __init__(self, in_units, out_units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, in_units, out_units))
    def forward(self, X):
        b = X.shape[0] # b = batch_size
        o = self.weight.shape[2] # o = out_units
        y = torch.zeros(b,o) # y的shape为(batch_size，out_units)
        for k in range(o):
            for i in range(b):
                # 矩阵乘法维度分别为：(1*4), (4*4), (4*1)
                #print(torch.matmul(X[i,:],self.weight[:,:,k]))
                #print(torch.matmul(torch.matmul(X[i,:],self.weight[:,:,k]),X[i,:]))
                y[i,k] = torch.matmul(torch.matmul(X[i,:],self.weight[:,:,k]),X[i,:])
                #print(y[i][k])
                #print(y[i,k])
        return y
X = torch.randn(2, 4)
linear = Linear5_4_1(4, 2)

print(linear.weight.shape)
print(linear(X))
print(linear(X).sum(axis=1))

torch.Size([4, 4, 2])
tensor([[-1.5814, -4.6637],
        [ 2.3336,  0.8622]], grad_fn=<CopySlices>)
tensor([-6.2452,  3.1958], grad_fn=<SumBackward1>)


In [5]:
import torch
from torch import nn
from torch.nn import functional as F

x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
x2

tensor([0, 1, 2, 3])

In [6]:
y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2['x']

tensor([0, 1, 2, 3])

In [19]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
print(Y)
torch.save(net.state_dict(), 'mlp.params')
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
Y_clone = clone(X)
Y_clone == Y

tensor([[ 0.1496,  0.0502, -0.0438,  0.1878, -0.1688, -0.0956,  0.2382,  0.1080,
         -0.0609,  0.1702],
        [ 0.0398, -0.2328, -0.2326, -0.0437, -0.0265, -0.1886,  0.0265,  0.0865,
         -0.0544,  0.1310]], grad_fn=<AddmmBackward0>)


tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

In [20]:
torch.save(net.output.state_dict(), 'mlp.out_put_params')
params=torch.load('mlp.out_put_params')
print(params)

OrderedDict([('weight', tensor([[ 0.0564,  0.0267,  0.0545,  ..., -0.0380,  0.0343,  0.0190],
        [-0.0127, -0.0444, -0.0539,  ..., -0.0445,  0.0140,  0.0315],
        [ 0.0235,  0.0286,  0.0332,  ..., -0.0583, -0.0183, -0.0211],
        ...,
        [-0.0321, -0.0391,  0.0431,  ..., -0.0465, -0.0060,  0.0504],
        [-0.0546,  0.0160, -0.0326,  ..., -0.0229,  0.0531, -0.0302],
        [-0.0409,  0.0568, -0.0100,  ...,  0.0009, -0.0478,  0.0048]])), ('bias', tensor([-0.0223,  0.0433,  0.0002,  0.0180,  0.0437,  0.0105,  0.0472, -0.0081,
         0.0019,  0.0508]))])


In [22]:
clone2=MLP()
clone2.eval()
clone2.output.load_state_dict(torch.load('mlp.out_put_params'))
print(clone2.hidden.weight==net.hidden.weight)
clone2.output.weight==net.output.weight

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [33]:
#连模型一块保存
model = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
torch.save(clone2, 'model.pt')
m = torch.load('model.pt',weights_only=False)
print(m)

MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)
