In [None]:
#mounts from google drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#cell for data loading
import pandas as pd
import os

csv_path = '/content/structural-survey/CRAVE_scores.csv'
structural_dir = '/content/structural-survey/Structural'

if os.path.exists(csv_path) and os.path.exists(structural_dir):
    df = pd.read_csv(csv_path)

    #get all .nii files
    nii_files = [f for f in os.listdir(structural_dir) if f.endswith('.nii')]

    #map subject ids to file paths
    subject_to_path = {}
    for filename in nii_files:
        
        temp_name = filename.replace('wc0c', '', 1) if filename.startswith('wc0c') else filename
        #split at 1st underscore
        subject_id = temp_name.split('_')[0]
        subject_to_path[subject_id] = os.path.join(structural_dir, filename)

    #find column corresponding to subj ids
    matched_col = None
    extracted_ids_set = set(subject_to_path.keys())

    for col in df.columns:
        col_values = set(df[col].astype(str))
        if len(col_values.intersection(extracted_ids_set)) > 0:
            matched_col = col
            break

    if matched_col:
        df = df[df[matched_col].astype(str).isin(extracted_ids_set)].copy()
        df['file_path'] = df[matched_col].astype(str).map(subject_to_path)
        for col in ['CAMS', 'QIDS', 'GAD']:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')

        df['CAMS_binary'] = (df['CAMS'] >= 26).astype(int)
        df['QIDS_binary'] = (df['QIDS'] >= 11).astype(int)
        df['GAD_binary'] = (df['GAD'] >= 11).astype(int)

        print("Data loaded and targets created successfully. (Prints hidden)")

In [None]:
## ----------***This is not used****-----------
#pca analysis
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import recall_score # Ensure recall_score is imported

#standardize
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# no truncation to see variance curve
pca_full = PCA()
pca_full.fit(X_scaled)

#visualize explained variance
plt.figure(figsize=(10, 5))
plt.plot(np.cumsum(pca_full.explained_variance_ratio_), marker='o', linestyle='--')
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance')
plt.title('PCA: Explained Variance by Components')
plt.grid(True)
plt.show()

#apply pca w 95% of variance
pca = PCA(n_components=0.95)
X_pca = pca.fit_transform(X_scaled)

print(f"Original Feature Shape: {X.shape}")
print(f"Reduced PCA Shape: {X_pca.shape}")
print(f"Kept {X_pca.shape[1]} components to explain 95% of variance.")

#re-train classifiers on pca data
print("\n--- Retraining on PCA Features ---")

results_store_pca = {}

for target in targets:
    print(f"\n=== Target: {target} (PCA) ===")
    results_store_pca[target] = {}

    if target not in df_filtered.columns:
        continue

    y = df_filtered[target].values

    for name, clf in classifiers.items():
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        try:
            y_pred = cross_val_predict(clf, X_pca, y, cv=cv)

            acc = accuracy_score(y, y_pred)
            f1 = f1_score(y, y_pred, zero_division=0)
            recall = recall_score(y, y_pred, zero_division=0)
            cm = confusion_matrix(y, y_pred)

            print(f"{name}: Accuracy={acc:.4f}, F1={f1:.4f}, Recall={recall:.4f}")

        except Exception as e:
            print(f"Error {name}: {e}")

In [None]:
#load model and extract features
import sys
import os
import subprocess
import pkg_resources
import zipfile
import pandas as pd
import numpy as np
import torch
import nibabel as nib
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
from google.colab import drive

#must install alive_progress dependency for this to work
required = {'alive_progress'}
installed = {pkg.key for pkg in pkg_resources.working_set}
missing = required - installed
if missing:
    print("Installing missing packages:", missing)
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', *missing])

base_search_path = '/content/'
zip_path = '/content/gdrive/MyDrive/structural-survey.zip'

#debugging
key_file = 'AD_pretrained_utilities.py'
file_found = False
for root, dirs, files in os.walk(base_search_path):
    if key_file in files:
        file_found = True
        break

