In [None]:
import torch
from tqdm.notebook import tqdm
import numpy as np
from transformer.reimei import ReiMei, ReiMeiParameters
import matplotlib.pyplot as plt
from config import AE_CHANNELS, DIT_S as DIT, MODELS_DIR_BASE, SIGLIP_HF_NAME, BERT_HF_NAME, SIGLIP_EMBED_DIM, BERT_EMBED_DIM
from torch.amp import autocast
from transformers import SiglipTokenizer, SiglipTextModel, AutoTokenizer, ModernBertModel
from diffusers import AutoencoderDC

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
device

In [4]:
# base_dim = 1024
# base_heads = 32

input_dim = AE_CHANNELS
num_layers = 24
embed_dim = 1152
num_heads = embed_dim // 32
mlp_dim = embed_dim
num_experts = 64
active_experts = 2.0
shared_experts = 1
token_mixer_layers = 2
image_text_expert_ratio = 16
dropout = 0.1

# m_d = float(embed_dim) / float(base_dim)

# assert (embed_dim // num_heads) == (base_dim // base_heads)

params = ReiMeiParameters(
    channels=input_dim,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    mlp_dim=mlp_dim,
    siglip_dim=SIGLIP_EMBED_DIM,
    bert_dim=BERT_EMBED_DIM,
    num_experts=num_experts,
    active_experts=active_experts,
    shared_experts=shared_experts,
    dropout=dropout,
    token_mixer_layers=token_mixer_layers,
    image_text_expert_ratio=image_text_expert_ratio,
    # m_d=m_d,
)
DTYPE = torch.bfloat16

In [5]:
model = ReiMei(params)
# model.load_state_dict(torch.load("models/reimei_model_and_optimizer_epoch_40_f32.pt")['ema_model_state_dict'])

In [None]:
# Print the number of parameters in the model
print("Number of parameters in the model: ", sum(p.numel() for p in model.parameters()))


In [7]:
siglip_model = SiglipTextModel.from_pretrained(SIGLIP_HF_NAME, cache_dir=f"{MODELS_DIR_BASE}/siglip").to(device)
siglip_tokenizer = SiglipTokenizer.from_pretrained(SIGLIP_HF_NAME, cache_dir=f"{MODELS_DIR_BASE}/siglip")

In [8]:
bert_model = ModernBertModel.from_pretrained(BERT_HF_NAME, cache_dir=f"{MODELS_DIR_BASE}/modernbert").to(device)
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_HF_NAME, cache_dir=f"{MODELS_DIR_BASE}/modernbert")

In [12]:
padded_len = 64


In [13]:
prompts = [
    "A group of people in traditional Indian attire gather around a table laden with food, blending traditional and modern elements in a casual and relaxed atmosphere." 
] * 1

In [14]:
siglip_inputs = []
unpadded_lens = []
for prompt in prompts:
    input = siglip_tokenizer.encode(prompt, return_tensors="pt", padding=True).to("cuda").squeeze()
    unpadded_lens.append(input.shape[-1])
    if input.shape[0] < padded_len:
        padding = torch.zeros(padded_len - input.shape[0], dtype=input.dtype, device=input.device)
        input = torch.cat([input, padding])
    else:
        input = input[:padded_len]
    siglip_inputs.append(input)

In [15]:
bert_inputs = []
unpadded_lens = []
for prompt in prompts:
    input = bert_tokenizer.encode("[CLS]"+prompt, return_tensors="pt", padding=True).to("cuda").squeeze()
    unpadded_lens.append(input.shape[-1])
    if input.shape[0] < 65:
        padding = torch.ones(65 - input.shape[0], dtype=input.dtype, device=input.device)
        input = torch.cat([input, padding])
    else:
        input = input[:65]
    bert_inputs.append(input)

In [16]:
siglip_inputs = torch.stack(siglip_inputs)

In [17]:
bert_inputs = torch.stack(bert_inputs)

In [18]:
siglip_outputs = siglip_model(siglip_inputs, output_hidden_states=True)
siglip_embeddings = siglip_outputs.hidden_states[-1]
siglip_vec = siglip_outputs.pooler_output

In [19]:
bert_outputs = bert_model(bert_inputs, output_hidden_states=True).hidden_states[-1] # (bs, 65, 768). The 65 is CLS + 64 tokens. So we need to seperate the CLS token from the rest.
bert_vec = bert_outputs[:, 0, :] # (bs, 1024)
bert_embeddings = bert_outputs[:, 1:, :] # (bs, 64, 1024)

In [None]:
bert_embeddings.shape

In [None]:
siglip_embeddings.shape, 

In [22]:
del siglip_model, bert_model, siglip_tokenizer, bert_tokenizer

In [23]:
noise = torch.randn(1, 32, 32, 32).to(device).to(torch.bfloat16)

In [24]:
model = model.eval().to(device, dtype=torch.bfloat16)
with autocast("cuda", dtype=torch.bfloat16):
    pred = model.sample(noise, siglip_embeddings, siglip_vec, bert_embeddings, bert_vec, sample_steps=50, cfg=7.0)
    del model

In [None]:
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-in-1.0-diffusers", torch_dtype=DTYPE, cache_dir=f"{MODELS_DIR_BASE}/dc_ae", revision="main").to(device, ).eval()
# print param count
print(sum(p.numel() for p in vae.parameters()))

In [26]:
# model = model.eval()
with autocast("cuda", dtype=torch.bfloat16):
    # pred = model.sample(noise, siglip_embeddings, siglip_vec, bert_embeddings, bert_vec, sample_steps=50, cfg=7.0)

    pred = vae.decode(pred).sample
    # Change range of pred from x to y to -1 to 1
    min_val = pred.min()
    max_val = pred.max()

    pred = (pred - min_val) / (max_val - min_val)
    pred = 2 * pred - 1

In [None]:
with torch.inference_mode():
    pred_cpu = pred.cpu().to(torch.float32)
    pred_np = pred_cpu.permute(0, 2, 3, 1).numpy()
    pred_np = (pred_np + 1) / 2
    pred_np = (pred_np * 255).astype(np.uint8)  # Convert to uint8

    # Create figure for predictions
    fig_pred, axes_pred = plt.subplots(3, 3, figsize=(12, 12))  # 3 rows, 3 columns

    # Plot predictions
    for i, ax in enumerate(axes_pred.flatten()):
        if i < pred_np.shape[0]:
            ax.imshow(pred_np[i])
            # ax.set_title(prompts[i])  # Add this line to set the title
            ax.axis('off')
        else:
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()