In [None]:
import random
from typing import Dict, Iterable, List, Optional, Tuple

import h5py
import numpy as np
import torch
import torch.nn as nn
import webdataset as wds
from torch.utils.data import Dataset

import re

from __future__ import print_function
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pixyz.distributions import Normal, Bernoulli, Laplace, ProductOfNormal
from pixyz.models import Model
from pixyz.losses import KullbackLeibler
from pixyz.losses import LogProb
from pixyz.losses import Expectation as E
from pixyz.losses import Parameter
from pixyz.utils import print_latex
from pixyz.utils import epsilon

seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [3]:
data_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/data/mindeye2_nsd/"
subj_list = [1, 2, 5, 7] # [1, 2, 3, 4, 5, 6, 7, 8] 
all_subj_num_ssessions_list = [40, 40, 32, 30, 40, 32, 40, 30]
all_subj_lambdas = [1, 1, 1, 1, 1, 1, 1, 1]

batch_size = 128
seed = 42

In [None]:
def my_split_by_node(urls): return urls

def load_all_subj_data(data_path, subj_list, data_range, subj_num_ssessions_list=None):
      all_subj_data = {}

      for s in subj_list:
          if data_range == "no-shared1000":
              data_url = f"{data_path}/wds/subj0{s}/train/" "{0.." + f"{subj_num_ssessions_list[s-1]-1}" + "}.tar"
          elif data_range == "shared1000":
              data_url = f"{data_path}/wds/subj0{s}/new_test/" + "0.tar"
          elif data_range == "all":
              data_url = (
                  f"{data_path}/wds/subj0{s}/train/" "{0.." + f"{subj_num_ssessions_list[s-1]-1}" + "}.tar::"
                  f"{data_path}/wds/subj0{s}/new_test/0.tar"
              )
          else:
              raise ValueError(f"Unsupported data_range: {data_range}")

          subj_iter_data = wds.WebDataset(data_url, resampled=False, shardshuffle=False, nodesplitter=my_split_by_node) \
                              .decode("torch") \
                              .rename(behav="behav.npy",
                                      past_behav="past_behav.npy",
                                      future_behav="future_behav.npy",
                                      olds_behav="olds_behav.npy") \
                              .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])

          # global_trial (behav[0, 5]) で昇順ソート
          subj_data = list(subj_iter_data)
          subj_data = sorted(subj_data, key=lambda sample: sample[0][0, 5].item())
          all_subj_data[f"subj0{s}"] = subj_data

      print("Loaded all subj data\n")
      
      return all_subj_data

def load_all_subj_voxels(data_path, subj_list):
    all_subj_voxels = {}
    all_subj_num_voxels = {}

    for subj in subj_list:
        s = f"{int(subj):02d}"
        f = h5py.File(f"{data_path}/betas_all_subj{s}_fp32_renorm.hdf5", "r")
        betas = torch.Tensor(f["betas"][:]).to("cpu")
        all_subj_voxels[f"subj{s}"] = betas
        all_subj_num_voxels[f"subj{s}"] = betas[0].shape[-1]

    print("Loaded all subj voxels\n")
    
    return all_subj_voxels, all_subj_num_voxels

all_subj_shared1000_data = load_all_subj_data(data_path, subj_list, data_range="shared1000")
# all_subj_all_data = load_all_subj_data(data_path, subj_list, data_range="all", subj_num_ssessions_list=all_subj_num_ssessions_list)

all_subj_voxels, all_subj_num_voxels = load_all_subj_voxels(data_path, subj_list)

NameError: name 'h5py' is not defined

In [None]:
from fmri_reconstruction_with_dmvae.mindeye2_nsd.datasets.align import align_subject_trials

aligned_all_subj_shared1000_data = align_subject_trials(all_subj_shared1000_data, anchor_subject="subj01")
# aligned_all_subj_data = align_subject_trials(all_subj_all_data, anchor_subject="subj01")

