In [None]:
# Activate Virtual Environment and Install Requirements
#!python3 -m venv ../brain_model_env
#!source ../brain_model_env/bin/activate
#!python3 -m ipykernel install --user --name=brain_model_env --display-name "Python (brain_model_env)"
#remember to switch to notebook/virtual environment kernel

In [1]:
from dp_model.model_files.sfcn import SFCN
from dp_model import dp_loss as dpl
from dp_model import dp_utils as dpu
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import nibabel as nib
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_auc_score

# ASD Prediction on ABIDEI Test Set

In [2]:
# -----------------------
# Settings
# -----------------------
input_root = '../ABIDE_Dataset/data/JustBrain/ABIDEI'
participants_path = './ABIDEI/participants.tsv'
model_weights_path = './ABIDEI/finetuned_sfcn_best.pth'
label_column = 'label'

# -----------------------
# Load model
# -----------------------
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_weights_path, weights_only=True, map_location=torch.device('cpu')))
#model.cuda()
model.eval()

# -----------------------
# Load labels
# -----------------------
df = pd.read_csv(participants_path, sep='\t')
df = df[df.dataset == 'test'] #only test set
df['participant_id'] = df['participant_id'].str.strip()
df = df.set_index('participant_id')

# -----------------------
# Inference loop
# -----------------------
records = []
correct = 0
total = 0

# Loop through subject subfolders
for subject_id in tqdm(sorted(os.listdir(input_root))):
    if not subject_id.startswith('sub-'):
        continue

    # Full path to T1w file
    anat_dir = os.path.join(input_root, subject_id, 'anat')
    if not os.path.isdir(anat_dir):
        continue

    # Find the T1w NIfTI file (assuming one per subject)
    nii_files = [f for f in os.listdir(anat_dir) if f.endswith('.nii.gz')]
    t1w_file = None
    for f in nii_files:
        if subject_id in f and 'T1w' in f:
            t1w_file = f
            break

    if t1w_file is None:
        print(f"No T1w file found for {subject_id}")
        continue

    full_path = os.path.join(anat_dir, t1w_file)

    if subject_id not in df.index: # not in test set
        #print(f"{subject_id} not in participants.tsv") 
        continue

    true_label = df.loc[subject_id, label_column]

    # Load and normalize
    data = nib.load(full_path).get_fdata()
    data = data / data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Prepare input tensor (1, 1, D, H, W)
    input_tensor = torch.tensor(data.reshape((1,1) + data.shape), dtype=torch.float32)#.cuda()

    # Run model
    with torch.no_grad():
        output = model(input_tensor)
        probs = output[0].cpu().reshape([1, -1])
        probs = np.exp(probs.numpy().reshape(-1))

        
    pred_label = np.argmax(probs)
    is_correct = int(pred_label == true_label)

    correct += is_correct
    total += 1

    records.append({
        'subject_id': subject_id,
        'true_label': true_label,
        'predicted_label': pred_label,
        'prob_nonASD': probs[0],
        'prob_ASD': probs[1],
        'correct': is_correct
    })

# -----------------------
# Save and report
# -----------------------
results_df = pd.DataFrame(records)
results_df.to_csv('./ABIDEI/sfcn_asd_predictions.csv', index=False)

#F1 and AUC_ROC
print(f'F1-score: {f1_score(results_df.true_label, results_df.predicted_label)}')
print(f'ROC_AUC-score: {roc_auc_score(results_df.true_label, results_df.prob_ASD)}')


#Accuracy
accuracy = correct / total if total > 0 else 0
print(f"\nFinished inference on {total} subjects")
print(f"Accuracy: {accuracy:.2%}")


100%|██████████| 985/985 [02:16<00:00,  7.20it/s]  


F1-score: 0.4264705882352941
ROC_AUC-score: 0.5367575462512171

Finished inference on 183 subjects
Accuracy: 57.38%


# ASD Prediction on ABIDEII OOD Dataset

In [4]:
# -----------------------
# Settings
# -----------------------
input_root = '../ABIDE_Dataset/data/JustBrain/ABIDEII'
participants_path = './ABIDEII/participants.tsv'
model_weights_path = './ABIDEI/finetuned_sfcn_best.pth'
label_column = 'label'

