In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys, logging
sys.path.append('./src')
sys.path.append('..')

logging.basicConfig(
    level=logging.INFO, 
    stream=sys.stdout,
    force=True
)

In [3]:
from functools import partial

import torch
from torch.utils.data import DataLoader

import pandas as pd
from tqdm.auto import tqdm
from huggingface_hub import HfApi

from src.catsbench import BenchmarkHDG, BenchmarkHDGConfig
from src.catsbench.metrics import ShapeScore, TrendScore, TrajectoryKLDivergence
from src.catsbench.utils import gumbel_sample
from src.catsbench.prior import Prior

from src.utils import CoupleDataset

# Helper functions

In [65]:
def benchmark_baseline(
    bench_names: list[str],
    repo_name: str,
    method: str,
    method_sample_func: callable,
    batch_size: int, 
    num_cond_samples: int, 
    device: str
) -> list[dict]:
    results = []
    for bench_name in bench_names:
        bench = BenchmarkHDG.from_pretrained(
            repo_name, bench_name, init_benchmark=False, device=device 
        )
        # bench.input_dataset, bench.target_dataset = bench.sample_input_target(bench.num_val_samples)
        # bench.push_to_hub('gregkseno/catsbench', subfolder=bench.name, commit_message='Update validation data')

        # ========================== METRICS ==========================
        shape_metric = ShapeScore(dim=bench.dim, num_categories=bench.num_categories).to(device)
        trend_metric = TrendScore(dim=bench.dim, num_categories=bench.num_categories).to(device)
        cond_shape_metric = ShapeScore(
            dim=bench.dim, num_categories=bench.num_categories, conditional=True
        ).to(device)
        cond_trend_metric = TrendScore(
            dim=bench.dim, num_categories=bench.num_categories, conditional=True
        ).to(device)
        forward_kl_div = TrajectoryKLDivergence().to(device)
        reverse_kl_div = TrajectoryKLDivergence().to(device)

        # ========================== DATA ==========================
        dataset = CoupleDataset(
            input_dataset=bench.input_dataset,
            target_dataset=bench.target_dataset
        )
        dataloader = DataLoader(
            dataset, batch_size=batch_size,
        )

        for x_start, x_end in tqdm(dataloader, desc=f'Benchmarking {bench_name}'):
            # ========================== UNCONDITIONAL METRICS ==========================
            pred_x_end, transition_probs = method_sample_func(x_start, bench)
            shape_metric.update(x_end, pred_x_end)
            trend_metric.update(x_end, pred_x_end)

            # ========================== CONDITIONAL METRICS ==========================
            x_start_repeated = x_start[0].unsqueeze(0).expand(num_cond_samples, -1)
            x_end_repeated = bench.sample(x_start_repeated)
            pred_x_end_repeated, _ = method_sample_func(x_start_repeated, bench)
            cond_shape_metric.update(x_end_repeated, pred_x_end_repeated)
            cond_trend_metric.update(x_end_repeated, pred_x_end_repeated)

            # ========================== TRAJECTORY METRICS ==========================
            if transition_probs is None:
                continue
            true_trajectory, true_transition_logits = bench.sample_trajectory(x_start, return_transitions=True)
            pred_traj_list = [x_start]
            pred_logits_list, model_logits_list = [], []
            pred_x_t = x_start
            for t in range(true_transition_logits.shape[0]):
                probs_t = transition_probs[t].unsqueeze(0).expand(pred_x_t.shape[0], -1, -1, -1) # (B, D, C, C)
                
                # gather predicted transition probabilities
                pred_probs_inicies = pred_x_t[:, :, None, None].expand(-1, -1, 1, bench.num_categories) # (B, D, 1, C)
                pred_probs_t = torch.gather(probs_t, dim=2, index=pred_probs_inicies).squeeze(2)  # (B, D, C)
                pred_logits_list.append(torch.log(pred_probs_t + torch.finfo(pred_probs_t.dtype).eps))

                # gather model transition probabilities
                true_x_t = true_trajectory[t]  # (B, D)
                true_probs_inicies = true_x_t.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, bench.num_categories)  # (B, D, 1, C)
                true_probs_t = torch.gather(probs_t, dim=2, index=true_probs_inicies).squeeze(2)  # (B, D, C)
                model_logits_list.append(torch.log(true_probs_t + torch.finfo(true_probs_t.dtype).eps))

                # sample next predicted point
                pred_x_t = torch.multinomial(
                    pred_probs_t.flatten(end_dim=-2), num_samples=1
                ).squeeze(-1).reshape_as(x_start)
                pred_traj_list.append(pred_x_t)

            pred_trajectory = torch.stack(pred_traj_list)
            pred_transition_logits = torch.stack(pred_logits_list)
            model_transition_logits = torch.stack(model_logits_list)
            
            # we need only num_steps + 1 points to compute transitions
            true_trajectory = true_trajectory[:-1]
            pred_trajectory = pred_trajectory[:-1]
            
            timesteps = torch.arange(true_trajectory.shape[0], device=device)
            timesteps = timesteps.repeat_interleave(true_trajectory.shape[1])
            
            true_trajectory = true_trajectory.flatten(end_dim=1)
            pred_trajectory = pred_trajectory.flatten(end_dim=1)
            true_transition_logits = true_transition_logits.flatten(end_dim=1)
            pred_transition_logits = pred_transition_logits.flatten(end_dim=1)

            # the KL div must be computed in cross fashion:
            # forward KL is KL with respect to true trajectory
            # reverse KL is KL with respect to predicted trajectory
            reverse_kl_div.update(
                p=pred_transition_logits, 
                q=bench.get_transition_logits(pred_trajectory, timesteps)
            )
            forward_kl_div.update(
                p=true_transition_logits.reshape_as(model_transition_logits), 
                q=model_transition_logits
            )
        
        results.append({
            'method': method,
            'dim': bench.dim,
            'num_categories': bench.num_categories,
            'alpha': bench.alpha,
            'prior_type': bench.prior_type,
            'shape_score': shape_metric.compute().cpu().item(),
            'trend_score': trend_metric.compute().cpu().item(),
            'cond_shape_score': cond_shape_metric.compute().cpu().item(),
            'cond_trend_score': cond_trend_metric.compute().cpu().item(),
            'forward_kl_div': forward_kl_div.compute().cpu().item(),
            'reverse_kl_div': reverse_kl_div.compute().cpu().item(),
        })
    return results