In [63]:
def split_aligned_data(
      aligned_data: List[Dict[str, object]],
      subj_list: Iterable[int],
      train_occurrence_max: int = 2,
  ) -> Tuple[
      List[Tuple[int, Dict[str, Optional[int]]]],
      List[Tuple[int, Dict[str, Optional[int]]]],
  ]:
      """
      aligned_data: align_subject_trials の結果（各要素に cocoidx / occurrence / subjXX フィールドがある）
      subj_list: 例 [1, 2, 3] のような被験者 ID リスト
      戻り値: (train_data, test_data)
        - 各 data[i] は (cocoidx, {"subj01": global_trial or None, ...})
      """
      train_data: List[Tuple[int, Dict[str, Optional[int]]]] = []
      test_data: List[Tuple[int, Dict[str, Optional[int]]]] = []

      for data in aligned_data:
          cocoidx = int(data["cocoidx"])
          global_trials: Dict[str, Optional[int]] = {}

          for subj in subj_list:
              s = f"subj{int(subj):02d}"
              subject_info = data.get(s)
              global_trials[s] = (
                  subject_info["global_trial"] if subject_info is not None else None
              )

          sample = (cocoidx, global_trials)
          occurrence = int(data["occurrence"])

          if occurrence <= train_occurrence_max:
              train_data.append(sample)
          else:
              test_data.append(sample)

      return train_data, test_data

In [64]:
train_data, test_data = split_aligned_data(aligned_all_subj_shared1000_data, subj_list, train_occurrence_max=2)
# train_data, test_data = split_aligned_data(aligned_all_subj_data, subj_list, train_occurrence_max=2)

In [65]:
class StimulusTrialMappingDataset(Dataset):
      """
      (cocoidx, {"subjXX": global_trial or None, ...}) のリストをそのまま扱う Dataset。
      """
      def __init__(self, data: List[Tuple[int, Dict[str, Optional[int]]]]):
          self.data = data

      def __len__(self) -> int:
          return len(self.data)

      def __getitem__(self, idx: int) -> Tuple[int, Dict[str, Optional[int]]]:
          return self.data[idx]

In [66]:
train_dataset = StimulusTrialMappingDataset(train_data)
test_dataset = StimulusTrialMappingDataset(test_data)

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

### Model

In [72]:
class Encoder(nn.Module):
    def __init__(self, x_dim, hidden_dim=1024):
        super().__init__()
        self.fc = nn.Linear(x_dim, hidden_dim)

    def forward(self, x):
        return F.relu(self.fc(x))

class Decoder(nn.Module):
    def __init__(self, x_dim, hidden_dim=1024):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, x_dim)

    def forward(self, x):
        return self.fc(x)

In [73]:
class Inference(Normal):
    def __init__(self, enc, var, cond_var, z_dim, hidden_dim=1024):
        super(Inference, self).__init__(var=var, cond_var=cond_var, name="q")

        self.enc = enc
        self.mu = nn.Linear(hidden_dim, z_dim)
        self.logvar = nn.Linear(hidden_dim, z_dim)

    def forward(self, **x):
        x = torch.cat([x[_cond_var] for _cond_var in self.cond_var], dim=1)
        h = self.enc(x)

        return {"loc": self.mu(h), "scale": F.softplus(self.logvar(h)) + epsilon()}

class Generator(Laplace):
    def __init__(self, dec, var, cond_var, zp_dim, zs_dim, hidden_dim=1024):
        super(Generator, self).__init__(var=var, cond_var=cond_var, name="p")

        self.fc = nn.Linear(zp_dim + zs_dim, hidden_dim)
        self.dec = dec

    def forward(self, **z):
        z = torch.cat([z[_cond_var] for _cond_var in self.cond_var], dim=1)
        h = F.relu(self.fc(z))

        return {"loc": self.dec(h), "scale": torch.tensor(1.0).to(z.device)}

In [74]:
hidden_dim = 1024

enc_dict = {}
dec_dict = {}

for subj in subj_list:
    s = f"{int(subj):02d}"
    
    enc_dict[f"subj{s}_zp"] = Encoder(x_dim=all_subj_num_voxels[f"subj{s}"], hidden_dim=hidden_dim)
    enc_dict[f"subj{s}_zs"] = Encoder(x_dim=all_subj_num_voxels[f"subj{s}"], hidden_dim=hidden_dim)
    dec_dict[f"subj{s}"] = Decoder(x_dim=all_subj_num_voxels[f"subj{s}"], hidden_dim=hidden_dim)

In [75]:
zp_dim = 256
zs_dim = 768

