In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import embedders

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from tqdm.notebook import tqdm
# from tqdm import tqdm

# Filter out warnings raised when sampling Wishart distribution in Gaussian mixtures
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

if device != torch.device("cuda"):
    sample_device = torch.device("cpu")
else:
    sample_device = device

print(f"Device: {device}, Sample Device: {sample_device}")

Device: cuda, Sample Device: cuda


In [8]:
results = []

# Signatures - using non-Gu approach for now
SIGNATURES = [
    [(-1, 2), (-1, 2)],  # HH
    [(-1, 2), (0, 2)],  # HE
    [(-1, 2), (1, 2)],  # HS
    [(1, 2), (1, 2)],  # SS
    [(1, 2), (0, 2)],  # SE
    [(-1, 4)],  # H
    [(0, 4)],  # E
    [(1, 4)],  # S
]

SIGNATURES_STR = ["HH", "HE", "HS", "SS", "SE", "H", "E", "S"]

DIM = 4
N_SAMPLES = 10
N_POINTS = 1_000
# N_POINTS = 100
N_CLASSES = 8
N_CLUSTERS = 32
# MAX_DEPTH = None
COV_SCALE_MEANS = 1.0
COV_SCALE_POINTS = 1.0

# TASK = "regression"
TASK = "classification"
# RESAMPLE_SCALES = False
# N_FEATURES = "d_choose_2"

# SCORE = "f1-micro" if TASK == "classification" else "rmse"
SCORE = ["f1-micro", "accuracy"] if TASK == "classification" else ["rmse"]

my_tqdm = tqdm(total=len(SIGNATURES) * N_SAMPLES)
for i, (sig, sigstr) in enumerate(zip(SIGNATURES[-1:], SIGNATURES_STR[-1:])):
    for seed in range(N_SAMPLES)[-1:]:
        print(sig, sigstr, seed)
        # Ensure unique seed per trial
        seed = seed + N_SAMPLES * i
        pm = embedders.manifolds.ProductManifold(signature=sig, device=sample_device)

        # Get X, y
        X, y = embedders.gaussian_mixture.gaussian_mixture(
            pm=pm,
            seed=seed,
            num_points=N_POINTS,
            num_classes=N_CLASSES,
            num_clusters=N_CLUSTERS,
            cov_scale_means=COV_SCALE_MEANS / DIM,
            cov_scale_points=COV_SCALE_POINTS / DIM,
            task=TASK,
        )
        X = X.to(device)
        y = y.to(device)
        pm = pm.to(device)

        # if RESAMPLE_SCALES:
        #     scale = 0.5 - np.random.rand() * 20
        #     pm.P[0].scale = torch.exp(torch.tensor(scale)).item()
        #     pm.P[0].manifold._log_scale = torch.nn.Parameter(torch.tensor(scale))

        # Benchmarks are now handled by the benchmark function
        model_results = embedders.benchmarks.benchmark(
            # X, y, pm, max_depth=MAX_DEPTH, task=TASK, score=SCORE, seed=seed, n_features=N_FEATURES, device=device
            X, y, pm, task=TASK, score=SCORE, seed=seed, device=device
        )
        
        # # Create a flat dictionary for this run
        # run_results = {"signature": sigstr, "seed": seed}

        # # Flatten the nested model results
        # for model, metrics in model_results.items():
        #     for metric, value in metrics.items():
        #         run_results[f"{model}_{metric}"] = value

        # results.append(run_results)
        model_results["signature"] = sigstr
        model_results["seed"] = seed
        results.append(model_results)
        my_tqdm.update(1)


results = pd.DataFrame(results)

  0%|          | 0/80 [00:00<?, ?it/s]

[(1, 4)] S 9


  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [9]:
results.to_csv(f"embedders/data/results_icml/{TASK}_signature2.tsv", index=False, sep="\t")

In [4]:
device = torch.device("cuda", 1)

results = []

# Signatures - using non-Gu approach for now
SIGNATURES = [
    [(-1, 2), (-1, 2)],  # HH
    [(-1, 2), (0, 2)],  # HE
    [(-1, 2), (1, 2)],  # HS
    [(1, 2), (1, 2)],  # SS
    [(1, 2), (0, 2)],  # SE
    [(-1, 4)],  # H
    [(0, 4)],  # E
    [(1, 4)],  # S
]

SIGNATURES_STR = ["HH", "HE", "HS", "SS", "SE", "H", "E", "S"]

DIM = 4
N_SAMPLES = 10
N_POINTS = 1_000
# N_POINTS = 100
N_CLASSES = 8
N_CLUSTERS = 32
# MAX_DEPTH = None
COV_SCALE_MEANS = 1.0
COV_SCALE_POINTS = 1.0

TASK = "regression"
# TASK = "classification"
# RESAMPLE_SCALES = False
# N_FEATURES = "d_choose_2"

# SCORE = "f1-micro" if TASK == "classification" else "rmse"
SCORE = ["f1-micro", "accuracy"] if TASK == "classification" else ["rmse"]

my_tqdm = tqdm(total=len(SIGNATURES) * N_SAMPLES)
for i, (sig, sigstr) in enumerate(zip(SIGNATURES, SIGNATURES_STR)):
    for seed in range(N_SAMPLES):
        # Ensure unique seed per trial
        seed = seed + N_SAMPLES * i
        pm = embedders.manifolds.ProductManifold(signature=sig, device=sample_device)

        # Get X, y
        X, y = embedders.gaussian_mixture.gaussian_mixture(
            pm=pm,
            seed=seed,
            num_points=N_POINTS,
            num_classes=N_CLASSES,
            num_clusters=N_CLUSTERS,
            cov_scale_means=COV_SCALE_MEANS / DIM,
            cov_scale_points=COV_SCALE_POINTS / DIM,
            task=TASK,
        )
        X = X.to(device)
        y = y.to(device)
        pm = pm.to(device)

        # if RESAMPLE_SCALES:
        #     scale = 0.5 - np.random.rand() * 20
        #     pm.P[0].scale = torch.exp(torch.tensor(scale)).item()
        #     pm.P[0].manifold._log_scale = torch.nn.Parameter(torch.tensor(scale))

        # Benchmarks are now handled by the benchmark function
        model_results = embedders.benchmarks.benchmark(
            # X, y, pm, max_depth=MAX_DEPTH, task=TASK, score=SCORE, seed=seed, n_features=N_FEATURES, device=device
            X, y, pm, task=TASK, score=SCORE, seed=seed, device=device
        )
        
        # # Create a flat dictionary for this run
        # run_results = {"signature": sigstr, "seed": seed}

        # # Flatten the nested model results
        # for model, metrics in model_results.items():
        #     for metric, value in metrics.items():
        #         run_results[f"{model}_{metric}"] = value

        # results.append(run_results)
        model_results["signature"] = sigstr
        model_results["seed"] = seed
        results.append(model_results)
        my_tqdm.update(1)


results = pd.DataFrame(results)

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

KeyboardInterrupt: 

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [5]:
results.to_csv(f"../data/results/{TASK}_signature_md{MAX_DEPTH}_ICML.tsv", index=False, sep="\t")