In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import numpy as np
import os
import torch.nn.init as init

#FIRST VARIANT OF THE HUMAN VISUAL SYSTEM
class ReferenceVisualNetwork(nn.Module):
    """
    A simple feedforward CNN with two parallel pathways (Ventral and Dorsal)
    to serve as the baseline for the assignment.
    """
    def __init__(self):
        super(ReferenceVisualNetwork, self).__init__()
        
        # --- 1. Initial Feature Extraction (Shared V1/V2 - Early Visual Areas) ---
        # Input size for Fashion MNIST: (1, 28, 28)
        # These layers model the initial processing common to both pathways (V1, V2, Area hOc1, Area hOc2).
        self.v1_v2_conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) # (1, 28, 28) -> (32, 28, 28)
        self.v1_v2_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)                            # (32, 28, 28) -> (32, 14, 14)

        self.v1_v2_conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) # (32, 14, 14) -> (64, 14, 14)
        self.v1_v2_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)                            # (64, 14, 14) -> (64, 7, 7)
        
        # --- 2. Pathway Split and Processing ---
        
        # 2a. Ventral Pathway (The 'What' Pathway - hOc4d, hOc5)
        # Characterized by more layers/filters for hierarchical feature extraction.
        self.ventral_conv = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) # (64, 7, 7) -> (128, 7, 7)
        self.ventral_pool = nn.MaxPool2d(kernel_size=2, stride=2)                              # (128, 7, 7) -> (128, 3, 3)
        
        # Flattened size: 128 * 3 * 3 = 1152
        self.ventral_fc = nn.Linear(128 * 3 * 3, 256)
        
        # 2b. Dorsal Pathway (The 'Where/How' Pathway - hOc3d)
        # Characterized as being potentially shallower and focused on spatial/motion information.
        self.dorsal_conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)  # (64, 7, 7) -> (64, 7, 7)
        self.dorsal_pool = nn.MaxPool2d(kernel_size=3, stride=3)                                # (64, 7, 7) -> (64, 2, 2)
        
        # Flattened size: 64 * 2 * 2 = 256
        self.dorsal_fc = nn.Linear(64 * 2 * 2, 128)
        
        # --- 3. Final Classification Layer ---
        # Total concatenated features: 256 (Ventral) + 128 (Dorsal) = 384
        self.classifier = nn.Linear(256 + 128, 10) # 10 classes for Fashion MNIST

    def forward(self, x):
        
        # 1. Initial Feature Extraction (Shared V1/V2)
        x = self.v1_v2_pool1(F.relu(self.v1_v2_conv1(x)))
        shared_features = self.v1_v2_pool2(F.relu(self.v1_v2_conv2(x)))
        
        # 2. Pathway Split
        # 2a. Ventral Pathway
        v = self.ventral_pool(F.relu(self.ventral_conv(shared_features)))
        v = v.view(v.size(0), -1) # Flatten
        v_out = F.relu(self.ventral_fc(v))
        
        # 2b. Dorsal Pathway
        d = self.dorsal_pool(F.relu(self.dorsal_conv(shared_features)))
        d = d.view(d.size(0), -1) # Flatten
        d_out = F.relu(self.dorsal_fc(d))
        
        # 3. Concatenation and Classification
        combined = torch.cat([v_out, d_out], dim=1)
        output = self.classifier(combined)
        
        # The feature dictionary is needed later for RSA (Step 6).
        return output, {'v1_v2': shared_features, 'ventral_fc': v_out, 'dorsal_fc': d_out}

# --- Example of creating an instance (for testing/setup) ---
# model = ReferenceVisualNetwork()
# dummy_input = torch.randn(64, 1, 28, 28) # Batch size of 64
# output = model(dummy_input)
# print(f"Output shape: {output.shape}") # Should be (64, 10)

In [14]:
file_paths = {
    'V1_left': 'Human_brain_data/Area hOc1 (V1, 17, CalcS) left - cell density/tabular.csv',
    'V1_right': 'Human_brain_data/Area hOc1 (V1, 17, CalcS) right - cell density/tabular.csv',
    'V2_right': 'Human_brain_data/Area hOc2 (V2, 18) right - cell density/tabular.csv',
    'hOc3d_left': 'Human_brain_data/Area hOc3d (Cuneus) left - cell density/tabular.csv',
    'hOc4d_right': 'Human_brain_data/Area hOc4d (Cuneus) right - cell density/tabular.csv',
    'hOc5_right': 'Human_brain_data/Area hOc5 (LOC) right - cell density/tabular.csv',
}