dist_dict = {}
for subj in subj_list:
    s = f"{int(subj):02d}"
    
    # q_φ(zp_subj | x_subj): q_φ1(zp1 | x1), q_φ2(zp2 | x2), ...
    dist_dict[f"q_zp{s}__x{s}"] = Inference(enc=enc_dict[f"subj{s}_zp"], var=[f"zp{s}"], cond_var=[f"x{s}"], z_dim=zp_dim, hidden_dim=hidden_dim).to(device)
    
    # q_φ(zs_subj | x_subj): q_φ1(zs1 | x1), q_φ2(zs2 | x2), ...
    dist_dict[f"q_zs__x{s}"] = Inference(enc=enc_dict[f"subj{s}_zs"], var=["zs"], cond_var=[f"x{s}"], z_dim=zs_dim, hidden_dim=hidden_dim).to(device)

    # prior(zp_subj): prior(zp1), prior(zp2), ...
    dist_dict[f"prior_zp{s}"] = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=[f"zp{s}"], features_shape=[zp_dim], name="p_{prior}").to(device)

    # p_θ(x_subj | zp_subj, zs): p_θ1(x1 | zp1, zs), p_θ2(x2 | zp2, zs), ...
    dist_dict[f"p_x{s}__zp{s}_zs"] = Generator(dec=dec_dict[f"subj{s}"], var=[f"x{s}"], cond_var=[f"zp{s}", "zs"], zp_dim=zp_dim, zs_dim=zs_dim, hidden_dim=hidden_dim).to(device)


# q_φ(zs | x)
dist_dict["q_zs__x"] = ProductOfNormal([dist_dict[f"q_zs__x{int(subj):02d}"] for subj in subj_list], name="q").to(device)

# prior(zs)
dist_dict["prior_zs"] = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["zs"], features_shape=[zs_dim], name="p_{prior}").to(device)

In [76]:
loss = 0

for subj_target in subj_list:
    loss_subj_target = 0

    s_t = f"{int(subj_target):02d}"

    lambda_subj = Parameter(f"lambda_{s_t}")

    joint_recon_loss = - lambda_subj * E(dist_dict[f"q_zp{s_t}__x{s_t}"], E(dist_dict["q_zs__x"], LogProb(dist_dict[f"p_x{s_t}__zp{s_t}_zs"])))
    joint_recon_kl = KullbackLeibler(dist_dict[f"q_zp{s_t}__x{s_t}"], dist_dict[f"prior_zp{s_t}"]) + KullbackLeibler(dist_dict[f"q_zs__x"], dist_dict[f"prior_zs"])
    loss_subj_target += joint_recon_loss + joint_recon_kl

    for subj_source in subj_list:
        s_s = f"{int(subj_source):02d}"

        if s_t == s_s:
            self_recon_loss = - lambda_subj * E(dist_dict[f"q_zp{s_t}__x{s_t}"], E(dist_dict[f"q_zs__x{s_s}"], LogProb(dist_dict[f"p_x{s_t}__zp{s_t}_zs"])))
            self_recon_kl = KullbackLeibler(dist_dict[f"q_zp{s_t}__x{s_t}"], dist_dict[f"prior_zp{s_t}"]) + KullbackLeibler(dist_dict[f"q_zs__x{s_s}"], dist_dict["prior_zs"])
            loss_subj_target += self_recon_loss + self_recon_kl
        else:
            cross_recon_loss = - lambda_subj * E(dist_dict[f"q_zp{s_t}__x{s_t}"], E(dist_dict[f"q_zs__x{s_s}"], LogProb(dist_dict[f"p_x{s_t}__zp{s_t}_zs"])))
            cross_recon_kl = KullbackLeibler(dist_dict[f"q_zp{s_t}__x{s_t}"], dist_dict[f"prior_zp{s_t}"]) + KullbackLeibler(dist_dict[f"q_zs__x{s_s}"], dist_dict["prior_zs"])
            loss_subj_target += cross_recon_loss + cross_recon_kl
            
    loss += loss_subj_target.mean()

In [77]:
model = Model(loss=loss, distributions=list(dist_dict.values()), optimizer=optim.Adam, optimizer_params={"lr": 1e-3})

### Training

