In [1]:
#Saving and loading models
#Import the packages
import torch
from torch import nn
from torchvision import transforms, datasets
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

In [2]:
#Import the dataset
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])

trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [3]:
image, label = next(iter(trainloader))
image.shape

torch.Size([64, 1, 28, 28])

In [4]:
#Building a model
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 10)
        
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))        
        x = self.dropout(F.relu(self.fc3(x)))        
        x = self.dropout(F.relu(self.fc4(x)))        
    
        x = F.log_softmax(self.fc5(x), dim=1)
        
        return x

In [6]:
#Training the model
from torch import optim

model = Classifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.003)

epochs = 5
for e in range(epochs):
    running_loss = 0
    print("Pass : ", e+1, " ...")
    for images, labels in trainloader:
        optimizer.zero_grad()
        
        logps = model(images)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    else:
        validation_loss = 0
        accuracy = 0
        with torch.no_grad():
            model.eval()
            for images, labels in testloader:
                logps = model(images)
                loss = criterion(logps, labels)
                ps = torch.exp(logps)
                
                top_p, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor))
                
                validation_loss += loss.item()
        print("Training loss : ", running_loss/len(trainloader))
        print("Test loss : ", validation_loss/len(testloader))
        print("Accuracy : ", accuracy/len(testloader))
        
        model.train()

Pass :  1  ...
Training loss :  0.6561651755688287
Test loss :  0.4650468892856702
Accuracy :  tensor(0.8310)
Pass :  2  ...
Training loss :  0.508288616326445
Test loss :  0.40141186132423406
Accuracy :  tensor(0.8572)
Pass :  3  ...
Training loss :  0.4795622202093159
Test loss :  0.3940083250752898
Accuracy :  tensor(0.8580)
Pass :  4  ...
Training loss :  0.4560041187891065
Test loss :  0.3765881781670839
Accuracy :  tensor(0.8678)
Pass :  5  ...
Training loss :  0.44298516380697933
Test loss :  0.3492640993837863
Accuracy :  tensor(0.8749)


In [10]:
#Print the model parameters
print(model)
print(model.state_dict().keys())

Classifier(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=10, bias=True)
  (dropout): Dropout(p=0.2)
)
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias', 'fc5.weight', 'fc5.bias'])


In [12]:
#Saving the state of the model
torch.save(model.state_dict(), 'checkpoint.pth')

In [15]:
#Loading the state of the model
state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())

odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias', 'fc5.weight', 'fc5.bias'])


In [29]:
#Load state dict to the network
model.load_state_dict(state_dict, strict=False)
print(model.state_dict().keys())

odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias', 'fc5.weight', 'fc5.bias'])
