In [None]:
# Dependencies.
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

from collections import OrderedDict

import copy

In [None]:
# Set constants.
BATCH_SIZE = 4
NUM_EPOCHS = 100

In [None]:
class Model(nn.Module):
    r"""Basic model.
    """

    def __init__(self):
        r"""The initializer.
        """
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=2, kernel_size=(3, 3), stride=(1, 1), bias=True
        )
        self.bn1 = nn.BatchNorm2d(num_features=2)
        self.act1 = nn.ReLU()
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2), stride=(1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""

        Parameters
        ----------
        x:
            Input tensor.
            SHAPE: [B, C_in, H, W].

        Returns
        -------
        feature (implicit):
            Tensor of features of the input.
            SHAPE: [B, C_out, H_out, W_out].
        """
        return self.pool1(
            self.act1(
                self.bn1(
                    self.conv1(
                        x
                    )
                )
            )
        )

In [None]:
model = Model()
print('model:\n{}'.format(model))

In [None]:
state_dict = model.state_dict()
for name, value in state_dict.items():
    print('*'*79)
    print('{}:\n{}'.format(name, value))

In [None]:
# Print the state dict of the model.
print('state dict of the model:\n{}'.format(state_dict))

In [None]:
# Explore the state dictionary.
print('*'*79)
print('state dict length: {}'.format(len(state_dict)))
print('*'*79)
for key in state_dict:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in state_dict.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in state_dict.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
# Create a custom list and then convert to ordered dict.
data_list: list = [
    ('entry_0', 0),
    ('entry_1', '1')
]
print('*'*79)
print('data as list:\n{}'.format(data_list))
data_ord_dict: OrderedDict = OrderedDict(data_list)
print('*'*79)
print('data as ordered dict:\n{}'.format(data_ord_dict))


In [None]:
# Create a state-dict on the fly.
custom_state_dict: OrderedDict = OrderedDict()
custom_state_dict['conv1.weight'] = torch.ones_like(model.conv1.weight.data)*1.0
custom_state_dict['conv1.bias'] = torch.ones_like(model.conv1.bias.data)*2.0
custom_state_dict['bn1.weight'] = torch.ones_like(model.bn1.weight.data)*3.0
custom_state_dict['bn1.bias'] = torch.ones_like(model.bn1.bias.data)*4.0
custom_state_dict['bn1.running_mean'] = torch.ones_like(model.bn1.running_mean.data)*5.0
custom_state_dict['bn1.running_var'] = torch.ones_like(model.bn1.running_var.data)*6.0
custom_state_dict['bn1.num_batches_tracked'] = torch.ones_like(model.bn1.num_batches_tracked.data)*7.0

In [None]:
# Load the model to this custom state dictionary.
model.load_state_dict(custom_state_dict)

In [None]:
# Explore the updated state dictionary.
manually_updated_state_dict = copy.deepcopy(model.state_dict())
print('*'*79)
print('state dict length: {}'.format(len(manually_updated_state_dict)))
print('*'*79)
for key in manually_updated_state_dict:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in manually_updated_state_dict.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in manually_updated_state_dict.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
# Check output shape.
with torch.no_grad():
    model = model.eval()
    x_trial = torch.randn(size=(BATCH_SIZE, 1, 5, 5))
    print('output shape: {}'.format(
        model(x_trial).shape
    ))
model = model.eval()

In [None]:
# Create input-output data.
x = torch.randn(size=(BATCH_SIZE, 1, 5, 5))
y = torch.randn(size=(BATCH_SIZE, 2, 2, 2))

In [None]:
optimizer = optim.SGD(params=model.parameters(), lr=1e-3)
old_state = model.state_dict()

In [None]:
# Explore the updated state dictionary.
print('*'*79)
print('state dict length: {}'.format(len(old_state)))
print('*'*79)
for key in old_state:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in old_state.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in old_state.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
model = model.train()
first_loss = None

for epoch in range(NUM_EPOCHS):
    optimizer.zero_grad()
    y_pred = model(x)
    loss = torch.mean((y - y_pred)**2)
    loss.backward()
    optimizer.step()
    print('epoch: {}\t\tloss: {}'.format(
        epoch, loss.detach().numpy()
    ))
    if epoch == 0:
        first_loss = loss.detach().numpy()

In [None]:
# Explore the updated state dictionary.
trained_model_state_dict = model.state_dict()
print('*'*79)
print('state dict length: {}'.format(len(trained_model_state_dict)))
print('*'*79)
for key in trained_model_state_dict:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in trained_model_state_dict.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in trained_model_state_dict.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
"""
State dicts obtained earlier pass the tensors by reference.
Thus, when the model updates, so does the state dict obtained earlier.
In order to create a separate copy of the whole state dict, `copy.deepcopy` should be used.
"""
# Reset the earlier model state and check the loss.
model.load_state_dict(manually_updated_state_dict)
state_dict_after_reset = model.state_dict()
# Explore the updated state dictionary.
print('*'*79)
print('state dict length: {}'.format(len(state_dict_after_reset)))
print('*'*79)
for key in state_dict_after_reset:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in state_dict_after_reset.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in state_dict_after_reset.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
# Reset the earlier model state and check the loss.
model.load_state_dict(custom_state_dict)
state_dict_after_reset_to_custom = model.state_dict()
# Explore the updated state dictionary.
print('*'*79)
print('state dict length: {}'.format(len(state_dict_after_reset_to_custom)))
print('*'*79)
for key in state_dict_after_reset_to_custom:
    print('key: {}\n\tkey type: {}'.format(key, type(key)))
print('*'*79)
for key, value in state_dict_after_reset_to_custom.items():
    print('key: {}\nvalue: {}'.format(key, value))
print('*'*79)
for key, value in state_dict_after_reset_to_custom.items():
    print('key: {}\n\tvalue type: {}'.format(key, type(value)))

In [None]:
model = model.train()
y_pred = model(x)
loss = torch.mean((y - y_pred)**2)
first_loss_after_reset = loss.detach().numpy()
print('*'*79)
print('first loss before reset:\n{}'.format(first_loss))
print('*'*79)
print('first loss after reset:\n{}'.format(first_loss_after_reset))
print('*'*79)
print('discrepancy: {}'.format(np.linalg.norm(first_loss_after_reset - first_loss)))