In [78]:
def get_recon_dict_batch(input_dict):
    z_dict = {}
    recon_dict_batch = {}

    with torch.no_grad():
        for subj in subj_list:
            s = f"{int(subj):02d}"
            z_dict[f"zp{s}"] = dist_dict[f"q_zp{s}__x{s}"].sample(input_dict, return_all=False) 
            z_dict[f"zs__x{s}"] = dist_dict[f"q_zs__x{s}"].sample(input_dict, return_all=False)
        
        z_dict["zs__x"] = dist_dict["q_zs__x"].sample(input_dict, return_all=False)

        for subj_target in subj_list:
            s_t = f"{int(subj_target):02d}"

            recon_dict_batch[f"joint_recon_x{s_t}"] = dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x"]).cpu()
            
            for subj_resource in subj_list:
                s_s = f"{int(subj_resource):02d}"

                if s_t == s_s:
                    recon_dict_batch[f"self_recon_x{s_t}"] = dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()
                else:
                    recon_dict_batch[f"cross_recon_x{s_t}__x{s_s}"] = dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()

    return recon_dict_batch

def calc_cosine_dict_batch(input_dict, recon_dict_batch):
    cosine_dict_batch = {}

    for key in recon_dict_batch:
        s_t = re.search(r'recon_x(\d{2})', key).group(1)
        cosine = F.cosine_similarity(input_dict[f"x{s_t}"].cpu(), recon_dict_batch[key])
        cosine_dict_batch[key] = cosine

    return cosine_dict_batch

def calc_pearson_dict_batch(input_dict, recon_dict_batch):
    pearson_dict_batch = {}

    for key in recon_dict_batch:
        s_t = re.search(r'recon_x(\d{2})', key).group(1)
        input_centered = input_dict[f"x{s_t}"] - input_dict[f"x{s_t}"].mean(dim=1, keepdim=True)
        recon_centered = recon_dict_batch[key] - recon_dict_batch[key].mean(dim=1, keepdim=True)
        pearson = F.cosine_similarity(input_centered.cpu(), recon_centered, dim=1)
        pearson_dict_batch[key] = pearson
    
    return pearson_dict_batch

def update_metrics(metrics_dict, input_dict, recon_dict_batch, n_batches):
      cosine_dict = calc_cosine_dict_batch(input_dict, recon_dict_batch)
      pearson_dict = calc_pearson_dict_batch(input_dict, recon_dict_batch)
      
      for key in recon_dict_batch:
          entry = metrics_dict.setdefault(key, {"cosine_mean": 0.0, "pearson_mean": 0.0})
          entry["cosine_mean"] += cosine_dict[key].mean() / n_batches
          entry["pearson_mean"] += pearson_dict[key].mean() / n_batches

      return metrics_dict

In [79]:
def train(epoch):
    train_loss = 0
    n_batches = len(train_dl)
    train_metrics_dict = {}

    for data in tqdm(train_dl):
        input_dict = {}
        lambda_dict = {}
        
        for subj in subj_list:
            s = f"{int(subj):02d}"

            input_dict[f"x{s}"] = all_subj_voxels[f"subj{s}"][data[1][f"subj{s}"]].to(device)
            lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

        loss = model.train(input_dict | lambda_dict)
        train_loss += loss

        if epoch == epochs:
            recon_dict_batch = get_recon_dict_batch(input_dict)
            train_metrics_dict = update_metrics(train_metrics_dict, input_dict, recon_dict_batch, n_batches)

    train_loss = train_loss / n_batches
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss, train_metrics_dict

def test(epoch):
    test_loss = 0
    n_batches = len(test_dl)
    test_metrics_dict = {}

    with torch.no_grad():
        for data in tqdm(test_dl):
            input_dict = {}
            lambda_dict = {}
            
            for subj in subj_list:
                s = f"{int(subj):02d}"
                input_dict[f"x{s}"] = all_subj_voxels[f"subj{s}"][data[1][f"subj{s}"]].to(device)
                lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

            loss = model.test(input_dict | lambda_dict)
            test_loss += loss

            if epoch == epochs:
                recon_dict_batch = get_recon_dict_batch(input_dict)
                test_metrics_dict = update_metrics(test_metrics_dict, input_dict, recon_dict_batch, n_batches)

    test_loss = test_loss / n_batches
    print('Epoch: {} Test loss: {:.4f}'.format(epoch, test_loss))
    return test_loss, test_metrics_dict

