In [91]:
# reload meta_env_agent to update changes
import sys
from tqdm import tqdm
import numpy as np
import importlib
importlib.reload(sys.modules['meta_env_agent'])
from meta_env_agent import A2CAgent, DynamicBandit,  BanditEnv, Config, collect_episode_data, fit_pca_on_task_difference
from sklearn.decomposition import PCA



In [92]:
config = Config()

a2c_agent = A2CAgent(config.INPUT_SIZE, config.HIDDEN_SIZE, config.OUTPUT_SIZE,
                        lr=config.LEARNING_RATE,
                        gamma=config.GAMMA,
                        value_loss_coef=config.VALUE_LOSS_COEF,
                        entropy_coef=config.ENTROPY_COEF)

In [93]:
a2c_agent.load("a2c_agent_10000epi.pth")  # 모델 로드
a2c_agent.eval()  # 평가 모드로 설정


  """


A2CAgent(
  (lstm_cell): LSTMCell(2, 32)
  (action_head): Linear(in_features=32, out_features=2, bias=True)
  (value_head): Linear(in_features=32, out_features=1, bias=True)
)

In [94]:

env = DynamicBandit(training=False)


In [None]:
NUM_TEST_EPISODES = 500 # 분석의 신뢰도를 높이기 위해 여러 에피소드 실행
all_episode_data = []

# 분석할 특정 p0 값들을 번갈아 가며 환경 생성
# p0_pairs_to_test = [[0.05, 0.95], [0.1, 0.9], [0.15, 0.85], [0.2, 0.8],
#                     [0.25, 0.75], [0.3, 0.7], [0.35, 0.65], [0.4, 0.6],
#                     [0.45, 0.55], [0.5, 0.5], [0.55, 0.45], [0.6, 0.4],
#                     [0.65, 0.35], [0.7, 0.3], [0.75, 0.25], [0.8, 0.2],
#                     [0.85, 0.15], [0.9, 0.1], [0.95, 0.05]]
p0_pairs_to_test =[[0.4 0.6],[0.6,0.4]]
print(f"\n{NUM_TEST_EPISODES}개의 테스트 에피소드에서 데이터 수집을 시작합니다...")
env = DynamicBandit(training=False)

for i in tqdm(range(NUM_TEST_EPISODES), desc="Data Collection"):
    p0_pair = p0_pairs_to_test[i % len(p0_pairs_to_test)] # p0 쌍을 번갈아 선택
    env.reset(p0_pair=p0_pair)
    print(p0_pair)
    data = collect_episode_data(a2c_agent, env,  episode_length=config.EPISODE_LENGTH)
    all_episode_data.append(data)

# 3. PCA 모델 훈련
# 모든 hidden state를 합쳐서 PCA 모델을 학습시킴
all_hidden_states = np.vstack([d['hidden_states'] for d in all_episode_data])
pca = PCA(n_components=config.PCA_COMPONENTS)
pca.fit(all_hidden_states)
print(f"\nPCA 분석 완료. 설명된 분산: {pca.explained_variance_ratio_}")



500개의 테스트 에피소드에서 데이터 수집을 시작합니다...


Data Collection:   1%|          | 3/500 [00:00<00:21, 23.37it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:   2%|▏         | 9/500 [00:00<00:20, 24.37it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:   2%|▏         | 12/500 [00:00<00:20, 23.60it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:   4%|▎         | 18/500 [00:00<00:20, 23.93it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:   5%|▍         | 24/500 [00:00<00:19, 24.47it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]


Data Collection:   5%|▌         | 27/500 [00:01<00:19, 24.50it/s]

[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]


Data Collection:   7%|▋         | 33/500 [00:01<00:18, 24.69it/s]

[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]


Data Collection:   8%|▊         | 39/500 [00:01<00:18, 25.02it/s]

[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection:   9%|▉         | 45/500 [00:01<00:18, 25.12it/s]

[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  10%|█         | 51/500 [00:02<00:18, 24.85it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  11%|█▏        | 57/500 [00:02<00:18, 24.53it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:  12%|█▏        | 60/500 [00:02<00:17, 24.71it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]


Data Collection:  13%|█▎        | 66/500 [00:02<00:17, 24.41it/s]

[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]


Data Collection:  14%|█▍        | 72/500 [00:02<00:17, 24.69it/s]

[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]


Data Collection:  16%|█▌        | 78/500 [00:03<00:16, 24.93it/s]

[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection:  17%|█▋        | 84/500 [00:03<00:16, 24.59it/s]

[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]


Data Collection:  18%|█▊        | 90/500 [00:03<00:16, 24.89it/s]

[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  19%|█▊        | 93/500 [00:03<00:16, 24.77it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:  20%|█▉        | 99/500 [00:04<00:16, 24.87it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]


Data Collection:  21%|██        | 105/500 [00:04<00:15, 24.85it/s]

[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  22%|██▏       | 111/500 [00:04<00:15, 24.90it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]


Data Collection:  23%|██▎       | 117/500 [00:04<00:15, 24.22it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:  25%|██▍       | 123/500 [00:05<00:15, 24.53it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  26%|██▌       | 129/500 [00:05<00:14, 24.84it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]


Data Collection:  27%|██▋       | 135/500 [00:05<00:14, 24.92it/s]

[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]


Data Collection:  28%|██▊       | 141/500 [00:05<00:14, 25.23it/s]

[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]


Data Collection:  29%|██▉       | 147/500 [00:05<00:13, 25.29it/s]

[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  31%|███       | 153/500 [00:06<00:13, 25.10it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]


Data Collection:  31%|███       | 156/500 [00:06<00:13, 24.67it/s]

[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]


Data Collection:  32%|███▏      | 162/500 [00:06<00:13, 24.86it/s]

[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  34%|███▎      | 168/500 [00:06<00:13, 24.97it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]


Data Collection:  35%|███▍      | 174/500 [00:07<00:13, 24.46it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:  36%|███▌      | 180/500 [00:07<00:12, 24.69it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  37%|███▋      | 186/500 [00:07<00:12, 24.93it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]


Data Collection:  38%|███▊      | 192/500 [00:07<00:12, 25.18it/s]

[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]


Data Collection:  39%|███▉      | 195/500 [00:07<00:12, 24.67it/s]

[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]


Data Collection:  40%|████      | 201/500 [00:08<00:12, 24.54it/s]

[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]


Data Collection:  41%|████▏     | 207/500 [00:08<00:11, 24.82it/s]

[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]


Data Collection:  43%|████▎     | 213/500 [00:08<00:11, 24.57it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:  43%|████▎     | 216/500 [00:08<00:11, 24.42it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  44%|████▍     | 222/500 [00:08<00:11, 24.68it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]


Data Collection:  46%|████▌     | 228/500 [00:09<00:11, 24.12it/s]

[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]


Data Collection:  47%|████▋     | 234/500 [00:09<00:10, 24.57it/s]

[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]


Data Collection:  48%|████▊     | 240/500 [00:09<00:10, 24.86it/s]

[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  49%|████▉     | 246/500 [00:09<00:10, 25.14it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]


Data Collection:  50%|█████     | 252/500 [00:10<00:09, 25.20it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]


Data Collection:  52%|█████▏    | 258/500 [00:10<00:09, 25.31it/s]

[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]


Data Collection:  53%|█████▎    | 264/500 [00:10<00:09, 25.10it/s]

[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]


Data Collection:  54%|█████▍    | 270/500 [00:10<00:09, 25.18it/s]

[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:  55%|█████▌    | 276/500 [00:11<00:08, 25.20it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]


Data Collection:  56%|█████▋    | 282/500 [00:11<00:08, 25.11it/s]

[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]


Data Collection:  57%|█████▋    | 285/500 [00:11<00:08, 24.23it/s]

[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]


Data Collection:  58%|█████▊    | 291/500 [00:11<00:08, 24.56it/s]

[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]


Data Collection:  59%|█████▉    | 297/500 [00:12<00:08, 24.50it/s]

[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  61%|██████    | 303/500 [00:12<00:07, 24.70it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]


Data Collection:  61%|██████    | 306/500 [00:12<00:07, 24.80it/s]

[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]


Data Collection:  62%|██████▏   | 312/500 [00:12<00:07, 24.69it/s]

[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  64%|██████▎   | 318/500 [00:12<00:07, 24.25it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  64%|██████▍   | 321/500 [00:12<00:07, 24.06it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:  65%|██████▌   | 327/500 [00:13<00:07, 23.93it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]


Data Collection:  67%|██████▋   | 333/500 [00:13<00:07, 23.82it/s]

[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]


Data Collection:  68%|██████▊   | 339/500 [00:13<00:06, 24.38it/s]

[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]


Data Collection:  69%|██████▉   | 345/500 [00:13<00:06, 24.78it/s]

[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection:  70%|███████   | 351/500 [00:14<00:05, 24.89it/s]

[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  71%|███████▏  | 357/500 [00:14<00:05, 24.99it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]


Data Collection:  73%|███████▎  | 363/500 [00:14<00:05, 25.27it/s]

[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]


Data Collection:  74%|███████▍  | 369/500 [00:14<00:05, 25.34it/s]

[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]


Data Collection:  74%|███████▍  | 372/500 [00:15<00:05, 25.03it/s]

[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]


Data Collection:  76%|███████▌  | 378/500 [00:15<00:04, 24.93it/s]

[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:  77%|███████▋  | 384/500 [00:15<00:04, 25.17it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]


Data Collection:  78%|███████▊  | 390/500 [00:15<00:04, 24.50it/s]

[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]


Data Collection:  79%|███████▉  | 396/500 [00:16<00:04, 24.89it/s]

[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]


Data Collection:  80%|████████  | 402/500 [00:16<00:03, 24.78it/s]

[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection:  82%|████████▏ | 408/500 [00:16<00:03, 25.10it/s]

[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  82%|████████▏ | 411/500 [00:16<00:03, 25.04it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  83%|████████▎ | 417/500 [00:16<00:03, 24.91it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]


Data Collection:  85%|████████▍ | 423/500 [00:17<00:03, 25.00it/s]

[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]


Data Collection:  86%|████████▌ | 429/500 [00:17<00:02, 24.47it/s]

[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  87%|████████▋ | 435/500 [00:17<00:02, 24.13it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]


Data Collection:  88%|████████▊ | 438/500 [00:17<00:02, 24.21it/s]

[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection:  89%|████████▉ | 444/500 [00:17<00:02, 23.80it/s]

[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]


Data Collection:  90%|█████████ | 450/500 [00:18<00:02, 24.22it/s]

[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]


Data Collection:  91%|█████████ | 456/500 [00:18<00:01, 24.14it/s]

[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]


Data Collection:  92%|█████████▏| 459/500 [00:18<00:01, 24.16it/s]

[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]


Data Collection:  93%|█████████▎| 465/500 [00:18<00:01, 24.07it/s]

[0.35, 0.65]
[0.4, 0.6]
[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]


Data Collection:  94%|█████████▍| 471/500 [00:19<00:01, 24.15it/s]

[0.6, 0.4]
[0.65, 0.35]
[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]


Data Collection:  95%|█████████▍| 474/500 [00:19<00:01, 24.26it/s]

[0.85, 0.15]
[0.9, 0.1]
[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]


Data Collection:  96%|█████████▌| 480/500 [00:19<00:00, 24.50it/s]

[0.15, 0.85]
[0.2, 0.8]
[0.25, 0.75]
[0.3, 0.7]
[0.35, 0.65]
[0.4, 0.6]


Data Collection:  97%|█████████▋| 486/500 [00:19<00:00, 24.58it/s]

[0.45, 0.55]
[0.5, 0.5]
[0.55, 0.45]
[0.6, 0.4]
[0.65, 0.35]


Data Collection:  98%|█████████▊| 492/500 [00:19<00:00, 23.69it/s]

[0.7, 0.3]
[0.75, 0.25]
[0.8, 0.2]
[0.85, 0.15]
[0.9, 0.1]


Data Collection:  99%|█████████▉| 495/500 [00:20<00:00, 23.50it/s]

[0.95, 0.05]
[0.05, 0.95]
[0.1, 0.9]
[0.15, 0.85]
[0.2, 0.8]


Data Collection: 100%|██████████| 500/500 [00:20<00:00, 24.63it/s]

[0.25, 0.75]
[0.3, 0.7]

PCA 분석 완료. 설명된 분산: [9.9986601e-01 4.0269802e-05]





In [96]:
pca = fit_pca_on_task_difference(all_episode_data, p0_pairs_to_test)



디버깅 정보:
Task 1 (p0=[0.05, 0.95])에 대해 수집된 에피소드 수: 0
Task 2 (p0=[0.1, 0.9])에 대해 수집된 에피소드 수: 0
실제로 수집된 p0 값들 (처음 10개): ['[0.05, 0.45]', '[0.1, 0.4]', '[0.15, 0.35]', '[0.2, 0.3]', '[0.25, 0.25]', '[0.3, 0.2]', '[0.35, 0.15000000000000002]', '[0.4, 0.09999999999999998]', '[0.45, 0.04999999999999999]', '[0.5, 0.0]']


ValueError: PCA 분석에 필요한 한 개 또는 두 개 과제에 대한 데이터가 수집되지 않았습니다. 데이터 수집 과정이나 p0 값 비교 로직을 확인해주세요.