# Eval

In [60]:
repo_name = 'gregkseno/catsbench'
batch_size = 128
num_cond_samples = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [61]:
api = HfApi()
tree = api.list_repo_tree(repo_id=repo_name)
bench_names = [tree_item.path for tree_item in tree if tree_item.path.startswith('hdg_')]
# bench_names

## Independent

In [62]:
def independent_sample_func(
    x_start: torch.Tensor, 
    benchmark: BenchmarkHDG
):
    return benchmark.sample_target(x_start.shape[0]), None

results = benchmark_baseline(
    bench_names,
    repo_name,
    'independent',
    method_sample_func=independent_sample_func,
    batch_size=batch_size,
    num_cond_samples=num_cond_samples,
    device=device
)

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d16_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]



INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d16_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d16_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d16_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d2_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d2_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d2_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d2_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d64_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d64_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d64_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

Benchmarking hdg_d64_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

In [63]:
independent = pd.DataFrame(results)

independent["Prior"] = independent.apply(
    lambda r: f"D={r.dim} | {r.prior_type} | α={r.alpha}",
    axis=1
)
metrics = [
    'shape_score', 'trend_score', 'cond_shape_score', 'cond_trend_score'
]
for metric in metrics:
    table = independent.pivot_table(
        index="method", columns="Prior", values=metric,
    ).round(3)
    display(table)

Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
independent,0.987,0.98,0.981,0.983,0.982,0.983,0.98,0.985,0.985,0.981,0.981,0.981


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
independent,0.961,0.955,0.944,0.947,0.966,0.966,0.961,0.966,0.948,0.923,0.92,0.92


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
independent,0.699,0.742,0.498,0.57,0.511,0.831,0.632,0.643,0.653,0.658,0.543,0.61


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
independent,0.568,0.676,0.369,0.457,0.474,0.782,0.52,0.551,0.478,0.508,0.352,0.431


