In [1]:
%env CUDA_VISIBLE_DEVICES=4

env: CUDA_VISIBLE_DEVICES=4


In [2]:
# we now need to read the grads back into the proper TRAKer and continue

from pprint import pprint

import numpy as np
import pyarrow.dataset as ds
import torch
from config import Config
from model import Model
from trak import TRAKer

In [3]:
cfg = Config()
# cfg.device="cpu"
pprint(cfg)

Config(device='cuda',
       worker_id=0,
       worker_total=20,
       dry_run=False,
       debug=False,
       output_dir='/raid/pdpl/trak/grads/',
       save_dir='/raid/pdpl/trak/trak_results/',
       write_chunks=1000,
       seed=42,
       proj_dim=2048,
       num_contrastive_samples=50000,
       datasets={'commonpool': DatasetConfig(uri='/datasets/datacomp/shards/{00000000..00001287}.tar',
                                             uris=None,
                                             size=None,
                                             num_workers=16,
                                             num_samples=10367394),
                 'fairvision/AMD': DatasetConfig(uri='/datasets/fairvision/AMD/shards/amd-train-{000000..000005}.tar',
                                                 uris=None,
                                                 size=None,
                                                 num_workers=16,
                                                 

In [4]:
encoder_cfg = cfg.encoders[1]
from pathlib import Path

input_path = str(
    Path(cfg.output_dir) / encoder_cfg.name / encoder_cfg.ood_dataset_name
)
dataset = ds.dataset(input_path, format="parquet")
train_set_size = dataset.count_rows()

In [5]:
model = Model(encoder_cfg, "cpu")
model, _, _, _ = model.create_model_and_transforms()
traker = TRAKer(
    save_dir=cfg.save_dir,
    model=model,
    task="clip",
    train_set_size=train_set_size,
    device=cfg.device,
    proj_dim=cfg.proj_dim,
    use_half_precision=True,
)

INFO:TRAK:Using ChunkedCudaProjector with2 chunks of sizes[125979393, 25297920].
INFO:STORE:Existing model IDs in /raid/pdpl/trak/trak_results: [0, 1]
INFO:STORE:Model IDs that have been finalized: [0]
INFO:STORE:No existing TRAK scores in /raid/pdpl/trak/trak_results.


In [6]:
traker.saver.load_current_store(encoder_cfg.model_id)

In [8]:
import torch as ch

grads = traker.saver.current_store["grads"]
proj_dim = grads.shape[1]
# result = ch.zeros(proj_dim, proj_dim, dtype=ch.float16, device="cuda")
blocks = ch.split(ch.as_tensor(grads), split_size_or_sections=20_000, dim=0)


In [9]:
xtx = traker.score_computer.get_xtx(ch.as_tensor(grads)).cpu()

In [10]:
from tqdm import tqdm

In [11]:
lambda_reg = 0.0  # Default regularization term
dtype = ch.float16  # Using float16 as default
CUDA_MAX_DIM_SIZE = 100_000
grads = ch.as_tensor(traker.saver.current_store["grads"])
blocks = ch.split(grads, split_size_or_sections=CUDA_MAX_DIM_SIZE, dim=0)
xtx_reg = xtx + lambda_reg * torch.eye(
    xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype
)
xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32))
xtx_inv /= xtx_inv.abs().mean()

xtx_inv = xtx_inv.to(dtype)

In [11]:
# result = ch.empty(grads.shape[0], xtx_inv.shape[1], dtype=dtype, device="cpu")

# for i, block in enumerate(tqdm(blocks, desc="Computing X^TX inverse")):
#     start = i * CUDA_MAX_DIM_SIZE
#     end = min(grads.shape[0], (i + 1) * CUDA_MAX_DIM_SIZE)
#     result[start:end] = block @ xtx_inv
#     break

In [12]:
import jax
import jax.numpy as jnp


In [13]:
grads_jax = jnp.array(grads, device=jax.devices("cpu")[0])

In [14]:
xtx_inv_jax = jnp.array(xtx_inv, device=jax.devices("cpu")[0])

In [15]:
# xtx_inv_jax = process_large_grads(grads_jax, xtx_inv_jax, CUDA_MAX_DIM_SIZE)

In [15]:
# Find factors of grads_jax.shape[0]
n = grads_jax.shape[0]
factors = [i for i in range(1, n + 1) if n % i == 0]
smallest_factor_over_100 = next(f for f in sorted(factors) if f > 100)
print(f"Smallest factor of {n} over 100: {smallest_factor_over_100}")

Smallest factor of 10367394 over 100: 349


In [16]:
grads_blocks = jnp.split(grads_jax, smallest_factor_over_100)

In [17]:
xtx_inv_gpu = jax.device_put(xtx_inv_jax, jax.devices("cuda")[0])


In [18]:
xtx_inv_blocks = [
    jax.device_get(jax.device_put(block, jax.devices("cuda")[0]) @ xtx_inv_gpu)
    for block in tqdm(grads_blocks)
]
with jax.default_device(jax.devices("cpu")[0]):
    xtx_inv_jax = jnp.concatenate(xtx_inv_blocks)


100%|██████████| 349/349 [00:28<00:00, 12.16it/s]


In [19]:
traker.saver.current_store["features"][:] = ch.as_tensor(
    np.asarray(xtx_inv_jax), device="cpu"
)
traker.saver.model_ids[encoder_cfg.model_id]["is_finalized"] = 1
traker.saver.serialize_current_model_id_metadata()

  traker.saver.current_store["features"][:] = ch.as_tensor(


# end

In [6]:
# traker.finalize_features(model_ids=[0, 1])

Finalizing features for all model IDs..:   0%|          | 0/2 [00:00<?, ?it/s]

Output()

Output()

In [None]:
# from pathlib import Path

# from rich.console import Console
# from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn

# console = Console()

# for c in cfg.encoders[2:]:
#     console.rule(f"[bold red]Processing encoder {c.name}")

#     with Progress(
#         TextColumn("[progress.description]{task.description}"),
#         BarColumn(),
#         TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
#         TimeElapsedColumn(),
#         # TimeRemainingColumn(),
#         console=console,
#     ) as progress:
#         # Create task for each operation
#         checkpoint_task = progress.add_task("Loading checkpoint...", total=1)
#         traker_task = progress.add_task(
#             "Loading checkpoint into TRAKer...", total=1, visible=False
#         )
#         path_task = progress.add_task(
#             "Setting up input path...", total=1, visible=False
#         )
#         dataset_task = progress.add_task(
#             "Loading dataset...", total=1, visible=False
#         )
#         table_task = progress.add_task(
#             "Converting to table...", total=1, visible=False
#         )
#         grads_task = progress.add_task(
#             "Stacking gradients...", total=1, visible=False
#         )
#         loss_grads_task = progress.add_task(
#             "Stacking loss gradients...", total=1, visible=False
#         )
#         store_grads_task = progress.add_task(
#             "Storing gradients...", total=1, visible=False
#         )
#         store_loss_task = progress.add_task(
#             "Storing loss gradients...", total=1, visible=False
#         )
#         flag_task = progress.add_task(
#             "Setting featurization flag...", total=1, visible=False
#         )
#         meta_task = progress.add_task(
#             "Serializing metadata...", total=1, visible=False
#         )

#         # Load checkpoint
#         checkpoint = torch.load(c.path, map_location="cpu")
#         progress.update(checkpoint_task, advance=1)

#         # Load into TRAKer
#         progress.update(traker_task, visible=True)
#         traker.load_checkpoint(checkpoint, c.model_id)
#         progress.update(traker_task, advance=1)

#         # Setup input path
#         progress.update(path_task, visible=True)
#         input_path = str(Path(cfg.output_dir) / c.name / c.ood_dataset_name)
#         progress.update(path_task, advance=1)

#         # Load dataset
#         progress.update(dataset_task, visible=True)
#         dataset = ds.dataset(input_path, format="parquet")
#         progress.update(dataset_task, advance=1)

#         # Convert to table
#         progress.update(table_task, visible=True)
#         table = dataset.to_table(columns=["grads", "loss_grads"])
#         progress.update(table_task, advance=1)

#         # Stack gradients
#         progress.update(grads_task, visible=True)
#         grads = np.stack(table["grads"].to_numpy())
#         progress.update(grads_task, advance=1)

#         # Stack loss gradients
#         progress.update(loss_grads_task, visible=True)
#         loss_grads = np.stack(table["loss_grads"].to_numpy())
#         progress.update(loss_grads_task, advance=1)

#         # Store gradients
#         progress.update(store_grads_task, visible=True)
#         traker.saver.current_store["grads"][:] = grads
#         progress.update(store_grads_task, advance=1)

#         # Store loss gradients
#         progress.update(store_loss_task, visible=True)
#         traker.saver.current_store["out_to_loss"][:] = loss_grads[
#             :, np.newaxis
#         ]
#         progress.update(store_loss_task, advance=1)

#         # Set featurization flag
#         progress.update(flag_task, visible=True)
#         traker.saver.current_store["is_featurized"][:] = 1
#         progress.update(flag_task, advance=1)

#         # Serialize metadata
#         progress.update(meta_task, visible=True)
#         traker.saver.serialize_current_model_id_metadata()
#         progress.update(meta_task, advance=1)

#     console.print(f"[bold green]✓ Finished processing encoder {c.name}\n")