In [1]:
import torch
import random
import yaml
import pprint
import copy

from dataclasses import dataclass, asdict

from src.models import modules
from src.models.modules import text_encoder_model, x_t2i_module, vit_predictor, SelfThenCrossBlock

DEVICE_0 = 'cpu'

##################
with open('configs/in1k_vith14_ep300.yaml', 'r') as y_file:
    params = yaml.load(y_file, Loader=yaml.FullLoader)
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(params)


{   'data': {   'batch_size': 16,
                'color_jitter_strength': 0.0,
                'crop_scale': [0.3, 1.0],
                'crop_size': 224,
                'image_folder': '.',
                'num_workers': 10,
                'pin_mem': True,
                'root_path': 'src/datasets/',
                'use_color_distortion': False,
                'use_gaussian_blur': False,
                'use_horizontal_flip': False},
    'logging': {'folder': 'src/loggings', 'write_tag': 'jepa'},
    'mask': {   'allow_overlap': False,
                'aspect_ratio': [0.75, 1.5],
                'enc_mask_scale': [0.85, 1.0],
                'min_keep': 10,
                'num_enc_masks': 1,
                'num_pred_masks': 4,
                'patch_size': 14,
                'pred_mask_scale': [0.15, 0.2]},
    'meta': {   'copy_data': False,
                'load_checkpoint': False,
                'model_name': 'vit_huge',
                'pred_depth': 12,
                '

In [2]:

@dataclass
class ModelConfig:
    SIZE: int = 224
    PATCH_SIZE: int = params['mask']['patch_size']

    V_EMBED_DIM: int = 1280
    T_EMBED_DIM: int = 768
    H_EMBED_DIM: int = 768
    PRED_EMBED_DIM: int = params['meta']['pred_emb_dim']

    DROP_RATE: float = 0.15
    ATTN_DROP_RATE: float = 0.15
    MLP_RATIO: float = 4.0

    PRED_ATTN_DEPTH: int = params['meta']['pred_depth']
    CROSS_ATTN_DEPTH: int = 4

    PRED_NUM_HEADS: int = 12
    CROSS_NUM_HEADS: int = 8

MODEL_CONFIG = ModelConfig()

print(f"Init models...")

# Target T2I Module
crosser = x_t2i_module(
    text_embed_dim=MODEL_CONFIG.T_EMBED_DIM,
    vision_embed_dim=MODEL_CONFIG.V_EMBED_DIM,
    hidden_dim=MODEL_CONFIG.H_EMBED_DIM,
    depth=MODEL_CONFIG.CROSS_ATTN_DEPTH,
    num_heads=MODEL_CONFIG.CROSS_NUM_HEADS,
    mlp_ratio=MODEL_CONFIG.MLP_RATIO,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=MODEL_CONFIG.DROP_RATE,
    attn_drop_rate=MODEL_CONFIG.ATTN_DROP_RATE,
    cross_block=SelfThenCrossBlock
).to(DEVICE_0)
crosser_total_params = sum(p.numel() for p in crosser.parameters())
print(f"{crosser_total_params=}")

print('\n\nDone init models\n\n')


Init models...
crosser_total_params=38792448


Done init models




In [3]:

@dataclass
class ModelConfig:
    SIZE: int = 224
    PATCH_SIZE: int = params['mask']['patch_size']

    V_EMBED_DIM: int = 1280
    T_EMBED_DIM: int = 768
    H_EMBED_DIM: int = 768
    PRED_EMBED_DIM: int = params['meta']['pred_emb_dim']

    DROP_RATE: float = 0.15
    ATTN_DROP_RATE: float = 0.15
    MLP_RATIO: float = 4.0

    PRED_ATTN_DEPTH: int = params['meta']['pred_depth']
    CROSS_ATTN_DEPTH: int = 6

    PRED_NUM_HEADS: int = 12
    CROSS_NUM_HEADS: int = 10

MODEL_CONFIG = ModelConfig()

print(f"Init models...")

# Target T2I Module
crosser = x_t2i_module(
    text_embed_dim=MODEL_CONFIG.T_EMBED_DIM,
    vision_embed_dim=MODEL_CONFIG.V_EMBED_DIM,
    hidden_dim=MODEL_CONFIG.H_EMBED_DIM,
    depth=MODEL_CONFIG.CROSS_ATTN_DEPTH,
    num_heads=MODEL_CONFIG.CROSS_NUM_HEADS,
    mlp_ratio=MODEL_CONFIG.MLP_RATIO,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=MODEL_CONFIG.DROP_RATE,
    attn_drop_rate=MODEL_CONFIG.ATTN_DROP_RATE,
    cross_block=SelfThenCrossBlock
).to(DEVICE_0)
crosser_total_params = sum(p.numel() for p in crosser.parameters())
print(f"{crosser_total_params=}")

print('\n\nDone init models\n\n')


Init models...
crosser_total_params=57696000


Done init models




In [4]:

@dataclass
class ModelConfig:
    SIZE: int = 224
    PATCH_SIZE: int = params['mask']['patch_size']

    V_EMBED_DIM: int = 1280
    T_EMBED_DIM: int = 768
    H_EMBED_DIM: int = 1024
    PRED_EMBED_DIM: int = params['meta']['pred_emb_dim']

    DROP_RATE: float = 0.15
    ATTN_DROP_RATE: float = 0.15
    MLP_RATIO: float = 4.0

    PRED_ATTN_DEPTH: int = params['meta']['pred_depth']
    CROSS_ATTN_DEPTH: int = 8

    PRED_NUM_HEADS: int = 12
    CROSS_NUM_HEADS: int = 12

MODEL_CONFIG = ModelConfig()

print(f"Init models...")

# Target T2I Module
crosser = x_t2i_module(
    text_embed_dim=MODEL_CONFIG.T_EMBED_DIM,
    vision_embed_dim=MODEL_CONFIG.V_EMBED_DIM,
    hidden_dim=MODEL_CONFIG.H_EMBED_DIM,
    depth=MODEL_CONFIG.CROSS_ATTN_DEPTH,
    num_heads=MODEL_CONFIG.CROSS_NUM_HEADS,
    mlp_ratio=MODEL_CONFIG.MLP_RATIO,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=MODEL_CONFIG.DROP_RATE,
    attn_drop_rate=MODEL_CONFIG.ATTN_DROP_RATE,
    cross_block=SelfThenCrossBlock
).to(DEVICE_0)
crosser_total_params = sum(p.numel() for p in crosser.parameters())
print(f"{crosser_total_params=}")

print('\n\nDone init models\n\n')


Init models...
crosser_total_params=131492864


Done init models




In [5]:

@dataclass
class ModelConfig:
    SIZE: int = 448
    PATCH_SIZE: int = 16

    V_EMBED_DIM: int = 1280
    T_EMBED_DIM: int = 768
    H_EMBED_DIM: int = 768
    PRED_EMBED_DIM: int = params['meta']['pred_emb_dim']

    DROP_RATE: float = 0.15
    ATTN_DROP_RATE: float = 0.15
    MLP_RATIO: float = 4.0

    PRED_ATTN_DEPTH: int = params['meta']['pred_depth']
    CROSS_ATTN_DEPTH: int = 4

    PRED_NUM_HEADS: int = 12
    CROSS_NUM_HEADS: int = 8

MODEL_CONFIG = ModelConfig()

print(f"Init models...")

# Target T2I Module
crosser = x_t2i_module(
    text_embed_dim=MODEL_CONFIG.T_EMBED_DIM,
    vision_embed_dim=MODEL_CONFIG.V_EMBED_DIM,
    hidden_dim=MODEL_CONFIG.H_EMBED_DIM,
    depth=MODEL_CONFIG.CROSS_ATTN_DEPTH,
    num_heads=MODEL_CONFIG.CROSS_NUM_HEADS,
    mlp_ratio=MODEL_CONFIG.MLP_RATIO,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=MODEL_CONFIG.DROP_RATE,
    attn_drop_rate=MODEL_CONFIG.ATTN_DROP_RATE,
).to(DEVICE_0)
crosser_total_params = sum(p.numel() for p in crosser.parameters())
print(f"{crosser_total_params=}")

print('\n\nDone init models\n\n')


Init models...
crosser_total_params=29336832


Done init models




In [2]:
import torch
checkpoint = torch.load("trains/VQA-1732161891/epoch-30.pt", map_location='cpu')


In [4]:
checkpoint.keys()

dict_keys(['crosser', 'mlp_head', 'opt', 'scaler', 'epoch', 'loss', 'val_metrics'])

In [11]:
for name in checkpoint['crosser'].keys():
    print(name)
for name in checkpoint['mlp_head'].keys():
    print(name)

vision_proj.weight
vision_proj.bias
vision_norm.weight
vision_norm.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.cross_attn.query.weight
blocks.0.cross_attn.query.bias
blocks.0.cross_attn.key.weight
blocks.0.cross_attn.key.bias
blocks.0.cross_attn.value.weight
blocks.0.cross_attn.value.bias
blocks.0.cross_attn.proj.weight
blocks.0.cross_attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.cross_attn.query.weight
blocks.1.cross_attn.query.bias
blocks.1.cross_attn.key.weight
blocks.1.cross_attn.key.bias
blocks.1.cross_attn.value.weight
blocks.1.cross_attn.value.bias
blocks.1.cross_attn.proj.weight
blocks.1.cross_attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.cross_attn.query.weight
b

In [12]:
del checkpoint