# -----------------------
# Load model
# -----------------------
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_weights_path, weights_only=True, map_location=torch.device('cpu')))
#model.cuda()
model.eval()

# -----------------------
# Load labels
# -----------------------
df = pd.read_csv(participants_path, sep='\t')
df['participant_id'] = df['participant_id'].str.strip()
df = df.set_index('participant_id')

# -----------------------
# Inference loop
# -----------------------
records = []
correct = 0
total = 0

# Loop through subject subfolders
for subject_id in tqdm(sorted(os.listdir(input_root))):
    if not subject_id.startswith('sub-'):
        continue

    # Full path to T1w file
    anat_dir = os.path.join(input_root, subject_id, 'anat')
    if not os.path.isdir(anat_dir):
        continue

    # Find the T1w NIfTI file (assuming one per subject)
    nii_files = [f for f in os.listdir(anat_dir) if f.endswith('.nii.gz')]
    t1w_file = None
    for f in nii_files:
        if subject_id in f and 'T1w' in f:
            t1w_file = f
            break

    if t1w_file is None:
        print(f"No T1w file found for {subject_id}")
        continue

    full_path = os.path.join(anat_dir, t1w_file)

    if subject_id not in df.index:
        print(f"{subject_id} not in participants.tsv")
        continue

    true_label = df.loc[subject_id, label_column]

    # Load and normalize
    data = nib.load(full_path).get_fdata()
    data = data / data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Prepare input tensor (1, 1, D, H, W)
    input_tensor = torch.tensor(data.reshape((1,1) + data.shape), dtype=torch.float32)#.cuda()

    # Run model
    with torch.no_grad():
        output = model(input_tensor)
        probs = output[0].cpu().reshape([1, -1])
        probs = np.exp(probs.numpy().reshape(-1))

        
    pred_label = np.argmax(probs)
    is_correct = int(pred_label == true_label)

    correct += is_correct
    total += 1

    records.append({
        'subject_id': subject_id,
        'true_label': true_label,
        'predicted_label': pred_label,
        'prob_nonASD': probs[0],
        'prob_ASD': probs[1],
        'correct': is_correct
    })

# -----------------------
# Save and report
# -----------------------
results_df = pd.DataFrame(records)
results_df.to_csv('./ABIDEII/sfcn_asd_predictions.csv', index=False)

#F1 and AUC_ROC
print(f'F1-score: {f1_score(results_df.true_label, results_df.predicted_label)}')
print(f'ROC_AUC-score: {roc_auc_score(results_df.true_label, results_df.prob_ASD)}')


#Accuracy
accuracy = correct / total if total > 0 else 0
print(f"\nFinished inference on {total} subjects")
print(f"Accuracy: {accuracy:.2%}")


 32%|███▏      | 310/961 [02:29<05:21,  2.02it/s]

sub-29057 not in participants.tsv
sub-29058 not in participants.tsv
sub-29059 not in participants.tsv
sub-29060 not in participants.tsv
sub-29062 not in participants.tsv
sub-29063 not in participants.tsv
sub-29064 not in participants.tsv
sub-29065 not in participants.tsv
sub-29066 not in participants.tsv
sub-29067 not in participants.tsv
sub-29068 not in participants.tsv
sub-29069 not in participants.tsv
sub-29070 not in participants.tsv
sub-29071 not in participants.tsv
sub-29072 not in participants.tsv
sub-29073 not in participants.tsv
sub-29074 not in participants.tsv
sub-29075 not in participants.tsv
sub-29076 not in participants.tsv
sub-29077 not in participants.tsv
sub-29078 not in participants.tsv
sub-29079 not in participants.tsv
sub-29080 not in participants.tsv
sub-29081 not in participants.tsv
sub-29082 not in participants.tsv
sub-29083 not in participants.tsv
sub-29085 not in participants.tsv
sub-29086 not in participants.tsv
sub-29087 not in participants.tsv
sub-29088 not 

100%|██████████| 961/961 [07:52<00:00,  2.03it/s]

F1-score: 0.3916913946587537
ROC_AUC-score: 0.5682258732309545

Finished inference on 924 subjects
Accuracy: 55.63%



