In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchsummary import summary

import time
import random

import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  return torch._C._cuda_getDeviceCount() > 0


device(type='cpu')

In [3]:
#### FUNCTIONS DECLARATION ###

In [4]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) #kernel = filter size. #out = number of filters
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
      
    
    def forward(self, t):   
    #hidden conv layers
        t = self.conv1(t) 
        t = F.relu(t) #activation function
        t = F.max_pool2d(t, kernel_size=2, stride=2)
    
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
    
    #hidden linear layers.
        t = t.reshape(-1, 12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
    
        t = self.fc2(t)
        t = F.relu(t)
    
    #output layer
        t = self.out(t)
  
        return t

In [5]:
def get_num_correct(preds, labels): 
    return preds.argmax(dim=1).eq(labels).sum().item()

In [6]:
def training(network, loader, optimizer, num_epochs):
    for epoch in range(num_epochs):

        total_loss = 0
        total_correct = 0

        for batch in loader: #così prendo tutti i batch e quindi il dataset completo
            images, labels = batch
    
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
    
            optimizer.zero_grad() #gradient must be reset every time, otherwise it's added to the previous one
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)
    
        print(f'epoch: {epoch}, total_correct: {total_correct}, loss: {total_loss}')

In [7]:
def testing(network, dataset, loader):
    total_correct = 0
    for batch in loader:
        images, labels = batch
        predictions = network(images)
        correct = get_num_correct(predictions, labels)
        total_correct += correct
    return(f'total correct: {total_correct} / {len(dataset)}. Accuracy: {(total_correct/len(dataset))*100}')

In [8]:
def find_indices(idx):
    indices = []
    for i in range(len(idx)):
        if idx[i].item() == True:
            indices.append(i)
    return indices


def split_dataset(dataset):
    subsets = []
    for i in range(10):
        idx = mnist.targets==i
        indices = find_indices(idx)
        subset = torch.utils.data.Subset(dataset, indices)
        #print('subset:', i, 'len: ', len(subset))
        subsets.append(subset)
    return subsets

In [9]:
def example_replay(N, network, memory_dataset, train_dataset, memory_loader, train_loader):
    crumbs = []
    for digit in digits:
        l = len(digit)
        indices = random.sample(range(1,l), N)
        crumb = torch.utils.data.Subset(digit, indices)
        crumbs.append(crumb)
    crumbs.append(train_dataset)
    dirty_dataset = torch.utils.data.ConcatDataset(crumbs)
    dirty_loader = torch.utils.data.DataLoader(dirty_dataset, batch_size=100, shuffle=True)
    print(f'Sto rinfrescando la memoria con {N} elementi da mnist per ogni classe. dirty: {len(dirty_dataset)}')
    opt_replay = optim.Adam(network.parameters(), lr=0.01) 
    training(network, dirty_loader, opt_replay, 10)
    print(f'Results for mnist: {testing(network, memory_dataset, memory_loader)}') 
    print(f'Results for training-dataset: {testing(network, train_dataset, train_loader)}')
    print('    ')

In [10]:
#### DATASETS ###

In [11]:
USPS_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])

In [12]:
mnist = torchvision.datasets.MNIST(
                        root='./data'
                       ,train=True
                       ,download=True
                       ,transform = transforms.Compose([transforms.ToTensor()])
                        )

fashion = torchvision.datasets.FashionMNIST(
                        root='./data/FashionMNIST'
                        ,train=True
                        ,download=True
                        ,transform=transforms.Compose([transforms.ToTensor()])
                        )

usps = torchvision.datasets.USPS("./data"
                     , train=True
                     , download=True
                     , transform = USPS_transform
                    )


In [13]:
mnist_loader = torch.utils.data.DataLoader(mnist, batch_size = 100, shuffle=True)
fashion_loader = torch.utils.data.DataLoader(fashion, batch_size = 100, shuffle=True)
usps_loader = torch.utils.data.DataLoader(usps, batch_size = 100, shuffle=True)

