In [2]:
import pickle
import numpy as np

In [23]:
import numpy as np
import pickle

# Set random seed for reproducibility
np.random.seed(1337)

# Define paths
image_path = '/home/maria/Documents/HarvardData/Images'
session_ims = pickle.load(open('/home/maria/Documents/HarvardData/processed_sessions_v3/Bo220226/session_images.p','rb'))

# Construct full image paths
image_paths = np.array([f"{image_path}/{im.split('/')[2]}" for im in session_ims])

# Total number of images
n_total = len(session_ims)
print(f"Total number of images: {n_total}")

# Define the number of training samples
n_train = 1000

# Ensure that n_train does not exceed n_total
if n_train > n_total:
    raise ValueError("Number of training samples exceeds the total number of available images.")

# Randomly select unique training indices without replacement
training_path_inds = np.random.choice(n_total, size=n_train, replace=False)
training_paths = image_paths[training_path_inds]

# Determine test indices as those not in training_path_inds
test_inds = np.setdiff1d(np.arange(n_total), training_path_inds)
test_paths = image_paths[test_inds]

# Print shapes to verify
print(f"Training indices shape: {training_path_inds.shape}")  # Should be (1000,)
print(f"Number of test samples: {len(test_paths)}")           # Should be n_total - 1000

# Optional: Verify no overlap between training and test sets
overlap = np.intersect1d(training_paths, test_paths)
print(f"Number of overlapping images between training and test sets: {len(overlap)}")  # Should be 0


Total number of images: 1250
Training indices shape: (1000,)
Number of test samples: 250
Number of overlapping images between training and test sets: 0


In [26]:
import os
import numpy as np
import pickle
from PIL import Image
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

# ================================
# 1. Setup and Image Path Handling
# ================================

# Set random seed for reproducibility
np.random.seed(1337)
torch.manual_seed(1337)

# Define base image directory and possible extensions
base_image_dir = '/home/maria/Documents/HarvardData/Images'
possible_extensions = ['.jpg', '.JPG', '.png', '.PNG']

# Load session image identifiers
session_ims = pickle.load(open('/home/maria/Documents/HarvardData/processed_sessions_v3/Bo220226/session_images.p','rb'))

# Function to find the correct image path with existing extension
def find_image_path(session_ims, base_image_dir, possible_extensions):
    image_paths = []
    for p in session_ims:
        # Extract the filename after 'OOD_monkey_data/Images/'
        if 'OOD_monkey_data/Images/' in p:
            filename = p.split('OOD_monkey_data/Images/')[-1]
        else:
            filename = os.path.basename(p)  # Fallback to basename
        
        base_name, ext = os.path.splitext(filename)
        
        # Try each possible extension until a file is found
        found_path = None
        for ext_candidate in possible_extensions:
            candidate_path = os.path.join(base_image_dir, base_name + ext_candidate)
            if os.path.exists(candidate_path):
                found_path = candidate_path
                break
        
        if found_path is None:
            # If no matching file is found, warn and skip
            print(f"Warning: No matching file found for base name: {base_name}")
            # Optionally, append a placeholder or handle as needed
            # image_paths.append('/path/to/placeholder.jpg') # Uncomment if using placeholders
        else:
            image_paths.append(found_path)
    
    return np.array(image_paths)

# Get the array of valid image paths
image_paths = find_image_path(session_ims, base_image_dir, possible_extensions)
print(f"Total valid images found: {len(image_paths)}")

# ================================
# 2. Splitting the Dataset
# ================================

# Define the number of training samples
n_train = 1000

# Ensure that n_train does not exceed the total number of images
if n_train > len(image_paths):
    raise ValueError("Number of training samples exceeds the total number of available images.")

# Generate unique training indices without replacement
training_path_inds = np.random.choice(len(image_paths), size=n_train, replace=False)
training_paths = image_paths[training_path_inds]

# Determine test indices as the complement of training indices
test_inds = np.setdiff1d(np.arange(len(image_paths)), training_path_inds)
test_paths = image_paths[test_inds]

print(f"Number of training samples: {len(training_paths)}")  # Should be 1000
print(f"Number of test samples: {len(test_paths)}")          # Should be len(image_paths) - 1000

# Optional: Verify no overlap
overlap = np.intersect1d(training_paths, test_paths)
print(f"Number of overlapping images between training and test sets: {len(overlap)}")  # Should be 0

# ================================
# 3. Data Augmentation
# ================================