## Prior

In [66]:
def prior_sample_func(
    x_start: torch.Tensor, 
    benchmark: BenchmarkHDG
):
    log_probs = benchmark.prior.extract_last_cum_matrix(x_start)
    transition_probs = (benchmark
        .prior
        .log_p_cum[1:]
        .unsqueeze(1)
        .expand(-1, benchmark.dim, -1, -1)
    ).exp()
    return gumbel_sample(log_probs), transition_probs

results = benchmark_baseline(
    bench_names,
    repo_name,
    'prior',
    method_sample_func=prior_sample_func,
    batch_size=batch_size,
    num_cond_samples=num_cond_samples,
    device=device
)

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Benchmarking hdg_d16_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d16_s50_prior_gaussian_a0.05/model.s(…):   0%|          | 0.00/5.13M [00:00<?, ?B/s]

Benchmarking hdg_d16_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d16_s50_prior_uniform_a0.005/model.s(…):   0%|          | 0.00/5.13M [00:00<?, ?B/s]

Benchmarking hdg_d16_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d16_s50_prior_uniform_a0.01/model.sa(…):   0%|          | 0.00/5.13M [00:00<?, ?B/s]

Benchmarking hdg_d16_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d2_s50_prior_gaussian_a0.02/model.sa(…):   0%|          | 0.00/642k [00:00<?, ?B/s]

Benchmarking hdg_d2_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d2_s50_prior_gaussian_a0.05/model.sa(…):   0%|          | 0.00/642k [00:00<?, ?B/s]

Benchmarking hdg_d2_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d2_s50_prior_uniform_a0.005/model.sa(…):   0%|          | 0.00/642k [00:00<?, ?B/s]

Benchmarking hdg_d2_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d2_s50_prior_uniform_a0.01/model.saf(…):   0%|          | 0.00/642k [00:00<?, ?B/s]

Benchmarking hdg_d2_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d64_s50_prior_gaussian_a0.02/model.s(…):   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Benchmarking hdg_d64_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d64_s50_prior_gaussian_a0.05/model.s(…):   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Benchmarking hdg_d64_s50_prior_gaussian_a0.05:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d64_s50_prior_uniform_a0.005/model.s(…):   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Benchmarking hdg_d64_s50_prior_uniform_a0.005:   0%|          | 0/157 [00:00<?, ?it/s]

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


hdg_d64_s50_prior_uniform_a0.01/model.sa(…):   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Benchmarking hdg_d64_s50_prior_uniform_a0.01:   0%|          | 0/157 [00:00<?, ?it/s]

In [67]:
prior = pd.DataFrame(results)

prior["Prior"] = prior.apply(
    lambda r: f"D={r.dim} | {r.prior_type} | α={r.alpha}",
    axis=1
)
metrics = [
    'shape_score', 'trend_score', 
    'cond_shape_score', 'cond_trend_score', 
    'forward_kl_div', 'reverse_kl_div'
]
for metric in metrics:
    table = prior.pivot_table(
        index="method", columns="Prior", values=metric,
    ).round(3)
    display(table)

Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.405,0.338,0.411,0.421,0.322,0.479,0.41,0.433,0.56,0.491,0.476,0.467


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.193,0.119,0.136,0.138,0.171,0.304,0.251,0.295,0.369,0.233,0.232,0.204


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.294,0.291,0.297,0.329,0.167,0.454,0.394,0.42,0.406,0.371,0.358,0.348


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.111,0.082,0.09,0.095,0.085,0.277,0.231,0.273,0.195,0.121,0.136,0.11


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.526,0.666,0.301,0.527,0.501,0.637,0.272,0.509,0.502,0.637,0.278,0.498


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,2.225,4.623,1.391,1.528,2.118,3.045,0.806,1.336,1.417,3.544,1.027,1.244


