In [6]:
import torch
import gradio as gr
import plotly.graph_objects as go
from nilearn import datasets
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import MultiNiftiMapsMasker
import numpy as np

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device = torch.device("cpu")

# Load the model
try:
    scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device)
    
    if isinstance(scripted_model, torch.nn.DataParallel):
        scripted_model = scripted_model.module
    
    scripted_model.to(device)
    scripted_model.eval()
except Exception as e:
    print(f"Error loading model: {str(e)}")
    exit(1)

# Fetch atlas (e.g., DiFuMo)
dim = 64  
try:
    difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
    atlas_filename = difumo.maps
except Exception as e:
    print(f"Error fetching atlas: {str(e)}")
    exit(1)

# Create masker
masker = MultiNiftiMapsMasker(
    maps_img=atlas_filename,
    standardize=True,
    n_jobs=-1,
    verbose=0
)

# Connectivity measure
connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)

# Feature extraction function
def extract_features_multiple(func_preproc_files):
    all_features = []
    if not func_preproc_files:
        return all_features
    
    print("Fitting masker on the first subject...")
    masker.fit(func_preproc_files[0])
    
    for i, sub in enumerate(func_preproc_files):
        print(f"Processing subject {i+1}...")
        masked_data = masker.transform(sub)
        transformed_data = connectome_measure.fit_transform([masked_data])[0]
        all_features.append(transformed_data)
    
    print("All subjects processed.")
    return all_features

# Function to generate a Plotly probability plot
def plot_probability(probability):
    labels = ["No Autism", "Autism"]
    probs = [1 - probability, probability]
    colors = ["#6a0dad", "#d896ff"]  # Dark purple and light purple

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=labels,
        y=probs,
        marker=dict(color=colors),
        text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"],
        textposition="auto",
    ))

    fig.update_layout(
        title="Autism Prediction Probability",
        paper_bgcolor="black",
        plot_bgcolor="black",
        font=dict(color="white"),
        xaxis=dict(title="Diagnosis", showgrid=False),
        yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"),
    )

    return fig

# Prediction function
def predict_autism(fmri_files, age, gender):
    try:
        if not fmri_files:
            return "Please upload at least one valid .nii.gz file.", None
        
        features_list = extract_features_multiple(fmri_files)
        if not features_list:
            return "Error: Failed to extract features from the fMRI files.", None
        
        age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device)
        gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device)

        predictions = []
        plots = []

        for features in features_list:
            features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)

            with torch.no_grad():
                prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
                probability = torch.sigmoid(prediction).item()
            
            result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
            predictions.append(result)

            # Generate Plotly probability plot
            plots.append(plot_probability(probability))
        
        return "\n".join(predictions), plots[0]  # Return text and Plotly figure

    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio interface
iface = gr.Interface(
    fn=predict_autism,
    inputs=[
        gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"),
        gr.Number(label="Age", minimum=0, maximum=120),
        gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
    ],
    outputs=[
        gr.Text(label="Prediction Result"),
        gr.Plot(label="Prediction Probability Plot"),
    ],
    title="Autism Prediction from fMRI Data",
    description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
    theme="default",
    flagging_mode="never"
)

iface.launch()


* Running on local URL:  http://127.0.0.1:7865

To create a public link, set `share=True` in `launch()`.




Fitting masker on the first subject...
Processing subject 1...
All subjects processed.
