In [1]:
%env CUDA_VISIBLE_DEVICES=7
from pathlib import Path
from pprint import pprint

import numpy as np
import pyarrow.dataset as ds
import torch as ch
import zarr
from torch import Tensor
from tqdm.rich import tqdm, trange

from config.config import Config, ExperimentConfig

cfg = Config()
# cfg.device="cpu"
pprint(cfg)


env: CUDA_VISIBLE_DEVICES=7
Config(device='cuda',
       worker_id=0,
       worker_total=1,
       dry_run=False,
       debug=False,
       output_dir='/raid/pdpl/trak/grads/',
       save_dir='/raid/pdpl/trak/trak_results/',
       s3_endpoint_url='https://s3.fraunhofer.de',
       write_chunks=1000,
       seed=42,
       proj_dim=2048,
       num_contrastive_samples=50000,
       datasets={'CIFAR100': DatasetConfig(uri='/datasets/cifar100/shards/cifar100-train-{000000..000049}.tar',
                                           uris=None,
                                           size=None,
                                           num_workers=16,
                                           splittable=True,
                                           custom=True),
                 'Food101': DatasetConfig(uri='/datasets/food101/shards/food101-train-{000000..000075}.tar',
                                          uris=None,
                                          size=None,
        

In [2]:
cfg.experiments = [ExperimentConfig(name="raw")]

In [4]:
from training.data import uid_int_to_str

In [5]:
DEBUG = False

In [6]:
def load_ood_grads(experiment_cfg, encoder_cfg):
    input_path = str(
        Path(cfg.output_dir)
        / experiment_cfg.name
        / encoder_cfg.name
        / experiment_cfg.ood_dataset_name
        / "data.zarr"
    )
    dataset = zarr.open(input_path)
    if DEBUG:
        uids = uid_int_to_str(dataset["uid"][:1000])
        g = dataset["grads"][:1000]
        out_to_loss = dataset["loss_grads"][:1000]
    else:
        uids = uid_int_to_str(dataset["uid"][:])
        g = dataset["grads"][:]
        out_to_loss = dataset["loss_grads"][:]
    dtype = [
        ("uids", uids.dtype),
        ("grads", g.dtype, g.shape[1]),
        (
            "loss_grads",
            out_to_loss.dtype,
        ),
    ]
    combined = np.empty(len(uids), dtype=dtype)
    combined["uids"] = uids
    combined["grads"] = g
    combined["loss_grads"] = out_to_loss

    # Sort in-place based on uids
    combined.sort(order="uids")

    # Extract back the sorted arrays
    uids = combined["uids"]
    g = ch.tensor(
        np.ascontiguousarray(combined["grads"]), device="cpu"
    ).pin_memory()
    out_to_loss = ch.tensor(
        np.ascontiguousarray(combined["loss_grads"]), device="cpu"
    ).pin_memory()
    return uids, g, out_to_loss


def load_dataset_size(experiment_cfg, encoder_cfg):
    if DEBUG:
        return 1000
    input_path = str(
        Path(cfg.output_dir)
        / experiment_cfg.name
        / encoder_cfg.name
        / experiment_cfg.ood_dataset_name
        / "data.zarr"
    )
    dataset = zarr.open(input_path)
    return dataset["grads"].shape[0]

In [7]:
def get_xtx(grads: Tensor, batch_size=20_000, progress=None) -> Tensor:
    proj_dim = grads.shape[1]
    result = ch.zeros(proj_dim, proj_dim, dtype=grads.dtype, device="cuda")
    blocks = ch.split(grads, split_size_or_sections=batch_size, dim=0)

    # Use progress.track if progress bar is provided, otherwise use regular iteration
    iterator = (
        progress.track(blocks, description="Computing XTX")
        if progress
        else blocks
    )
    for block in iterator:
        result += block.T.to("cuda") @ block.to("cuda")

    return result


In [8]:
def get_x_xtx_inv(
    grads: Tensor,
    xtx: Tensor,
    lambda_reg=0.0,
    batch_size=20_000,
    progress=None,
) -> Tensor:
    xtx_reg = xtx + lambda_reg * ch.eye(
        xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype
    )
    xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32))

    # center X^TX inverse a bit to avoid numerical issues when going to float16
    xtx_inv /= xtx_inv.abs().mean()
    xtx_inv = xtx_inv.to(grads.dtype)

    grads_blocks = ch.split(grads, split_size_or_sections=batch_size, dim=0)

    # Move xtx_inv to GPU once before the loop
    xtx_inv_gpu = xtx_inv.cuda()

    # Process blocks on GPU
    result_blocks = []
    # Use progress.track if progress bar is provided
    iterator = (
        progress.track(grads_blocks, description="Processing blocks")
        if progress
        else grads_blocks
    )
    for block in iterator:
        block_gpu = block.cuda()
        result_gpu = block_gpu @ xtx_inv_gpu
        result_blocks.append(result_gpu.cpu())

    # Concatenate results on CPU
    result = ch.cat(result_blocks)

    return result.to(dtype=grads.dtype)

In [9]:
def get_indices(target, id: bool = True):
    id_indices_zarr = zarr.open("/raid/pdpl/id_downstream_idx.zarr", mode="r")
    if id:
        return id_indices_zarr[target]["id_indices"]
    else:
        return id_indices_zarr[target]["downstream_indices"]

In [10]:
train_dataset_size = load_dataset_size(
    cfg.experiments[0], cfg.experiments[0].encoders[0]
)

In [11]:
targets = [
    # "fitzpatrick17k",
    # "fairvision/dr",
    "fairvision/amd",
    # "fairvision/glaucoma",
    # "pcam",
    # "food101",
    # "cifar100",
    # "stl10",
]

In [None]:
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
)