# The column name from your readme file
DENSITY_COL = 'Segmented cell body density (detected cells / 0.1mm3)'

# Base filter count for the least dense region (F_base)
BASE_FILTER_COUNT = 64
# Scaling factor for shared layers (V1_conv1 is half the size of V1_conv2)
SHARED_CONV1_RATIO = 0.5


#CALCULATE AVERAGE DENSITIES

avg_densities = {}
all_densities = []

for region, path in file_paths.items():
    if os.path.exists(path):
        try:
            df = pd.read_csv(path)
            # Calculate the mean density across the 100 measurements
            mean_density = df[DENSITY_COL].mean()
            avg_densities[region] = mean_density
            all_densities.append(mean_density)
        except Exception as e:
            print(f"Error reading {path}: {e}")
            avg_densities[region] = np.nan
    else:
        print(f"File not found for {region}. Please check path: {path}")
        avg_densities[region] = np.nan
        
# Filter out NaNs if files were missing
valid_densities = [d for d in all_densities if not np.isnan(d)]

if not valid_densities:
    raise ValueError("No valid density data was loaded. Please fix file paths.")

# Find the minimum density (D_min) for scaling
MIN_DENSITY = min(valid_densities)

print(f"Average Densities (cells/0.1mm3): {avg_densities}")
print(f"Minimum Density (D_min): {MIN_DENSITY:.2f}")


#MAP DENSITY TO ARCHITECTURAL COMPONENTS\

# 1. Shared V1/V2 (Averaging V1 left, V1 right, V2 right)
shared_v1_v2_avg_density = np.mean([
    avg_densities['V1_left'], avg_densities['V1_right'], avg_densities['V2_right']
])
# 2. Dorsal Pathway (hOc3d left)
dorsal_avg_density = avg_densities['hOc3d_left']
# 3. Ventral Pathway (Averaging hOc4d right, hOc5 right)
ventral_avg_density = np.mean([
    avg_densities['hOc4d_right'], avg_densities['hOc5_right']
])

# Calculate scaling factors
scaling_factors = {
    'Shared_V1_V2': shared_v1_v2_avg_density / MIN_DENSITY,
    'Dorsal': dorsal_avg_density / MIN_DENSITY,
    'Ventral': ventral_avg_density / MIN_DENSITY,
}

#\CALCULATE FINAL FILTER COUNTS\

# Calculate the raw size for V1_V2_CONV2_CHANNELS before rounding
conv2_size_raw = scaling_factors['Shared_V1_V2'] * BASE_FILTER_COUNT

DENSITY_SCALED_FILTERS = {
    # Shared V1/V2 Layers
    # Uses the raw calculated size
    'V1_V2_CONV2_CHANNELS': int(np.round(conv2_size_raw)),
    # Uses the raw calculated size multiplied by 0.5 ratio
    'V1_V2_CONV1_CHANNELS': int(np.round(conv2_size_raw * SHARED_CONV1_RATIO)),
    
    # Dorsal Pathway
    'DORSAL_CONV_CHANNELS': int(np.round(scaling_factors['Dorsal'] * BASE_FILTER_COUNT)),
    'DORSAL_FC_SIZE': int(np.round(scaling_factors['Dorsal'] * (BASE_FILTER_COUNT * 2))), 
    
    # Ventral Pathway
    'VENTRAL_CONV_CHANNELS': int(np.round(scaling_factors['Ventral'] * BASE_FILTER_COUNT)),
    'VENTRAL_FC_SIZE': int(np.round(scaling_factors['Ventral'] * (BASE_FILTER_COUNT * 4))),
}

# Ensure all channel counts are at least 1 and round to nearest power of 2 for CNN efficiency
def round_to_power_of_2(n):
    if n <= 0: return 2
    return int(2**np.round(np.log2(n)))

for key in DENSITY_SCALED_FILTERS:
    if 'CHANNELS' in key:
        # Round convolutional channel counts to powers of 2 (32, 64, 128, etc.)
        DENSITY_SCALED_FILTERS[key] = round_to_power_of_2(DENSITY_SCALED_FILTERS[key])
    elif 'SIZE' in key:
        # Keep FC sizes as general integers, rounding to nearest 32
        DENSITY_SCALED_FILTERS[key] = max(32, int(np.round(DENSITY_SCALED_FILTERS[key] / 32) * 32))

