In [1]:
from torchinfo import summary
from model_utils import *
import numpy as np
from dataset import *
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
import networkx as nx
#import torch_geometric
import torch
import torch.nn as nn
from model import *

from torch.utils.data import DataLoader, SubsetRandomSampler

import pickle as pkl

In [2]:
# Dataset
train_set = CustomizableMNIST(root='./data', train=True, download=True)

val_set_ratio = 0.2
shuffle=True
batch_size = 32

train_loader, valid_loader = split_and_shuffle_data(train_set, val_set_ratio, batch_size)

Initializing CustomizableMNIST...
Training set
Init done.



In [14]:
nb_nodes = 784
nb_features = 1
out_features = 1

GNN_test = simpliest_GNN(nb_nodes, nb_features, 10)

In [15]:
im_test, label = train_set[0]

# adjacency matrix (1 is fine for now)
im_test_np = train_set.get_item_numpy(0)
adj_mat = compute_adj_mat(im_test_np)
norm_adj_mat = norm_adjacency(adj_mat)

In [12]:
class Trainer:

    def __init__(self, device):
        print("Starting trainer init...")
        self.device = device
        print("Trainer init done.\n")


    def train_step(self, model, adj_mat,
                   dataloader, 
                   optimizer, criterion):
        
        running_loss = 0.0
        model.train()

        for images, targets in dataloader:       
            images, targets = images.to(self.device), targets.to(self.device)
            optimizer.zero_grad()

            #Predictions 
            targets_hat = model(images, adj_mat)
            #Loss + backprop
            loss = criterion(targets_hat, targets)
            loss.backward()
            optimizer.step()

            # item() method detach automatically from the graph
            running_loss += loss.item()
        return running_loss / len(dataloader)
    

    def valid_step(self, model, adj_mat, valid_loader, criterion):
            
        model.eval()
        running_loss     = 0.0
        running_accuracy = 0.0
        n_predictions = 0
        
        model.eval()
        with torch.no_grad():
            for images, targets in valid_loader:
                images, targets = images.to(self.device), targets.to(self.device)
                
                # Inferences
                targets_hat = model(images, adj_mat)
                
                # loss
                loss = criterion(targets_hat, targets)
                running_loss += loss.item()
                
                # accuracy
                # (winner takes all)
                _, targets_hat    = torch.max(targets_hat.data, 1)
                running_accuracy += (targets_hat == targets).sum().item()
                n_predictions    += targets_hat.size(0)
        
        return running_loss / len(valid_loader), 100 * (running_accuracy / n_predictions)


    def train(self, n_epochs, adj_mat,
              model, optimizer, criterion,
              train_dataloader, valid_dataloader,
              file_path_save_trained_model, file_path_save_best_acc_model, results_file_path,
              train_loss_name, valid_loss_name, accuracy_name,
              best_accuracy_is_max=True):
        """
            Training main entry point.
        """
        print("Starting training...\n")

        # sending to device
        if next(model.parameters()).device != self.device:
            model = model.to(self.device)
        print("The model will be running on", next(model.parameters()).device, "device.\n")

        results = []
        best_accuracy = 0.0 

        for epoch in range(1, n_epochs + 1):

            epoch_accuracy    = 0.0
            train_epoch_loss  = 0.0
            valid_epoch_loss  = 0.0

            train_epoch_loss                 = self.train_step(model, adj_mat, train_dataloader, optimizer, criterion)
            valid_epoch_loss, epoch_accuracy = self.valid_step(model, adj_mat, valid_dataloader, criterion)
            
            print(f'Epoch: {epoch}/{n_epochs}, {train_loss_name}: {train_epoch_loss:.4f}, {valid_loss_name}: {valid_epoch_loss:.4f}, {accuracy_name}: {epoch_accuracy:.2f}%')
            
            # saving reached best model
            if best_accuracy_is_max:
                if epoch_accuracy > best_accuracy:
                    #save_checkpoint(model, optimizer, epoch, file_path_save_best_acc_model)
                    best_accuracy = epoch_accuracy
            else:
                if epoch_accuracy < best_accuracy:
                    #save_checkpoint(model, optimizer, epoch, file_path_save_best_acc_model)
                    best_accuracy = epoch_accuracy

            results.append((train_epoch_loss, valid_epoch_loss, epoch_accuracy))  

        # Saving the model
        #print('Saving the model...\n')
        #model = model.to('cpu')
        #save_model(model, file_path_save_trained_model)

        # Saving the performances
        #with open(results_file_path, 'wb') as f:
        #    pkl.dump(results, f) 

        print("Training finish.\n") 

        return model, optimizer
    


In [13]:
# Model
n_classes = 10
#print(summary(GNN_test))

# Training
trainer = Trainer("cpu")

n_epochs  = 15
lr        = 1e-4
optimizer = torch.optim.Adam(GNN_test.parameters(), lr)
criterion = torch.nn.CrossEntropyLoss()

file_path_save_trained_model  = "./savings/models/simpliest_GNN_test"
file_path_save_best_acc_model = "./savings/models/simpliest_GNN_test_accuracy.pt"
results_file_path             = "./savings/results/results_simpliest_GNN_test.pkl"
train_loss_name               = "Cross Entropy Loss"
valid_loss_name               = "Cross Entropy Loss"
accuracy_name                 = "Accuracy"

model, optimizer = trainer.train(n_epochs, adj_mat,
                                 GNN_test, optimizer, criterion,
                                 train_loader, valid_loader, 
                                 file_path_save_trained_model, file_path_save_best_acc_model, results_file_path,
                                 train_loss_name, valid_loss_name, accuracy_name, 
                                 best_accuracy_is_max=True)

Starting trainer init...
Trainer init done.

Starting training...

The model will be running on cpu device.

batched_tens_adj_mat torch.Size([32, 784, 784])
image torch.Size([32, 784, 1])
AX torch.Size([32, 784, 1])
Y torch.Size([32, 784, 1])
Y torch.Size([32, 784])
Y torch.Size([32, 784])
output torch.Size([32, 10])
prob torch.Size([32, 10])
batched_tens_adj_mat torch.Size([32, 784, 784])
image torch.Size([32, 784, 1])
AX torch.Size([32, 784, 1])
Y torch.Size([32, 784, 1])
Y torch.Size([32, 784])
Y torch.Size([32, 784])
output torch.Size([32, 10])
prob torch.Size([32, 10])
batched_tens_adj_mat torch.Size([32, 784, 784])
image torch.Size([32, 784, 1])
AX torch.Size([32, 784, 1])
Y torch.Size([32, 784, 1])
Y torch.Size([32, 784])
Y torch.Size([32, 784])
output torch.Size([32, 10])
prob torch.Size([32, 10])
batched_tens_adj_mat torch.Size([32, 784, 784])
image torch.Size([32, 784, 1])
AX torch.Size([32, 784, 1])
Y torch.Size([32, 784, 1])
Y torch.Size([32, 784])
Y torch.Size([32, 784])
o

KeyboardInterrupt: 