git checkout -b dummy #commit hash
git delete branch

In [1]:
from __future__ import annotations
import os
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import time
from tqdm import tqdm

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel

from app.vjepa.transforms import make_transforms
import root.args as parser #import _parse_args, _resolve_sampling_kwargs
from root.models.model import _build_model
import root.utils as utils
import root.dataset as dataset
from root.ddp import _wrap_ddp, _ddp_mean, _is_distributed
from src.datasets.video_dataset import make_videodataset
from src.utils.distributed import init_distributed
from src.utils.logging import AverageMeter, get_logger
import json
import argparse
from train import run_validation

logger = get_logger(__name__, force=True)

In [6]:
# import subprocess
# import os

# commit_hash = "a1b2c3d"
# worktree_path = os.path.abspath(f"../temp_{commit_hash}") # Define the temporary folder path
# subprocess.run(["git", "worktree", "add", "--detach", worktree_path, commit_hash]) # 1. Create worktree (detached state)

# # # 2. Run the eval script inside that folder
# # # 'cwd' sets the working directory for this command only
# # subprocess.run(["python", "eval.py", "--checkpoint", ckpt_path], cwd=worktree_path)

# # 3. Force remove the worktree
# subprocess.run(["git", "worktree", "remove", "--force", worktree_path])


checkpoint_name = 'LoadCachedVitS_NoReset_GLU>FFN>Lin_dcfv33tc'
output_dir = f"/data3/mgaur/vjepa2"
config = json.load(open(os.path.join(output_dir, checkpoint_name, "config.json"), "r"))
args = utils.dict_to_namespace(config)


In [4]:
# torchrun compatibility: prefer env vars if available, but still call repo helper.
env_rank = utils._env_int("RANK", None)
env_world = utils._env_int("WORLD_SIZE", None)
world_size, rank = init_distributed(port=args.dist_port, rank_and_world_size=(env_rank, env_world))

device = utils._get_device()
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

is_master = rank == 0
logger.info(f"Initialized device={device}, rank/world={rank}/{world_size}")

# --- data
if args.cache_dino_feats:
    mode = "eval" #when load_cache_feats=true, transforms not called, automatically eval mode
else:
    mode = "train"
train_transform = make_transforms(mode=mode, crop_size=args.crop_size)
eval_transform = make_transforms(mode="eval", crop_size=args.crop_size)
sampling_kwargs = parser._resolve_sampling_kwargs(args)


train_ds, train_loader, train_sampler, val_ds, val_loader, val_sampler = dataset.get_loaders(args, train_transform, eval_transform, sampling_kwargs, rank, world_size, is_master)

# --- model / opt
model = _build_model(args, device)
model = _wrap_ddp(model, device, world_size)
checkpoint = torch.load(os.path.join(output_dir, checkpoint_name, "best.pt"), weights_only=False)
model.load_state_dict(checkpoint["model"])

[INFO    ][2026-01-27 12:55:42][root                ][init_distributed         ] SLURM vars not set (distributed training not available)
[INFO    ][2026-01-27 12:55:42][__main__            ][<module>                 ] Initialized device=cuda:0, rank/world=0/1
[INFO    ][2026-01-27 12:55:43][root                ][make_videodataset        ] VideoDataset dataset created
[INFO    ][2026-01-27 12:55:43][root                ][make_videodataset        ] VideoDataset unsupervised data loader created
[INFO    ][2026-01-27 12:55:43][root                ][make_videodataset        ] VideoDataset dataset created
[INFO    ][2026-01-27 12:55:43][root                ][make_videodataset        ] VideoDataset unsupervised data loader created


<All keys matched successfully>

In [None]:
model.eval()
with torch.no_grad():
    loss, acc = run_validation(model, val_loader, device, world_size, epoch=0, step=0, is_master=is_master)

if _is_distributed(world_size):
    dist.destroy_process_group()

print(f"acc: {acc}")
print(f"loss: {loss}")

Validating epoch 0: 100%|██████████| 194/194 [00:02<00:00, 70.29it/s]

acc: 0.38798078894615173
loss: 2.7454991340637207



