In [27]:
import torch
import torch.nn as nn
from models import LeNet
from models_structured import PrunedLeNet, PrunedCancerNet_fc
import numpy as np
import torch 
import torchvision.transforms as transforms
from torchvision import transforms, datasets
from models import CancerNet_fc, LeNet, ResNet
from datasets_custom import CancerDataset, RotatedMNIST
from torch.nn.utils import prune
import torchvision.transforms as transforms
from torch.functional import F
from laplace import Laplace 
from laplace import KronLaplace, DiagLaplace
import torch
import time 
from utils import evaluate_classification
from marglikopt import marglik_optimization
import os
import matplotlib.pyplot as plt
import wandb
from utils import evaluate_classification
from fvcore.nn import FlopCountAnalysis
from tueplots import bundles



In [28]:
def get_probs(model, data_loader, device):
    all_logits = torch.tensor([]).to(device)
    all_targets = torch.tensor([], dtype=torch.long).to(device)

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            logits = model(inputs)
            all_logits = torch.cat((all_logits, logits), dim=0)
            all_targets = torch.cat((all_targets, labels), dim=0)
            
    probs = F.softmax(all_logits, dim=1)
    return probs, all_targets



def brier_score_ours(predictions, targets, num_classes):
    
    targets_one_hot = F.one_hot(targets, num_classes).to(dtype=predictions.dtype)
    

    squared_differences = (predictions - targets_one_hot) ** 2
    

    mean_squared_difference = squared_differences.mean()
    
    return mean_squared_difference


