In [2]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 2

import os
import custom_datasets
from custom_datasets.rollout_push_any import RolloutPushAnyDataset
from custom_datasets.concat_datasets import ConcatDataset
from custom_datasets.pusht import PushTDataset
from torch.utils.data import DataLoader
import re
from pathlib import Path
import zarr
import numpy as np
import torch
import torch.nn.functional as F
import imageio
from IPython.display import Video

from typing import Dict
import time

import hydra
from hydra import initialize, compose
from omegaconf import OmegaConf

In [3]:
initialize(config_path="../configs", job_name="notebook_job")
cfg = compose(config_name="train_pusht_rollout.yaml", overrides=['subset_fraction=0.01'])

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../configs", job_name="notebook_job")


## Create Dataset

In [6]:
dataset: ConcatDataset = hydra.utils.instantiate(cfg.env.dataset)

In [29]:
def split_and_slice_dataset(dataset):
    kwargs = {
        "train_fraction": cfg.train_fraction,
        "random_seed": cfg.seed,
        "window_size": cfg.window_size,
        "future_conditional": (cfg.goal_conditional == "future"),
        "min_future_sep": cfg.min_future_sep,
        "future_seq_len": cfg.goal_seq_len,
        "num_extra_predicted_actions": cfg.num_extra_predicted_actions,
    }
    return custom_datasets.core.get_train_val_sliced(dataset, **kwargs)

In [30]:
train_set, test_set = split_and_slice_dataset(dataset)

In [31]:
train_loader = DataLoader(train_set, shuffle=False, batch_size=8, pin_memory=True)
iterator = iter(train_loader)

In [32]:
obs, _, _ = next(iterator)

In [33]:
obs.size()

torch.Size([8, 5, 1, 3, 224, 224])

## Create Models

In [38]:
encoder = hydra.utils.instantiate(cfg.encoder).to('cuda')
encoder_optim = torch.optim.AdamW(
    params=encoder.parameters(),
    lr=cfg.ssl_lr,
    weight_decay=cfg.ssl_weight_decay,
    betas=tuple(cfg.betas),
)

In [39]:
dino_head = hydra.utils.instantiate(cfg.ssl.dino_head).to('cuda')
dino_head_optim = torch.optim.AdamW(
params=dino_head.parameters(),
lr=cfg.ssl_lr,
weight_decay=cfg.ssl_weight_decay,
betas=tuple(cfg.betas))

In [40]:
projector = hydra.utils.instantiate(cfg.projector, _recursive_=False).to('cuda')
projector_optim = projector.configure_optimizers(
    lr=cfg.ssl_lr,
    weight_decay=cfg.ssl_weight_decay,
    betas=tuple(cfg.betas),
)    

In [41]:
ssl = hydra.utils.instantiate(
    cfg.ssl,
    encoder=encoder,
    dino_head=dino_head,
    projector=projector,
).to('cuda')

## Forward pass

In [34]:
obs, _, _ = next(iterator)

In [63]:
obs_enc, obs_proj, ssl_loss, ssl_loss_components = ssl.forward(obs.to('cuda'))

In [64]:
ssl_loss.backward()

In [65]:
org_p = next(ssl.forward_dynamics.parameters()).clone().detach()
ssl.step()
torch.equal(org_p, next(ssl.forward_dynamics.parameters()))

False

In [60]:
next(projector.parameters()).grad

In [62]:
ssl.step()

## Load Checkpoint

In [1]:
checkpoint_root = '/home/sm/PycharmProjects/dynamo_ssl/exp_local/2024.12.15/185030_train_pusht_rollout_dynamo'

In [4]:
encoder = hydra.utils.instantiate(cfg.encoder)



In [24]:
ckpt_1 = torch.load(checkpoint_root + '/snapshot_1.pt')
ckpt_2 = torch.load(checkpoint_root + '/snapshot_3.pt')

model_key = 'dino_head'

model1 = ckpt_1[model_key]
model2 = ckpt_2[model_key]

In [26]:
model1_sd = {k: v for k, v in model1.named_parameters()}
model2_sd = {k: v for k, v in model2.named_parameters()}

In [27]:
for k, v in model1_sd.items():
    if torch.equal(v, model2_sd[k]):
        print(k)

last_layer.weight_g
