In [30]:
import torch
import torch.nn as nn

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [32]:
class SimpleResNet(nn.Module):
    def __init__(self):
        super(SimpleResNet, self).__init__()

        self.relu = nn.ReLU()

        self.conv0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        self.block11 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.block12 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2, bias=False)

        self.block21 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32)
        )

        self.block22 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32)
        )

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=2, bias=False)

        self.block31 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.block32 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.avg_pool = nn.AvgPool2d(8)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        out0 = self.conv0(x)
        out1 = self.block11(out0)
        out1 = self.block12(out1)

        res2 = self.conv2(out1)
        out2 = self.block21(out1)
        out2 = self.block22(out2)
        out2 += res2
        out2 = self.relu(out2)

        res3 = self.conv3(out2)
        out3 = self.block31(out2)
        out3 = self.block32(out3)
        out3 += res3
        out3 = self.relu(out3)

        out3 = self.avg_pool(out3)
        out3 = self.flatten(out3)
        out = self.fc(out3)

        return out

In [33]:
model = SimpleResNet().to(device)

In [34]:
optimizer = torch.optim.Adam(model.parameters())

In [35]:
checkpoint = torch.load("checkpoint/resnet_cifar10_checkpoint_epoch_1.ckpt")

In [36]:
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [37]:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [38]:
opt_list = []
epochs = 50

for i in range(1, epochs + 1):
    path = f"checkpoint/resnet_cifar10_checkpoint_epoch_{i}.ckpt"
    optimizer = torch.optim.Adam(model.parameters())
    optimizer.load_state_dict((torch.load(path))['optimizer_state_dict'])
    opt_list.append(optimizer)

In [40]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
    lr: 1.52587890625e-08
    weight_decay: 0
)

In [41]:
model

SimpleResNet(
  (relu): ReLU()
  (conv0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block11): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block12): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, 

In [64]:
for weight in model.parameters():
    print(weight)

Parameter containing:
tensor([[[[ 1.5131e-01,  9.5123e-03, -1.2547e-01],
          [ 4.0993e-03,  4.8594e-02, -6.2425e-02],
          [ 9.7352e-02,  6.3569e-02,  7.4559e-02]],

         [[-7.3820e-02,  8.8872e-02, -1.9021e-01],
          [-1.0396e-01,  1.0133e-01, -1.8919e-01],
          [ 1.3992e-01,  5.3560e-02, -8.1676e-02]],

         [[ 1.4386e-01, -9.3858e-02, -1.9801e-01],
          [ 1.8050e-01,  1.3746e-01, -4.0536e-02],
          [ 1.8444e-01, -1.2090e-01, -1.6080e-01]]],


        [[[-4.4122e-02, -1.2965e-02,  1.4578e-01],
          [-1.8020e-01, -1.6643e-01, -5.3410e-02],
          [ 1.5892e-01, -1.4168e-01,  1.2104e-01]],

         [[-1.9841e-02,  1.4268e-01, -1.4450e-01],
          [ 6.5984e-02,  1.5953e-01,  1.5550e-01],
          [ 1.0923e-01, -6.9869e-02, -1.3199e-02]],

         [[-6.6539e-02, -1.3752e-01,  1.2914e-02],
          [-3.5294e-02, -6.8667e-02,  1.9269e-01],
          [ 1.3438e-02,  1.1676e-02, -1.2238e-01]]],


        [[[ 1.6678e-01,  2.7475e-02,  1.4112

In [65]:
for name, weight in model.named_parameters():
    print(name, weight)

conv0.0.weight Parameter containing:
tensor([[[[ 1.5131e-01,  9.5123e-03, -1.2547e-01],
          [ 4.0993e-03,  4.8594e-02, -6.2425e-02],
          [ 9.7352e-02,  6.3569e-02,  7.4559e-02]],

         [[-7.3820e-02,  8.8872e-02, -1.9021e-01],
          [-1.0396e-01,  1.0133e-01, -1.8919e-01],
          [ 1.3992e-01,  5.3560e-02, -8.1676e-02]],

         [[ 1.4386e-01, -9.3858e-02, -1.9801e-01],
          [ 1.8050e-01,  1.3746e-01, -4.0536e-02],
          [ 1.8444e-01, -1.2090e-01, -1.6080e-01]]],


        [[[-4.4122e-02, -1.2965e-02,  1.4578e-01],
          [-1.8020e-01, -1.6643e-01, -5.3410e-02],
          [ 1.5892e-01, -1.4168e-01,  1.2104e-01]],

         [[-1.9841e-02,  1.4268e-01, -1.4450e-01],
          [ 6.5984e-02,  1.5953e-01,  1.5550e-01],
          [ 1.0923e-01, -6.9869e-02, -1.3199e-02]],

         [[-6.6539e-02, -1.3752e-01,  1.2914e-02],
          [-3.5294e-02, -6.8667e-02,  1.9269e-01],
          [ 1.3438e-02,  1.1676e-02, -1.2238e-01]]],


        [[[ 1.6678e-01,  2.74