This notebook demonstrates using the Fisher Information to calculate generalisability and trainability metrics

In [1]:
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import PIL

from torchsummary import summary

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import configparser as ConfigParser


import utils
# Ipmport various network architectures
from networks import AGRadGalNet, VanillaLeNet, testNet, DNSteerableLeNet, DNSteerableAGRadGalNet
# Import various data classes
from datasets import FRDEEPF
from datasets import MiraBest_full, MBFRConfident, MBFRUncertain, MBHybrid
from datasets import MingoLoTSS, MLFR, MLFRTest

Load in the dataset and the relevant configuration

In [2]:

PATH = "configs/"
cfg_base = "C4_attention_mirabest.cfg"
config = ConfigParser.ConfigParser(allow_no_value=True)
config.read(PATH + cfg_base)
device = "cuda"
train_loader, valid_loader  = utils.data.load(
    config, 
    train=True, 
    augmentation='config', 
    data_loader=True
)

Files already downloaded and verified




In [3]:
print(f"Loading in {config['model']['base']}")
net = locals()[config['model']['base']](**config['model']).to(device)

quiet = config.getboolean('DEFAULT', 'quiet')
early_stopping = config.getboolean('training', 'early_stopping')

# Read / Create Folder for Data to be Saved
root = config['data']['directory']
os.makedirs(root, exist_ok=True)

if not quiet:
    if 'DN' not in config['model']['base']:
        summary(net, (1, 150, 150))
    print(device)
    if device == torch.device('cuda'):
        print(torch.cuda.get_device_name(device=device))

Loading in DNSteerableAGRadGalNet


  sampled_basis = sampled_basis[mask, ...]


cuda


Attempt to load in the best saved model

In [4]:
path_supliment = config['data']['augment']+'/'
model = utils.utils.load_model(config, load_model='best', device=device, path_supliment=path_supliment)

Save the Model Weights and then Train the Last Layer to ensure the grad information is retained

In [5]:
Conv1a = model.conv1a.weights
Conv1b = model.conv1b.weights
Conv1c = model.conv1c.weights
Conv2a = model.conv2a.weights
Conv2b = model.conv2b.weights
Conv2c = model.conv2c.weights
Conv3a = model.conv3a.weights
Conv3b = model.conv3b.weights
Conv3c = model.conv3c.weights
Conv4a = model.conv4a.weights
Conv4b = model.conv4b.weights
Psi1 = model.attention1.psi.weight
Psi2 = model.attention2.psi.weight
Psi3 = model.attention3.psi.weight
Theta1 = model.attention1.theta.weight
Theta2 = model.attention2.theta.weight
Theta3 = model.attention3.theta.weight
Phi1 = model.attention1.phi.weight
Phi2 = model.attention2.phi.weight
Phi3 = model.attention3.phi.weight

In [7]:
net.conv1a.weights=Conv1a 
net.conv1b.weights=Conv1b 
net.conv1c.weights=Conv1c 
net.conv2a.weights=Conv2a 
net.conv2b.weights=Conv2b 
net.conv2c.weights=Conv2c 
net.conv3a.weights=Conv3a 
net.conv3b.weights=Conv3b 
net.conv3c.weights=Conv3c 
net.conv4a.weights=Conv4a 
net.conv4b.weights=Conv4b 
net.attention1.psi.weight=Psi1  
net.attention2.psi.weight=Psi2  
net.attention3.psi.weight=Psi3 
net.attention1.theta.weight=Theta1 
net.attention2.theta.weight=Theta2 
net.attention3.theta.weight=Theta3 
net.attention1.phi.weight=Phi1 
net.attention2.phi.weight=Phi2 
net.attention3.phi.weight=Phi3
net.conv1a.weights.requires_grad=False
net.conv1b.weights.requires_grad=False
net.conv1c.weights.requires_grad=False
net.conv2a.weights.requires_grad=False
net.conv2b.weights.requires_grad=False
net.conv2c.weights.requires_grad=False
net.conv3a.weights.requires_grad=False
net.conv3b.weights.requires_grad=False
net.conv3c.weights.requires_grad=False
net.conv4a.weights.requires_grad=False
net.conv4b.weights.requires_grad=False
net.attention1.psi.weight.requires_grad=False
net.attention2.psi.weight.requires_grad=False
net.attention3.psi.weight.requires_grad=False
net.attention1.theta.weight.requires_grad=False
net.attention2.theta.weight.requires_grad=False
net.attention3.theta.weight.requires_grad=False
net.attention1.phi.weight.requires_grad=False
net.attention2.phi.weight.requires_grad=False
net.attention3.phi.weight.requires_grad=False
del(model)

Test a Simplified Version of the Training Cycle

In [8]:
def train(net, 
          device, 
          config,
          Epoch,
          train_loader,
          valid_loader,
          optimizer,
          root_out_directory_addition='',
          scheduler = None,
          save_validation_updates=True,
          class_splitting_index=1,
          loss_function = nn.CrossEntropyLoss(),
          output_model=True,
          early_stopping=True,
          output_best_validation=False,
          stop_after_epochs_without_update=2000
         ):
    """Very Simple version of the training loop used in the train.py file to try and find approximate
    gradients of the Classifier Layer
    """
    # -----------------------------------------------------------------------------
    # Initialise Seeds
    torch.manual_seed(42)
    np.random.seed(42)
    # -----------------------------------------------------------------------------
    # Training Loop
    validation_loss_min = np.Inf
    for epoch_count in range(Epoch):
    
        # Model Training
        train_loss = 0.
        validation_loss = 0.
        confussion_matrix = np.zeros((2,2))
        net.train() #Set network to train mode.
        if 'binary_labels' in locals():
            del binary_labels
        if 'outputs' in locals():
            del outputs

        # Loop across data augmentations
        for batch_idx , (data, labels) in enumerate(train_loader): #Iterates through each batch.
            data = data.to(device)
            labels = labels.to(device)

                # Create binary labels to remove morphological subclassifications (for MiraBest) ### IS THIS STILL NECESSARY?
            binary_labels = np.zeros(labels.size(), dtype=int)
            binary_labels = np.where(labels.cpu().numpy()<class_splitting_index, binary_labels, binary_labels+1)
            binary_labels = torch.from_numpy(binary_labels).to(device)
                
                # Loss & backpropagation
            pred = net.forward(data)
            optimizer.zero_grad()
            loss = loss_function(pred,binary_labels)
            loss.backward(retain_graph=True)
            if scheduler == None:
                optimizer.step()
            train_loss += (loss.item()*data.size(0))
            if scheduler != None:
                scheduler.step(train_loss)
        print(train_loss/(len(train_loader.dataset)))

In [9]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-5)

In [10]:
train(net, device, config, 10, train_loader, valid_loader, optimizer)



0.5352245236144346
0.5429738830117619
0.5884865347076865
0.5343444137012258
0.5093576487372903
0.5588863211519578
0.5623948048142826
0.5237291630576638
0.5422553209697499
0.5166515392415664
0.5582884795525495
0.5428332090377808
0.5369892085299772
0.5197013301007888
0.5242063368068022
0.5304787264150732


KeyboardInterrupt: 

In [11]:
print(net.classifier.weight.grad)

tensor([[-9.1327e-04,  1.4837e-32, -4.2458e-04, -5.1269e-04, -5.6729e-03,
          0.0000e+00],
        [ 9.1319e-04,  0.0000e+00,  4.2454e-04,  5.1264e-04,  5.6725e-03,
          0.0000e+00]], device='cuda:0')


Now Fisher and Jacobian Imports