In [80]:
epochs = 20

for epoch in range(1, epochs + 1):
    train_loss, train_metrics_dict = train(epoch)
    test_loss, test_metrics_dict = test(epoch)

100%|██████████| 15/15 [00:01<00:00,  9.55it/s]


Epoch: 1 Train loss: 419556.8125


100%|██████████| 7/7 [00:00<00:00, 15.61it/s]


Epoch: 1 Test loss: 419191.8125


100%|██████████| 15/15 [00:01<00:00,  9.71it/s]


Epoch: 2 Train loss: 409474.9062


100%|██████████| 7/7 [00:00<00:00, 15.71it/s]


Epoch: 2 Test loss: 407633.9062


100%|██████████| 15/15 [00:01<00:00,  9.67it/s]


Epoch: 3 Train loss: 401342.7188


100%|██████████| 7/7 [00:00<00:00, 16.02it/s]


Epoch: 3 Test loss: 401230.9688


100%|██████████| 15/15 [00:01<00:00,  9.69it/s]


Epoch: 4 Train loss: 395673.8750


100%|██████████| 7/7 [00:00<00:00, 15.99it/s]


Epoch: 4 Test loss: 396717.6250


100%|██████████| 15/15 [00:01<00:00,  9.66it/s]


Epoch: 5 Train loss: 391783.7188


100%|██████████| 7/7 [00:00<00:00, 15.80it/s]


Epoch: 5 Test loss: 393597.1875


100%|██████████| 15/15 [00:01<00:00,  9.41it/s]


Epoch: 6 Train loss: 387763.0000


100%|██████████| 7/7 [00:00<00:00, 15.73it/s]


Epoch: 6 Test loss: 391010.4062


100%|██████████| 15/15 [00:01<00:00,  9.48it/s]


Epoch: 7 Train loss: 384698.5000


100%|██████████| 7/7 [00:00<00:00, 15.76it/s]


Epoch: 7 Test loss: 388814.3750


100%|██████████| 15/15 [00:01<00:00,  9.71it/s]


Epoch: 8 Train loss: 381465.0625


100%|██████████| 7/7 [00:00<00:00, 15.92it/s]


Epoch: 8 Test loss: 386842.3750


100%|██████████| 15/15 [00:01<00:00,  9.67it/s]


Epoch: 9 Train loss: 378511.0625


100%|██████████| 7/7 [00:00<00:00, 15.92it/s]


Epoch: 9 Test loss: 385386.3438


100%|██████████| 15/15 [00:01<00:00,  9.66it/s]


Epoch: 10 Train loss: 375759.8438


100%|██████████| 7/7 [00:00<00:00, 16.01it/s]


Epoch: 10 Test loss: 383991.0625


100%|██████████| 15/15 [00:01<00:00,  9.71it/s]


Epoch: 11 Train loss: 372976.7812


100%|██████████| 7/7 [00:00<00:00, 15.56it/s]


Epoch: 11 Test loss: 383115.9062


100%|██████████| 15/15 [00:01<00:00,  9.68it/s]


Epoch: 12 Train loss: 370451.4062


100%|██████████| 7/7 [00:00<00:00, 15.92it/s]


Epoch: 12 Test loss: 382224.7188


100%|██████████| 15/15 [00:01<00:00,  9.70it/s]


Epoch: 13 Train loss: 368262.3438


100%|██████████| 7/7 [00:00<00:00, 15.89it/s]


Epoch: 13 Test loss: 381596.4688


100%|██████████| 15/15 [00:01<00:00,  9.59it/s]


Epoch: 14 Train loss: 366185.6875


100%|██████████| 7/7 [00:00<00:00, 15.74it/s]


Epoch: 14 Test loss: 381224.7188


100%|██████████| 15/15 [00:01<00:00,  9.67it/s]


Epoch: 15 Train loss: 364111.9375


100%|██████████| 7/7 [00:00<00:00, 15.73it/s]


Epoch: 15 Test loss: 380674.5312


100%|██████████| 15/15 [00:01<00:00,  9.33it/s]


Epoch: 16 Train loss: 362168.2812


100%|██████████| 7/7 [00:00<00:00, 15.84it/s]


Epoch: 16 Test loss: 380833.5938


