In [1]:
epochs = 6
epochs_with_freezed_params = 2
learning_rate = 1e-3

In [2]:
import torch.nn as nn
import torch
import numpy as np
from bokeh.plotting import figure
from bokeh.io import show
from bokeh.models import LinearAxis, Range1d
from itertools import chain

class ConvNet(nn.Module): 
     def __init__(self): 
         super(ConvNet, self).__init__() 
         self.acc_list = []
         self.loss_list = []
         self.layer1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2), 
            nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))  
         self.drop_out = nn.Dropout() 
         self.fc1 = nn.Linear(14 * 14 * 32, 500) 
         self.fc2 = nn.Linear(500, 10)
         
     def forward(self, x): 
         out = self.layer1(x) 
         out = out.reshape(out.size(0), -1) 
         out = self.drop_out(out) 
         out = self.fc1(out) 
         out = self.fc2(out) 
         return out      
     
     def train_model(self, train_loader, optimizer, loss_fn, freeze_conv_layers = False, need_unfreezing_action = False):
         if freeze_conv_layers:
             self.freeze_conv_layer_params()
             
         self.train()
         for epoch in range(epochs):
            if need_unfreezing_action:
                self.try_to_unfreeze(epoch)
            for index, (images, labels) in enumerate(train_loader):
               outputs = self(images)
               loss = loss_fn(outputs, labels)
      
               optimizer.zero_grad()
               loss.backward()
               optimizer.step()

               total = labels.size(dim=0)
               _, predicted = torch.max(outputs.data, dim=1)
               correct = (predicted == labels).sum().item()

               if index % 100 == 99:
                  print(f"Epoch [{epoch + 1}/{epochs}], Step [{index + 1}/{len(train_loader)}], Loss: {loss.item()}, Accuracy: {(correct / total) * 100}%")  

       
     def test_model(self, test_loader, loss_fn):
         self.eval()
         with torch.no_grad():
            matched_labels_count = 0
            total_labels_count = 0
            for images, labels in test_loader:
                  outputs = self(images)
                  loss = loss_fn(outputs, labels)
                  
                  self.loss_list.append(loss.item())

                  _, prediction_labels = torch.max(outputs.data, dim=1)
                  total_labels_count += labels.size(dim=0)
                  matched_labels_count += (prediction_labels == labels).sum().item()
                  
                  self.acc_list.append(matched_labels_count / total_labels_count)

            print(f"Accuracy: {matched_labels_count / total_labels_count}")
            
     def visualize(self):
         p = figure(y_axis_label='Loss', width=1000, y_range=(0, 1), title='PyTorch ConvNet results')
         p.extra_y_ranges = {'Accuracy': Range1d(start=0, end=100)}
         p.add_layout(LinearAxis(y_range_name='Accuracy', axis_label='Accuracy (%)'), 'right')
         p.line(np.arange(len(self.loss_list)), self.loss_list)
         p.line(np.arange(len(self.loss_list)), np.array(self.acc_list) * 100, y_range_name='Accuracy', color='red')
         show(p)

     def freeze_conv_layer_params(self):
         for param in chain(self.layer1.parameters(), self.layer2.parameters()):
             param.requires_grad = False  

     def unfreeze_conv_layer_params(self):
         for param in chain(self.layer1.parameters(), self.layer2.parameters()):
             param.requires_grad = True 

     def try_to_unfreeze(self, current_epoch):
         if current_epoch > epochs_with_freezed_params:              
            self.unfreeze_conv_layer_params()

In [3]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize

def train_test_split_loaders(train_dataset, test_dataset, batch_size):
    return (
            DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True), 
            DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
        )

trans = Compose([ToTensor()])

train_mnist = MNIST(root="./MNIST", train=True, transform=trans, download=True)
test_mnist = MNIST(root="./MNIST", train=False, transform=trans)

train_fashion_mnist = FashionMNIST(root="./FashionMNIST", train=True, transform=trans, download=True)
test_fashion_mnist = FashionMNIST(root="./FashionMNIST", train=False, transform=trans)

batch_size = 100
train_mnist_loader, test_mnist_loader = train_test_split_loaders(train_mnist, test_mnist, batch_size)
train_fashion_mnist_loader, test_fashion_mnist_loader = train_test_split_loaders(train_fashion_mnist, test_fashion_mnist, batch_size)