# Define your data augmentation pipeline
augmentation_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)  # Adjust based on image channels
])

# Function to load and augment images
def load_and_augment(image_paths, num_augmentations=5):
    augmented_X = []
    for path in image_paths:
        try:
            image = Image.open(path).convert('RGB')  # Ensure 3 channels
            image_np = np.array(image)
            for _ in range(num_augmentations):
                augmented_image = augmentation_transform(image_np)
                augmented_X.append(augmented_image.numpy())
        except Exception as e:
            print(f"Error loading image {path}: {e}")
            # Optionally, handle corrupted images or skip
    return np.array(augmented_X)

# Apply augmentations to training images
num_augmentations = 5  # Number of augmentations per image
augmented_X = load_and_augment(training_paths, num_augmentations=num_augmentations)
print(f"Number of augmented images: {augmented_X.shape[0]}")  # Should be 1000 * 5 = 5000

# ================================
# 4. Preparing Neural Response Data
# ================================

# Load neural response data
# Replace the following lines with your actual data loading mechanism
# For example, if neural data is stored in a pickle file:
# neural_data = pickle.load(open('/path/to/neural_data.p','rb'))

# Placeholder neural data
# Assuming y_train has shape [n_train, n_neurons]
# Replace this with your actual neural data
n_neurons = 64  # Example number of neurons
y_train = np.random.rand(n_train, n_neurons)  # Replace with actual data

# Similarly, load neural responses for test set
# y_test = ...  # Shape: [n_test, n_neurons]
y_test = np.random.rand(len(test_paths), n_neurons)  # Replace with actual data

# For augmented data, assuming neural responses are invariant
augmented_y = np.repeat(y_train, repeats=num_augmentations, axis=0)  # Shape: [n_train * num_augmentations, n_neurons]
print(f"Shape of augmented_y: {augmented_y.shape}")  # Should be [5000, 64]

# ================================
# 5. Combining Original and Augmented Data
# ================================

# Convert original training images to tensors and normalize similarly
def preprocess_image_original(path):
    try:
        image = Image.open(path).convert('RGB')
        image_np = np.array(image)
        image_tensor = augmentation_transform(image_np)
        return image_tensor.numpy()
    except Exception as e:
        print(f"Error loading image {path}: {e}")
        return np.zeros((3, 224, 224))  # Example placeholder, adjust as needed

# Load and preprocess original training images
original_X = []
for path in training_paths:
    original_X.append(preprocess_image_original(path))
original_X = np.array(original_X)
print(f"Number of original training images after preprocessing: {original_X.shape[0]}")  # Should be 1000

# Combine original and augmented data
X_combined = np.vstack((original_X, augmented_X))  # Shape: [6000, C, H, W]
y_combined = np.vstack((y_train, augmented_y))    # Shape: [6000, n_neurons]

print(f"Combined training data shape: {X_combined.shape}, Combined labels shape: {y_combined.shape}")

# ================================
# 6. Creating PyTorch Datasets and DataLoaders
# ================================

# Convert combined data to PyTorch tensors
X_combined_tensor = torch.tensor(X_combined, dtype=torch.float32)
y_combined_tensor = torch.tensor(y_combined, dtype=torch.float32)

# Create a TensorDataset
dataset = TensorDataset(X_combined_tensor, y_combined_tensor)

# Assign higher weights to original data to prioritize labeled samples
weights = torch.ones(len(dataset))
weights[:n_train] = 2.0  # Higher weight for original data

# Create a WeightedRandomSampler
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

# Define DataLoader
batch_size = 64
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

# ================================
# 7. Defining the Mixture of Experts (MoE) Model
# ================================

class MixtureOfExperts(nn.Module):
    def __init__(self, input_size, num_experts, hidden_size, output_size, dropout_rate=0.3):
        super(MixtureOfExperts, self).__init__()
        self.num_experts = num_experts
        # Define experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_size, output_size)
            )
            for _ in range(num_experts)
        ])
        # Define gating network
        self.gate = nn.Sequential(
            nn.Linear(input_size, num_experts),
            nn.Softmax(dim=1)  # Outputs weights for each expert
        )

    def forward(self, x):
        # Flatten images if necessary (assuming x has shape [batch_size, C, H, W])
        x = x.view(x.size(0), -1)
        # Get gating weights
        gating_weights = self.gate(x)  # Shape: [batch_size, num_experts]
        # Get expert outputs
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # Shape: [batch_size, num_experts, output_size]
        # Weighted sum of expert outputs
        out = torch.sum(gating_weights.unsqueeze(2) * expert_outputs, dim=1)  # Shape: [batch_size, output_size]
        return out

