In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

In [2]:
BATCH_SIZE = 4
IN_DIM = 5
OUT_DIM = 2

In [3]:
model = nn.Linear(
    in_features=IN_DIM, out_features=OUT_DIM, bias=True
)

optimizer = optim.SGD(
    params=model.parameters(), lr=1e-3
)

x = torch.randn(size=(BATCH_SIZE, IN_DIM))
y = torch.randn(size=(BATCH_SIZE, OUT_DIM))

In [4]:
print('*'*79)
print('model weights:\n{}'.format(model.weight.data.clone().detach().numpy()))
print('*'*79)
print('model biases:\n{}'.format(model.bias.data.clone().detach().numpy()))

*******************************************************************************
model weights:
[[ 0.13402611 -0.17495948 -0.33387586  0.2108041   0.02582273]
 [ 0.39484465 -0.27511713 -0.21096924  0.11630732 -0.30228046]]
*******************************************************************************
model biases:
[ 0.17120987 -0.3511966 ]


In [5]:
_ = torch.nn.init.xavier_normal_(model.weight)
_ = torch.nn.init.constant_(model.bias, 0)

In [6]:
model = model.train()
optimizer.zero_grad()
y_pred = model(x)
loss = torch.mean((y - y_pred)**2)
loss.backward()

In [7]:
# Save old weights and biases.
weights_before_reinit = model.weight.data.clone().detach().numpy()
biases_before_reinit = model.bias.data.clone().detach().numpy()

In [8]:
print('*'*79)
print('model weights before re-initialization:\n{}'.format(weights_before_reinit))
print('*'*79)
print('model biases before re-initialization:\n{}'.format(biases_before_reinit))

*******************************************************************************
model weights before re-initialization:
[[ 1.226915    1.0614384   0.3598972   0.8370321  -0.14699934]
 [-0.1565047   0.4408919   0.60147136  0.28267437  0.23321863]]
*******************************************************************************
model biases before re-initialization:
[0. 0.]


In [9]:
# Save gradients.
d_weights = model.weight.grad.data.clone().detach().numpy()
d_biases = model.bias.grad.data.clone().detach().numpy()

In [10]:
print('*'*79)
print('delta weights:\n{}'.format(d_weights))
print('*'*79)
print('delta biases:\n{}'.format(d_biases))

*******************************************************************************
delta weights:
[[0.2993092  1.2013772  0.6499523  0.85826516 0.40363598]
 [0.19460805 0.6190603  0.49644375 0.22114085 0.1711989 ]]
*******************************************************************************
delta biases:
[0.00441459 0.01844712]


In [11]:
# Re-initialize weights.
_ = torch.nn.init.xavier_normal_(model.weight)
_ = torch.nn.init.constant_(model.bias, 0)

In [12]:
# Save weights after re-initialization.
weights_after_reinit = model.weight.data.clone().detach().numpy()
biases_after_reinit = model.bias.data.clone().detach().numpy()

In [13]:
print('*'*79)
print('model weights after re-initialization:\n{}'.format(weights_after_reinit))
print('*'*79)
print('model biases after re-initialization:\n{}'.format(biases_after_reinit))

*******************************************************************************
model weights after re-initialization:
[[ 0.42601535 -0.25671273 -0.5048637   1.0244417   0.354806  ]
 [ 0.19252676 -0.1976713   1.0080578  -0.6726875   0.24037957]]
*******************************************************************************
model biases after re-initialization:
[0. 0.]


In [14]:
# Step the optimizer.
optimizer.step()

In [15]:
# Weights after step.
weights_after_reinit_and_step = model.weight.data.clone().detach().numpy()
biases_after_reinit_and_step = model.bias.data.clone().detach().numpy()

In [16]:
print('*'*79)
print('model weights after re-initialization and step:\n{}'.format(weights_after_reinit_and_step))
print('*'*79)
print('model biases after re-initialization and step:\n{}'.format(biases_after_reinit_and_step))

*******************************************************************************
model weights after re-initialization and step:
[[ 0.42571604 -0.25791413 -0.5055136   1.0235834   0.35440236]
 [ 0.19233215 -0.19829035  1.0075614  -0.6729086   0.24020837]]
*******************************************************************************
model biases after re-initialization and step:
[-4.4145886e-06 -1.8447117e-05]


In [17]:
print('*'*79)
print('weight updates are consistent:\n{}'.format(
    np.linalg.norm(weights_after_reinit_and_step - (weights_after_reinit - 1.0*1e-3*d_weights), ord=2),
))
print('*'*79)
print('bias updates are consistent:\n{}'.format(
    np.linalg.norm(biases_after_reinit_and_step - (biases_after_reinit - 1.0*1e-3*d_biases), ord=2),
))

*******************************************************************************
weight updates are consistent:
0.0
*******************************************************************************
bias updates are consistent:
0.0