if not file_found:
    print("Dataset files not found. Attempting to restore...")
    if not os.path.exists('/content/gdrive'):
        print("Mounting Google Drive...")
        drive.mount('/content/gdrive')

    if os.path.exists(zip_path):
        print(f"Unzipping '{zip_path}'...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(base_search_path)
        print("Unzip complete.")
    else:
        print(f"Warning: Zip file not found at {zip_path}.")

#model utilities
model_dir = None
for root, dirs, files in os.walk(base_search_path):
    if key_file in files:
        model_dir = root
        print(f"Found {key_file} in {model_dir}")
        break

if model_dir:
    if model_dir not in sys.path:
        sys.path.append(model_dir)
        print(f"Added {model_dir} to sys.path")
else:
    raise RuntimeError(f"Could not locate {key_file} in {base_search_path}")

if 'X' not in globals() or 'df_filtered' not in globals():
    print("Variables 'X' or 'df_filtered' missing. Regenerating...")
    csv_path = None
    structural_dir = None

    for root, dirs, files in os.walk(base_search_path):
        if 'CRAVE_scores.csv' in files:
            csv_path = os.path.join(root, 'CRAVE_scores.csv')
        if 'Structural' in dirs:
            structural_dir = os.path.join(root, 'Structural')

    if csv_path and structural_dir:
        df = pd.read_csv(csv_path)

        nii_files = [f for f in os.listdir(structural_dir) if f.endswith('.nii')]
        subject_to_path = {}
        for fname in nii_files:
            sid = fname.replace('wc0c', '', 1).split('_')[0] if fname.startswith('wc0c') else fname.split('_')[0]
            subject_to_path[sid] = os.path.join(structural_dir, fname)

        #filter
        extracted_ids = set(subject_to_path.keys())
        matched_col = None
        for col in df.columns:
            if len(set(df[col].astype(str)).intersection(extracted_ids)) > 0:
                matched_col = col
                break

        if matched_col:
            df = df[df[matched_col].astype(str).isin(extracted_ids)].copy()
            df['file_path'] = df[matched_col].astype(str).map(subject_to_path)

            for c, thresh in [('CAMS', 26), ('QIDS', 11), ('GAD', 11)]:
                if c in df.columns:
                    df[c] = pd.to_numeric(df[c], errors='coerce')
                    df[f'{c}_binary'] = (df[c] >= thresh).astype(int)

            #run model
            try:
                from AD_pretrained_utilities import CNN, CNN_8CL_B
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                weights_path = os.path.join(model_dir, 'AD_pretrained_weights.pt')

                if os.path.exists(weights_path):
                    param = CNN_8CL_B()
                    model = CNN(param)
                    model.load_state_dict(torch.load(weights_path, map_location=device))
                    model.to(device)
                    model.eval()

                    features_list = []
                    def hook(module, input, output):
                        features_list.append(input[0].detach().cpu().numpy())
                    if hasattr(model, 'f') and len(model.f) > 0:
                        model.f[-1].register_forward_hook(hook)

                    X_list = []
                    valid_indices = []
                    target_shape = (73, 96, 96)

                    print(f"Processing {len(df)} subjects...")
                    for idx, row in df.iterrows():
                        try:
                            img = nib.load(row['file_path'])
                            data = img.get_fdata()
                            data = np.nan_to_num(data)
                            if data.shape != target_shape:
                                zoom = [t/s for t, s in zip(target_shape, data.shape)]
                                data = ndimage.zoom(data, zoom, order=1)
                            if data.max() > data.min():
                                data = (data - data.min()) / (data.max() - data.min())

                            inp = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0).to(device)
                            features_list = []
                            with torch.no_grad():
                                _ = model(inp)
                            if features_list:
                                X_list.append(features_list[0].flatten())
                                valid_indices.append(idx)
                        except Exception:
                            pass

                    if X_list:
                        X = np.vstack(X_list)
                        df_filtered = df.loc[valid_indices].copy()
                        print(f"X shape: {X.shape}")
                    else:
                        raise RuntimeError("No features extracted")
                else:
                    raise RuntimeError("Weights not found")
            except Exception as e:
                raise RuntimeError(f"Feature extraction failed: {e}")
        else:
            raise RuntimeError("No matching subject IDs")
    else:
        raise RuntimeError("CSV or Structural dir not found")

#training
print("\nTraining...")
classifiers = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'SVM': SVC(kernel='linear', random_state=42),
    'MLP': MLPClassifier(hidden_layer_sizes=(64,), max_iter=1000, random_state=42)
}
results_store = {}
targets = ['CAMS_binary', 'QIDS_binary', 'GAD_binary']

if 'X' in globals() and 'df_filtered' in globals():
    for target in targets:
        print(f"\nTarget: {target}")
        results_store[target] = {}
        if target in df_filtered.columns:
            y = df_filtered[target].values
            if len(y) >= 5:
                for name, clf in classifiers.items():
                    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
                    try:
                        y_pred = cross_val_predict(clf, X, y, cv=cv)
                        acc = accuracy_score(y, y_pred)
                        f1 = f1_score(y, y_pred, zero_division=0)
                        recall = recall_score(y, y_pred, zero_division=0)
                        cm = confusion_matrix(y, y_pred)
                        results_store[target][name] = {'accuracy': acc, 'f1_score': f1, 'recall': recall}
                        print(f"{name}: Acc={acc:.4f}, F1={f1:.4f}, Recall={recall:.4f}")
                    except Exception as e:
                        print(f"Error {name}: {e}")
            else:
                print("Not enough data")