avg_out_to_loss = ch.zeros(train_dataset_size, device="cpu")
avg_scores = {k: ch.zeros(train_dataset_size, device="cpu") for k in targets}

with Progress(
    SpinnerColumn(),
    TextColumn("[progress.description]{task.description}"),
    BarColumn(),
    TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
    TimeElapsedColumn(),
) as progress:
    encoder_task = progress.add_task(
        "Processing encoders...", total=len(cfg.experiments[0].encoders)
    )

    for encoder_cfg in cfg.experiments[0].encoders:
        # all_scores = {}
        uids, g, out_to_loss = load_ood_grads(cfg.experiments[0], encoder_cfg)
        avg_out_to_loss += out_to_loss
        xtx = get_xtx(ch.tensor(g, device="cpu"), progress=progress)
        x_xtx_inv = get_x_xtx_inv(
            ch.tensor(g, device="cpu"), xtx, progress=progress
        )
        features = x_xtx_inv.pin_memory()

        target_task = progress.add_task(
            f"Processing targets for {encoder_cfg.name}...", total=len(targets)
        )

        for target in targets:
            input_path = str(
                Path(cfg.output_dir)
                / cfg.experiments[0].name
                / encoder_cfg.name
                / target
            )
            dataset_target = ds.dataset(input_path, format="parquet")
            batch_size = 16384
            scanner = dataset_target.scanner(
                columns=["grads", "uid"], batch_size=batch_size
            )
            batches = scanner.to_batches()
            grads_list = []
            uids_list = []

            batch_task = progress.add_task(
                f"Loading batches for {target}...",
                total=dataset_target.count_rows() // batch_size,
            )

            for batch in scanner.to_batches():
                grads_list.extend(
                    batch.column("grads").to_numpy(zero_copy_only=False)
                )
                uids_list.extend(
                    batch.column("uid").to_numpy(zero_copy_only=False)
                )
                progress.advance(batch_task)

            progress.remove_task(batch_task)

            g_target = np.stack(grads_list)
            uids_target = np.stack(uids_list)
            dtype = [
                ("uids", uids_target.dtype),
                ("grads", g_target.dtype, g_target.shape[1]),
            ]
            combined = np.empty(len(uids_target), dtype=dtype)
            combined["uids"] = uids_target
            combined["grads"] = g_target
            combined.sort(order="uids")
            uids_target = combined["uids"]
            g_target = combined["grads"]
            id_indices = get_indices(
                target, id=False
            )  # get downstream indices
            g_target_pt = ch.tensor(
                g_target[id_indices], device="cpu"
            ).pin_memory()

            batch_size = 8192 * 2
            scores = []

            score_task = progress.add_task(
                f"Computing scores for {target}...",
                total=len(features) // batch_size + 1,
            )

            for i in range(0, len(features), batch_size):
                batch = features[i : i + batch_size].cuda()
                batch_scores = ch.mean(batch @ g_target_pt.cuda().T, axis=1)
                scores.append(batch_scores.cpu())
                progress.advance(score_task)

            progress.remove_task(score_task)
            scores = ch.cat(scores)
            avg_scores[target] += scores
            # all_scores[target] = scores.cpu()
            progress.advance(target_task)

        progress.remove_task(target_task)
        progress.advance(encoder_task)