## Feature-wise EOT (D-IMF)

In [None]:
def get_transition_matrix(
    fb: str, 
    coupling: torch.Tensor,
    prior: Prior, 
    t: int
) -> torch.Tensor:
    num_categories = prior.num_categories

    # Calculate posterior probs
    x_0 = torch.arange(num_categories, dtype=torch.long, device=coupling.device)
    x_tn = torch.arange(num_categories, dtype=torch.long, device=coupling.device)
    x_0, x_tn =  torch.meshgrid(x_0, x_tn)
    x_0, x_tn = x_0.flatten(), x_tn.flatten()
    t_n = torch.full(size=(x_tn.shape[0],), fill_value=t, dtype=torch.long, device=coupling.device)
    log_posterior = prior.posterior_logits(x_start=x_0, x_t=x_tn, t=t_n).log_softmax(dim=-1)
    log_posterior = log_posterior.reshape(3 * [num_categories])[:, None, :, :] # 0, 1, t_n, t_nm1

    # Calculate bridge probs
    x_0 = torch.arange(num_categories, dtype=torch.long, device=coupling.device)
    x_1 = torch.arange(num_categories, dtype=torch.long, device=coupling.device)
    x_0, x_1 =  torch.meshgrid(x_0, x_1)
    x_0, x_1 = x_0.flatten(), x_1.flatten()
    log_bridge = prior.bridge_logits(x_start=x_0, x_end=x_1, t=t_n).log_softmax(dim=-1)
    log_bridge = log_bridge.reshape(3 * [num_categories])[:, :, :, None] # 0, 1, t_n, t_nm1

    probs = (log_posterior + log_bridge).exp() * coupling[:, :, None, None]
    joint_distribution = probs.sum(dim=[0, 1])
    eps = torch.finfo(joint_distribution.dtype).eps
    if fb == 'forward':
        transition_matrix = (joint_distribution / (joint_distribution.sum(dim=0, keepdim=True) + eps)).T
    else:
        transition_matrix = (joint_distribution / (joint_distribution.sum(dim=1, keepdim=True) + eps))
    return transition_matrix

def categorical_d_imf(
    x_start: torch.Tensor,
    num_imf_iterations: int, 
    p_0: torch.Tensor, 
    p_1: torch.Tensor,
    prior: Prior,
) -> torch.Tensor:
    num_categories = prior.num_categories
    num_timesteps = prior.num_timesteps

    coupling = p_0[:, None] * p_1[None, :]
    for _ in range(num_imf_iterations):
        forward_transition_probs = torch.eye(num_categories, dtype=torch.float, device=p_0.device)
        for t_n in range(1, num_timesteps + 2):
            forward_transition_probs @= get_transition_matrix('forward', coupling, prior, t_n)
        coupling = p_0[:, None] * forward_transition_probs
        
        backward_transition_probs = torch.eye(num_categories, dtype=torch.float, device=p_1.device)
        for t_n in reversed(range(1, num_timesteps + 2)):
            backward_transition_probs @= get_transition_matrix('backward', coupling, prior, t_n)
        coupling = backward_transition_probs.T * p_1[None, :]
    return coupling

def compute_marginal(samples: torch.Tensor, num_categories: int):
    offset = torch.arange(samples.shape[1], device=samples.device) * num_categories
    flat_indices = (samples + offset).view(-1)
    counts = torch.bincount(flat_indices, minlength=samples.shape[1]*num_categories).float()
    return counts.view(samples.shape[1], num_categories) / samples.shape[0]

