In [1]:
"""
While the activation were computed they were saved in chunks corresponding to batches.
This script combines these batch-wise saved activations into single tensors per block for easier analysis.
WARNING: This script is just for activation preparation and supposed to be run only once. 
"""



In [2]:
import os
import sys
PROJECT_ROOT = os.path.abspath("../..")
sys.path.insert(0, PROJECT_ROOT)

print("Project root:", PROJECT_ROOT)

import torch
from models import VQVAE, build_vae_var
from memorization.data_prep.subset_imagenet import get_balanced_imagenet_dataset
from pathlib import Path
import shutil

Project root: /BS/scene_repre/work/VAR


In [4]:
import os
import torch
from glob import glob

def combine_ranks_for_block(
    main_dir,
    layer,
    block_name,
    expected_num_images=12800,
):
    """
    Combine activation files from all ranks for one block and one layer.

    Produces a tensor of shape:
        (num_images, S, C)
    ordered by dataset index.
    """

    all_indices = []
    all_activations = []

    # Iterate over all rank directories (rank0, rank1, ...)
    for rank_dir in sorted(os.listdir(main_dir)):
        rank_path = os.path.join(main_dir, rank_dir, layer, block_name)

        if not os.path.isdir(rank_path):
            continue

        # Load all batch files from this rank
        batch_files = sorted(glob(os.path.join(rank_path, "batch*.pt")))

        for bf in batch_files:
            data = torch.load(bf, map_location="cpu")

            # data["indices"]: (B,)
            # data["activations"]: (B, S, C)
            all_indices.append(data["indices"])
            all_activations.append(data["activations"])

    # Concatenate everything
    all_indices = torch.cat(all_indices, dim=0)        # (N,)
    all_activations = torch.cat(all_activations, dim=0)  # (N, S, C)

    assert all_indices.shape[0] == all_activations.shape[0], \
        "Mismatch between indices and activations!"

    # Sort by dataset index
    sorted_indices, sort_order = torch.sort(all_indices)
    sorted_activations = all_activations[sort_order]

    # Sanity checks
    assert sorted_indices.unique().numel() == sorted_indices.numel(), \
        "Duplicate indices detected!"
    assert sorted_indices.numel() == expected_num_images, \
        f"Expected {expected_num_images} images, got {sorted_indices.numel()}"

    return sorted_activations, sorted_indices


In [5]:
main_dir = "/scratch/inf0/user/hpetekka/var_mem/output_activations_corrected_test"
layer = "fc1_act"

final_dir = os.path.join(main_dir, "combined", layer)
os.makedirs(final_dir, exist_ok=True)

for block_id in range(16):
    block_name = f"block_{block_id}"

    activations, indices = combine_ranks_for_block(
        main_dir=main_dir,
        layer=layer,
        block_name=block_name,
        expected_num_images=12800,
    )

    # activations: (12800, 10, 4096)
    torch.save(
        {
            "activations": activations,
            "indices": indices,  # optional but nice to keep
        },
        os.path.join(final_dir, f"{block_name}.pt")
    )

    print(f"Saved combined {layer}/{block_name}: {activations.shape}")


Saved combined fc1_act/block_0: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_1: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_2: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_3: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_4: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_5: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_6: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_7: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_8: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_9: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_10: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_11: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_12: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_13: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_14: torch.Size([12800, 10, 4096])
Saved combined fc1_act/block_15: torch.Size([12800, 10, 4096])
