In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

import random

from collections import OrderedDict

In [2]:
BATCH_SIZE = 4
IN_DIM = 5
HIDDEN_DIM = 3
OUT_DIM = 2
NUM_EPOCHS = 100

In [3]:
# Set seeds.
torch.manual_seed(43865)
random.seed(43865)
np.random.seed(43865)

In [4]:
encoder_state_dict = OrderedDict()
decoder_state_dict = OrderedDict()
encoder_state_dict['weight'] = torch.randn(HIDDEN_DIM, IN_DIM)
decoder_state_dict['weight'] = torch.randn(OUT_DIM, HIDDEN_DIM)
print('encoder state:\n{}'.format(encoder_state_dict['weight']))
print('decoder state:\n{}'.format(decoder_state_dict['weight']))

encoder state:
tensor([[-0.4795, -0.8184, -1.7956, -0.7481,  0.4085],
        [ 0.6789,  0.0540, -0.5776,  0.4270, -0.0400],
        [ 0.8674, -0.5720, -1.3464, -0.0660,  1.0107]])
decoder state:
tensor([[-0.6348,  0.1551, -0.2081],
        [-1.1093,  0.7112, -1.0249]])


In [5]:
# Set seeds.
torch.manual_seed(43865)
random.seed(43865)
np.random.seed(43865)
model_state_dict = OrderedDict()
model_state_dict['fc1.weight'] = torch.randn(HIDDEN_DIM, IN_DIM)
model_state_dict['fc2.weight'] = torch.randn(OUT_DIM, HIDDEN_DIM)
for param, value in model_state_dict.items():
    print('{} state:\n{}'.format(param, value))

fc1.weight state:
tensor([[-0.4795, -0.8184, -1.7956, -0.7481,  0.4085],
        [ 0.6789,  0.0540, -0.5776,  0.4270, -0.0400],
        [ 0.8674, -0.5720, -1.3464, -0.0660,  1.0107]])
fc2.weight state:
tensor([[-0.6348,  0.1551, -0.2081],
        [-1.1093,  0.7112, -1.0249]])


In [6]:
# Define encoder, decoder, and complete model.
encoder = nn.Linear(
    in_features=IN_DIM, out_features=HIDDEN_DIM, bias=False
)
decoder = nn.Linear(
    in_features=HIDDEN_DIM, out_features=OUT_DIM, bias=False
)

class Model(nn.Module):
    r"""Basic model.
    """

    def __init__(self):
        r"""The initializer.
        """
        super(Model, self).__init__()
        self.fc1 = nn.Linear(
            in_features=IN_DIM, out_features=HIDDEN_DIM, bias=False
        )
        self.fc2 = nn.Linear(
            in_features=HIDDEN_DIM, out_features=OUT_DIM, bias=False
        )

    def forward(self, x):
        r"""Implements the forward pass.

        Parameters
        ----------
        x:
            Input tensor.
            SHAPE: [B, input_dim].

        Returns
        -------
        feature (implicit):
            The tensor of features of the input.
            SHAPE: [B, output_dim].
        """
        return self.fc2(self.fc1(x))

model = Model()

print('encoder state dict:\n{}'.format(encoder.state_dict().keys()))
print('decoder state dict:\n{}'.format(decoder.state_dict().keys()))
print('model state dict:\n{}'.format(model.state_dict().keys()))

encoder state dict:
odict_keys(['weight'])
decoder state dict:
odict_keys(['weight'])
model state dict:
odict_keys(['fc1.weight', 'fc2.weight'])


In [7]:
# Set the states for all the models.
_ = encoder.load_state_dict(state_dict=encoder_state_dict)
_ = decoder.load_state_dict(state_dict=decoder_state_dict)
_ = model.load_state_dict(state_dict=model_state_dict)

In [8]:
encoder_optim = optim.SGD(params=encoder.parameters(), lr=1e-3)
decoder_optim = optim.SGD(params=decoder.parameters(), lr=1e-3)
model_optim = optim.SGD(params=model.parameters(), lr=1e-3)

In [9]:
# Encoder-decoder training.
torch.manual_seed(43865)
random.seed(43865)
np.random.seed(43865)
# Create data.
x = torch.randn(size=(BATCH_SIZE, IN_DIM))
y = torch.randn(size=(BATCH_SIZE, OUT_DIM))
# Set to training mode.
encoder = encoder.train()
decoder = decoder.train()
encoder_decoder_loss_list = []
for epoch in range(NUM_EPOCHS):
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()
    y_pred = decoder(encoder(x))
    loss = torch.mean((y_pred - y)**2)
    loss_value = loss.detach().numpy()
    print('epoch: {}\t\tloss: {}'.format(epoch, loss_value))
    encoder_decoder_loss_list.append(loss_value)
    loss.backward()
    encoder_optim.step()
    decoder_optim.step()