def ece_score(probs, labels, n_bins=10):
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = torch.zeros(1, device=probs.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculate ECE for each class
        for i in range(probs.shape[1]):
            in_bin = (probs[:, i] >= bin_lower) & (probs[:, i] < bin_upper)
            prop_in_bin = in_bin.float().mean()

            if prop_in_bin.item() > 0:
                avg_prob_in_bin = probs[in_bin, i].mean()
                avg_acc_in_bin = (labels[in_bin] == i).float().mean()

                ece += torch.abs(avg_acc_in_bin - avg_prob_in_bin) * prop_in_bin

    return (ece / probs.shape[1]).item() 


In [29]:
train_dataset = CancerDataset(train=True)
valid_dataset = CancerDataset(train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False)


In [30]:
EXPORTED_DIR = "/xxxxxxxx/pattern/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_struct/removed_structure_new"

In [31]:
model_data = []

In [32]:
baseline_model = CancerNet_fc(30, 100, 2)
baseline_model.load_state_dict(torch.load("/xxxxxxxx/pattern/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_struct/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_baseline_acc0.9649122953414917_marg_0.6714989542961121.pt"))
baseline_model.to("cuda")
baseline_model.eval()
flops = FlopCountAnalysis(baseline_model, torch.randn(1, 30).to("cuda"))

probs, targets = get_probs(baseline_model, valid_loader, "cuda")

brier = brier_score_ours(probs, targets, probs.shape[1])
print(brier)

ece = ece_score(probs, targets)
baseline_size = os.path.getsize("/xxxxxxxx/pattern/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_struct/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_baseline_acc0.9649122953414917_marg_0.6714989542961121.pt")

model_data.append({"sparsity":0, "accuracy": 96.49122953414917,  "flops": flops.total(),"marglik": 0.6714989542961121,"ece":ece, "brier":brier.item(), "size":baseline_size})


tensor(0.0316, device='cuda:0')


In [33]:



for model_name in os.listdir(EXPORTED_DIR):
    if model_name.endswith('.pt'):
        print(model_name)
        sp = int(model_name.split("_")[4])
        acc = float(model_name.split("_")[6])
        marg = float(model_name.split("_")[8].strip(".pt"))
        
        model_meta = torch.load(os.path.join(EXPORTED_DIR, model_name))["model_meta"]

        # Create and load model (assuming PrunedCancerNet_fc is defined elsewhere)
        new_model = PrunedCancerNet_fc(
            input_size=model_meta["input_size"],
            output_size=model_meta["output_size"],
            hidden_size1=model_meta["hidden_size1"],
            hidden_size2=model_meta["hidden_size2"],
        )
        new_model.load_state_dict(torch.load(os.path.join(EXPORTED_DIR, model_name))["model_state_dict"])
        new_model.to("cuda")
        new_model.eval()
        probs, targets = get_probs(new_model, valid_loader, "cuda")
        brier = brier_score_ours(probs, targets, probs.shape[1]).item()
        ece = ece_score(probs, targets,10)
        disk_size = os.path.getsize(os.path.join(EXPORTED_DIR, model_name))
        
        print(brier)
        print(ece)
        
        # Calculate FLOPs
        flops = FlopCountAnalysis(new_model, torch.randn(1, 30).to("cuda")).total()

        # Collect data
        model_data.append({'sparsity': sp, 'accuracy': acc, 'flops': flops, "marglik": marg , "ece":ece, "brier":brier,"size":disk_size})



CancerNet_fc_breast_cancerreducued_40_acc_94.73684430122375_marg_0.42974528670310974.pt
0.048851966857910156
0.1089942455291748
CancerNet_fc_breast_cancerreducued_85_acc_90.35087823867798_marg_0.6733307242393494.pt
0.20159350335597992
0.3433155119419098
CancerNet_fc_breast_cancerreducued_80_acc_92.10526943206787_marg_0.5269500613212585.pt
0.11393766850233078
0.2190008908510208
CancerNet_fc_breast_cancerreducued_99_acc_37.71929740905762_marg_1.0673283338546753.pt
0.40188300609588623
0.4086129665374756
CancerNet_fc_breast_cancerreducued_95_acc_91.22806787490845_marg_0.5478118062019348.pt
0.1507350206375122
0.32013222575187683
CancerNet_fc_breast_cancerreducued_70_acc_92.98245906829834_marg_0.5295510292053223.pt
0.11371316015720367
0.22209279239177704
CancerNet_fc_breast_cancerreducued_60_acc_93.85964870452881_marg_0.4583999812602997.pt
0.0634060651063919
0.12445912510156631
CancerNet_fc_breast_cancerreducued_20_acc_94.73684430122375_marg_0.48285186290740967.pt
0.051847945898771286
0.1050

In [34]:
unstructured_perf = [ 
     96.19883040935673, 95.76023391812866, 95.32163742690058, 95.6140350877193, 
        95.32163742690061, 95.90643274853801, 95.02923976608186, 94.29824561403511, 
        67.54385964912281, 45.90643274853801
]
flops_unstructured = [ 13200 ] * len(unstructured_perf)

In [35]:
# Separate the data for plotting
sparsities = [d['sparsity'] for d in model_data]
accuracies = [d['accuracy'] for d in model_data]
flops = [d['flops'] for d in model_data]
marg = [d['marglik'] for d in model_data]
brier = [d['brier'] for d in model_data] 
ece =  [d['ece'] for d in model_data] 
disk_size = [d['size'] for d in model_data]
# Sort the data by flops
sorted_data = sorted(zip(flops, sparsities, accuracies, marg, brier, ece, disk_size))

# Unzip the sorted data
sorted_flops, sorted_sparsities, sorted_accuracies, sorted_marg, sorted_brier, sorted_ece, disk_size = zip(*sorted_data)

In [23]:

# Plot Sparsity vs FLOPs

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(sorted_sparsities, sorted_flops)

# add text for each point
for i, txt in enumerate(sorted_flops):
    # reverse the text for better readability
    plt.annotate(txt, (sorted_sparsities[i], sorted_flops[i]), rotation=90)

plt.xlabel('Model Sparsity (%)')
plt.ylabel('FLOPs ')
plt.grid(True)
plt.title('Model Sparsity vs FLOPs')
plt.yscale('log')

# Plot FLOPs vs Accuracy
plt.subplot(1, 2, 2)
plt.plot(sorted_flops, sorted_accuracies, color='blue', marker='o', label='Structured')
plt.plot(flops_unstructured, unstructured_perf, color='orange', marker='x', label='Unstructured')
plt.legend()

# add text for each point
for i, txt in enumerate(sorted_flops):
    plt.annotate(txt, (sorted_flops[i], sorted_accuracies[i]), rotation=90)

plt.xlabel('FLOPs')
plt.ylabel('Validation Accuracy ')
plt.grid(True)
plt.title('FLOPs vs Accuracy/MargLik')
ax3 = plt.twinx()
ax3.scatter(sorted_flops, sorted_marg, color='red')
ax3.set_ylabel('Marglik ', color='red')
ax3.tick_params(axis='y', colors='red')
ax3.set_ylim([0, max(sorted_marg) + 0.1])
plt.xscale('log')
#plt.tight_layout()
plt.savefig("CancerNet_flops_vs_acc_tue.pdf")
plt.show()


RuntimeError: Failed to process string with tex because latex could not be found

Error in callback <function _draw_all_if_interactive at 0x7fa1a55b6950> (for post_execute):


RuntimeError: Failed to process string with tex because latex could not be found

RuntimeError: Failed to process string with tex because latex could not be found

<Figure size 1000x500 with 3 Axes>

In [36]:
import pandas as pd
df = pd.DataFrame(model_data)
df.to_csv("cancer_flops_analysis_make_tue.csv")

In [37]:
df.head()

Unnamed: 0,sparsity,accuracy,flops,marglik,ece,brier,size
0,0,96.49123,13200,0.671499,0.075785,0.031623,57279
1,40,94.736844,1740,0.429745,0.108994,0.048852,10495
2,85,90.350878,255,0.673331,0.343316,0.201594,4407
3,80,92.105269,340,0.52695,0.219001,0.113938,4727
4,99,37.719297,33,1.067328,0.408613,0.401883,3511


In [43]:
EXPORTED_DIR ="/xxxxxxxx/inference/CancerNet_fc_breast_cancer_DiagLaplace_diagonal_50_wp/laplacekron_b"

model_data = []

baseline_model = CancerNet_fc(30, 100, 2)

baseline_model.load_state_dict(torch.load("/xxxxxxxx/pattern/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_struct/CancerNet_fc_breast_cancer_KronLaplace_unitwise_50_wp_baseline_acc0.9649122953414917_marg_0.6714989542961121.pt"))
baseline_model.to("cuda")
baseline_model.eval()

probs, targets = get_probs(baseline_model, valid_loader, "cuda")

brier = brier_score_ours(probs, targets, probs.shape[1])

ece = ece_score(probs, targets)
# skip disk and flops 
model_data.append({"sparsity":0, "accuracy": 96.49122953414917, "marglik": 0.6714989542961121,"ece":ece, "brier":brier.item()})


for model_name in os.listdir(EXPORTED_DIR):
    if model_name.endswith('.pt'):
        print(model_name)
        sp = int(model_name.split("_")[3])
        acc = float(model_name.split("_")[5])
        marg = float(model_name.split("_")[7].strip(".pt"))
        model_meta = torch.load(os.path.join(EXPORTED_DIR, model_name))

        # Create and load model (assuming PrunedCancerNet_fc is defined elsewhere)
        new_model = CancerNet_fc(30, 100, 2)
        new_model.load_state_dict(torch.load(os.path.join(EXPORTED_DIR, model_name)))
        new_model.to("cuda")
        new_model.eval()
        probs, targets = get_probs(new_model, valid_loader, "cuda")
        brier = brier_score_ours(probs, targets,probs.shape[1]).item()
        ece = ece_score(probs, targets,10)
        print(brier)
        print(ece)
        
        # Calculate FLOPs
        flops = 0
    

        # Collect data
        model_data.append({'sparsity': sp, 'accuracy': acc, 'flops': flops, "marglik":marg, "by_opres":0, "ece":ece, "brier":brier})


CancerNet_fc_sp_99_acc_37.719298245614034_marg_1.0538321733474731.pt
0.4657992720603943
0.5302312970161438
CancerNet_fc_sp_60_acc_94.73684210526316_marg_3.2067408561706543.pt
0.03843589499592781
0.05452299118041992
CancerNet_fc_sp_75_acc_94.73684210526316_marg_4.09218692779541.pt
0.03389754146337509
0.06002712622284889
CancerNet_fc_sp_95_acc_87.71929824561404_marg_2.9200398921966553.pt
0.10536357760429382
0.10901208221912384
CancerNet_fc_sp_20_acc_95.6140350877193_marg_3.3435819149017334.pt
0.031446345150470734
0.05950549617409706
CancerNet_fc_sp_40_acc_92.98245614035088_marg_3.859661340713501.pt
0.03710773587226868
0.05946286767721176
CancerNet_fc_sp_90_acc_92.10526315789474_marg_3.674030303955078.pt
0.046089787036180496
0.07847480475902557
CancerNet_fc_sp_80_acc_94.73684210526316_marg_3.955268144607544.pt
0.03115650825202465
0.05664987117052078
CancerNet_fc_sp_70_acc_94.73684210526316_marg_3.219619035720825.pt
0.03767303749918938
0.055590055882930756
CancerNet_fc_sp_85_acc_93.8596491

In [44]:
df = pd.DataFrame(model_data)
df.to_csv("unstructured_cancer_analysis.csv")