<a href="https://colab.research.google.com/github/koleshwargajanan/NeurosymbolicAIForFlowStateOfTheBrain/blob/main/NeurosymbolicAIForFlowStateOfTheBrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim

# -------------------------------
# Configuration
# -------------------------------
NUM_SUBJECTS = 32
NUM_ROIS = 68          # DKT atlas
TIME_STEPS = 120       # fMRI time points
IMG_SIZE = (64, 64, 64)
NUM_CLASSES = 2
LAMBDA_SYM = 0.2
ROI_EMBED_DIM = 16

# -------------------------------
# Dummy HCP-Style Inputs
# -------------------------------
mri_data = torch.randn(NUM_SUBJECTS, 1, *IMG_SIZE)
fmri_data = torch.randn(NUM_SUBJECTS, NUM_ROIS, TIME_STEPS)
labels = torch.randint(0, NUM_CLASSES, (NUM_SUBJECTS,))

# =========================================================
# CNN for Volumetric MRI (Structural Representation)
# =========================================================
class MRI_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(8, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d(1)
        )

    def forward(self, x):
        x = self.net(x)
        return x.view(x.size(0), -1)   # (subjects, 16)

# =========================================================
# ROI-wise RNN for fMRI Time-Series
# =========================================================
class ROI_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRU(
            input_size=TIME_STEPS,
            hidden_size=ROI_EMBED_DIM,
            batch_first=True
        )

    def forward(self, x):
        # x: (subjects, ROIs, time)
        outputs, _ = self.rnn(x)
        return outputs  # (subjects, ROIs, ROI_EMBED_DIM)

# =========================================================
# Neuro-Symbolic Model
# =========================================================
class NeuroSymbolicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.mri_model = MRI_CNN()
        self.roi_model = ROI_RNN()
        self.classifier = nn.Linear(16 + ROI_EMBED_DIM, NUM_CLASSES)

    def forward(self, mri, fmri):
        # Structural features
        mri_feat = self.mri_model(mri)                  # (subjects, 16)

        # ROI-wise functional features
        roi_feat = self.roi_model(fmri)                 # (subjects, ROIs, emb)

        # Pool ROIs for task prediction
        roi_pooled = roi_feat.mean(dim=1)               # (subjects, emb)

        # Joint representation
        z = torch.cat([mri_feat, roi_pooled], dim=1)

        logits = self.classifier(z)
        return logits, roi_feat

# =========================================================
# Symbolic Constraint: Hemispheric Symmetry
# =========================================================
LEFT_ROIS = list(range(0, 34))
RIGHT_ROIS = list(range(34, 68))

def symbolic_constraint(roi_features):
    # roi_features: (subjects, ROIs, embedding_dim)
    left_repr = roi_features[:, LEFT_ROIS, :].mean(dim=1)
    right_repr = roi_features[:, RIGHT_ROIS, :].mean(dim=1)
    return torch.mean((left_repr - right_repr) ** 2)

# =========================================================
# Model, Loss, Optimizer
# =========================================================
model = NeuroSymbolicModel()
task_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# =========================================================
# Training Step (Single Iteration)
# =========================================================
model.train()
optimizer.zero_grad()

outputs, roi_features = model(mri_data, fmri_data)

task_loss = task_loss_fn(outputs, labels)
symbolic_loss = symbolic_constraint(roi_features)

total_loss = task_loss + LAMBDA_SYM * symbolic_loss

total_loss.backward()
optimizer.step()

# =========================================================
# Outputs
# =========================================================
print(f"Task Loss: {task_loss.item():.4f}")
print(f"Symbolic Loss: {symbolic_loss.item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")

# Flow-state probability example
flow_probabilities = torch.softmax(outputs, dim=1)
print("Flow-state probability (first subject):", flow_probabilities[0])






Task Loss: 0.6835
Symbolic Loss: 0.0442
Total Loss: 0.6924
Flow-state probability (first subject): tensor([0.5744, 0.4256], grad_fn=<SelectBackward0>)
