In [1]:

import os
import torch
from torch.utils.data import Dataset, DataLoader
import pydicom
import numpy as np
from torchvision import transforms
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import torch.nn as nn
from torch.utils.data import random_split
from model import MedicalImageCNN, CNNToRNA,CNNClassifier, train_model
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F

from helper import collate_fn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from sklearn.metrics import roc_auc_score
import numpy as np
import torch
import torch.nn.functional as F

def evaluate_model_auc(model, dataloader, device):
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for images, _, labels, _ in dataloader:
            images = images.to(device)
            labels = labels.to(device).long()  # ensure labels are class indices

            outputs = model(images)  # shape: [B, num_classes]
            probs = F.softmax(outputs, dim=1)  # shape: [B, num_classes]

            all_probs.append(probs.cpu())
            all_targets.append(labels.cpu())

    probs = torch.cat(all_probs).numpy()   # shape: [N, num_classes]
    targets = torch.cat(all_targets).numpy().squeeze()  # shape: [N]

    if len(probs.shape) != 2:
        raise ValueError(f"Predicted probabilities must be a 2D array. Got shape: {probs.shape}")
    if len(targets.shape) != 1:
        raise ValueError(f"Targets must be a 1D array. Got shape: {targets.shape}")

    try:
        # Use multi_class only if more than 2 classes
        if probs.shape[1] == 2:
            auc = roc_auc_score(targets, probs[:, 1])  # class 1 probability
        else:
            auc = roc_auc_score(targets, probs, multi_class='ovo', average='macro')
    except ValueError as e:
        print(f"Error calculating AUC: {e}")
        auc = None 
    return auc

In [3]:
print("Top 10 genes:")

#load the gene performance data from the saved file
if not os.path.exists('gene_performance.pth'):
    raise FileNotFoundError("The file 'gene_performance.pth' does not exist.")
gene_performance = torch.load('gene_performance.pth', weights_only=False)
gene_performance = dict(sorted(gene_performance.items(), key=lambda item: item[1]['total_lost']))


#sort by total_lost in descending order


for gene in gene_performance.keys():
    #recalculate the test accuracy
    model = gene_performance[gene]['model']
    model.to(device)
    test_loader = DataLoader(gene_performance[gene]['test_dataset'], batch_size=1, shuffle=False, collate_fn=collate_fn)
    gene_performance[gene]['test_accuracy'] = evaluate_model_auc(gene_performance[gene]['model'], test_loader, device)
    print(gene, gene_performance[gene]['total_lost'], gene_performance[gene]['test_accuracy'])


Top 10 genes:
ORAI2 0.2984073036595395 0.7777777777777779
DNAL1 0.3143503289473684 0.9411764705882353
CHD7 0.4294240851151316 0.96875
SEC22A 0.4331247430098684 1.0
RCOR1 0.4368832236842105 0.703125
CASP2 0.45602256373355265 0.9215686274509803
SAV1 0.4601187455026727 0.4166666666666667
RAB11FIP4 0.49955588892886515 0.6470588235294118
PGS1 0.5249730160361842 0.5119047619047619
ATAD3C 0.5324211120605469 0.8333333333333333
SNORA54 0.5655593872070312 0.5499999999999999
ANO9 0.5947779605263158 0.6309523809523809
DNM1L 0.6033260947779605 0.5066666666666667
ZNF121 0.6186009457236842 0.5714285714285714
INTS4 0.622789884868421 0.5625
ZNF485 0.6252797444661459 0.6458333333333333
CDKL5 0.6349198190789473 0.35714285714285715
CDK5RAP1 0.6377981085526315 0.6703296703296703
CDON 0.6459282769097222 0.6309523809523809
SETD3 0.6466899671052632 0.5494505494505494
ZNF606 0.6623920641447368 0.3626373626373626
TSNARE1 0.6647178248355263 0.40659340659340654
TRNP1 0.67218017578125 0.45454545454545453
MYO1D 0.6

In [4]:
#now save the updated gene_performance dictionary
torch.save(gene_performance, 'gene_performance.pth')