In [1]:
import sys
from omegaconf import OmegaConf
import torch
from tqdm import tqdm
from util import *

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
class RunningStatsTensors(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.n = 0  # Count of data points seen so far
        # Initialize mean and variance tensors as parameters to enable device transfer
        self.mean = torch.nn.Parameter(torch.zeros(1))
        self.S = torch.nn.Parameter(torch.zeros(1))

    def update(self, x):
        batch_size = x.numel()  # Get the number of elements in the new tensor

        if self.n == 0:
            # First batch initialization
            self.mean.data = x.mean()  # Initialize the mean with the batch mean
            # Initialize variance with the sum of squared differences from the batch mean
            self.S.data = ((x - self.mean) ** 2).sum()
        else:
            # Calculate the new total number of elements
            new_n = self.n + batch_size
            delta = x.mean() - self.mean
            new_mean = self.mean + delta * batch_size / new_n

            # Update the variance (S) using Welford's method
            self.S.data += ((x - self.mean)**2).sum()

            # Update the running mean
            self.mean.data = new_mean

        # Update the count of elements
        self.n += batch_size

    def get_mean(self):
        return self.mean
    
    def get_std(self):
        # Unbiased estimate: divide by n - 1
        if self.n > 1:
            return torch.sqrt(self.S / (self.n - 1))
        else:
            return torch.tensor(0.0, device=self.mean.device)

In [3]:
running_stats = RunningStatsTensors()
running_stats2 = RunningStatsTensors()

In [4]:
ckpt = "logs_tk/2025-05-20T16-27-30_vq_IF_dino2_e16_DLC11518833_dec/checkpoints/last.ckpt"
cfg = "logs_tk/2025-05-20T16-27-30_vq_IF_dino2_e16_DLC11518833_dec/config.yaml"

In [5]:
config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model)

model.load_state_dict(torch.load(ckpt)["state_dict"], strict=True)
model = model.cuda()
_ = model.eval()


data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
train_loader = data.train_dataloader()

VQLPIPSWithDiscriminator initialized with hinge loss.


[2025-06-28 13:10:02] [INFO] Loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
[2025-06-28 13:10:03] [INFO] Loading pretrained weights from Hugging Face hub (timm/vit_base_patch14_dinov2.lvd142m)
[2025-06-28 13:10:03] [INFO] [timm/vit_base_patch14_dinov2.lvd142m] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
[2025-06-28 13:10:03] [INFO] Resized position embedding: (37, 37) to (16, 16).
[2025-06-28 13:10:05] [INFO] Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.mae)
[2025-06-28 13:10:05] [INFO] [timm/vit_base_patch16_224.mae] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
[2025-06-28 13:10:05] [INFO] Resized position embedding: (14, 14) to (16, 16).
  model.load_state_dict(torch.load(ckpt)["state_dict"], strict=True)
[2025-06-28 13:10:09] [INFO] Loading dataset: train
[2025-06-28 13:10:39] [INFO] <data.cu

Total dataset length: 492929
Total dataset length: 5258


In [None]:
print("Calculating dataset statistics...")


for idx, batch in tqdm(enumerate(train_loader), desc="Processing Batches"):
    # Compute the statistics for the current batch
    with torch.no_grad():
        # Forward pass through the model
        

        x = batch # ['image']
        # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
        x = x.cuda()
        x= x.float()



        h = model.encode(x)['continuous']
        if isinstance(h, tuple):
            h1 = h[0]
            h2 = h[1]
            

            h1 = h1.view(-1)
            h2 = h2.view(-1)
            # Update the running statistics
            running_stats.update(h1)
            running_stats2.update(h2)

            if idx %500==0 and idx > 0:
                mean = running_stats.get_mean()
                std = running_stats.get_std()
                print(f"1: Latent mean: {mean.mean()}, Latent std: {std.mean()}")
                mean2 = running_stats2.get_mean()
                std2 = running_stats2.get_std()
                print(f"2: Latent mean: {mean2.mean()}, Latent std: {std2.mean()}")
                print(f'enc_scale: {1/std.mean()},  enc_scale_dino: {1/std2.mean()}')
        else:
            h = h.view(-1)
            # Update the running statistics
            running_stats.update(h)

            if idx %500==0 and idx > 0:
                mean = running_stats.get_mean()
                std = running_stats.get_std()
                print(f"Latent mean: {mean.mean()}, Latent std: {std.mean()}")
                print(f'enc_scale: {1/std.mean()}')


Calculating dataset statistics...


Processing Batches: 500it [02:44,  4.87it/s]

1: Latent mean: 0.03317122906446457, Latent std: 0.35201364755630493
2: Latent mean: 0.013119198381900787, Latent std: 0.3533176779747009
enc_scale: 2.840798854827881,  enc_scale_dino:2.8303141593933105


Processing Batches: 1000it [05:03,  6.01it/s]

1: Latent mean: 0.03351333364844322, Latent std: 0.35197195410728455
2: Latent mean: 0.013172510080039501, Latent std: 0.3533124327659607
enc_scale: 2.8411355018615723,  enc_scale_dino:2.8303561210632324


Processing Batches: 1499it [07:11,  4.43it/s]

1: Latent mean: 0.03354831784963608, Latent std: 0.3519652485847473
2: Latent mean: 0.013123990967869759, Latent std: 0.3533129394054413
enc_scale: 2.8411896228790283,  enc_scale_dino:2.8303520679473877


Processing Batches: 1542it [07:19,  8.09it/s]