In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [None]:
labels = torch.arange(6)
labels

In [None]:
labels = labels.roll(-1)
labels

In [None]:
labels[-1] = -100
labels

In [None]:
@dataclass
class VLMConfig:
    vit_hidden_dim: int = 768
    vit_inter_dim: int = 4 * vit_hidden_dim
    vit_patch_size: int = 16
    vit_img_size: int = 512
    vit_n_heads: int = 12
    vit_dropout: float = 0.0
    vit_n_blocks: int = 12
    vit_ln_eps: float = 1e-6
    vit_cls_flag: bool = False
    vit_model_type: str = (
        "google/siglip-base-patch16-512"  #'google/siglip-base-patch16-224'
    )

    lm_hidden_dim: int = 576
    lm_inter_dim: int = 1536
    lm_rms_eps: float = 1e-5
    lm_re_base: int = 100000
    lm_max_position_embeddings: int = 8192
    lm_vocab_size: int = 49280
    lm_n_heads: int = 9
    lm_n_kv_heads: int = 3
    lm_dropout: float = 0.0
    lm_n_blocks: int = 30
    lm_attn_scaling: float = 1.0
    lm_max_length: int = (
        256 - 64
    )  # Deduct the image token length to achieve a 'nice number'
    lm_use_tokens: bool = (
        False  # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)
    )
    lm_tie_weights: bool = (
        False  # Decide if you want to tie the LM Head weight to the token embeding weights
    )
    lm_model_type: str = "HuggingFaceTB/SmolLM2-135M"
    lm_tokenizer: str = "HuggingFaceTB/cosmo2-tokenizer"
    lm_eos_token_id: int = 0

    mp_pixel_shuffle_factor: int = 4

    vlm_load_backbone_weights: bool = True
    vlm_checkpoint_path: str = "vlm_model_0502_smolvlm.pth"


In [None]:
cfg = VLMConfig()

In [None]:
cfg.lm_hidden_dim, cfg.lm_n_heads

In [None]:
dim = cfg.lm_hidden_dim // cfg.lm_n_heads; dim

In [None]:
base = cfg.lm_re_base; base

In [None]:
max_seq_len = cfg.lm_max_position_embeddings; max_seq_len

In [None]:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq

In [None]:
original_max_seq_len = cfg.lm_max_position_embeddings
original_max_seq_len

In [None]:
position_ids = torch.arange(4).unsqueeze(0).expand(1, -1)
position_ids.shape

In [None]:
batch_size, seq_len = position_ids.shape

In [None]:
flat_position_ids = position_ids.reshape(-1).float()
flat_position_ids

In [None]:
inv_freq.shape

In [None]:
freqs = flat_position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0)
freqs.shape

In [None]:
freqs = freqs.reshape(1, 4, -1)
freqs.shape

In [None]:
freqs[0, 1, 0], freqs[0, 1, 1]

In [None]:
emb = torch.cat([freqs, freqs], dim=-1)
emb.shape

In [None]:
emb[0, 1, 0], emb[0, 1, 1], emb[0, 1, 2]

In [None]:
dim = 32
emb[0, 1, 0 + dim], emb[0, 1, 1 + dim], emb[0, 1, 2 + dim]

In [None]:
cos = torch.cos(emb)
sin = torch.sin(emb)
cos.shape, sin.shape

In [None]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [None]:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
cos.shape, sin.shape

In [None]:
q = torch.arange(4 * 64).view(1, 1, 4, 64)
q

In [None]:
q_rotated = rotate_half(q)
q_rotated

In [None]:
a = torch.arange(4 * 64).reshape(4, 64)
a.shape

In [None]:
b = a.reshape(-1)
b.shape

In [None]:
a = torch.arange(12).reshape(3, 4)
a.shape

In [None]:
b = a.repeat_interleave(2, dim=1)

In [None]:
b

In [None]:
a