In [1]:
import torch
import torch.nn as nn

import numpy as np

import random

import copy

from collections import OrderedDict

In [2]:
IN_DIM = 5
HIDDEN_DIM = 3
OUT_DIM = 2

In [3]:
# Set seeds.
torch.manual_seed(43865)
random.seed(43865)
np.random.seed(43865)

In [4]:
class Model(nn.Module):
    r"""Basic model.
    """

    def __init__(self):
        r"""The initializer.
        """
        super(Model, self).__init__()
        self.fc1 = nn.Linear(
            in_features=IN_DIM, out_features=HIDDEN_DIM, bias=False
        )
        self.fc2 = nn.Linear(
            in_features=HIDDEN_DIM, out_features=OUT_DIM, bias=False
        )

    def forward(self, x):
        r"""Implements the forward pass.

        Parameters
        ----------
        x:
            Input tensor.
            SHAPE: [B, input_dim].

        Returns
        -------
        feature (implicit):
            The tensor of features of the input.
            SHAPE: [B, output_dim].
        """
        return self.fc2(self.fc1(x))

model = Model()

In [5]:
original_state_dict = copy.deepcopy(model.state_dict())
for param, value in original_state_dict.items():
    print('*'*79)
    print('name: {}\nvalue:\n{}'.format(param, value))

*******************************************************************************
name: fc1.weight
value:
tensor([[-0.3728, -0.3287,  0.2641,  0.2815, -0.2725],
        [-0.1719,  0.0949,  0.2583, -0.0935, -0.3799],
        [ 0.1419, -0.0349, -0.2499,  0.0175, -0.3181]])
*******************************************************************************
name: fc2.weight
value:
tensor([[ 0.1687,  0.5752, -0.3429],
        [-0.5647,  0.2967, -0.4464]])


In [6]:
partial_state_dict_fc1 = OrderedDict([
    (
        'fc1.weight',
        2.0*torch.from_numpy(
            np.ones(shape=list(model.fc1.weight.data.shape))
        ).float()
     )
])

partial_state_dict_fc2 = OrderedDict([
    (
        'fc2.weight',
        4.0*torch.from_numpy(
            np.ones(shape=list(model.fc2.weight.data.shape))
        ).float()
     )
])

In [7]:
model.load_state_dict(partial_state_dict_fc1, strict=False)
for param, value in model.state_dict().items():
    print('*'*79)
    print('name: {}\nvalue:\n{}'.format(param, value))

*******************************************************************************
name: fc1.weight
value:
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])
*******************************************************************************
name: fc2.weight
value:
tensor([[ 0.1687,  0.5752, -0.3429],
        [-0.5647,  0.2967, -0.4464]])


In [8]:
model.load_state_dict(partial_state_dict_fc2, strict=False)
for param, value in model.state_dict().items():
    print('*'*79)
    print('name: {}\nvalue:\n{}'.format(param, value))

*******************************************************************************
name: fc1.weight
value:
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])
*******************************************************************************
name: fc2.weight
value:
tensor([[4., 4., 4.],
        [4., 4., 4.]])


In [9]:
model.load_state_dict(original_state_dict, strict=True)
for param, value in model.state_dict().items():
    print('*'*79)
    print('name: {}\nvalue:\n{}'.format(param, value))

*******************************************************************************
name: fc1.weight
value:
tensor([[-0.3728, -0.3287,  0.2641,  0.2815, -0.2725],
        [-0.1719,  0.0949,  0.2583, -0.0935, -0.3799],
        [ 0.1419, -0.0349, -0.2499,  0.0175, -0.3181]])
*******************************************************************************
name: fc2.weight
value:
tensor([[ 0.1687,  0.5752, -0.3429],
        [-0.5647,  0.2967, -0.4464]])


In [10]:
model.load_state_dict(partial_state_dict_fc2, strict=False)
print('*'*79)
print('*'*79)
for param, value in model.state_dict().items():
    print('*'*39)
    print('name: {}\nvalue:\n{}'.format(param, value))
print('*'*79)
print('*'*79)
model.load_state_dict(original_state_dict, strict=True)
for param, value in model.state_dict().items():
    print('*'*39)
    print('name: {}\nvalue:\n{}'.format(param, value))
print('*'*79)
print('*'*79)
model.load_state_dict(partial_state_dict_fc1, strict=False)
for param, value in model.state_dict().items():
    print('*'*39)
    print('name: {}\nvalue:\n{}'.format(param, value))

*******************************************************************************
*******************************************************************************
***************************************
name: fc1.weight
value:
tensor([[-0.3728, -0.3287,  0.2641,  0.2815, -0.2725],
        [-0.1719,  0.0949,  0.2583, -0.0935, -0.3799],
        [ 0.1419, -0.0349, -0.2499,  0.0175, -0.3181]])
***************************************
name: fc2.weight
value:
tensor([[4., 4., 4.],
        [4., 4., 4.]])
*******************************************************************************
*******************************************************************************
***************************************
name: fc1.weight
value:
tensor([[-0.3728, -0.3287,  0.2641,  0.2815, -0.2725],
        [-0.1719,  0.0949,  0.2583, -0.0935, -0.3799],
        [ 0.1419, -0.0349, -0.2499,  0.0175, -0.3181]])
***************************************
name: fc2.weight
value:
tensor([[ 0.1687,  0.5752, -0.3429],
        [-