In [14]:
######## CREIAMO IL NOSTRO NETWORK #####

network = Network()
optimizer = optim.Adam(network.parameters(), lr=0.01)

In [15]:
######## FACCIAMO IL TRAINING SU MNIST #######

training(network, mnist_loader, optimizer, 10)

epoch: 0, total_correct: 56063, loss: 126.12270255759358
epoch: 1, total_correct: 58590, loss: 48.19147628429346
epoch: 2, total_correct: 58742, loss: 42.75087843975052
epoch: 3, total_correct: 58888, loss: 38.25904893380357
epoch: 4, total_correct: 58951, loss: 38.96840668073855
epoch: 5, total_correct: 59012, loss: 36.07952815084718
epoch: 6, total_correct: 59060, loss: 35.85998096416006
epoch: 7, total_correct: 59093, loss: 34.70628324430436
epoch: 8, total_correct: 59124, loss: 33.69260239775758
epoch: 9, total_correct: 59090, loss: 34.65646258946799


In [16]:
#### SALVIAMO QUESTA RETE #####

torch.save(network.state_dict(), 'PATHS/mnist_trained.pth')

In [17]:
#### FACCIAMO IL TEST SU MNIST/USPS/FASHION

print('mnist: ', (testing(network, mnist, mnist_loader)))
print('usps: ', (testing(network, usps, usps_loader)))
print('fashion: ', (testing(network, fashion, fashion_loader)))

mnist:  total correct: 59317 / 60000. Accuracy: 98.86166666666666
usps:  total correct: 3260 / 7291. Accuracy: 44.712659443149086
fashion:  total correct: 6610 / 60000. Accuracy: 11.016666666666666


In [21]:
#### FACCIAMO IL TRAINING SU USPS SENZA MEMORIA ###

training(network, usps_loader, optimizer, 10)

epoch: 0, total_correct: 6610, loss: 24.15700989216566
epoch: 1, total_correct: 6999, loss: 10.273860404267907
epoch: 2, total_correct: 7097, loss: 6.829998582135886
epoch: 3, total_correct: 7115, loss: 6.304347493685782
epoch: 4, total_correct: 7174, loss: 4.059767778497189
epoch: 5, total_correct: 7168, loss: 4.262478807242587
epoch: 6, total_correct: 7181, loss: 3.609099432011135
epoch: 7, total_correct: 7221, loss: 2.2873658523894846
epoch: 8, total_correct: 7239, loss: 1.772407375217881
epoch: 9, total_correct: 7241, loss: 1.6851980885548983


In [22]:
### FACCIAMO IL TEST SU MNIST/USPS ###
print('mnist: ', (testing(network, mnist, mnist_loader)))
print('usps: ', (testing(network, usps, usps_loader)))

mnist:  total correct: 51287 / 60000. Accuracy: 85.47833333333334
usps:  total correct: 7251 / 7291. Accuracy: 99.45137841174049


In [24]:
### RIPRENDIAMO LA RETE MEMORIZZATA AL SOLO MNIST TRAINED E PROVIAMO A FARE IL TRAINING SU USPS CON MEMORIA DI MNIST ###
### il nostro obiettivo sarà quello di migliorare l'accuratezza del mnist dell' 85% ###
digits = split_dataset(mnist)
network_mnist = Network()
for N in (1,2,5,10,50,100,500,1000,2000,5000):
    network_mnist.load_state_dict(torch.load('PATHS/mnist_trained.pth'))
    example_replay(N, network_mnist, mnist, usps, mnist_loader, usps_loader)

