# This notebook uses a Convolutional Neural Net (CNN) to predict Grid Cell alignment in real time using VR trajectory as labels.

In [19]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import os
import glob

In [3]:
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold, train_test_split

## 1. Load and preprocess the input NIFTI images

In [20]:
# Set paths and parameters
data_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\preprocessed\s05'
behavioral_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05'
time_interval = 1.5  # Time interval between images in seconds

# Custom Dataset for loading NIFTI images and labels
class BrainDataset(Dataset):
    def __init__(self, run_dirs, behavioral_dir, time_interval):
        self.images = []
        self.labels = []

        # Loop through each run directory
        for run_dir in run_dirs:
            # Load NIFTI images
            nii_files = sorted(os.listdir(os.path.join(run_dir, 'masked_outputs')))
            run_images = [nib.load(os.path.join(run_dir, 'masked_outputs', f)).get_fdata() for f in nii_files]

            # Load and process behavioral data
            for run_dir in run_dirs:
                # Extract the run identifier (e.g., 'run001')
                run_base = os.path.basename(run_dir).split('_')[0]
                
                # Use glob to search for files containing the run identifier in the behavioral directory
                search_pattern = os.path.join(behavioral_dir, f"*{run_base}*.tsv")
                behavioral_files = glob.glob(search_pattern)
                
                # Check if any matching files were found
                if behavioral_files:
                    behavioral_file = behavioral_files[0]  # Assuming there's only one matching file per run
                    run_behavioral_data = pd.read_csv(behavioral_file, sep='\t')
                    # Continue processing the loaded data...
                    print(f"Loaded behavioral data from: {behavioral_file}")
                else:
                    print(f"No behavioral data file found for {run_base}. Searched pattern: {search_pattern}")
        
            # Extract orientation values and synchronize with images
            orientations = run_behavioral_data['Orientation'].values
            timestamps = run_behavioral_data['Time'].values
            time_points = np.arange(0, time_interval * len(run_images), time_interval)
            labels = np.interp(time_points, timestamps, orientations)

            # Append images and labels
            self.images.extend(run_images)
            self.labels.extend(labels)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx], dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, label

In [21]:
run_name = 'run001_8'
run_base = run_name.split('_')[0]  # Splits at '_' and takes the first part
print(run_base)


run001


In [22]:
# CNN Model Definition
class GridCellCNN(nn.Module):
    def __init__(self):
        super(GridCellCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 12 * 12 * 12, 128)  # Update based on input size
        self.fc2 = nn.Linear(128, 1)  # Regression for orientation

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 12 * 12 * 12)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [23]:
# Train and evaluate the CNN
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.view(-1), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')


## 2. Load and Preprocess the VR trajectory data:

In [24]:
# Load and preprocess data for subject
run_dirs = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if d.startswith('run')]
dataset = BrainDataset(run_dirs, behavioral_dir, time_interval)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

Loaded behavioral data from: C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05\s05_indoor_NoBarrier_run001.tsv
Loaded behavioral data from: C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05\s05_indoor_NoBarrier_run002.tsv
Loaded behavioral data from: C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05\s05_indoor_NoBarrier_run003.tsv
Loaded behavioral data from: C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05\s05_indoor_NoBarrier_run004.tsv
Loaded behavioral data from: C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05\s05_indoor_NoBarrier_r

## 3. Set up cross-validation: 

## 4. Build and train the CNN:

In [25]:
# Initialize CNN model, loss function, and optimizer
model = GridCellCNN()
criterion = nn.MSELoss()  # Mean squared error for regression
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [26]:
# Train the model
train_model(model, train_loader, criterion, optimizer)

RuntimeError: shape '[-1, 110592]' is invalid for input of size 737280

In [None]:
# External cross-validation (for future use on all subjects)
def cross_validate_subjects(subject_ids, data_dir):
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    for train_idx, test_idx in kf.split(subject_ids):
        train_subjects = [subject_ids[i] for i in train_idx]
        test_subjects = [subject_ids[i] for i in test_idx]
        print(f"Train on {train_subjects}, Test on {test_subjects}")
        # Implement training/testing for each fold