In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import os

data_path = '/home/aca10131kr/eeg-to-fmri/Datasets/01 eeg_fmri_data.h5'

# Define a simple CNN model for EEG to fMRI conversion
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv3d(64, 32, kernel_size=(3, 3, 3), stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0)
        self.conv2 = nn.Conv3d(32, 16, kernel_size=(3, 3, 3), stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 67 * 2 * 2, 512)  # Adjusted according to the output size after conv and pooling layers
        self.fc2 = nn.Linear(512, 64 * 64 * 30)
        self.dropout = nn.Dropout(p=0.5)
        self.batch_norm1 = nn.BatchNorm3d(32)
        self.batch_norm2 = nn.BatchNorm3d(16)

    def forward(self, x):
        print(f'Input shape: {x.shape}')
        x = self.pool(F.relu(self.batch_norm1(self.conv1(x))))
        print(f'After conv1 and pool: {x.shape}')
        x = self.pool(F.relu(self.batch_norm2(self.conv2(x))))
        print(f'After conv2 and pool: {x.shape}')
        x = x.view(-1, 16 * 67 * 2 * 2)  # Updated to match the correct flattened size
        print(f'After flattening: {x.shape}')
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        print(f'After fc1: {x.shape}')
        x = self.fc2(x)
        x = x.view(-1, 64, 64, 30)
        print(f'Output shape: {x.shape}')
        return x

# Load a subset of the data for testing
with h5py.File(data_path, 'r') as f:
    eeg_train = np.array(f['eeg_train'][:])
    fmri_train = np.array(f['fmri_train'][:])
    eeg_test = np.array(f['eeg_test'][:])
    fmri_test = np.array(f['fmri_test'][:])

# Define the dataset
class EEGfMRIDataset(Dataset):
    def __init__(self, eeg_data, fmri_data):
        self.eeg_data = eeg_data
        self.fmri_data = fmri_data

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        eeg = self.eeg_data[idx]
        fmri = self.fmri_data[idx]
        return eeg, fmri

# Create the dataset and dataloader
train_dataset = EEGfMRIDataset(torch.tensor(eeg_train, dtype=torch.float32), torch.tensor(fmri_train, dtype=torch.float32))
test_dataset = EEGfMRIDataset(torch.tensor(eeg_test, dtype=torch.float32), torch.tensor(fmri_test, dtype=torch.float32))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Instantiate the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.squeeze(-1)  # Remove the singleton dimension
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 10 == 9:
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 10:.3f}')
            running_loss = 0.0

print('Finished Training')

# Testing loop (optional)
model.eval()
test_loss = 0.0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels.squeeze(-1)  # Remove the singleton dimension
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

print(f'Test Loss: {test_loss / len(test_loader):.3f}')


FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '/home/aca10131kr/eeg-to-fmri/Datasets/01 eeg_fmri_data.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)