In [52]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image


class NeuronDataset(Dataset):
    def __init__(self, firing_rate_data, image_labels, image_paths):
        """
        Args:
            firing_rate_data: numpy array of shape (n_trials, n_neurons, n_timebins)
            image_labels: list/array of image class labels
            image_paths: list/array of image file paths
        """
        self.firing_rate_data = firing_rate_data
        self.image_labels = image_labels
        self.image_paths = image_paths
    
    def __len__(self):
        return len(self.firing_rate_data)
    
    def __getitem__(self, idx):
        # 神经元数据: (n_neurons, n_timebins) -> (n_timebins, n_neurons)
        neuro_tensor = torch.tensor(self.firing_rate_data[idx], dtype=torch.float32).T
        
        # 图像标签和路径
        image_label = self.image_labels[idx]
        image_path = self.image_paths[idx]
        
        try:
            img = Image.open(image_path).convert('RGB')
            # 手动 resize 和转换为 tensor
            img = img.resize((256, 256))
            img_array = np.array(img, dtype=np.float32) / 255.0  # 归一化到 [0,1]
            img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)  # HWC -> CHW
            img_tensor = img_tensor * 2 - 1  # 归一化到 [-1,1]
        except (FileNotFoundError, OSError) as e:
            # 创建空白图像
            img_tensor = torch.zeros(3, 256, 256)
            print(f"Warning: {image_path} not found or corrupted: {e}")
        
        return neuro_tensor, img_tensor, image_label


def load_firing_rate_data(npz_path, train = True):
    """加载 firing rate 数据并筛选有效试次"""
    data = np.load(npz_path, allow_pickle=True)
    firing_rate_list = []
    label = []
    # 获取所有 firing rate 键
    firing_rate_keys = [k for k in data.files if k.endswith('__firing_rate')]
    firing_rate_keys = sorted(firing_rate_keys)
    
    if train:
        train = 'train'
    else:
        train = 'test'
    
    for key in firing_rate_keys:
        if train in key:
            temp = data[key]

            if np.sum(temp[:100, :]) != 0:
                label.append(key.split("__")[0])
                firing_rate_list.append(temp[:, :10])
    
    return firing_rate_list, label


def load_image_data(csv_path):
    """加载图像数据"""
    df = pd.read_csv(csv_path)
    
    # 提取 class 和 local_path
    image_labels = df['class'].values
    image_paths = df['local_path'].values
    
    print(f"Loaded {len(image_labels)} images from {csv_path}")
    
    return image_labels, image_paths


def create_datasets(npz_path, csv_path, test_size=0.2, random_state=42):
    """创建训练和验证数据集"""
    
    # 加载数据
    firing_rate_data, label = load_firing_rate_data(npz_path)
    image_labels, image_paths = load_image_data(csv_path)
    
    # 分割训练和验证集
    train_indices, val_indices = train_test_split(
        range(len(label)), test_size=test_size, random_state=random_state
    )
    
    train_dataset = NeuronDataset(
        [firing_rate_data[i] for i in train_indices],
        [image_labels[i] for i in train_indices],
        [image_paths[i] for i in train_indices]
    )
    
    val_dataset = NeuronDataset(
        [firing_rate_data[i] for i in val_indices],
        [image_labels[i] for i in val_indices],
        [image_paths[i] for i in val_indices]
    )
    
    return train_dataset, val_dataset



In [None]:
npz_path = "/media/ubuntu/sda/Monkey/sorted_result/20240112/Block_1/sort/firing_rate_dict_B1_V1_instant_20ms.npz"
csv_path = "/media/ubuntu/sda/Monkey/scripts/train_image.csv"

train_dataset, val_dataset = create_datasets(npz_path, csv_path)

print(f"Train set: {len(train_dataset)}")
print(f"Val set: {len(val_dataset)}")


train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)

for neuro_tensor, img, label in train_loader:
    print(f"Neuro tensor shape: {neuro_tensor.shape}")
    print(f"Image shape: {img.shape}")
    print(f"Labels: {label}")
    break
    

Loaded 22248 images from /media/ubuntu/sda/Monkey/scripts/train_image.csv
Train set: 251
Val set: 63
Neuro tensor shape: torch.Size([2, 10, 204])
Image shape: torch.Size([2, 3, 256, 256])
Labels: ('album', 'airplane')
