# Evaluation of a Model

In [1]:
import os

import e2cnn

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import PIL

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import configparser as ConfigParser

import utils
# Ipmport various network architectures
from networks import AGRadGalNet, DNSteerableLeNet, DNSteerableAGRadGalNet #e2cnn module only works in python3.7+
# Import various data classes
from datasets import FRDEEPF
from datasets import MiraBest_full, MBFRConfident, MBFRUncertain, MBHybrid
from datasets import MingoLoTSS, MLFR, MLFRTest

from sklearn.metrics import classification_report, roc_curve, auc

# Set seeds for reproduceability
torch.manual_seed(42)
np.random.seed(42)

# Get correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""# Read in config file
config_name = ["bowles2021mirabest.cfg", 
               "bowles2021mingo.cfg", 
               "scaife2021mirabestDN.cfg", 
               "scaife2021mingo.cfg", 
               "e2attentionmirabest.cfg"
              ]
config_name = "configs/"+config_name[-1]
config = ConfigParser.ConfigParser(allow_no_value=True)
config.read(config_name)

# Load network architecture (with random weights)
print(f"Loading in {config['model']['base']}")
net = locals()[config['model']['base']](**config['model']).to(device)
"""

'# Read in config file\nconfig_name = ["bowles2021mirabest.cfg", \n               "bowles2021mingo.cfg", \n               "scaife2021mirabestDN.cfg", \n               "scaife2021mingo.cfg", \n               "e2attentionmirabest.cfg"\n              ]\nconfig_name = "configs/"+config_name[-1]\nconfig = ConfigParser.ConfigParser(allow_no_value=True)\nconfig.read(config_name)\n\n# Load network architecture (with random weights)\nprint(f"Loading in {config[\'model\'][\'base\']}")\nnet = locals()[config[\'model\'][\'base\']](**config[\'model\']).to(device)\n'

In [2]:
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

configs = [
    #"scaife2021mirabest.cfg", # Fully Evaluated
    #"scaife2021mirabest-RandAug.cfg", # Fully Evaluated
    #"scaife2021mirabest-RestrictedAug.cfg", # Fully Evaluated
    #"scaife2021mingo.cfg", # Fully Evaluated
    #"scaife2021mingo-RandAug.cfg", # Fully Evaluated
    #"scaife2021mingo-RestrictedAug.cfg", # Fully Evaluated
    
    #"bowles2021mirabest.cfg", # Fully Evaluated
    #"bowles2021mirabest-RandAug.cfg", # Fully Evaluated
    #"bowles2021mirabest-RestrictedAug.cfg", # Fully Evaluated
    #"bowles2021mingo.cfg", # Fully Evaluated
    #"bowles2021mingo-RandAug.cfg", # Fully Evaluated
    #"bowles2021mingo-RestrictedAug.cfg", # Fully Evaluated
    
    #"e2attentionmirabest.cfg", # Fully Evaluated
    #"e2attentionmirabest-RandAug.cfg", # Fully Evaluated
    #"e2attentionmirabest-RestrictedAug.cfg", # Fully Evaluated
    #"e2attentionmingo.cfg", # Fully Evaluated
    #"e2attentionmingo-RandAug.cfg", # Fully Evaluated
    #"e2attentionmingo-RestrictedAug.cfg", # Fully Evaluated
    
    #"5kernel_e2attentionmirabest.cfg",
    #"5kernel_e2attentionmirabest-RandAug.cfg", # Fully Evaluated
    #"5kernel_e2attentionmirabest-RestrictedAug.cfg",
    #"5kernel_e2attentionmingo.cfg",
    #"5kernel_e2attentionmingo-RandAug.cfg", # Fully Evaluated
    #"5kernel_e2attentionmingo-RestrictedAug.cfg",
    
    #"7kernel_e2attentionmirabest.cfg",
    #"7kernel_e2attentionmirabest-RandAug.cfg", # Fully Evaluated
    #"7kernel_e2attentionmirabest-RestrictedAug.cfg",
    #"7kernel_e2attentionmingo.cfg",
    #"7kernel_e2attentionmingo-RandAug.cfg", # Fully Evaluated
    #"7kernel_e2attentionmingo-RestrictedAug.cfg",
    
    "C4_attention_mirabest.cfg", # Fully Evaluted to: '/raid/scratch/mbowles/EquivariantSelfAttention/models/e2attention/mirabest/fisher/'
    "C8_attention_mirabest.cfg",
    "C16_attention_mirabest.cfg",
    "D4_attention_mirabest.cfg", # Fully Evaluted to: '/raid/scratch/mbowles/EquivariantSelfAttention/models/e2attention/mirabest/fisher/'
    "D8_attention_mirabest.cfg",
    #"D16_attention_mirabest.cfg",
]

data_configs = [
    "e2attentionmirabest.cfg", # Mirabest Dataset - MBFR
    "e2attentionmingo.cfg" # Mingo Dataset - MLFR
]
augmentations = [
    #"rotation and flipping",
    "random rotation",
    #"restricted random rotation"
]

for cfg in configs:
    print(cfg)
    config = ConfigParser.ConfigParser(allow_no_value=True)
    data_config = ConfigParser.ConfigParser(allow_no_value=True)
    config.read('configs/'+cfg)
    csv_path = config['output']['directory'] +'/'+ config['data']['augment'] +'/'+ config['output']['training_evaluation']
    df = pd.read_csv(csv_path)
    best = df.iloc[list(df['validation_update'])].iloc[-1]
    
    # Extract models kernel size
    if config.has_option('model', 'kernel_size'):
        kernel_size = config.getint('model', 'kernel_size')
    elif "LeNet" in config['model']['base']:
        kernel_size = 5
    else:
        kernel_size = 3
    
    net = locals()[config['model']['base']](**config['model']).to(device)
    
    
    for d_cfg in data_configs:
        for augmentation in augmentations:
            path_supliment = config['data']['augment']+'/'
            model = utils.utils.load_model(config, load_model='best', device=device, path_supliment=path_supliment)
            data_config.read('configs/'+d_cfg)
            print(f"Evaluating {cfg}: {config['output']['directory']}/{config['data']['augment']}\t{data_config['data']['dataset']}\t{augmentation}")
            data  = utils.data.load(
                data_config,
                train=False,
                augmentation=augmentation, 
                data_loader=True
            )
            
            y_pred, y_labels = utils.evaluation.predict(
                model, 
                data, 
                augmentation_loops=100, 
                raw_predictions=True
            )
            
            utils.evaluation.save_evaluation(
                y_pred, 
                y_labels,
                model_name=config['model']['base'],
                kernel_size=kernel_size,
                train_data=config['data']['dataset'],
                train_augmentation=config['data']['augment'],
                test_data=data_config['data']['dataset'],
                test_augmentation=augmentation,
                epoch=int(best.name),
                best=True,
                raw=False,
                PATH='/raid/scratch/mbowles/EquivariantSelfAttention/models/e2attention/mirabest/fisher/'
            )

C4_attention_mirabest.cfg


  sampled_basis = sampled_basis[mask, ...]


Evaluating C4_attention_mirabest.cfg: models/e2attention/mirabest/fisher/C4/random rotation	MBFRUncertain	random rotation
Files already downloaded and verified




Evaluating C4_attention_mirabest.cfg: models/e2attention/mirabest/fisher/C4/random rotation	MLFR	random rotation




Files already downloaded and verified




D8_attention_mirabest.cfg


  sampled_basis = sampled_basis[mask, ...]


RuntimeError: Error(s) in loading state_dict for DNSteerableAGRadGalNet:
	Missing key(s) in state_dict: "conv1a._basisexpansion.block_expansion_('irrep_0,0', 'regular').sampled_basis", "bnorm1a.indices_16", "bnorm1a.batch_norm_[16].weight", "bnorm1a.batch_norm_[16].bias", "bnorm1a.batch_norm_[16].running_mean", "bnorm1a.batch_norm_[16].running_var", "conv1b._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm1b.indices_16", "bnorm1b.batch_norm_[16].weight", "bnorm1b.batch_norm_[16].bias", "bnorm1b.batch_norm_[16].running_mean", "bnorm1b.batch_norm_[16].running_var", "conv1c._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm1c.indices_16", "bnorm1c.batch_norm_[16].weight", "bnorm1c.batch_norm_[16].bias", "bnorm1c.batch_norm_[16].running_mean", "bnorm1c.batch_norm_[16].running_var", "gpool1.in_indices_16", "gpool1.out_indices_16", "conv2a._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm2a.indices_16", "bnorm2a.batch_norm_[16].weight", "bnorm2a.batch_norm_[16].bias", "bnorm2a.batch_norm_[16].running_mean", "bnorm2a.batch_norm_[16].running_var", "conv2b._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm2b.indices_16", "bnorm2b.batch_norm_[16].weight", "bnorm2b.batch_norm_[16].bias", "bnorm2b.batch_norm_[16].running_mean", "bnorm2b.batch_norm_[16].running_var", "conv2c._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm2c.indices_16", "bnorm2c.batch_norm_[16].weight", "bnorm2c.batch_norm_[16].bias", "bnorm2c.batch_norm_[16].running_mean", "bnorm2c.batch_norm_[16].running_var", "gpool2.in_indices_16", "gpool2.out_indices_16", "conv3a._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm3a.indices_16", "bnorm3a.batch_norm_[16].weight", "bnorm3a.batch_norm_[16].bias", "bnorm3a.batch_norm_[16].running_mean", "bnorm3a.batch_norm_[16].running_var", "conv3b._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm3b.indices_16", "bnorm3b.batch_norm_[16].weight", "bnorm3b.batch_norm_[16].bias", "bnorm3b.batch_norm_[16].running_mean", "bnorm3b.batch_norm_[16].running_var", "conv3c._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm3c.indices_16", "bnorm3c.batch_norm_[16].weight", "bnorm3c.batch_norm_[16].bias", "bnorm3c.batch_norm_[16].running_mean", "bnorm3c.batch_norm_[16].running_var", "gpool3.in_indices_16", "gpool3.out_indices_16", "conv4a._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm4a.indices_16", "bnorm4a.batch_norm_[16].weight", "bnorm4a.batch_norm_[16].bias", "bnorm4a.batch_norm_[16].running_mean", "bnorm4a.batch_norm_[16].running_var", "conv4b._basisexpansion.block_expansion_('regular', 'regular').sampled_basis", "bnorm4b.indices_16", "bnorm4b.batch_norm_[16].weight", "bnorm4b.batch_norm_[16].bias", "bnorm4b.batch_norm_[16].running_mean", "bnorm4b.batch_norm_[16].running_var", "gpool4.in_indices_16", "gpool4.out_indices_16". 
	Unexpected key(s) in state_dict: "conv1a._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm1a.indices_1", "bnorm1a.batch_norm_[1].weight", "bnorm1a.batch_norm_[1].bias", "bnorm1a.batch_norm_[1].running_mean", "bnorm1a.batch_norm_[1].running_var", "bnorm1a.batch_norm_[1].num_batches_tracked", "conv1b._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm1b.indices_1", "bnorm1b.batch_norm_[1].weight", "bnorm1b.batch_norm_[1].bias", "bnorm1b.batch_norm_[1].running_mean", "bnorm1b.batch_norm_[1].running_var", "bnorm1b.batch_norm_[1].num_batches_tracked", "conv1c._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm1c.indices_1", "bnorm1c.batch_norm_[1].weight", "bnorm1c.batch_norm_[1].bias", "bnorm1c.batch_norm_[1].running_mean", "bnorm1c.batch_norm_[1].running_var", "bnorm1c.batch_norm_[1].num_batches_tracked", "conv2a._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm2a.indices_1", "bnorm2a.batch_norm_[1].weight", "bnorm2a.batch_norm_[1].bias", "bnorm2a.batch_norm_[1].running_mean", "bnorm2a.batch_norm_[1].running_var", "bnorm2a.batch_norm_[1].num_batches_tracked", "conv2b._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm2b.indices_1", "bnorm2b.batch_norm_[1].weight", "bnorm2b.batch_norm_[1].bias", "bnorm2b.batch_norm_[1].running_mean", "bnorm2b.batch_norm_[1].running_var", "bnorm2b.batch_norm_[1].num_batches_tracked", "conv2c._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm2c.indices_1", "bnorm2c.batch_norm_[1].weight", "bnorm2c.batch_norm_[1].bias", "bnorm2c.batch_norm_[1].running_mean", "bnorm2c.batch_norm_[1].running_var", "bnorm2c.batch_norm_[1].num_batches_tracked", "conv3a._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm3a.indices_1", "bnorm3a.batch_norm_[1].weight", "bnorm3a.batch_norm_[1].bias", "bnorm3a.batch_norm_[1].running_mean", "bnorm3a.batch_norm_[1].running_var", "bnorm3a.batch_norm_[1].num_batches_tracked", "conv3b._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm3b.indices_1", "bnorm3b.batch_norm_[1].weight", "bnorm3b.batch_norm_[1].bias", "bnorm3b.batch_norm_[1].running_mean", "bnorm3b.batch_norm_[1].running_var", "bnorm3b.batch_norm_[1].num_batches_tracked", "conv3c._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm3c.indices_1", "bnorm3c.batch_norm_[1].weight", "bnorm3c.batch_norm_[1].bias", "bnorm3c.batch_norm_[1].running_mean", "bnorm3c.batch_norm_[1].running_var", "bnorm3c.batch_norm_[1].num_batches_tracked", "conv4a._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm4a.indices_1", "bnorm4a.batch_norm_[1].weight", "bnorm4a.batch_norm_[1].bias", "bnorm4a.batch_norm_[1].running_mean", "bnorm4a.batch_norm_[1].running_var", "bnorm4a.batch_norm_[1].num_batches_tracked", "conv4b._basisexpansion.block_expansion_('irrep_0,0', 'irrep_0,0').sampled_basis", "bnorm4b.indices_1", "bnorm4b.batch_norm_[1].weight", "bnorm4b.batch_norm_[1].bias", "bnorm4b.batch_norm_[1].running_mean", "bnorm4b.batch_norm_[1].running_var", "bnorm4b.batch_norm_[1].num_batches_tracked". 
	size mismatch for conv1a.weights: copying a param with shape torch.Size([18]) from checkpoint, the shape in current model is torch.Size([66]).
	size mismatch for conv1a.filter: copying a param with shape torch.Size([6, 1, 5, 5]) from checkpoint, the shape in current model is torch.Size([96, 1, 5, 5]).
	size mismatch for conv1b.weights: copying a param with shape torch.Size([108]) from checkpoint, the shape in current model is torch.Size([6336]).
	size mismatch for conv1b.filter: copying a param with shape torch.Size([6, 6, 5, 5]) from checkpoint, the shape in current model is torch.Size([96, 96, 5, 5]).
	size mismatch for conv1c.weights: copying a param with shape torch.Size([108]) from checkpoint, the shape in current model is torch.Size([6336]).
	size mismatch for conv1c.filter: copying a param with shape torch.Size([6, 6, 5, 5]) from checkpoint, the shape in current model is torch.Size([96, 96, 5, 5]).
	size mismatch for conv2a.weights: copying a param with shape torch.Size([288]) from checkpoint, the shape in current model is torch.Size([16896]).
	size mismatch for conv2a.filter: copying a param with shape torch.Size([16, 6, 5, 5]) from checkpoint, the shape in current model is torch.Size([256, 96, 5, 5]).
	size mismatch for conv2b.weights: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([45056]).
	size mismatch for conv2b.filter: copying a param with shape torch.Size([16, 16, 5, 5]) from checkpoint, the shape in current model is torch.Size([256, 256, 5, 5]).
	size mismatch for conv2c.weights: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([45056]).
	size mismatch for conv2c.filter: copying a param with shape torch.Size([16, 16, 5, 5]) from checkpoint, the shape in current model is torch.Size([256, 256, 5, 5]).
	size mismatch for conv3a.weights: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([90112]).
	size mismatch for conv3a.filter: copying a param with shape torch.Size([32, 16, 5, 5]) from checkpoint, the shape in current model is torch.Size([512, 256, 5, 5]).
	size mismatch for conv3b.weights: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([180224]).
	size mismatch for conv3b.filter: copying a param with shape torch.Size([32, 32, 5, 5]) from checkpoint, the shape in current model is torch.Size([512, 512, 5, 5]).
	size mismatch for conv3c.weights: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([180224]).
	size mismatch for conv3c.filter: copying a param with shape torch.Size([32, 32, 5, 5]) from checkpoint, the shape in current model is torch.Size([512, 512, 5, 5]).
	size mismatch for conv4a.weights: copying a param with shape torch.Size([6144]) from checkpoint, the shape in current model is torch.Size([360448]).
	size mismatch for conv4a.filter: copying a param with shape torch.Size([64, 32, 5, 5]) from checkpoint, the shape in current model is torch.Size([1024, 512, 5, 5]).
	size mismatch for conv4b.weights: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([720896]).
	size mismatch for conv4b.filter: copying a param with shape torch.Size([64, 64, 5, 5]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 5, 5]).