In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchsummary import summary

from colorama import Fore, Style

import time
import random

import numpy as np
import matplotlib.pyplot as plt

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [23]:
#### FUNCTIONS DECLARATION ###

In [24]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6,
                               kernel_size=5)  # kernel = filter size. #out = number of filters
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        # hidden conv layers
        t = self.conv1(t)
        t = F.relu(t)  # activation function
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # hidden linear layers.
        t = t.reshape(-1, 12 * 4 * 4)
        t = self.fc1(t)
        t = F.relu(t)

        t = self.fc2(t)
        t = F.relu(t)

        # output layer
        t = self.out(t)

        return t

In [25]:
def get_num_correct(preds, labels): 
    return preds.argmax(dim=1).eq(labels).sum().item()

In [26]:
def training(network, dataset, loader, lr, num_epochs, test=None):
    optimizer = optim.Adam(network.parameters(), lr=lr)
    for epoch in range(num_epochs):

        total_loss = 0
        total_correct = 0

        for batch in loader:
            images, labels = batch

            preds = network(images)
            loss = F.cross_entropy(preds, labels)

            optimizer.zero_grad()  # gradient must be reset every time, otherwise it's added to the previous one
            loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)

        accuracy = (total_correct / len(dataset)) * 100
        print(f'epoch: {epoch}, loss: {total_loss}, total_correct: {total_correct} / {len(dataset)}, --> {Fore.LIGHTCYAN_EX}Accuracy: {accuracy}{Style.RESET_ALL}')
        if test is not None:
            for t in test:
                print(f"\t\t\t\t {Fore.LIGHTGREEN_EX}Testing back... {testing(network, t[0], t[1])}{Style.RESET_ALL}")


In [27]:
def testing(network, dataset, loader):
    total_correct = 0
    with torch.no_grad():
        for batch in loader:
            images, labels = batch
            predictions = network(images)
            correct = get_num_correct(predictions, labels)
            total_correct += correct
        return (
            f'total correct: {total_correct} / {len(dataset)}. {Fore.LIGHTMAGENTA_EX}Accuracy: {(total_correct / len(dataset)) * 100}{Style.RESET_ALL}')

In [28]:
def find_indices(idx):
    indices = []
    for i in range(len(idx)):
        if idx[i].item() == True:
            indices.append(i)
    return indices


def split_dataset(dataset):
    subsets = []
    for i in range(10):
        idx = mnist.targets==i
        indices = find_indices(idx)
        subset = torch.utils.data.Subset(dataset, indices)
        #print('subset:', i, 'len: ', len(subset))
        subsets.append(subset)
    return subsets

In [37]:
def example_replay(N, digits, network, memory_dataset, train_dataset, memory_loader, train_loader):
    crumbs = []
    for digit in digits:
        l = len(digit)
        indices = random.sample(range(1,l), N)
        crumb = torch.utils.data.Subset(digit, indices)
        crumbs.append(crumb)
    crumbs.append(train_dataset)
    dirty_dataset = torch.utils.data.ConcatDataset(crumbs)
    dirty_loader = torch.utils.data.DataLoader(dirty_dataset, batch_size=100, shuffle=True)
    print(f'Sto rinfrescando la memoria con {N} elementi da mnist per ogni classe. dirty: {len(dirty_dataset)}')
    training(network, dirty_dataset, dirty_loader, lr=1e-3, num_epochs=10)
    print(f'Results for mnist: {testing(network, mnist, mnist_loader)}') 
    print(f'Results for usps: {testing(network, usps, usps_loader)}')
    print(f'Results for svhn: {testing(network, svhn, svhn_loader)}')
    print('    ')

In [30]:
#### DATASETS ###

In [31]:
USPS_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])


SVHN_transform = transforms.Compose([transforms.Resize((28, 28)),
                                    transforms.ToTensor(),
                                    transforms.Grayscale(num_output_channels=1) ])

In [32]:
mnist = torchvision.datasets.MNIST(
                        root='./data'
                       ,train=True
                       ,download=True
                       ,transform = transforms.Compose([transforms.ToTensor()])
                        )

usps = torchvision.datasets.USPS("./data"
                     , train=True
                     , download=True
                     , transform = USPS_transform
                    )

svhn = torchvision.datasets.SVHN(root='./data' ,
                                 split='train' ,
                                 transform=SVHN_transform,
                                 download=True)


Using downloaded and verified file: ./data\train_32x32.mat


In [33]:
mnist_loader = torch.utils.data.DataLoader(mnist, batch_size = 100, shuffle=True)
usps_loader = torch.utils.data.DataLoader(usps, batch_size = 100, shuffle=True)
svhn_loader = torch.utils.data.DataLoader(svhn, batch_size=100, shuffle=True)

