In [None]:
# Activate Virtual Environment and Install Requirements
#!python3 -m venv ../brain_model_env
#!source ../brain_model_env/bin/activate
#!python3 -m ipykernel install --user --name=brain_model_env --display-name "Python (brain_model_env)"
#remember to switch to notebook/virtual environment kernel

In [2]:
from dp_model.model_files.sfcn import SFCN
from dp_model import dp_loss as dpl
from dp_model import dp_utils as dpu
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import nibabel as nib
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_auc_score

# Example on Single Subject

In [3]:
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
fp_ = './sex_prediction/run_20191008_00_epoch_last.p'
model.load_state_dict(torch.load(fp_, weights_only=True))
model.cuda()

# Example data: some random brain in the MNI152 1mm std space
data_path = '../ABIDE_Dataset/data/JustBrain/ABIDEI/sub-0051038/anat/sub-0051038_T1w.nii.gz'
data = nib.load(data_path).get_fdata()
y = torch.tensor([0]) # Assuming Sex is Male (0=Female, 1=Male)

# Preprocessing
data = data/data.mean()
data = dpu.crop_center(data, (160, 192, 160))

# Move the data from numpy to torch tensor on GPU
sp = (1,1)+data.shape
data = data.reshape(sp)
input_data = torch.tensor(data, dtype=torch.float32).cuda()
print(f'Input data shape: {input_data.shape}')
print(f'dtype: {input_data.dtype}')

# Evaluation
model.eval() # Don't forget this. BatchNorm will be affected if not in eval mode.
with torch.no_grad():
    output = model(input_data)

# Output, loss, visualisation
x = output[0].cpu().reshape([1, -1])
loss = F.nll_loss(x, y)

# Prediction, Visualisation and Summary
x = np.exp(x.numpy().reshape(-1))

print('\nPredicted probability: \nFemale\t%.2f%%,\nMale\t%.2f%%'%(x[0]*100, x[1]*100))

Input data shape: torch.Size([1, 1, 160, 192, 160])
dtype: torch.float32

Predicted probability: 
Female	33.69%,
Male	66.31%


# Sex Prediction Test on ABIDEI

In [6]:
# -----------------------
# Settings
# -----------------------
input_root = '../ABIDE_Dataset/data/JustBrain/ABIDEI'  # Root folder with sub-xxxx/anat/*.nii.gz
participants_path = '../ABIDE_Dataset/data/ABIDEI/participants.tsv'
model_weights_path = './sex_prediction/run_20191008_00_epoch_last.p'
label_column = 'sex'  # or whatever your ground truth column is

# -----------------------
# Load model
# -----------------------
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_weights_path, weights_only=True))
model.cuda()
model.eval()

# -----------------------
# Load labels
# -----------------------
df = pd.read_csv(participants_path, sep='\t')
df['participant_id'] = df['participant_id'].str.strip()
df = df.set_index('participant_id')

# -----------------------
# Inference loop
# -----------------------
records = []
correct = 0
total = 0

# Loop through subject subfolders
for subject_id in tqdm(sorted(os.listdir(input_root))):
    if not subject_id.startswith('sub-'):
        continue

    # Full path to T1w file
    anat_dir = os.path.join(input_root, subject_id, 'anat')
    if not os.path.isdir(anat_dir):
        continue

    # Find the T1w NIfTI file (assuming one per subject)
    nii_files = [f for f in os.listdir(anat_dir) if f.endswith('.nii.gz')]
    t1w_file = None
    for f in nii_files:
        if subject_id in f and 'T1w' in f:
            t1w_file = f
            break

    if t1w_file is None:
        print(f"No T1w file found for {subject_id}")
        continue

    full_path = os.path.join(anat_dir, t1w_file)

    if subject_id not in df.index:
        print(f"{subject_id} not in participants.tsv")
        continue

    true_label = df.loc[subject_id, label_column]

    # Load and normalize
    data = nib.load(full_path).get_fdata()
    data = data / data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Prepare input tensor (1, 1, D, H, W)
    input_tensor = torch.tensor(data.reshape((1,1) + data.shape), dtype=torch.float32).cuda()

    # Run model
    with torch.no_grad():
        output = model(input_tensor)
        probs = output[0].cpu().reshape([1, -1])
        probs = np.exp(probs.numpy().reshape(-1))

        
    pred_label = np.argmax(probs)
    is_correct = int(pred_label == true_label)

    correct += is_correct
    total += 1

    records.append({
        'subject_id': subject_id,
        'true_label': true_label,
        'predicted_label': pred_label,
        'prob_female': probs[0],
        'prob_male': probs[1],
        'correct': is_correct
    })

# -----------------------
# Save and report
# -----------------------
results_df = pd.DataFrame(records)
results_df.to_csv('../ABIDE_Dataset/outputs/ABIDEI/sfcn_sex_predictions.csv', index=False)

#F1 and AUC_ROC
print(f'F1-score: {f1_score(results_df.true_label, results_df.predicted_label)}')
print(f'ROC_AUC-score: {roc_auc_score(results_df.true_label, results_df.prob_male)}')


