In [98]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import trange
import os
import json
import pandas as pd

# ---- Utilities ----


def one_hot(index, size):
    vec = np.zeros(size)
    vec[index] = 1
    return vec


# ---- Model ----


class A2CLSTM(nn.Module):
    def __init__(self, input_size=4, hidden_size=64, action_size=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.q_head = nn.Linear(hidden_size, action_size)
        self.v_head = nn.Linear(hidden_size, 1)

    def forward(self, x, hidden=None):
        out, hidden = self.lstm(x, hidden)
        q_values = self.q_head(out)
        state_values = self.v_head(out)
        return q_values, state_values, hidden


# ---- Discounted return ----


def compute_returns(rewards, gamma):
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return returns


# ---- Run a single session ----


def run_session(model, probs, session_id, gamma=0.8, n_trials=200, train=True):
    if train:
        model.train()
        optimizer = optim.RMSprop(model.parameters(), lr=1e-3, alpha=0.99)

    log_probs, values, rewards, entropies = [], [], [], []
    hidden = None
    prev_action, prev_reward = 0, 0
    trial_data = []

    for trial_id in range(n_trials):
        x_t = np.concatenate([one_hot(prev_action, 2), one_hot(prev_reward, 2)])
        x_t = torch.tensor(x_t, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        q_vals, v_vals, hidden = model(x_t, hidden)
        probs_action = torch.softmax(q_vals[0, 0], dim=-1)
        dist = torch.distributions.Categorical(probs_action)
        action = dist.sample()
        reward = int(np.random.rand() < probs[action.item()])
        prev_action, prev_reward = action.item(), reward

        trial_data.append(
            {
                "session_id": session_id,
                "trial_id": trial_id,
                "chosen_action": action.item(),
                "reward": reward,
                "reward_prob": probs[action.item()],
                "model_prediction": probs_action.detach().numpy().tolist(),
            }
        )

        if train:
            log_probs.append(dist.log_prob(action))
            values.append(v_vals[0, 0])
            rewards.append(reward)
            entropies.append(dist.entropy())

    if train:
        returns = compute_returns(rewards, gamma)
        returns = torch.tensor(returns, dtype=torch.float32)
        values = torch.stack(values)
        log_probs = torch.stack(log_probs)
        entropies = torch.stack(entropies)
        advantages = returns - values.squeeze()

        actor_loss = -(log_probs * advantages.detach()).mean()
        critic_loss = 0.5 * (advantages**2).mean()
        entropy_loss = -0.5 * entropies.mean()
        loss = actor_loss + critic_loss + entropy_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return trial_data


# ---- Training ----


def train_agent(condition, n_sessions=100, n_trials=200, gamma=0.8, save_path=None):
    model = A2CLSTM()
    full_data = []
    probs_set = np.arange(0.1, 1, 0.1)
    for session_id in range(n_sessions):
        if condition == "structured":
            p1 = np.random.choice(probs_set)
            p2 = 1 - p1
            probs = np.array([p1, p2])
        else:
            p1 = np.random.choice(probs_set)
            p2 = np.random.choice(probs_set)
            probs = np.array([p1, p2])

        session_data = run_session(
            model, probs, session_id, gamma, n_trials, train=True
        )
        full_data.extend(session_data)

    if save_path:
        torch.save(model.state_dict(), save_path)
    return model, full_data


# ---- Testing ----


def test_agent(model, condition, n_sessions=20, n_trials=200):
    model.eval()
    full_data = []
    for session_id in range(n_sessions):
        if condition == "structured":
            probs = np.random.uniform(0.5, 1.0, size=2)
        else:
            probs = np.random.uniform(0.5, 1.0, size=2)
        session_data = run_session(
            model, probs, session_id, train=False, n_trials=n_trials
        )
        full_data.extend(session_data)
    return full_data


# ---- Save & Load Model ----


def save_model(model, filepath):
    torch.save(model.state_dict(), filepath)


def load_model(filepath):
    model = A2CLSTM()
    model.load_state_dict(torch.load(filepath))
    model.eval()
    return model


# ---- Main ----


def run_experiment(n_models=3, save_models=True):
    results = {"structured": [], "unstructured": []}
    os.makedirs("models", exist_ok=True)

    for cond in ["structured", "unstructured"]:
        for i in trange(n_models, desc=f"Training {cond}"):
            save_path = f"models/model_{cond}_{i}.pt" if save_models else None
            model, train_data = train_agent(cond, save_path=save_path)
            test_data = test_agent(model, cond)
            results[cond].append({"train": train_data, "test": test_data})
    return results


# ---- Save results to CSV ----


def save_results_to_csv(results, prefix):
    rows = []
    for cond in results:
        for model_idx, model_data in enumerate(results[cond]):
            for phase in ["train", "test"]:
                for trial in model_data[phase]:
                    trial_row = trial.copy()
                    trial_row["model"] = model_idx
                    trial_row["condition"] = cond
                    trial_row["phase"] = phase
                    rows.append(trial_row)
    df = pd.DataFrame(rows)
    df.to_csv(f"{prefix}_results.csv", index=False)


# ---- Example run ----

results = run_experiment(n_models=2)
save_results_to_csv(results, prefix="bandit")

# Save data to disk (optional)
with open("results_structured.json", "w") as f:
    json.dump(results["structured"], f, indent=2)

with open("results_unstructured.json", "w") as f:
    json.dump(results["unstructured"], f, indent=2)

Training structured: 100%|██████████| 2/2 [00:30<00:00, 15.24s/it]
Training unstructured: 100%|██████████| 2/2 [00:30<00:00, 15.33s/it]


In [97]:
results["structured"]

[{'train': [{'session_id': 0,
    'trial_id': 0,
    'chosen_action': 0,
    'reward': 0,
    'reward_prob': np.float64(0.30000000000000004),
    'model_prediction': [0.5254129767417908, 0.47458699345588684]},
   {'session_id': 0,
    'trial_id': 1,
    'chosen_action': 1,
    'reward': 1,
    'reward_prob': np.float64(0.7),
    'model_prediction': [0.5283531546592712, 0.47164681553840637]},
   {'session_id': 0,
    'trial_id': 2,
    'chosen_action': 1,
    'reward': 1,
    'reward_prob': np.float64(0.7),
    'model_prediction': [0.5217441320419312, 0.47825589776039124]},
   {'session_id': 0,
    'trial_id': 3,
    'chosen_action': 0,
    'reward': 0,
    'reward_prob': np.float64(0.30000000000000004),
    'model_prediction': [0.5166049599647522, 0.48339495062828064]},
   {'session_id': 0,
    'trial_id': 4,
    'chosen_action': 0,
    'reward': 1,
    'reward_prob': np.float64(0.30000000000000004),
    'model_prediction': [0.5216341614723206, 0.47836586833000183]},
   {'session_id': 

In [94]:
np.random.uniform(0.5, 1.0, size=2)

array([0.55642157, 0.6310812 ])