In [167]:
import torch
from torch import nn

In [168]:
D_MODEL = 7
D_OUTPUT = 4
BATCH_SIZE = 3

In [169]:
def train(model, x, label, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()

    y = model(x)

    t = torch.narrow(y, 1, 0, label.size(0))

    loss = loss_fn(t, label)
    loss.backward()

    optimizer.step()

    print("======================================")
    print("Input:", x.shape)
    print("Output:", y.shape)
    print("Loss:", loss)


In [170]:
model1 = nn.Linear(D_MODEL, D_OUTPUT)
model2 = nn.Linear(D_MODEL, D_OUTPUT)
model3 = nn.Linear(D_MODEL, D_OUTPUT)

w = model1.state_dict()
model2.load_state_dict(w)
model3.load_state_dict(w)

optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.01)
optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.01)
optimizer3 = torch.optim.SGD(model3.parameters(), lr=0.01)

loss_fn = nn.L1Loss()

In [171]:
x1 = torch.rand(BATCH_SIZE, 10, D_MODEL)
x2 = x1.clone()[:,:3]
x3 = x1.clone()[:,:3]
label = torch.rand(BATCH_SIZE, 3, D_OUTPUT)
label_alt = torch.rand(BATCH_SIZE, 3, D_OUTPUT)

In [172]:
train(model1, x1, label, optimizer1, loss_fn)
train(model2, x2, label, optimizer2, loss_fn)
train(model3, x3, label_alt, optimizer3, loss_fn)

Input: torch.Size([3, 10, 7])
Output: torch.Size([3, 10, 4])
Loss: tensor(0.6990, grad_fn=<MeanBackward0>)
Input: torch.Size([3, 3, 7])
Output: torch.Size([3, 3, 4])
Loss: tensor(0.6990, grad_fn=<MeanBackward0>)
Input: torch.Size([3, 3, 7])
Output: torch.Size([3, 3, 4])
Loss: tensor(0.6248, grad_fn=<MeanBackward0>)


In [173]:
model1.weight.grad, model2.weight.grad, model3.weight.grad

(tensor([[ 0.0073, -0.0464,  0.0789,  0.0083, -0.0264,  0.0192,  0.0296],
         [-0.1141, -0.1026, -0.1075, -0.1533, -0.1392, -0.1400, -0.1333],
         [-0.1141, -0.1026, -0.1075, -0.1533, -0.1392, -0.1400, -0.1333],
         [-0.0168,  0.0207, -0.0066,  0.0607,  0.0335,  0.0204, -0.0080]]),
 tensor([[ 0.0073, -0.0464,  0.0789,  0.0083, -0.0264,  0.0192,  0.0296],
         [-0.1141, -0.1026, -0.1075, -0.1533, -0.1392, -0.1400, -0.1333],
         [-0.1141, -0.1026, -0.1075, -0.1533, -0.1392, -0.1400, -0.1333],
         [-0.0168,  0.0207, -0.0066,  0.0607,  0.0335,  0.0204, -0.0080]]),
 tensor([[-0.0481, -0.0457, -0.0202, -0.0101, -0.0994, -0.0031, -0.0788],
         [-0.1141, -0.1026, -0.1075, -0.1533, -0.1392, -0.1400, -0.1333],
         [-0.1125, -0.0997, -0.0572, -0.0992, -0.1367, -0.0907, -0.1332],
         [-0.0920, -0.0543, -0.1065, -0.1146, -0.1232, -0.1060, -0.1271]]))

In [174]:
model1.bias.grad, model2.bias.grad, model3.bias.grad

(tensor([ 0.0278, -0.2500, -0.2500,  0.0278]),
 tensor([ 0.0278, -0.2500, -0.2500,  0.0278]),
 tensor([-0.0833, -0.2500, -0.1944, -0.1944]))

In [175]:
model1.state_dict(), model2.state_dict(), model3.state_dict()

(OrderedDict([('weight',
               tensor([[ 0.0825, -0.1129,  0.2361,  0.3138, -0.1080,  0.0393, -0.1499],
                       [-0.2254,  0.1682, -0.1074, -0.0833,  0.2779, -0.3085, -0.1837],
                       [-0.0502,  0.0518,  0.0295, -0.0972, -0.2373, -0.1899, -0.1611],
                       [-0.3544,  0.0774, -0.1833,  0.1362,  0.0135,  0.3510,  0.0616]])),
              ('bias', tensor([ 0.2968, -0.3719,  0.3636,  0.1682]))]),
 OrderedDict([('weight',
               tensor([[ 0.0825, -0.1129,  0.2361,  0.3138, -0.1080,  0.0393, -0.1499],
                       [-0.2254,  0.1682, -0.1074, -0.0833,  0.2779, -0.3085, -0.1837],
                       [-0.0502,  0.0518,  0.0295, -0.0972, -0.2373, -0.1899, -0.1611],
                       [-0.3544,  0.0774, -0.1833,  0.1362,  0.0135,  0.3510,  0.0616]])),
              ('bias', tensor([ 0.2968, -0.3719,  0.3636,  0.1682]))]),
 OrderedDict([('weight',
               tensor([[ 0.0830, -0.1129,  0.2371,  0.3140, -0.1073,  0

In [176]:
print(model1.weight == model2.weight)
print(model1.bias == model2.bias)
print("==========================================================")
print(model3.weight == model2.weight)
print(model3.bias == model2.bias)

tensor([[True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True]])
tensor([True, True, True, True])
tensor([[False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False]])
tensor([False,  True, False, False])
