In [47]:
import os
import re

tensor_dir_path = "./../tensor_dir"
tensor_list = sorted(os.listdir(tensor_dir_path))

# トライアルごとに辞書を初期化
tensor_dict = {}

for tensor in tensor_list:
    # ファイル名からトライアル番号を抽出
    match = re.search(r'trial(\d+)', tensor)
    if match:
        trial_number = f"trial{match.group(1)}"
        
        # トライアル番号ごとに辞書がなければ作成
        if trial_number not in tensor_dict:
            tensor_dict[trial_number] = {"eval": None, "bool": None, "mean": None, "std": None}
        
        # ファイルの種類に応じて適切なキーにパスを格納
        if "mean" in tensor:
            tensor_dict[trial_number]["mean"] = os.path.join(tensor_dir_path, tensor)
        elif "std" in tensor:
            tensor_dict[trial_number]["std"] = os.path.join(tensor_dir_path, tensor)
        elif "bool" in tensor:
            tensor_dict[trial_number]["bool"] = os.path.join(tensor_dir_path, tensor)
        elif "eval" in tensor:
            tensor_dict[trial_number]["eval"] = os.path.join(tensor_dir_path, tensor)

tensor_dict

{'trial1': {'eval': './../tensor_dir/tensor_eval_trial1.npy',
  'bool': './../tensor_dir/tensor_eval_bool_trial1.npy',
  'mean': './../tensor_dir/mean_tensor_trial1.npy',
  'std': './../tensor_dir/std_tensor_trial1.npy'},
 'trial10': {'eval': './../tensor_dir/tensor_eval_trial10.npy',
  'bool': './../tensor_dir/tensor_eval_bool_trial10.npy',
  'mean': './../tensor_dir/mean_tensor_trial10.npy',
  'std': './../tensor_dir/std_tensor_trial10.npy'},
 'trial11': {'eval': './../tensor_dir/tensor_eval_trial11.npy',
  'bool': './../tensor_dir/tensor_eval_bool_trial11.npy',
  'mean': './../tensor_dir/mean_tensor_trial11.npy',
  'std': './../tensor_dir/std_tensor_trial11.npy'},
 'trial12': {'eval': './../tensor_dir/tensor_eval_trial12.npy',
  'bool': './../tensor_dir/tensor_eval_bool_trial12.npy',
  'mean': './../tensor_dir/mean_tensor_trial12.npy',
  'std': './../tensor_dir/std_tensor_trial12.npy'},
 'trial13': {'eval': './../tensor_dir/tensor_eval_trial13.npy',
  'bool': './../tensor_dir/tensor

In [48]:
import numpy as np

# 指定されたトライアルの4つのテンソルを読み込む関数
def load_tensors(tensor_dict, trial_key):
    """
    指定されたトライアルの eval, bool, mean, std の4つのテンソルを読み込む。

    Parameters:
    - tensor_dict (dict): トライアルごとのテンソルファイルパスが格納された辞書。
    - trial_key (str): 読み込みたいトライアルのキー (例: "trial1")

    Returns:
    - dict: 各テンソルが格納された辞書。キーは "eval", "bool", "mean", "std"。
    """
    if trial_key not in tensor_dict:
        raise ValueError(f"{trial_key} is not in tensor_dict.")
    
    # 各テンソルを読み込み、結果を辞書に格納
    tensors = {}
    for tensor_type in ["eval", "bool", "mean", "std"]:
        file_path = tensor_dict[trial_key][tensor_type]
        if file_path is None:
            raise ValueError(f"File for {tensor_type} in {trial_key} is missing.")
        
        tensors[tensor_type] = np.load(file_path, allow_pickle=True)  # allow_pickle=True を追加

    return tensors

# 使用例
# 指定トライアルのテンソルを読み込む
trial_key = "trial30"  # 読み込みたいトライアル番号
tensors = load_tensors(tensor_dict, trial_key)

# 読み込んだテンソルを表示
print("Eval Tensor:", tensors["eval"])
print("Bool Tensor:", tensors["bool"])
print("Mean Tensor:", tensors["mean"])
print("Std Tensor:", tensors["std"])

Eval Tensor: [[[[[        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    ...
    [        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]]

   [[        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    ...
    [        nan         nan         nan ...         nan         nan
             nan]
    [        nan         nan         nan ...         nan         nan
             nan]
    [       

In [49]:
arr_eval = tensors["eval"]
display(arr_eval[~np.isnan(arr_eval)])
display(np.where(~np.isnan(arr_eval)))

array([ 7.84505249,  8.21789316, 11.8231656 ,  8.21789316,  6.31556639,
       10.27575727,  8.87475176,  8.39224449,  7.32461272,  8.39224449,
        9.63468429,  8.5594659 ,  8.94993407,  9.37428782,  8.30599919,
        4.21454286,  7.9415236 ,  9.5065458 ,  9.63468429,  8.47670863,
        6.97617905,  7.74622166,  5.85555296,  8.87475176,  8.5594659 ,
       11.18921788,  9.99673089, 11.10228943, 10.32934359,  9.87956326,
       10.68586306])

(array([ 0,  0,  2,  2,  2,  3,  3,  4,  5,  5,  5,  6,  6,  6,  6,  6,  6,
         6,  7,  7,  7,  7,  8,  8,  8,  8,  9,  9,  9, 10, 10]),
 array([ 4,  7,  1,  2,  7,  1,  3,  2,  3,  7,  9,  0,  1,  3,  6,  6,  6,
         9,  2,  5,  8,  9,  4,  8,  9, 10,  0,  5, 10,  3,  9]),
 array([ 4,  4, 10,  8,  3, 10, 10,  2,  8,  0,  2,  7,  6,  7,  2,  5,  6,
         2,  5,  2,  2,  5,  5,  1,  7, 10,  8, 10,  5,  2,  1]),
 array([ 7,  7, 10,  3,  6,  9,  8,  8,  8,  7,  7,  5,  0, 10,  2,  7,  3,
         6,  1, 10,  5,  4,  6,  5,  6,  2,  8,  1,  9,  1,  1]),
 array([ 5,  6,  0,  3,  5,  7,  6,  8,  3,  7,  0,  2,  6,  1,  1,  4,  0,
        10, 10,  5,  6,  8,  7,  8,  2,  9,  6,  0,  2,  7,  5]))

In [60]:
arr_eval[:, :, 4, 7, 5]

array([[       nan,        nan,        nan,        nan, 7.84505249,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,   

In [64]:
arr_std[:, :, 4, 7, 5]

array([[0.00408022, 0.00498716, 0.00434289, 0.00980046, 0.        ,
        0.01105436, 0.00392453, 0.00810743, 0.00310013, 0.01132657,
        0.01304129],
       [0.00390129, 0.00939088, 0.00702954, 0.01725062, 0.00868866,
        0.01993789, 0.00379478, 0.01644814, 0.0077972 , 0.02056172,
        0.02575646],
       [0.00622941, 0.00811971, 0.00862352, 0.01397781, 0.01257248,
        0.01204641, 0.00274109, 0.01097583, 0.00760585, 0.01006236,
        0.01942717],
       [0.00154688, 0.00887483, 0.00520561, 0.01193964, 0.00265268,
        0.00392566, 0.00365556, 0.00820683, 0.00794882, 0.00391521,
        0.00807614],
       [0.0049933 , 0.0078546 , 0.0061345 , 0.01188528, 0.00788773,
        0.0058992 , 0.00513172, 0.00908781, 0.0061843 , 0.00454342,
        0.0134193 ],
       [0.00824074, 0.0076824 , 0.01175858, 0.01540356, 0.01677175,
        0.016777  , 0.00421373, 0.01303576, 0.00574362, 0.01283498,
        0.02162047],
       [0.00272487, 0.00595689, 0.00381843, 0.01136266, 0.

In [50]:
arr_bool = tensors["bool"]
display(np.where(arr_bool))

(array([ 0,  0,  2,  2,  2,  3,  3,  4,  5,  5,  5,  6,  6,  6,  6,  6,  6,
         6,  7,  7,  7,  7,  8,  8,  8,  8,  9,  9,  9, 10, 10]),
 array([ 4,  7,  1,  2,  7,  1,  3,  2,  3,  7,  9,  0,  1,  3,  6,  6,  6,
         9,  2,  5,  8,  9,  4,  8,  9, 10,  0,  5, 10,  3,  9]),
 array([ 4,  4, 10,  8,  3, 10, 10,  2,  8,  0,  2,  7,  6,  7,  2,  5,  6,
         2,  5,  2,  2,  5,  5,  1,  7, 10,  8, 10,  5,  2,  1]),
 array([ 7,  7, 10,  3,  6,  9,  8,  8,  8,  7,  7,  5,  0, 10,  2,  7,  3,
         6,  1, 10,  5,  4,  6,  5,  6,  2,  8,  1,  9,  1,  1]),
 array([ 5,  6,  0,  3,  5,  7,  6,  8,  3,  7,  0,  2,  6,  1,  1,  4,  0,
        10, 10,  5,  6,  8,  7,  8,  2,  9,  6,  0,  2,  7,  5]))

In [62]:
arr_eval[:, :, 10, 8, 6]

array([[       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan, 8.87475176,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan,
               nan],
       [       nan,        nan,        nan,        nan,   

In [63]:
arr_std[:, :, 10, 8, 6]

array([[0.06298864, 0.11082878, 0.01766333, 0.15967201, 0.06435323,
        0.04918616, 0.00475629, 0.01522923, 0.11325   , 0.06011446,
        0.0699129 ],
       [0.06619207, 0.13419746, 0.01853944, 0.14958668, 0.06639509,
        0.06207913, 0.05262279, 0.03794816, 0.10911789, 0.07477624,
        0.07725937],
       [0.10106552, 0.19383721, 0.0100654 , 0.18748276, 0.09649893,
        0.09510563, 0.04863693, 0.04107262, 0.13431478, 0.09497372,
        0.10203685],
       [0.1500764 , 0.23628091, 0.02849347, 0.34565897, 0.15390733,
        0.11009627, 0.04400175, 0.02873953, 0.24897613, 0.1596702 ,
        0.15101395],
       [0.05905988, 0.07630622, 0.00440832, 0.11071927, 0.06352508,
        0.04590398, 0.02133998, 0.0168188 , 0.08011724, 0.06839509,
        0.05137032],
       [0.04391019, 0.09751091, 0.02974516, 0.14115465, 0.03386644,
        0.02775651, 0.01267581, 0.03645724, 0.10163507, 0.03550529,
        0.0521643 ],
       [0.03855169, 0.08680222, 0.01053036, 0.09942673, 0.

In [54]:
arr_mean

array([[[[[-5.09363354e-04, -4.47420488e-03, -6.58485167e-03, ...,
           -6.24379504e-03,  3.56457483e-04, -4.59425061e-03],
          [ 2.99937739e-03, -4.72907230e-03,  3.37757104e-03, ...,
            4.23826284e-03,  7.31329708e-03,  3.07592774e-03],
          [ 7.43286338e-03, -7.87907542e-03,  1.04810616e-03, ...,
            7.88763299e-03,  1.83469044e-02,  5.49880265e-03],
          ...,
          [-5.70936532e-03,  5.42880912e-03, -4.92941091e-03, ...,
           -1.72051782e-02, -2.30794768e-02, -8.44528087e-03],
          [-5.60241961e-03,  4.27884326e-03, -3.66093647e-03, ...,
           -8.37250171e-03, -1.57344500e-02, -6.60263204e-03],
          [-1.53827214e-04,  4.53962034e-04, -3.20195404e-03, ...,
           -3.48688771e-03,  3.48256895e-03, -7.53023316e-04]],

         [[ 2.05060591e-02,  1.15828013e-02,  8.00322633e-03, ...,
            1.92182360e-02, -8.84574229e-03,  1.70509992e-02],
          [-5.03606846e-03, -1.93482247e-03,  1.10867269e-03, ...,
      

In [51]:
arr_mean = tensors["mean"]
arr_mean[arr_bool]

array([ 7.84505249,  8.21789316, 11.8231656 ,  8.21789316,  6.31556639,
       10.27575727, -0.16470445,  8.39224449,  7.32461272,  8.39224449,
        9.63468429,  8.5594659 ,  8.94993407,  9.37428782,  8.30599919,
        4.21454286,  7.9415236 ,  9.5065458 ,  9.63468429,  8.47670863,
        6.97617905,  7.74622166,  5.85555296,  8.87475176,  8.5594659 ,
       11.18921788,  9.99673089, 11.10228943, 10.32934359,  9.87956326,
       10.68586306])

In [52]:
arr_std = tensors["std"]
arr_std[arr_bool]

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.34565897, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        ])

In [61]:
arr_std[:, :, 10, 8, 6]

array([[0.06298864, 0.11082878, 0.01766333, 0.15967201, 0.06435323,
        0.04918616, 0.00475629, 0.01522923, 0.11325   , 0.06011446,
        0.0699129 ],
       [0.06619207, 0.13419746, 0.01853944, 0.14958668, 0.06639509,
        0.06207913, 0.05262279, 0.03794816, 0.10911789, 0.07477624,
        0.07725937],
       [0.10106552, 0.19383721, 0.0100654 , 0.18748276, 0.09649893,
        0.09510563, 0.04863693, 0.04107262, 0.13431478, 0.09497372,
        0.10203685],
       [0.1500764 , 0.23628091, 0.02849347, 0.34565897, 0.15390733,
        0.11009627, 0.04400175, 0.02873953, 0.24897613, 0.1596702 ,
        0.15101395],
       [0.05905988, 0.07630622, 0.00440832, 0.11071927, 0.06352508,
        0.04590398, 0.02133998, 0.0168188 , 0.08011724, 0.06839509,
        0.05137032],
       [0.04391019, 0.09751091, 0.02974516, 0.14115465, 0.03386644,
        0.02775651, 0.01267581, 0.03645724, 0.10163507, 0.03550529,
        0.0521643 ],
       [0.03855169, 0.08680222, 0.01053036, 0.09942673, 0.

In [53]:
arr_eval[~np.isnan(arr_eval)] == arr_mean[arr_bool]

array([ True,  True,  True,  True,  True,  True, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True])