In [4]:
loss_fn = nn.CrossEntropyLoss()
mnist_model = ConvNet()
mnist_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=learning_rate)

In [5]:
fashion_mnist_model = ConvNet()
fashion_mnist_optimizer = torch.optim.Adam(fashion_mnist_model.parameters(), lr=learning_rate)

In [6]:

mnist_model.train_model(train_mnist_loader, mnist_optimizer, loss_fn)
mnist_model.test_model(test_mnist_loader, loss_fn)
torch.save(mnist_model.state_dict(), "./MNIST_model")
mnist_model.visualize()

Epoch [1/6], Step [100/600], Loss: 0.3662891387939453, Accuracy: 88.0%
Epoch [1/6], Step [200/600], Loss: 0.17334100604057312, Accuracy: 97.0%


KeyboardInterrupt: 

In [None]:
fashion_mnist_model.train_model(train_fashion_mnist_loader, fashion_mnist_optimizer, loss_fn)
fashion_mnist_model.test_model(test_fashion_mnist_loader, loss_fn)
torch.save(fashion_mnist_model.state_dict(), "./FashionMNIST_model")
fashion_mnist_model.visualize()

Epoch [1/6], Step [100/600], Loss: 0.649351954460144, Accuracy: 75.0%
Epoch [1/6], Step [200/600], Loss: 0.5036018490791321, Accuracy: 80.0%
Epoch [1/6], Step [300/600], Loss: 0.5208143591880798, Accuracy: 77.0%
Epoch [1/6], Step [400/600], Loss: 0.5240964293479919, Accuracy: 81.0%
Epoch [1/6], Step [500/600], Loss: 0.4630342125892639, Accuracy: 79.0%
Epoch [1/6], Step [600/600], Loss: 0.3398512601852417, Accuracy: 84.0%
Epoch [2/6], Step [100/600], Loss: 0.3794878423213959, Accuracy: 83.0%
Epoch [2/6], Step [200/600], Loss: 0.46919986605644226, Accuracy: 86.0%
Epoch [2/6], Step [300/600], Loss: 0.6601216793060303, Accuracy: 81.0%
Epoch [2/6], Step [400/600], Loss: 0.5186206698417664, Accuracy: 82.0%
Epoch [2/6], Step [500/600], Loss: 0.2892613410949707, Accuracy: 92.0%
Epoch [2/6], Step [600/600], Loss: 0.48812881112098694, Accuracy: 83.0%
Epoch [3/6], Step [100/600], Loss: 0.2348342388868332, Accuracy: 91.0%
Epoch [3/6], Step [200/600], Loss: 0.2664889395236969, Accuracy: 89.0%
Epoch

In [None]:
fashion_mnist_model = ConvNet()
fashion_mnist_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=learning_rate)
fashion_mnist_model.train_model(train_fashion_mnist_loader, fashion_mnist_optimizer, loss_fn)
fashion_mnist_model.test_model(test_fashion_mnist_loader, loss_fn)
torch.save(fashion_mnist_model.state_dict(), "./FashionMNIST_model_with_mnist_params")
fashion_mnist_model.visualize()

Epoch [1/6], Step [100/600], Loss: 2.308201313018799, Accuracy: 6.0%
Epoch [1/6], Step [200/600], Loss: 2.2895514965057373, Accuracy: 12.0%
Epoch [1/6], Step [300/600], Loss: 2.317331075668335, Accuracy: 9.0%
Epoch [1/6], Step [400/600], Loss: 2.3062806129455566, Accuracy: 8.0%
Epoch [1/6], Step [500/600], Loss: 2.3023929595947266, Accuracy: 12.0%
Epoch [1/6], Step [600/600], Loss: 2.3020949363708496, Accuracy: 13.0%
Epoch [2/6], Step [100/600], Loss: 2.3016440868377686, Accuracy: 9.0%
Epoch [2/6], Step [200/600], Loss: 2.306906223297119, Accuracy: 11.0%
Epoch [2/6], Step [300/600], Loss: 2.3007352352142334, Accuracy: 7.000000000000001%
Epoch [2/6], Step [400/600], Loss: 2.3010094165802, Accuracy: 13.0%
Epoch [2/6], Step [500/600], Loss: 2.311950922012329, Accuracy: 8.0%
Epoch [2/6], Step [600/600], Loss: 2.31272292137146, Accuracy: 11.0%
Epoch [3/6], Step [100/600], Loss: 2.308354616165161, Accuracy: 9.0%
Epoch [3/6], Step [200/600], Loss: 2.288602352142334, Accuracy: 11.0%
Epoch [3/6

In [None]:
fashion_mnist_model = ConvNet()
fashion_mnist_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=learning_rate)
fashion_mnist_model.train_model(train_fashion_mnist_loader, fashion_mnist_optimizer, loss_fn, freeze_conv_layers=True)
fashion_mnist_model.test_model(test_fashion_mnist_loader, loss_fn)
torch.save(fashion_mnist_model.state_dict(), "./FashionMNIST_model_with_mnist_params_freezed")
fashion_mnist_model.visualize()

