In [1]:
from torch.optim import Adam
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import LongTensor

In [2]:
torch.manual_seed(1000)
# torch.backends.cudnn.enabled = False # disables CUDNN non-determinism
# torch.cuda.manual_seed(1000) # if we are training on gpu

<torch._C.Generator at 0x7f6e901dff48>

In [3]:
BATCH_SIZE = 10
NUM_CLASSES = 10

In [4]:
arch = "CPU"

In [5]:
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net (
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear (400 -> 120)
  (fc2): Linear (120 -> 84)
  (fc3): Linear (84 -> 10)
)


In [6]:
optimizer = Adam(net.parameters())
crit = nn.CrossEntropyLoss()

In [7]:
for epoch in range(10):
    for minibatch_num in range(10):
        batch_x = Variable(torch.randn(BATCH_SIZE, 1, 32, 32))
        batch_y = Variable(LongTensor(BATCH_SIZE).random_(NUM_CLASSES))
        optimizer.zero_grad()
        out = net(batch_x)
        loss = crit(out, batch_y)
        loss.backward()
        optimizer.step()

In [8]:
CHECKPOINT_NAME = "check2.pth.tar"

In [9]:
def save_checkpoint(state, filename=CHECKPOINT_NAME):
    torch.save(state, filename)

In [10]:
save_checkpoint({
    'epoch': epoch + 1,
    'minibatch': minibatch_num,
    'arch': arch,
    'state_dict': net.state_dict(),
    'optimizer' : optimizer.state_dict(),
})

### Resuming

In [3]:
checkpoint1 = torch.load("check1.pth.tar")
# start_epoch1 = checkpoint['epoch']
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [4]:
checkpoint2 = torch.load("check2.pth.tar")
# start_epoch1 = checkpoint['epoch']
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [12]:
for k in checkpoint1.keys():
    if hasattr(checkpoint1[k], "__iter__"):
        for x, y in zip(checkpoint1[k], checkpoint2[k]):
            print("a")
            print(x == y)
    else:
        print("a")
        print(x == y)

a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True
a
True


In [6]:
checkpoint2["state_dict"]

OrderedDict([('conv1.weight', 
              (0 ,0 ,.,.) = 
                0.0459 -0.0957 -0.1438  0.0353  0.1765
               -0.1957 -0.0230 -0.0557  0.1446  0.1391
               -0.1072 -0.1745 -0.1753 -0.0568 -0.0291
               -0.1190 -0.0982  0.1426  0.1327  0.0398
               -0.0962  0.0867  0.0861  0.1489 -0.0354
              
              (1 ,0 ,.,.) = 
               -0.1695 -0.1063 -0.1856  0.1036  0.1507
               -0.1613  0.1193  0.1555 -0.1622  0.1716
                0.1310  0.1708 -0.1047 -0.0360  0.0709
               -0.1911 -0.0087  0.1839  0.1247 -0.0659
               -0.1008  0.0836  0.0143 -0.0502  0.1418
              
              (2 ,0 ,.,.) = 
               -0.1809  0.1400  0.1396 -0.1633  0.0681
                0.0236  0.1100  0.1405  0.0320 -0.1052
                0.1340 -0.2007  0.1560  0.0405 -0.1908
                0.1879 -0.1800  0.0832 -0.0926  0.0919
               -0.1366 -0.1788  0.0962 -0.0862 -0.0380
              
            