else:
    print("Data missing")

In [None]:
#feature selection
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel

#rf, logreg, mlp
req_models = {
    'Statistical (LogReg)': LogisticRegression(max_iter=1000, random_state=42),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
    'MLP': MLPClassifier(hidden_layer_sizes=(64,), max_iter=1000, random_state=42)
}

print("\n--- Running Feature Importance Selection ---\n")

results_store_selection = {}

for target in targets:
    print(f"\n=== Target: {target} ===")
    results_store_selection[target] = {}

    if target not in df_filtered.columns:
        continue

    y = df_filtered[target].values

    #find feature importance using rf
    #fit new rf for selection
    selector_clf = RandomForestClassifier(n_estimators=100, random_state=42)
    selector_clf.fit(X, y)

    #importance scores, sorted descending
    importances = selector_clf.feature_importances_
    indices = np.argsort(importances)[::-1]

    #top 5 features for visualization
    top_k = 5
    top_indices = indices[:top_k]
    print(f"Top {top_k} Features selected: {top_indices}")

    #new X w/ only top 5 features
    X_selected = X[:, top_indices]

    #train models on these features
    for name, clf in req_models.items():
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        try:
            y_pred = cross_val_predict(clf, X_selected, y, cv=cv)

            #metrics
            acc = accuracy_score(y, y_pred)
            f1 = f1_score(y, y_pred, zero_division=0)
            recall = recall_score(y, y_pred, zero_division=0)

            print(f"{name}: Accuracy={acc:.4f}, F1={f1:.4f}, Recall={recall:.4f}")

        except Exception as e:
            print(f"Error {name}: {e}")

In [None]:
#important
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, recall_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier

classifiers = {
    'Logistic Regression': LogisticRegression(max_iter=1000, random_state=42),
    'SVM': SVC(kernel='linear', random_state=42),
    'MLP': MLPClassifier(hidden_layer_sizes=(64,), max_iter=1000, random_state=42)
}

targets = ['CAMS_binary', 'QIDS_binary', 'GAD_binary']
metrics_list = []

if 'X' in globals() and 'df_filtered' in globals():
    #3x3 confusion matrix
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    plt.subplots_adjust(hspace=0.4, wspace=0.4)

    for i, target in enumerate(targets):
        if target not in df_filtered.columns:
            print(f"Target {target} missing.")
            continue

        y = df_filtered[target].values

        #debugging - skip if not enough samples
        if len(y) < 5:
             print(f"Not enough samples for {target}")
             continue

        for j, (clf_name, clf) in enumerate(classifiers.items()):
            ax = axes[i, j]

            #cross-validation predictions
            cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
            try:
                y_pred = cross_val_predict(clf, X, y, cv=cv)

                #metrics
                acc = accuracy_score(y, y_pred)
                f1 = f1_score(y, y_pred, zero_division=0)
                rec = recall_score(y, y_pred, zero_division=0)

                metrics_list.append({
                    'Target': target,
                    'Model': clf_name,
                    'Accuracy': acc,
                    'F1 Score': f1,
                    'Recall': rec
                })

                #confusion matrix
                cm = confusion_matrix(y, y_pred)
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, cbar=False,
                            xticklabels=['Low', 'High'], yticklabels=['Low', 'High'])
                ax.set_title(f'{target}\n{clf_name}')
                ax.set_ylabel('True Label')
                ax.set_xlabel('Predicted Label')
            except Exception as e:
                print(f"Error processing {clf_name} for {target}: {e}")
                ax.text(0.5, 0.5, "Error", ha='center', va='center')

    plt.tight_layout()
    plt.show()

    #metrics df
    metrics_df = pd.DataFrame(metrics_list)
    print("\nPerformance Metrics Summary:")
    display(metrics_df)

'''
    # comment this out, this was only for visualization at first
    #f1 score bar chart
    if not metrics_df.empty:
        plt.figure(figsize=(12, 6))
        sns.barplot(data=metrics_df, x='Target', y='F1 Score', hue='Model', palette='viridis')
        plt.title('Comparison of F1 Scores across Targets and Models')
        plt.ylim(0, 1.05)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()
'''
else:
    print("Data (X or df_filtered) is missing. Please run the previous feature extraction step.") #debugging