In [1]:
from hydra import initialize, compose
import dotenv
import os
import pathlib

from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.models import ModelFactory


In [2]:
def get_mod(run_id: str, device):
    with initialize("../configs", version_base="1.2.0"):
        cfg = compose(
            "config.yaml",
            overrides=[
                "compute.distributed=False",
                "dataset=imagenet",
                "model=vit",
                f"experiment.run_id={run_id}",
                "training.batch_size=2",
            ],
        )
    dotenv.load_dotenv("../.env", override=True)
    os.environ["IMAGE_NET_PATH"]
    checkpoint_dir = pathlib.Path(f"../artifacts/checkpoints/20230601_{run_id}")
    checkpoint = Checkpoint.load_best_checkpoint(checkpoint_dir=checkpoint_dir)
    model_state = checkpoint.model
    model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet
    )
    model.to(device)
    try:
        model.load_state_dict(model_state)
    except RuntimeError:
        model_state = (
            checkpoint.get_single_process_model_state_from_distributed_state()
        )
        model.load_state_dict(model_state)
    return model.get_submodule("encoder.layers.encoder_layer_11.mlp.0")


__RUN_IDS = {90: "nrblbn15"}
t_fc = get_mod(__RUN_IDS[90], "cpu")



In [11]:
t_fc.in_features

768

In [15]:
import torch

import jax
from typing import Any, Callable, Sequence, Optional, Tuple, Union
from jax import random, vmap, numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial

with torch.no_grad():
    kernel = t_fc.weight.detach().cpu().numpy()
    print(kernel.shape)
    bias = t_fc.bias.detach().cpu().numpy()

    # [outC, inC] -> [inC, outC]
    kernel = jnp.transpose(kernel, (1, 0))

    key = random.key(0)
    x = random.normal(key, (64, 768))

    variables = {'params': {'kernel': kernel, 'bias': bias}}
    j_fc = nn.Dense(features=3072)
    j_out = j_fc.apply(variables, x)

    t_x = torch.from_numpy(np.array(x))
    t_out = t_fc(t_x)
    t_out = t_out.detach().cpu().numpy()

    np.testing.assert_almost_equal(j_out, t_out, decimal=3)
    

(3072, 768)