Sto rinfrescando la memoria con 1 elementi da mnist per ogni classe. dirty: 7301
epoch: 0, total_correct: 6572, loss: 24.672026626765614
epoch: 1, total_correct: 7026, loss: 9.001401002286002
epoch: 2, total_correct: 7097, loss: 7.0303090792149305
epoch: 3, total_correct: 7135, loss: 5.849549734033644
epoch: 4, total_correct: 7094, loss: 7.842577479314059
epoch: 5, total_correct: 7185, loss: 3.8203589333425043
epoch: 6, total_correct: 7186, loss: 4.045394099961413
epoch: 7, total_correct: 7182, loss: 4.410895082866773
epoch: 8, total_correct: 7204, loss: 2.733721412077955
epoch: 9, total_correct: 7231, loss: 2.408134952536784
Results for mnist: total correct: 49979 / 60000. Accuracy: 83.29833333333333
Results for training-dataset: total correct: 7227 / 7291. Accuracy: 99.1222054587848
    
Sto rinfrescando la memoria con 2 elementi da mnist per ogni classe. dirty: 7311
epoch: 0, total_correct: 6621, loss: 23.82147867232561
epoch: 1, total_correct: 7025, loss: 10.230797497555614
epoch: 

In [25]:
########################################################################################################

In [26]:
# PROVIAMO A FARE LA STESSA COSA CON IL FASHION ####

In [28]:
network_mnist.load_state_dict(torch.load('PATHS/mnist_trained.pth'))
optimiz = optim.Adam(network_mnist.parameters(), lr=0.01)
training(network_mnist, fashion_loader, optimiz, 10)

epoch: 0, total_correct: 46028, loss: 390.22389698028564
epoch: 1, total_correct: 50604, loss: 256.0635912567377
epoch: 2, total_correct: 51372, loss: 232.65366527438164
epoch: 3, total_correct: 51786, loss: 222.0508035570383
epoch: 4, total_correct: 51973, loss: 215.52699786424637
epoch: 5, total_correct: 52230, loss: 208.8766010850668
epoch: 6, total_correct: 52391, loss: 205.15099634230137
epoch: 7, total_correct: 52530, loss: 201.1554946154356
epoch: 8, total_correct: 52626, loss: 200.67799077928066
epoch: 9, total_correct: 52670, loss: 196.334243029356


In [30]:
print('mnist: ', (testing(network_mnist, mnist, mnist_loader)))
print('fashion: ', (testing(network_mnist, fashion, fashion_loader)))

mnist:  total correct: 7154 / 60000. Accuracy: 11.923333333333334
fashion:  total correct: 51958 / 60000. Accuracy: 86.59666666666666


In [31]:
#qui si vede meglio che il mnist è peggiorato tantissimo!! se facciamo l'example replay come migliorerà?

In [32]:
for N in (1,2,5,10,50,100,500,1000,2000,5000):
    network_mnist.load_state_dict(torch.load('PATHS/mnist_trained.pth'))
    example_replay(N, network_mnist, mnist, fashion, mnist_loader, fashion_loader)

Sto rinfrescando la memoria con 1 elementi da mnist per ogni classe. dirty: 60010
epoch: 0, total_correct: 45884, loss: 394.37612199783325
epoch: 1, total_correct: 50583, loss: 257.1860605329275
epoch: 2, total_correct: 51301, loss: 236.10310129541904
epoch: 3, total_correct: 51649, loss: 224.20018675923347
epoch: 4, total_correct: 51997, loss: 216.77026434242725
epoch: 5, total_correct: 52226, loss: 211.31244352459908
epoch: 6, total_correct: 52321, loss: 208.45472189038992
epoch: 7, total_correct: 52436, loss: 204.71128568053246
epoch: 8, total_correct: 52604, loss: 199.52684369683266
epoch: 9, total_correct: 52669, loss: 197.20265114307404
Results for mnist: total correct: 11966 / 60000. Accuracy: 19.94333333333333
Results for training-dataset: total correct: 52224 / 60000. Accuracy: 87.03999999999999
    
Sto rinfrescando la memoria con 2 elementi da mnist per ogni classe. dirty: 60020
epoch: 0, total_correct: 45535, loss: 397.87415140867233
epoch: 1, total_correct: 50512, loss: 25