In [None]:
from sample_complexity_traj_utils import evaluate_one_estimator_torch
from concurrent.futures import ProcessPoolExecutor
from tqdm.notebook import tqdm
import numpy as np
import pickle
import matplotlib.pyplot as plt

def run_sample_complexity_study_torch_parallel(
        true_theta_np,
        d=5,
        num_init_trajs=10,
        jump_size=2,
        num_increments=10,
        num_repeat=20,
        learning_rate=0.001,
        num_epochs=1000,
        batch_size=32,
        max_workers=4  # you control # of parallel jobs here
    ):
    traj_sizes = [num_init_trajs * (jump_size ** i) for i in range(num_increments)]
    l2_errors, max_errors, log_likelihoods, estimators = [], [], [], []

    for k in tqdm(traj_sizes, desc="Trajectory sizes"):
        kwargs_list = [
            {
                "k": k,
                "d": d,
                "true_theta_np": true_theta_np,
                "learning_rate": learning_rate,
                "num_epochs": num_epochs,
                "batch_size": batch_size,
                "seed": np.random.randint(0, 1e6)
            }
            for _ in range(num_repeat)
        ]

        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            results = list(executor.map(evaluate_one_estimator_torch, kwargs_list))

        l2s, maxs, lls, theta_hats = zip(*results)
        l2_errors.append((np.mean(l2s), np.std(l2s)))
        max_errors.append((np.mean(maxs), np.std(maxs)))
        log_likelihoods.append((np.mean(lls), np.std(lls)))
        estimators.append(theta_hats)

    return traj_sizes, l2_errors, max_errors, log_likelihoods, estimators

def plot_with_ci(x, means_and_stds, ylabel, title, save=False):
    means = np.array([mean for mean, std in means_and_stds])
    stds = np.array([std for mean, std in means_and_stds])

    lower = means - 1.96 * stds
    upper = means + 1.96 * stds

    plt.figure(figsize=(6, 4))
    plt.plot(x, means, label="Mean", color="blue")
    plt.fill_between(x, lower, upper, color="blue", alpha=0.2, label="95% CI")
    plt.xlabel("Number of Trajectories")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"figs/{ylabel}_TB.png", dpi=300) if save else None
    plt.show()


In [7]:
d = 3
true_theta_np = np.random.uniform(0.2, 0.8, size=(1,d))

traj_sizes, l2s, maxs, lls, estimators = run_sample_complexity_study_torch_parallel(
    true_theta_np=true_theta_np,
    d=d,
    num_init_trajs=100,
    jump_size=2,
    num_increments=1,
    num_repeat=20,
    learning_rate=0.001,
    num_epochs=1000,
    batch_size=8,
    max_workers=12,
)

Trajectory sizes:   0%|          | 0/1 [00:00<?, ?it/s]

In [9]:
np.array(estimators).mean(axis=1)

array([[0.5617205 , 0.62166464, 0.4842712 ]], dtype=float32)

In [8]:
true_theta_np, estimators

(array([[0.35825943, 0.31447919, 0.73706058]]),
 [(array([0.9885555 , 0.35790345, 0.5317727 ], dtype=float32),
   array([0.5617071 , 0.50835997, 0.8113076 ], dtype=float32),
   array([0.85460883, 0.32857186, 0.3119584 ], dtype=float32),
   array([1.0460751 , 0.26730463, 0.2355254 ], dtype=float32),
   array([0.61465585, 0.3912766 , 0.50851035], dtype=float32),
   array([0.48299775, 0.95463413, 0.05600949], dtype=float32),
   array([0.3960647 , 0.9533036 , 0.23909846], dtype=float32),
   array([0.5215588 , 0.8190409 , 0.59503984], dtype=float32),
   array([0.5523347 , 0.53843904, 0.35275748], dtype=float32),
   array([0.6148491 , 0.57228684, 0.29582858], dtype=float32),
   array([0.34707072, 0.71496296, 0.59568214], dtype=float32),
   array([0.23477848, 0.71755755, 0.8187558 ], dtype=float32),
   array([0.5253758 , 1.0030268 , 0.13051999], dtype=float32),
   array([0.5581951 , 0.74052614, 0.7254748 ], dtype=float32),
   array([0.803447  , 0.46652877, 0.6129582 ], dtype=float32),
   arra

In [None]:
# with open("sample_complexity_results_TB.pkl", "wb") as f:
#     pickle.dump((traj_sizes, l2s, maxs, lls), f)
    
# # load results
# with open("sample_complexity_results_TB.pkl", "rb") as f:
#     traj_sizes, l2s, maxs, lls = pickle.load(f)

In [None]:
plot_with_ci(traj_sizes, l2s, "L2 Norm Error", "Sample Complexity: L2 Distance", save=True)
plot_with_ci(traj_sizes, maxs, "Max Norm Error", "Sample Complexity: Max Norm", save=True)
plot_with_ci(traj_sizes, lls, "Avg Log-Likelihood", "Sample Complexity: Log-Likelihood", save=True)