In [14]:
######## CREIAMO IL NOSTRO NETWORK #####

network = Network()

In [15]:
######## FACCIAMO IL TRAINING SU MNIST #######

training(network, mnist, mnist_loader, lr=0.01, num_epochs=10)

epoch: 0, loss: 121.96751784067601, total_correct: 56114 / 60000, --> [96mAccuracy: 93.52333333333334[0m
epoch: 1, loss: 46.97961293812841, total_correct: 58612 / 60000, --> [96mAccuracy: 97.68666666666667[0m
epoch: 2, loss: 39.94442648696713, total_correct: 58832 / 60000, --> [96mAccuracy: 98.05333333333334[0m
epoch: 3, loss: 38.81191668700194, total_correct: 58916 / 60000, --> [96mAccuracy: 98.19333333333333[0m
epoch: 4, loss: 35.137164528656285, total_correct: 59031 / 60000, --> [96mAccuracy: 98.385[0m
epoch: 5, loss: 35.94039195960795, total_correct: 59064 / 60000, --> [96mAccuracy: 98.44000000000001[0m
epoch: 6, loss: 34.47021762432996, total_correct: 59099 / 60000, --> [96mAccuracy: 98.49833333333333[0m
epoch: 7, loss: 35.01557728933403, total_correct: 59065 / 60000, --> [96mAccuracy: 98.44166666666668[0m
epoch: 8, loss: 34.733114582893904, total_correct: 59128 / 60000, --> [96mAccuracy: 98.54666666666667[0m
epoch: 9, loss: 34.38552636926761, total_correct: 5919

In [16]:
#### SALVIAMO QUESTA RETE #####

torch.save(network.state_dict(), 'PATHS/mnist_trained.pth')

In [17]:
#### FACCIAMO IL TEST SU MNIST/USPS/FASHION

print('mnist: ', (testing(network, mnist, mnist_loader)))
print('usps: ', (testing(network, usps, usps_loader))) #forward training

mnist:  total correct: 59264 / 60000. [95mAccuracy: 98.77333333333334[0m
usps:  total correct: 3180 / 7291. [95mAccuracy: 43.61541626663009[0m


In [18]:
#### FACCIAMO IL FINE-TUNING SU USPS SENZA MEMORIA ###

training(network, usps, usps_loader, lr=1e-3, num_epochs=10)

epoch: 0, loss: 55.63860809803009, total_correct: 5629 / 7291, --> [96mAccuracy: 77.20477300781786[0m
epoch: 1, loss: 20.49654772132635, total_correct: 6716 / 7291, --> [96mAccuracy: 92.11356466876973[0m
epoch: 2, loss: 14.65783029422164, total_correct: 6882 / 7291, --> [96mAccuracy: 94.39034426004663[0m
epoch: 3, loss: 11.495309740304947, total_correct: 6970 / 7291, --> [96mAccuracy: 95.59731175421753[0m
epoch: 4, loss: 9.48992214165628, total_correct: 7010 / 7291, --> [96mAccuracy: 96.14593334247704[0m
epoch: 5, loss: 7.984057238325477, total_correct: 7052 / 7291, --> [96mAccuracy: 96.7219860101495[0m
epoch: 6, loss: 6.71717047598213, total_correct: 7097 / 7291, --> [96mAccuracy: 97.33918529694144[0m
epoch: 7, loss: 5.713647751137614, total_correct: 7129 / 7291, --> [96mAccuracy: 97.77808256754903[0m
epoch: 8, loss: 4.912942534312606, total_correct: 7153 / 7291, --> [96mAccuracy: 98.10725552050474[0m
epoch: 9, loss: 4.240828321315348, total_correct: 7171 / 7291, -->

In [19]:
### FACCIAMO IL TEST SU MNIST/USPS ###
print('mnist: ', (testing(network, mnist, mnist_loader))) #forgetting
print('usps: ', (testing(network, usps, usps_loader)))

mnist:  total correct: 55870 / 60000. [95mAccuracy: 93.11666666666667[0m
usps:  total correct: 7191 / 7291. [95mAccuracy: 98.62844602935125[0m


In [20]:
### RIPRENDIAMO LA RETE MEMORIZZATA AL SOLO MNIST TRAINED E PROVIAMO A FARE IL TRAINING SU USPS CON MEMORIA DI MNIST ###
### il nostro obiettivo sarà quello di migliorare l'accuratezza del mnist dell' 85% ###
digits = split_dataset(mnist)
network_mnist = Network()
for N in (1,2,5,10,50,100,500,1000,2000,5000):
    network_mnist.load_state_dict(torch.load('PATHS/mnist_trained.pth'))
    example_replay(N, network_mnist, mnist, usps, mnist_loader, usps_loader)

Sto rinfrescando la memoria con 1 elementi da mnist per ogni classe. dirty: 7301
epoch: 0, loss: 55.07654559612274, total_correct: 5689 / 7301, --> [96mAccuracy: 77.92083276263526[0m
epoch: 1, loss: 20.639145158228985, total_correct: 6718 / 7301, --> [96mAccuracy: 92.01479249417888[0m
epoch: 2, loss: 14.579185705620148, total_correct: 6905 / 7301, --> [96mAccuracy: 94.57608546774414[0m
epoch: 3, loss: 11.440479651093483, total_correct: 6973 / 7301, --> [96mAccuracy: 95.50746473085879[0m
epoch: 4, loss: 9.402980236336582, total_correct: 7025 / 7301, --> [96mAccuracy: 96.2196959320641[0m
epoch: 5, loss: 7.881992489099503, total_correct: 7069 / 7301, --> [96mAccuracy: 96.82235310231475[0m
epoch: 6, loss: 6.535078125074506, total_correct: 7105 / 7301, --> [96mAccuracy: 97.31543624161074[0m
epoch: 7, loss: 5.4815196506640405, total_correct: 7138 / 7301, --> [96mAccuracy: 97.76742911929873[0m
epoch: 8, loss: 4.730538428528234, total_correct: 7163 / 7301, --> [96mAccuracy: 98

In [None]:
##### EXAMPLE REPLAY SU 3 TASK ####

In [39]:
network2 = Network()
combination = []
combination.append(mnist)
combination.append(usps)
joint = torch.utils.data.ConcatDataset(combination)
joint_loader = torch.utils.data.DataLoader(joint, batch_size=100, shuffle=True)

In [40]:
#### TRAINING SU MNIST E USPS INSIEME ####
training(network2, joint, joint_loader, lr=0.01, num_epochs=10)
torch.save(network2.state_dict(), 'PATHS/combined_mnist+usps.pth')

epoch: 0, loss: 132.13256027456373, total_correct: 63147 / 67291, --> [96mAccuracy: 93.84167273483824[0m
epoch: 1, loss: 58.16816308256239, total_correct: 65603 / 67291, --> [96mAccuracy: 97.49149217577387[0m
epoch: 2, loss: 55.49535472341813, total_correct: 65774 / 67291, --> [96mAccuracy: 97.74561234043185[0m
epoch: 3, loss: 46.918476193910465, total_correct: 66027 / 67291, --> [96mAccuracy: 98.12159129749892[0m
epoch: 4, loss: 48.134905279031955, total_correct: 65987 / 67291, --> [96mAccuracy: 98.06214798412863[0m
epoch: 5, loss: 46.86761846009176, total_correct: 66041 / 67291, --> [96mAccuracy: 98.14239645717852[0m
epoch: 6, loss: 41.7804994264734, total_correct: 66188 / 67291, --> [96mAccuracy: 98.36085063381434[0m
epoch: 7, loss: 43.846964185882825, total_correct: 66132 / 67291, --> [96mAccuracy: 98.27762999509592[0m
epoch: 8, loss: 47.02036430872977, total_correct: 66113 / 67291, --> [96mAccuracy: 98.24939442124504[0m
epoch: 9, loss: 40.95900429273024, total_co

In [41]:
digits2 = split_dataset(joint)
for N in (1,2,5,10,50,100,500,1000,2000,5000):
    network2.load_state_dict(torch.load('PATHS/mnist_trained.pth'))
    example_replay(N, digits2, network2, joint, svhn, joint_loader, svhn_loader)

Sto rinfrescando la memoria con 1 elementi da mnist per ogni classe. dirty: 73267
epoch: 0, loss: 1326.5553991794586, total_correct: 28482 / 73267, --> [96mAccuracy: 38.87425443924277[0m
epoch: 1, loss: 804.1493139266968, total_correct: 47763 / 73267, --> [96mAccuracy: 65.19033125417991[0m
epoch: 2, loss: 628.4870836734772, total_correct: 53784 / 73267, --> [96mAccuracy: 73.40821925286964[0m
epoch: 3, loss: 545.6923441588879, total_correct: 56426 / 73267, --> [96mAccuracy: 77.01420830660462[0m
epoch: 4, loss: 493.09509029984474, total_correct: 58012 / 73267, --> [96mAccuracy: 79.17889363560676[0m
epoch: 5, loss: 457.678987711668, total_correct: 59053 / 73267, --> [96mAccuracy: 80.5997242960678[0m
epoch: 6, loss: 432.8102174401283, total_correct: 59856 / 73267, --> [96mAccuracy: 81.69571567008339[0m
epoch: 7, loss: 412.7911439239979, total_correct: 60450 / 73267, --> [96mAccuracy: 82.50644901524561[0m
epoch: 8, loss: 396.05594316124916, total_correct: 60986 / 73267, --> 