In [1]:
from torch.optim import Adam
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import LongTensor

In [2]:
torch.manual_seed(1000)
# torch.backends.cudnn.enabled = False # disables CUDNN non-determinism
# torch.cuda.manual_seed(1000) # if we are training on gpu

<torch._C.Generator at 0x7f6e901dff48>

In [3]:
BATCH_SIZE = 10
NUM_CLASSES = 10

In [4]:
arch = "CPU"

In [5]:
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net (
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear (400 -> 120)
  (fc2): Linear (120 -> 84)
  (fc3): Linear (84 -> 10)
)


In [6]:
optimizer = Adam(net.parameters())
crit = nn.CrossEntropyLoss()

In [7]:
for epoch in range(10):
    for minibatch_num in range(10):
        batch_x = Variable(torch.randn(BATCH_SIZE, 1, 32, 32))
        batch_y = Variable(LongTensor(BATCH_SIZE).random_(NUM_CLASSES))
        optimizer.zero_grad()
        out = net(batch_x)
        loss = crit(out, batch_y)
        loss.backward()
        optimizer.step()

In [8]:
CHECKPOINT_NAME = "check2.pth.tar"

In [9]:
def save_checkpoint(state, filename=CHECKPOINT_NAME):
    torch.save(state, filename)

save_checkpoint({
    'epoch': epoch + 1,
    'minibatch': minibatch_num,
    'arch': arch,
    'state_dict': net.state_dict(),
    'optimizer' : optimizer.state_dict(),
})

### Resuming

In [3]:
checkpoint1 = torch.load("check1.pth.tar")
# start_epoch1 = checkpoint['epoch']
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [4]:
checkpoint2 = torch.load("check2.pth.tar")
# start_epoch1 = checkpoint['epoch']
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [13]:
for k in checkpoint1.keys():
    if hasattr(checkpoint1[k], "__iter__"):
        for x, y in zip(checkpoint1[k], checkpoint2[k]):
            print(x == y)
    else:
        print(x == y)

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