Epoch [1/6], Step [100/600], Loss: 2.3373935222625732, Accuracy: 10.0%
Epoch [1/6], Step [200/600], Loss: 2.320798873901367, Accuracy: 6.0%
Epoch [1/6], Step [300/600], Loss: 2.341214179992676, Accuracy: 5.0%
Epoch [1/6], Step [400/600], Loss: 2.3510124683380127, Accuracy: 8.0%
Epoch [1/6], Step [500/600], Loss: 2.355327844619751, Accuracy: 7.000000000000001%
Epoch [1/6], Step [600/600], Loss: 2.3380839824676514, Accuracy: 15.0%
Epoch [2/6], Step [100/600], Loss: 2.32938289642334, Accuracy: 12.0%
Epoch [2/6], Step [200/600], Loss: 2.325033187866211, Accuracy: 10.0%
Epoch [2/6], Step [300/600], Loss: 2.303967237472534, Accuracy: 17.0%
Epoch [2/6], Step [400/600], Loss: 2.290449619293213, Accuracy: 15.0%
Epoch [2/6], Step [500/600], Loss: 2.293976068496704, Accuracy: 11.0%
Epoch [2/6], Step [600/600], Loss: 2.334038257598877, Accuracy: 8.0%
Epoch [3/6], Step [100/600], Loss: 2.3173398971557617, Accuracy: 11.0%
Epoch [3/6], Step [200/600], Loss: 2.3274481296539307, Accuracy: 7.00000000000

In [None]:
fashion_mnist_model = ConvNet()
fashion_mnist_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=learning_rate)
fashion_mnist_model.train_model(train_fashion_mnist_loader, fashion_mnist_optimizer, loss_fn, freeze_conv_layers=True, \
                                need_unfreezing_action=True)
fashion_mnist_model.test_model(test_fashion_mnist_loader, loss_fn)
torch.save(fashion_mnist_model.state_dict(), "./FashionMNIST_model_with_mnist_params_unfreezed")
fashion_mnist_model.visualize()

Epoch [1/6], Step [100/600], Loss: 2.3158626556396484, Accuracy: 7.000000000000001%
Epoch [1/6], Step [200/600], Loss: 2.312411069869995, Accuracy: 6.0%
Epoch [1/6], Step [300/600], Loss: 2.302726984024048, Accuracy: 8.0%
Epoch [1/6], Step [400/600], Loss: 2.3084301948547363, Accuracy: 6.0%
Epoch [1/6], Step [500/600], Loss: 2.295276403427124, Accuracy: 11.0%
Epoch [1/6], Step [600/600], Loss: 2.296238422393799, Accuracy: 12.0%
Epoch [2/6], Step [100/600], Loss: 2.307976722717285, Accuracy: 7.000000000000001%
Epoch [2/6], Step [200/600], Loss: 2.308424472808838, Accuracy: 6.0%
Epoch [2/6], Step [300/600], Loss: 2.3003296852111816, Accuracy: 13.0%
Epoch [2/6], Step [400/600], Loss: 2.295710563659668, Accuracy: 9.0%
Epoch [2/6], Step [500/600], Loss: 2.329542636871338, Accuracy: 2.0%
Epoch [2/6], Step [600/600], Loss: 2.3218984603881836, Accuracy: 6.0%
Epoch [3/6], Step [100/600], Loss: 2.314357042312622, Accuracy: 7.000000000000001%
Epoch [3/6], Step [200/600], Loss: 2.314119577407837, 