In [None]:
import torch
import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [None]:
import torchvision.datasets
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5,), (0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x : x + torch.normal(x) * 0.6)
    #transforms.Normalize((0.5,), (0.5))
])

batch_size = 100

train_dataset = torchvision.datasets.MNIST('./', download=True, train=True, transform=train_transform)
test_dataset = torchvision.datasets.MNIST('./', download=True, train=False, transform=test_transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import torchvision

    
for i, data in enumerate(test_dataloader):
    x, y = data
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(torchvision.utils.make_grid(x, 10).numpy(), (1, 2, 0)))
    plt.show()
    break 

In [None]:
class LeNet5(torch.nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=5, padding=2)
        self.act1  = torch.nn.Tanh()
        self.pool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
       
        self.conv2 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=5, padding=0)
        self.act2  = torch.nn.Tanh()
        self.pool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.fc1   = torch.nn.Linear(5 * 5 * 16, 120)
        self.act3  = torch.nn.Tanh()
        
        self.fc2   = torch.nn.Linear(120, 84)
        self.act4  = torch.nn.Tanh()
        
        self.fc3   = torch.nn.Linear(84, 10)
    
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.act1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.act2(x)
        x = self.pool2(x)
        
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))

        x = self.fc1(x)
        x = self.act3(x)
        x = self.fc2(x)
        x = self.act4(x)
        x = self.fc3(x)
        
        return x
    
model = LeNet5()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
model = model.to(device)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3)

In [None]:
import time
import copy


def train_model(model, loss, optimizer, scheduler=None, epoch_number=10):
    history = {'Test' : {'Loss' : [], 'Acc' : []}, 'Train' : {'Loss' : [], 'Acc' : []}}
    best_acc = 0.
    best_epoch = 0
    best_model_state = copy.deepcopy(model.state_dict())
    
    since = time.time()

    for epoch in range(epoch_number):
        print('Epoch  {}  of  {} :'.format(epoch + 1, epoch_number))
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                dataloader = train_dataloader
                model.train()  # Set model to training mode
            else:
                dataloader = test_dataloader
                model.eval()   # Set model to evaluate mode

            running_loss = 0.
            running_acc = 0.

            # Iterate over data
            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # forward and backward
                with torch.set_grad_enabled(phase == 'train'):
                    preds = model(inputs)
                    loss_value = loss(preds, labels)
                    preds_class = preds.argmax(dim=1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss_value.backward()
                        optimizer.step()
                        if scheduler != None:
                            scheduler.step()
                
                # statistics
                running_loss += loss_value.item()
                running_acc += (preds_class == labels.data).float().mean()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)
            
            if phase == 'train':
                train_loss = epoch_loss
                train_acc = epoch_acc
            else:
                test_loss = epoch_loss
                test_acc = epoch_acc
                if test_acc > best_acc:
                    best_acc = test_acc
                    best_epoch = epoch + 1
                    best_model_state = copy.deepcopy(model.state_dict())
                    

        print('Train:   Loss:  {:4f}   Acc:  {:4f}'.format(train_loss, train_acc))
        print('Test:    Loss:  {:4f}   Acc:  {:4f}'.format(test_loss, test_acc))
        print()
        print('-' * 25)
        print()

        history['Test']['Loss'].append(test_loss)
        history['Test']['Acc'].append(test_acc)
        history['Train']['Loss'].append(train_loss)
        history['Train']['Acc'].append(train_acc)
    
    time_elapsed = time.time() - since
    print('\n')
    print('Training time:  {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc:   {:4f}'.format(best_acc))
    print('Best epoch:     {:d}'.format(best_epoch))
    
    return model, best_model_state, history

In [None]:
model, best_model_state, history = train_model(model, loss, optimizer, scheduler=None, epoch_number=5)

In [None]:
import matplotlib.pyplot as plt

plt.subplot(235)

plt.xlabel('Epoch number')
plt.ylabel('Loss')
plt.title('Loss')
plt.plot(history['Test']['Loss'], label = 'test_loss')


plt.subplot(236)

plt.xlabel('Epoch number')
plt.ylabel('Accuracy')
plt.title('Accuracy')
plt.plot(history['Test']['Acc'], label = 'test_accuracy')

plt.show()