# Recalculate V1_V2_CONV1 based on the FINAL, rounded V1_V2_CONV2 value
DENSITY_SCALED_FILTERS['V1_V2_CONV1_CHANNELS'] = round_to_power_of_2(
    DENSITY_SCALED_FILTERS['V1_V2_CONV2_CHANNELS'] // 2
)


print("\n--- FINAL SCALED FILTER COUNTS FOR CONSTRAINED MODEL ---")
print(DENSITY_SCALED_FILTERS)
print("----------------------------------------------------------")

Average Densities (cells/0.1mm3): {'V1_left': 90.42985758514597, 'V1_right': 90.42985758514597, 'V2_right': 76.51516244077028, 'hOc3d_left': 72.07978610755333, 'hOc4d_right': 61.15705207380484, 'hOc5_right': 63.61690940972587}
Minimum Density (D_min): 61.16

--- FINAL SCALED FILTER COUNTS FOR CONSTRAINED MODEL ---
{'V1_V2_CONV2_CHANNELS': 64, 'V1_V2_CONV1_CHANNELS': 32, 'DORSAL_CONV_CHANNELS': 64, 'DORSAL_FC_SIZE': 160, 'VENTRAL_CONV_CHANNELS': 64, 'VENTRAL_FC_SIZE': 256}
----------------------------------------------------------


In [15]:
V1_CONV1_C = 32    # V1/V2 CONV1 channels
V1_CONV2_C = 64    # V1/V2 CONV2 channels

D_CONV_C = 64       # Dorsal Conv channels (hOc3d)
D_FC_S = 160       # Dorsal FC size

V_CONV_C = 64     # Ventral Conv channels (hOc4d/hOc5)
V_FC_S = 256      # Ventral FC size

#SECOND VARIANT OF THE HUMAN VISUAL SYSTEM
class ConstrainedVisualNetwork(nn.Module):
    """
    A network constrained by human neuroanatomical data (cell density and SC).
    """
    def __init__(self):
        super(ConstrainedVisualNetwork, self).__init__()
        
        # --- 1. Shared V1/V2 (Size constrained by V1/V2 density) ---
        self.v1_v2_conv1 = nn.Conv2d(1, V1_CONV1_C, kernel_size=3, padding=1)
        self.v1_v2_pool1 = nn.MaxPool2d(2, 2)

        self.v1_v2_conv2 = nn.Conv2d(V1_CONV1_C, V1_CONV2_C, kernel_size=3, padding=1)
        self.v1_v2_pool2 = nn.MaxPool2d(2, 2) # Output size: (V1_CONV2_C, 7, 7)
        
        # --- 2. Ventral Pathway (hOc4d, hOc5 - Size constrained by density) ---
        self.ventral_conv = nn.Conv2d(V1_CONV2_C, V_CONV_C, kernel_size=3, padding=1)
        self.ventral_pool = nn.MaxPool2d(2, 2)
        
        # Flattened size: V_CONV_C * 3 * 3
        VENTRAL_FLATTEN_SIZE = V_CONV_C * 3 * 3
        self.ventral_fc = nn.Linear(VENTRAL_FLATTEN_SIZE, V_FC_S)
        
        # --- 3. Dorsal Pathway (hOc3d - Size constrained by density) ---
        self.dorsal_conv = nn.Conv2d(V1_CONV2_C, D_CONV_C, kernel_size=3, padding=1)
        self.dorsal_pool = nn.MaxPool2d(3, 3)
        
        # Flattened size: D_CONV_C * 2 * 2
        DORSAL_FLATTEN_SIZE = D_CONV_C * 2 * 2
        self.dorsal_fc = nn.Linear(DORSAL_FLATTEN_SIZE, D_FC_S)
        
        # --- 4. Structural Connectivity Module (SC) ---
        # This layer transforms V1_CONV1_C features (from V1_conv1 output) to match 
        # the spatial/channel dimensions of the Dorsal input (V1_CONV2_C, 7, 7).
        # Adjust the input/output channels and stride based on your SC data (e.g., V1->hOc3d).
        self.v1_to_dorsal_skip = nn.Conv2d(V1_CONV1_C, V1_CONV2_C, kernel_size=1, stride=4, bias=False) 
        
        # 5. Final Classifier
        self.classifier = nn.Linear(V_FC_S + D_FC_S, 10) 
        
        self._initialize_weights()

    def _initialize_weights(self):
        # Initializes weights for stability, matching the template provided earlier
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01)
                init.constant_(m.bias, 0)

    def forward(self, x):
        
        # 1. Initial V1 processing (Source for SC skip connection)
        v1_features = F.relu(self.v1_v2_conv1(x)) # (B, V1_CONV1_C, 28, 28)
        x = self.v1_v2_pool1(v1_features)         # (B, V1_CONV1_C, 14, 14)
        
        # 2. V2 Processing (Shared features)
        shared_features = self.v1_v2_pool2(F.relu(self.v1_v2_conv2(x))) # (B, V1_CONV2_C, 7, 7)
        
        # 3. Ventral Pathway
        v = self.ventral_pool(F.relu(self.ventral_conv(shared_features)))
        ventral_conv_out = v 
        
        # --- CRITICAL: STRUCTURAL CONNECTIVITY (SC) IMPLEMENTATION ---
        
        # Prepare SC skip: Transforms V1 features (V1_CONV1_C, 28, 28) 
        # to match the shape of the dorsal input (V1_CONV2_C, 7, 7).
        v1_skip = self.v1_to_dorsal_skip(v1_features) 

        # 4. Dorsal Pathway with SC Constraint
        # Input to dorsal pathway = Shared V2 features + Direct V1 Skip (Residual connection)
        # 
        # !!! CHECK YOUR CONNECTIVITY DATA AND ADJUST THE ADDITION BELOW !!!
        # If your data shows a strong V1 -> hOc3d connection, keep this:
        d_input = shared_features + v1_skip
        
        d = self.dorsal_pool(F.relu(self.dorsal_conv(d_input))) 
        
        # Flatten and FC for both pathways
        v_fc_in = ventral_conv_out.view(ventral_conv_out.size(0), -1)
        v_out = F.relu(self.ventral_fc(v_fc_in))
        
        d_fc_in = d.view(d.size(0), -1)
        d_out = F.relu(self.dorsal_fc(d_fc_in))
        
        # 5. Concatenation and Classification
        combined = torch.cat([v_out, d_out], dim=1)
        output = self.classifier(combined)
        
        return output, {'v1_v2': shared_features, 'ventral_fc': v_out, 'dorsal_fc': d_out}