#Accuracy
accuracy = correct / total if total > 0 else 0
print(f"\nFinished inference on {total} subjects")
print(f"Accuracy: {accuracy:.2%}")


100%|██████████| 985/985 [03:22<00:00,  4.86it/s]

F1-score: 0.8689320388349514
ROC_AUC-score: 0.7125050261359067

Finished inference on 985 subjects
Accuracy: 78.07%





# Sex Prediction Test on ABIDEII

In [7]:
# -----------------------
# Settings
# -----------------------
input_root = '../ABIDE_Dataset/data/JustBrain/ABIDEII'  # Root folder with sub-xxxx/anat/*.nii.gz
participants_path = '../ABIDE_Dataset/data/ABIDEII/participants.tsv'
model_weights_path = './sex_prediction/run_20191008_00_epoch_last.p'
label_column = 'sex'  # or whatever your ground truth column is

# -----------------------
# Load model
# -----------------------
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(model_weights_path, weights_only=True))
model.cuda()
model.eval()

# -----------------------
# Load labels
# -----------------------
df = pd.read_csv(participants_path, sep='\t')
df['participant_id'] = df['participant_id'].str.strip()
df = df.set_index('participant_id')

# -----------------------
# Inference loop
# -----------------------
records = []
correct = 0
total = 0

# Loop through subject subfolders
for subject_id in tqdm(sorted(os.listdir(input_root))):
    if not subject_id.startswith('sub-'):
        continue

    # Full path to T1w file
    anat_dir = os.path.join(input_root, subject_id, 'anat')
    if not os.path.isdir(anat_dir):
        continue

    # Find the T1w NIfTI file (assuming one per subject)
    nii_files = [f for f in os.listdir(anat_dir) if f.endswith('.nii.gz')]
    t1w_file = None
    for f in nii_files:
        if subject_id in f and 'T1w' in f:
            t1w_file = f
            break

    if t1w_file is None:
        print(f"No T1w file found for {subject_id}")
        continue

    full_path = os.path.join(anat_dir, t1w_file)

    if subject_id not in df.index:
        print(f"{subject_id} not in participants.tsv")
        continue

    true_label = df.loc[subject_id, label_column]

    # Load and normalize
    data = nib.load(full_path).get_fdata()
    data = data / data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Prepare input tensor (1, 1, D, H, W)
    input_tensor = torch.tensor(data.reshape((1,1) + data.shape), dtype=torch.float32).cuda()

    # Run model
    with torch.no_grad():
        output = model(input_tensor)
        probs = output[0].cpu().reshape([1, -1])
        probs = np.exp(probs.numpy().reshape(-1))

        
    pred_label = np.argmax(probs)
    is_correct = int(pred_label == true_label)

    correct += is_correct
    total += 1

    records.append({
        'subject_id': subject_id,
        'true_label': true_label,
        'predicted_label': pred_label,
        'prob_female': probs[0],
        'prob_male': probs[1],
        'correct': is_correct
    })

# -----------------------
# Save and report
# -----------------------
results_df = pd.DataFrame(records)
results_df.to_csv('../ABIDE_Dataset/outputs/ABIDEII/sfcn_sex_predictions.csv', index=False)

#F1 and AUC_ROC
print(f'F1-score: {f1_score(results_df.true_label, results_df.predicted_label)}')
print(f'ROC_AUC-score: {roc_auc_score(results_df.true_label, results_df.prob_male)}')


#Accuracy
accuracy = correct / total if total > 0 else 0
print(f"\nFinished inference on {total} subjects")
print(f"Accuracy: {accuracy:.2%}")


 32%|███▏      | 310/961 [01:01<02:20,  4.64it/s]

sub-29057 not in participants.tsv
sub-29058 not in participants.tsv
sub-29059 not in participants.tsv
sub-29060 not in participants.tsv
sub-29062 not in participants.tsv
sub-29063 not in participants.tsv
sub-29064 not in participants.tsv
sub-29065 not in participants.tsv
sub-29066 not in participants.tsv
sub-29067 not in participants.tsv
sub-29068 not in participants.tsv
sub-29069 not in participants.tsv
sub-29070 not in participants.tsv
sub-29071 not in participants.tsv
sub-29072 not in participants.tsv
sub-29073 not in participants.tsv
sub-29074 not in participants.tsv
sub-29075 not in participants.tsv
sub-29076 not in participants.tsv
sub-29077 not in participants.tsv
sub-29078 not in participants.tsv
sub-29079 not in participants.tsv
sub-29080 not in participants.tsv
sub-29081 not in participants.tsv
sub-29082 not in participants.tsv
sub-29083 not in participants.tsv
sub-29085 not in participants.tsv
sub-29086 not in participants.tsv
sub-29087 not in participants.tsv
sub-29088 not 

100%|██████████| 961/961 [03:07<00:00,  5.14it/s]


F1-score: 0.7582804792107117
ROC_AUC-score: 0.5678879310344828

Finished inference on 924 subjects
Accuracy: 62.88%
