In [None]:
import os
import torch
from datetime import datetime
import numpy as np
from multiprocessing import Pool, cpu_count
from bandit_task import TwoArmedBandit
from model import RNNActorCritic
from episode import collect_bandit_trajectory
from reptile_bandit_training import *

base_unix = "/storage1/fs1/shinung/Active/jackosvky/LearningToLearn"
base_windows = "//storage1.ris.wustl.edu/shinung/Active/jackosvky/LearningToLearn"
base_mac = "/Volumes/shinung/Active/jackosvky/LearningToLearn"

# Check which base path exists
if os.path.exists(base_unix):
    base = base_unix
elif os.path.exists(base_windows):
    base = base_windows
elif os.path.exists(base_mac):
    base = base_mac
else:
    base = os.getcwd()
    print(f"None of the base paths are available. Creating folder in the current directory: {base}")

current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
new_folder = os.path.join(base, current_time)
os.makedirs(new_folder, exist_ok=True)
os.chdir(new_folder)  

In [None]:
# Instantiate and train
model = RNNActorCritic(input_size=3, hidden_size=32, action_size=2)  # input = 2 (action) + 1 (reward)

# Choose training method: parallel or sequential
use_parallel = True  # Set to False for sequential training

if use_parallel:
    reptile_bandit_train_parallel(model, n_outer_iters=10000, meta_batch_size=10,
                                   inner_steps=5, n_rounds=50, inner_lr=0.02, 
                                   outer_lr=0.1, n_workers=None)  # None = auto-detect CPUs
else:
    reptile_bandit_train(model, n_outer_iters=10000, meta_batch_size=10,
                         inner_steps=5, n_rounds=50, inner_lr=0.02, outer_lr=0.1)

torch.save(model.state_dict(), "reptile_bandit_model.pth")

# Test the trained model
test_p_values = [0.1, 0.3, 0.5, 0.7, 0.9]
test_results = test_model_performance(
    model, 
    p_values=test_p_values,
    n_test_episodes=10,
    n_rounds=50,
    inner_steps=0,  # Test zero-shot performance (set to 5 for few-shot adaptation)
    inner_lr=0.02
)