# ================================
# 8. Training the Mixture of Experts (Student) Model
# ================================

# Define model parameters
input_size = X_combined_tensor.shape[1] * X_combined_tensor.shape[2] * X_combined_tensor.shape[3]  # C * H * W
num_experts = 3
hidden_size = 128
output_size = y_combined_tensor.shape[1]  # Number of neurons
dropout_rate = 0.3
learning_rate = 1e-3
num_epochs = 100

# Initialize the Mixture of Experts model
model = MixtureOfExperts(input_size, num_experts, hidden_size, output_size, dropout_rate)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

# Early Stopping parameters
best_val_loss = float('inf')
patience = 10
counter = 0

# Prepare validation data (e.g., 10% of the training set)
val_size = int(0.1 * n_train)
X_val_paths = training_paths[-val_size:]
y_val = y_train[-val_size:]
# Preprocess validation images
def preprocess_image_val(path):
    try:
        image = Image.open(path).convert('RGB')
        image_np = np.array(image)
        image_tensor = augmentation_transform(image_np)
        return image_tensor.numpy()
    except Exception as e:
        print(f"Error loading image {path}: {e}")
        return np.zeros((3, 224, 224))  # Example placeholder, adjust as needed

# Load and preprocess validation images
X_val = []
for path in X_val_paths:
    X_val.append(preprocess_image_val(path))
X_val = np.array(X_val)
y_val = y_train[-val_size:]

# Convert validation data to tensors
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32)

# Create validation DataLoader
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
X_val_tensor = X_val_tensor.to(device)
y_val_tensor = y_val_tensor.to(device)

# Adjust DataLoader to move batches to the appropriate device
def get_device_loader(loader, device):
    for batch in loader:
        yield [item.to(device) for item in batch]

# Training loop with Early Stopping
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for X_batch, y_batch in loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        
        # Forward pass
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Validation
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for X_val_batch, y_val_batch in val_loader:
            X_val_batch = X_val_batch.to(device)
            y_val_batch = y_val_batch.to(device)
            
            val_outputs = model(X_val_batch)
            loss = criterion(val_outputs, y_val_batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
    
    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_moe_student_model.pth')  # Save the best model
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    
    # Print progress every 10 epochs
    if (epoch + 1) % 10 == 0 or epoch == 0:
        avg_epoch_loss = epoch_loss / len(loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_epoch_loss:.4f}, Validation Loss: {val_loss:.4f}")

# Load the best model
model.load_state_dict(torch.load('best_moe_student_model.pth'))
model.to('cpu')  # Move back to CPU for evaluation if needed

# ================================
# 9. Evaluating the Student Model on the Test Set
# ================================

# Preprocess test images
def preprocess_image_test(path):
    try:
        image = Image.open(path).convert('RGB')
        image_np = np.array(image)
        image_tensor = augmentation_transform(image_np)
        return image_tensor.numpy()
    except Exception as e:
        print(f"Error loading image {path}: {e}")
        return np.zeros((3, 224, 224))  # Example placeholder, adjust as needed

# Load and preprocess test images
X_test = []
for path in test_paths:
    X_test.append(preprocess_image_test(path))
X_test = np.array(X_test)
print(f"Number of test images after preprocessing: {X_test.shape[0]}")

# Convert test data to tensors
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Move test data to device
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)

# Evaluate the model on the test set
model.eval()
with torch.no_grad():
    y_test_pred = model(X_test_tensor).cpu().numpy()
    y_test_true = y_test_tensor.cpu().numpy()

# Calculate R² scores for each neuron
variance_explained_test = r2_score(y_test_true, y_test_pred, multioutput="raw_values")
mean_r2_test = np.mean(variance_explained_test)

print(f"Mean R² on Test Set: {mean_r2_test:.4f}")

# Optionally, analyze individual neuron performance
for idx, r2 in enumerate(variance_explained_test):
    print(f"Neuron {idx+1}: R² = {r2:.4f}")


Total valid images found: 1250
Number of training samples: 1000
Number of test samples: 250
Number of overlapping images between training and test sets: 0


  return np.array(augmented_X)


ValueError: could not broadcast input array from shape (3,256,256) into shape (3,)