def d_imf_sample_func(
    x_start: torch.Tensor, 
    benchmark: BenchmarkHDG,
    num_iters: int,
    num_samples_for_probs: int,
):
    input_samples, target_samples = benchmark.sample_input_target(num_samples_for_probs)
    input_distribution = compute_marginal(input_samples, benchmark.num_categories)
    target_distribution = compute_marginal(target_samples, benchmark.num_categories)

    probs = torch.empty(*x_start.shape, benchmark.num_categories, device=x_start.device)
    transitions = torch.empty(
        benchmark.prior.num_timesteps+1, benchmark.dim, benchmark.num_categories, benchmark.num_categories,
        device=x_start.device
    )
    for d in range(benchmark.dim):
        coupling = categorical_d_imf(
            x_start=x_start[:, d],
            num_imf_iterations=num_iters,
            p_0=input_distribution[d],
            p_1=target_distribution[d],
            prior=benchmark.prior,
        )
        forward_transition_probs_list = []
        forward_transition_probs = torch.eye(benchmark.num_categories, dtype=torch.float, device=device)
        for t_n in range(1, benchmark.prior.num_timesteps + 2):
            transition_probs = get_transition_matrix('forward', coupling, benchmark.prior, t_n)
            forward_transition_probs @= transition_probs
            forward_transition_probs_list.append(transition_probs)

        probs[:, d, :] = forward_transition_probs[x_start[:, d]]
        transitions[:, d, :, :] = torch.stack(forward_transition_probs_list)
        
    samples = torch.multinomial(probs.flatten(end_dim=-2), num_samples=1).reshape(x_start.shape)
    return samples, transitions


func = partial(
    d_imf_sample_func, num_iters=10, 
    num_samples_for_probs=60_000
)
results = benchmark_baseline(
    bench_names,
    repo_name,
    'd_imf',
    method_sample_func=func,
    batch_size=batch_size,
    num_cond_samples=num_cond_samples,
    device=device
)

INFO:catsbench:[Rank 0] Skipping parameters initialization!
INFO:catsbench:[Rank 0] Initializing prior...
INFO:catsbench:[Rank 0] Skipping dataset initialization!


Benchmarking hdg_d16_s50_prior_gaussian_a0.02:   0%|          | 0/157 [00:00<?, ?it/s]

true_trajectory shape: torch.Size([17, 128, 16])
true_transition_logits shape: torch.Size([16, 128, 16, 50])
transition_probs shape: torch.Size([16, 16, 50, 50])


KeyboardInterrupt: 

In [None]:
fwot = pd.DataFrame(results)

fwot["Prior"] = fwot.apply(
    lambda r: f"D={r.dim} | {r.prior_type} | α={r.alpha}",
    axis=1
)
metrics = [
    'shape_score', 'trend_score', 
    'cond_shape_score', 'cond_trend_score', 
    'forward_kl_div', 'reverse_kl_div'
]
for metric in metrics:
    table = fwot.pivot_table(
        index="method", columns="Prior", values=metric,
    ).round(3)
    display(table)

Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.6,0.403,0.439,0.447,0.485,0.475,0.473,0.406,0.729,0.543,0.579,0.544


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.336,0.133,0.15,0.142,0.317,0.299,0.303,0.27,0.572,0.246,0.278,0.231


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.227,0.258,0.299,0.32,0.236,0.426,0.439,0.388,0.364,0.309,0.358,0.365


Prior,D=16 | gaussian | α=0.02,D=16 | gaussian | α=0.05,D=16 | uniform | α=0.005,D=16 | uniform | α=0.01,D=2 | gaussian | α=0.02,D=2 | gaussian | α=0.05,D=2 | uniform | α=0.005,D=2 | uniform | α=0.01,D=64 | gaussian | α=0.02,D=64 | gaussian | α=0.05,D=64 | uniform | α=0.005,D=64 | uniform | α=0.01
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
prior,0.071,0.07,0.094,0.091,0.145,0.262,0.272,0.247,0.158,0.095,0.135,0.117