epoch: 0		loss: 8.983654022216797
epoch: 1		loss: 8.795275688171387
epoch: 2		loss: 8.612491607666016
epoch: 3		loss: 8.435084342956543
epoch: 4		loss: 8.262846946716309
epoch: 5		loss: 8.09557819366455
epoch: 6		loss: 7.933091163635254
epoch: 7		loss: 7.775203704833984
epoch: 8		loss: 7.621746063232422
epoch: 9		loss: 7.472555160522461
epoch: 10		loss: 7.327475070953369
epoch: 11		loss: 7.186354160308838
epoch: 12		loss: 7.049053192138672
epoch: 13		loss: 6.915436267852783
epoch: 14		loss: 6.785372734069824
epoch: 15		loss: 6.6587395668029785
epoch: 16		loss: 6.5354180335998535
epoch: 17		loss: 6.415295124053955
epoch: 18		loss: 6.298261642456055
epoch: 19		loss: 6.184215545654297
epoch: 20		loss: 6.073055267333984
epoch: 21		loss: 5.964685440063477
epoch: 22		loss: 5.8590168952941895
epoch: 23		loss: 5.755959510803223
epoch: 24		loss: 5.655431270599365
epoch: 25		loss: 5.557348728179932
epoch: 26		loss: 5.4616379737854
epoch: 27		loss: 5.368220329284668
epoch: 28		loss: 5.27702856063

In [10]:
# Encoder-decoder training.
torch.manual_seed(43865)
random.seed(43865)
np.random.seed(43865)
# Create data.
x_model = torch.randn(size=(BATCH_SIZE, IN_DIM))
y_model = torch.randn(size=(BATCH_SIZE, OUT_DIM))
# Set to training mode.
model = model.train()
model_loss_list = []
for epoch in range(NUM_EPOCHS):
    model_optim.zero_grad()
    y_pred_model = model(x_model)
    loss_model = torch.mean((y_pred_model - y_model)**2)
    loss_model_value = loss_model.detach().numpy()
    print('epoch: {}\t\tloss: {}'.format(epoch, loss_model_value))
    model_loss_list.append(loss_model_value)
    loss_model.backward()
    model_optim.step()

epoch: 0		loss: 8.983654022216797
epoch: 1		loss: 8.795275688171387
epoch: 2		loss: 8.612491607666016
epoch: 3		loss: 8.435084342956543
epoch: 4		loss: 8.262846946716309
epoch: 5		loss: 8.09557819366455
epoch: 6		loss: 7.933091163635254
epoch: 7		loss: 7.775203704833984
epoch: 8		loss: 7.621746063232422
epoch: 9		loss: 7.472555160522461
epoch: 10		loss: 7.327475070953369
epoch: 11		loss: 7.186354160308838
epoch: 12		loss: 7.049053192138672
epoch: 13		loss: 6.915436267852783
epoch: 14		loss: 6.785372734069824
epoch: 15		loss: 6.6587395668029785
epoch: 16		loss: 6.5354180335998535
epoch: 17		loss: 6.415295124053955
epoch: 18		loss: 6.298261642456055
epoch: 19		loss: 6.184215545654297
epoch: 20		loss: 6.073055267333984
epoch: 21		loss: 5.964685440063477
epoch: 22		loss: 5.8590168952941895
epoch: 23		loss: 5.755959510803223
epoch: 24		loss: 5.655431270599365
epoch: 25		loss: 5.557348728179932
epoch: 26		loss: 5.4616379737854
epoch: 27		loss: 5.368220329284668
epoch: 28		loss: 5.27702856063

In [12]:
# Sanity checks.
print('data created in both cases is the same: {}'.format(
    torch.all(torch.eq(x, x_model)) and torch.all(torch.eq(y, y_model))
))
print('losses from both the cases match: {}'.format(
    len(encoder_decoder_loss_list) == len(model_loss_list) and \
    all([
        encoder_decoder_loss_list[idx] == model_loss_list[idx] \
            for idx in range(len(encoder_decoder_loss_list))
    ])
))

data created in both cases is the same: True
losses from both the cases match: True