avg_out_to_loss /= len(cfg.experiments[0].encoders)
avg_scores = {
    k: v / len(cfg.experiments[0].encoders) for k, v in avg_scores.items()
}
final_scores = {k: v * avg_out_to_loss for k, v in avg_scores.items()}


In [None]:
final_scores["fairvision/amd"]

In [None]:
scores_zarr = zarr.open(
    "/datasets/datacomp/nearest_neighbor_scores.zarr", mode="a"
)
if "trak" not in scores_zarr:
    scores_zarr.create_group("trak")
scores_zarr = scores_zarr["trak"]
# get target features
for target in [
    "fitzpatrick17k",
    "fairvision/dr",
    "fairvision/amd",
    "fairvision/glaucoma",
    "pcam",
    "food101",
    "cifar100",
    "stl10",
]:
    if target in all_scores.keys():
        continue
    if target in scores_zarr:
        all_scores[target] = scores_zarr[target]["id_scores"]
        continue
    input_path = str(Path(cfg.output_dir) / encoder_cfg.name / target)
    dataset_target = ds.dataset(input_path, format="parquet")
    batch_size = 16384
    scanner = dataset_target.scanner(
        columns=["grads", "uid"], batch_size=batch_size
    )
    batches = scanner.to_batches()
    grads_list = []
    uids_list = []
    for batch in tqdm(
        scanner.to_batches(), total=dataset_target.count_rows() // batch_size
    ):
        grads_list.extend(batch.column("grads").to_numpy(zero_copy_only=False))
        uids_list.extend(batch.column("uid").to_numpy(zero_copy_only=False))
    g_target = np.stack(grads_list)
    uids_target = np.stack(uids_list)
    dtype = [
        ("uids", uids_target.dtype),
        ("grads", g_target.dtype, g_target.shape[1]),
    ]
    combined = np.empty(len(uids_target), dtype=dtype)
    combined["uids"] = uids_target
    combined["grads"] = g_target
    combined.sort(order="uids")
    uids_target = combined["uids"]
    g_target = combined["grads"]
    id_indices = get_indices(target, id=False)  # get downstream indices
    g_target_pt = ch.tensor(g_target[id_indices], device="cpu").pin_memory()

    batch_size = 8192 * 2
    scores = []
    for i in trange(0, len(features), batch_size):
        batch = features_pt[i : i + batch_size].cuda()
        batch_scores = ch.mean(batch @ g_target_pt.cuda().T, axis=1)
        scores.append(batch_scores.cpu())
    scores = ch.cat(scores)
    scores = scores * out_to_loss
    if target not in scores_zarr:
        target_group = scores_zarr.create_group(target)
    else:
        target_group = scores_zarr[target]

    # Save the scores
    target_group.array(
        "id_scores", np.array(scores.cpu()), dtype=np.float32, overwrite=True
    )
    all_scores[target] = scores.cpu()


In [None]:
all_scores.keys()

In [None]:
import pandas as pd

data = []
for key, scores in all_scores.items():
    mean = scores.mean().item()
    std = scores.std().item()
    data.append({"dataset": key, "mean": mean, "std": std})

df = pd.DataFrame(data)
df = df.sort_values("std", ascending=False)
print(df.to_string(float_format=lambda x: "{:.4f}".format(x)))


In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 6))
sns.set_style("whitegrid", {"grid.alpha": 0.3})

# Create a continuous color palette with enough colors
num_datasets = len(all_scores)
colors = plt.cm.viridis(np.linspace(0, 1, num_datasets))


def compute_histogram(data, bins=50):
    hist, bin_edges = jnp.histogram(data, bins=bins)
    return hist, bin_edges


# Plot in order of standard deviation from df
for i, row in enumerate(df.itertuples()):
    key = row.dataset
    scores = all_scores[key]
    scores_np = jnp.array(scores.numpy())
    hist, bins = compute_histogram(scores_np)

    # Convert back to numpy for seaborn plotting
    hist = hist.block_until_ready()
    sns.lineplot(x=bins[:-1], y=hist, alpha=0.5, label=key, color=colors[i])
    plt.yscale("log")

plt.title("Distribution of Scores Across Datasets")
plt.xlabel("Score")
plt.ylabel("Count (log scale)")
plt.legend()
plt.show()

In [40]:
# from trak.utils import get_matrix_mult_blockwise

# full_scores = get_matrix_mult_blockwise(
#     features, ch.tensor(g_target, device="cpu"), ch.float16, bs=2048
# )
