In [1]:
%env CUDA_VISIBLE_DEVICES=2
from pathlib import Path
from pprint import pprint

import numpy as np
import pyarrow.dataset as ds
import torch
from config import Config
from model import Model
from trak import TRAKer

cfg = Config()
# cfg.device="cpu"
pprint(cfg)
encoder_cfg = cfg.encoders[0]

input_path = str(
    Path(cfg.output_dir) / encoder_cfg.name / encoder_cfg.ood_dataset_name
)
dataset = ds.dataset(input_path, format="parquet")
train_set_size = dataset.count_rows()
model = Model(encoder_cfg, "cpu")
model, _, _, _ = model.create_model_and_transforms()
traker = TRAKer(
    save_dir=cfg.save_dir,
    model=model,
    task="clip",
    train_set_size=train_set_size,
    device=cfg.device,
    proj_dim=cfg.proj_dim,
    use_half_precision=True,
)

env: CUDA_VISIBLE_DEVICES=2
Config(device='cuda',
       worker_id=0,
       worker_total=20,
       dry_run=False,
       debug=False,
       output_dir='/raid/pdpl/trak/grads/',
       save_dir='/raid/pdpl/trak/trak_results/',
       write_chunks=1000,
       seed=42,
       proj_dim=2048,
       num_contrastive_samples=50000,
       datasets={'commonpool': DatasetConfig(uri='/datasets/datacomp/shards/{00000000..00001287}.tar',
                                             uris=None,
                                             size=None,
                                             num_workers=16,
                                             num_samples=10367394),
                 'fairvision/AMD': DatasetConfig(uri='/datasets/fairvision/AMD/shards/amd-train-{000000..000005}.tar',
                                                 uris=None,
                                                 size=None,
                                                 num_workers=16,
                     

INFO:TRAK:Using ChunkedCudaProjector with2 chunks of sizes[125979393, 25297920].
INFO:STORE:Existing model IDs in /raid/pdpl/trak/trak_results: [0, 1, 2]
INFO:STORE:No model IDs in /raid/pdpl/trak/trak_results have been finalized.
INFO:STORE:No existing TRAK scores in /raid/pdpl/trak/trak_results.


In [2]:
traker.finalize_features()

Processing blocks: 100%|██████████| 349/349 [10:55<00:00,  1.88s/it]<?, ?it/s]
Processing blocks: 100%|██████████| 349/349 [12:01<00:00,  2.07s/it]<50:10, 1505.41s/it]
Processing blocks: 100%|██████████| 349/349 [00:40<00:00,  8.67it/s]<30:03, 1803.31s/it]
Finalizing features for all model IDs..: 100%|██████████| 3/3 [1:24:57<00:00, 1699.22s/it]


In [None]:
from pathlib import Path

from rich.console import Console
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn

console = Console()

for c in cfg.encoders[:2]:
    console.rule(f"[bold red]Processing encoder {c.name}")

    with Progress(
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        # TimeRemainingColumn(),
        console=console,
    ) as progress:
        # Create task for each operation
        checkpoint_task = progress.add_task("Loading checkpoint...", total=1)
        traker_task = progress.add_task(
            "Loading checkpoint into TRAKer...", total=1, visible=False
        )
        path_task = progress.add_task(
            "Setting up input path...", total=1, visible=False
        )
        dataset_task = progress.add_task(
            "Loading dataset...", total=1, visible=False
        )
        table_task = progress.add_task(
            "Converting to table...", total=1, visible=False
        )
        sort_task = progress.add_task(
            "Sorting by uid...", total=1, visible=False
        )
        grads_task = progress.add_task(
            "Stacking gradients...", total=1, visible=False
        )
        loss_grads_task = progress.add_task(
            "Stacking loss gradients...", total=1, visible=False
        )
        store_grads_task = progress.add_task(
            "Storing gradients...", total=1, visible=False
        )
        store_loss_task = progress.add_task(
            "Storing loss gradients...", total=1, visible=False
        )
        flag_task = progress.add_task(
            "Setting featurization flag...", total=1, visible=False
        )
        meta_task = progress.add_task(
            "Serializing metadata...", total=1, visible=False
        )

        # Load checkpoint
        checkpoint = torch.load(c.path, map_location="cpu")
        progress.update(checkpoint_task, advance=1)

        # Load into TRAKer
        progress.update(traker_task, visible=True)
        traker.load_checkpoint(checkpoint, c.model_id)
        progress.update(traker_task, advance=1)

        # Setup input path
        progress.update(path_task, visible=True)
        input_path = str(Path(cfg.output_dir) / c.name / c.ood_dataset_name)
        progress.update(path_task, advance=1)

        # Load dataset
        progress.update(dataset_task, visible=True)
        dataset = ds.dataset(input_path, format="parquet")
        progress.update(dataset_task, advance=1)

        # Convert to table
        progress.update(table_task, visible=True)
        table = dataset.to_table(columns=["uid", "grads", "loss_grads"])
        progress.update(table_task, advance=1)

        # Sorting by uid
        progress.update(sort_task, visible=True)
        table = table.sort_by("uid")
        progress.update(sort_task, advance=1)

        # Stack gradients
        progress.update(sort_task, visible=True)
        grads = np.stack(table["grads"].to_numpy())
        progress.update(grads_task, advance=1)

        # Stack loss gradients
        progress.update(loss_grads_task, visible=True)
        loss_grads = np.stack(table["loss_grads"].to_numpy())
        progress.update(loss_grads_task, advance=1)

        # Store gradients
        progress.update(store_grads_task, visible=True)
        traker.saver.current_store["grads"][:] = grads
        progress.update(store_grads_task, advance=1)

        # Store loss gradients
        progress.update(store_loss_task, visible=True)
        traker.saver.current_store["out_to_loss"][:] = loss_grads[
            :, np.newaxis
        ]
        progress.update(store_loss_task, advance=1)

        # Set featurization flag
        progress.update(flag_task, visible=True)
        traker.saver.current_store["is_featurized"][:] = 1
        progress.update(flag_task, advance=1)

        # Serialize metadata
        progress.update(meta_task, visible=True)
        traker.saver.serialize_current_model_id_metadata()
        progress.update(meta_task, advance=1)

    console.print(f"[bold green]✓ Finished processing encoder {c.name}\n")