In [16]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import time

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
LEARNING_RATE = 0.001
N_EPOCHS = 10 

def load_fashion_mnist():
    """Loads and prepares the Fashion MNIST dataset."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Download and load training dataset
    train_dataset = torchvision.datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform
    )

    # Split training set into training and validation
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_data, val_data = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    
    # Download and load test dataset
    test_dataset = torchvision.datasets.FashionMNIST(
        root='./data', train=False, download=True, transform=transform
    )
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, test_loader

def train_model(model, train_loader, val_loader, epochs=N_EPOCHS, lr=LEARNING_RATE, device=DEVICE):
    """
    Trains the given model and tracks loss/accuracy.
    """
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    
    print(f"Starting training for {type(model).__name__} on {device}...")
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # The forward pass now returns output AND features
            outputs, _ = model(inputs) 
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        avg_train_loss = running_loss / len(train_loader)
        
        # Validation
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs}: Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    end_time = time.time()
    print(f"Training finished in {(end_time - start_time):.2f} seconds.")
    return model, history

def evaluate_model(model, data_loader, criterion, device):
    """Evaluates the model on the given data loader."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = correct / total
    return avg_loss, accuracy

print("Training utility functions defined.")

Training utility functions defined.


In [17]:

# 1. Setup and Data Loading
train_loader, val_loader, test_loader = load_fashion_mnist() # This function was provided earlier

# 2. Train the Reference Model
print("--- Training Reference Network ---")
ref_model = ReferenceVisualNetwork()
ref_model_trained, ref_history = train_model(ref_model, train_loader, val_loader)
torch.save(ref_model_trained.state_dict(), 'reference_model.pth') # Save the model

# 3. Train the Constrained Model
print("\n--- Training Constrained Network ---")
constrained_model = ConstrainedVisualNetwork() 
constrained_model_trained, const_history = train_model(constrained_model, train_loader, val_loader)
torch.save(constrained_model_trained.state_dict(), 'constrained_model.pth') # Save the model

print("\nModels trained and saved. Ready for Step 6: RSA.")

100%|██████████| 26.4M/26.4M [00:05<00:00, 5.10MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 1.10MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.59MB/s]
100%|██████████| 5.15k/5.15k [00:00<?, ?B/s]


--- Training Reference Network ---
Starting training for ReferenceVisualNetwork on cpu...


ValueError: too many values to unpack (expected 2)