100%|██████████| 15/15 [00:01<00:00,  9.59it/s]


Epoch: 17 Train loss: 360406.5000


100%|██████████| 7/7 [00:00<00:00, 15.53it/s]


Epoch: 17 Test loss: 380644.3438


100%|██████████| 15/15 [00:01<00:00,  9.22it/s]


Epoch: 18 Train loss: 358579.9375


100%|██████████| 7/7 [00:00<00:00, 15.78it/s]


Epoch: 18 Test loss: 381936.4062


100%|██████████| 15/15 [00:01<00:00,  9.68it/s]


Epoch: 19 Train loss: 357381.1562


100%|██████████| 7/7 [00:00<00:00, 15.82it/s]


Epoch: 19 Test loss: 381923.0938


100%|██████████| 15/15 [00:02<00:00,  5.60it/s]


Epoch: 20 Train loss: 356072.4375


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]

Epoch: 20 Test loss: 382432.3125





In [81]:
train_metrics_dict

{'joint_recon_x01': {'cosine_mean': tensor(0.6697),
  'pearson_mean': tensor(0.6453)},
 'self_recon_x01': {'cosine_mean': tensor(0.6645),
  'pearson_mean': tensor(0.6397)},
 'cross_recon_x01__x02': {'cosine_mean': tensor(0.6599),
  'pearson_mean': tensor(0.6350)},
 'cross_recon_x01__x05': {'cosine_mean': tensor(0.6644),
  'pearson_mean': tensor(0.6396)},
 'cross_recon_x01__x07': {'cosine_mean': tensor(0.6651),
  'pearson_mean': tensor(0.6403)},
 'joint_recon_x02': {'cosine_mean': tensor(0.6841),
  'pearson_mean': tensor(0.6673)},
 'cross_recon_x02__x01': {'cosine_mean': tensor(0.6793),
  'pearson_mean': tensor(0.6623)},
 'self_recon_x02': {'cosine_mean': tensor(0.6756),
  'pearson_mean': tensor(0.6581)},
 'cross_recon_x02__x05': {'cosine_mean': tensor(0.6793),
  'pearson_mean': tensor(0.6623)},
 'cross_recon_x02__x07': {'cosine_mean': tensor(0.6799),
  'pearson_mean': tensor(0.6629)},
 'joint_recon_x05': {'cosine_mean': tensor(0.7236),
  'pearson_mean': tensor(0.7034)},
 'cross_recon_x

In [82]:
test_metrics_dict

{'joint_recon_x01': {'cosine_mean': tensor(0.5428),
  'pearson_mean': tensor(0.5063)},
 'self_recon_x01': {'cosine_mean': tensor(0.5389),
  'pearson_mean': tensor(0.5020)},
 'cross_recon_x01__x02': {'cosine_mean': tensor(0.5340),
  'pearson_mean': tensor(0.4976)},
 'cross_recon_x01__x05': {'cosine_mean': tensor(0.5379),
  'pearson_mean': tensor(0.5009)},
 'cross_recon_x01__x07': {'cosine_mean': tensor(0.5383),
  'pearson_mean': tensor(0.5016)},
 'joint_recon_x02': {'cosine_mean': tensor(0.5585),
  'pearson_mean': tensor(0.5250)},
 'cross_recon_x02__x01': {'cosine_mean': tensor(0.5543),
  'pearson_mean': tensor(0.5205)},
 'self_recon_x02': {'cosine_mean': tensor(0.5517),
  'pearson_mean': tensor(0.5179)},
 'cross_recon_x02__x05': {'cosine_mean': tensor(0.5538),
  'pearson_mean': tensor(0.5201)},
 'cross_recon_x02__x07': {'cosine_mean': tensor(0.5543),
  'pearson_mean': tensor(0.5207)},
 'joint_recon_x05': {'cosine_mean': tensor(0.6109),
  'pearson_mean': tensor(0.5771)},
 'cross_recon_x

### Notes

In [83]:
# データを合体させて、同じ刺激に対する複数被験者の脳活動をまとめて扱えるようにする（被験者1と被験者2なら問題ない）
# 同じ数の刺激がある被験者同士ならそのままの順番で合体させたらいいけど、被験者ごとにtestデータの数も異なるので、共通のところを抽出する必要がある
# 使うのはbehavの、0であるcoco_idx（画像の指定に使う）と、5であるglobal_trial（fmriのbetaの特定に使う）

# Technical notesが参考になるかも https://cvnlab.slite.page/p/h_T_2Djeid/Technical-notes
# shared1000ですら欠損があるので、この時点で弱教師あり学習の手法が必要になる可能性

# sub1とsub2だけでいいからDMVAEの実装をやってみてその妥当性を最低限検証する方向性がいいかも

# その後、複数被験者だったり、欠損対応するためにどうしたらいいかを考えていく
# （DMVAEだけでも欠損対応はおそらく可能だが、欠損の方が多いデータにうまく対応するようにはデザインされてないはずなので、やはり、弱教師あり学習の手法が必要）

In [84]:
# behavior = {
# - "cocoidx": int(behav.iloc[jj]['73KID'])-1, #0
# - "subject": subject,                        #1
# - "session": int(behav.iloc[jj]['SESSION']), #2
# - "run": int(behav.iloc[jj]['RUN']),         #3
# - "trial": int(behav.iloc[jj]['TRIAL']),     #4
# - "global_trial": int(i * (tar + 1)),        #5
# - "time": int(behav.iloc[jj]['TIME']),       #6
# - "isold": int(behav.iloc[jj]['ISOLD']),     #7
# - "iscorrect": iscorrect,                    #8
# - "rt": rt, # 0 = no RT                      #9
# - "changemind": changemind,                  #10
# - "isoldcurrent": isoldcurrent,              #11
# - "iscorrectcurrent": iscorrectcurrent,      #12
# - "total1": total1,                          #13
# - "total2": total2,                          #14
# - "button": button,                          #15
# - "shared1000": is_shared1000,               #16
# }

# 0 = COCO IDX (73K) (used to index coco_images_224_float16.hdf5)
# 1 = SUBJECT
# 2 = SESSION
# 3 = RUN
# 4 = TRIAL
# 5 = GLOBAL TRIAL (used to index betas_all_subj_fp32_renorm.hdf5)
# 6 = TIME
# 7 = ISOLD
# 8 = ISCORRECT
# 9 = RT
# 10 = CHANGEMIND
# 11 = ISOLDCURRENT
# 12 = ISCORRECTCURRENT
# 13 = TOTAL1
# 14 = TOTAL2
# 15 = BUTTON
# 16 = IS_SHARED1000

In [85]:
# def my_split_by_node(urls): return urls

# train_url = f"{data_path}/wds/subj01/new_test/" + "0.tar"

# train_data = wds.WebDataset(train_url, resampled=False, nodesplitter=my_split_by_node) \
#                                 .decode("torch") \
#                                 .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy") \
#                                 .to_tuple("behav", "past_behav", "future_behav", "olds_behav")

# samples = list(train_data)

# # !tar -tf /home/acg17270jl/projects/MindEyeV2/dataset/wds/subj01/new_test/0.tar | head -n 10

# import tarfile
# import io

# tar_path = "/home/acg17270jl/projects/MindEyeV2/dataset/wds/subj01/new_test/0.tar"

# target_index = "000000000" # 0から始まる9桁の数字
# target_index_int = int(target_index)
# target = f"sample{target_index}.behav.npy"  # ファイル名は tar -tf で確認したものに置き換え

# with tarfile.open(tar_path, "r") as tar:
#     fileobj = tar.extractfile(target)
#     data = io.BytesIO(fileobj.read())
#     arr = np.load(data, allow_pickle=False)
#     print(f"tarの{target_index_int}番目のサンプルのcocoidx: {arr[0][0]}")
#     print(f"train_dataの{target_index_int}番目のサンプルのcocoidx: {samples[target_index_int][0][0][0]}")
#     if arr[0][0] == samples[target_index_int][0][0][0]:
#         print("tarとtrain_dataの順番は一致しています")
#     else:
#         print("tarとtrain_dataの順番は一致していません")

# ローカルに保存されている.tarファイルの中身の順番（stimulus index）と、WebDatasetで読み込んだデータの順番（stimulus index）は同じ
# tarの0番目のサンプルのcocoidx: 46002.0
# train_dataの0番目のサンプルのcocoidx: 46002.0
# tarとtrain_dataの順番は一致しています