In [1]:
# !pip install transformers accelerate datasets diffusers Pillow==9.4.0 wandb

In [2]:
# from local_secrets import hf_token, wandb_key
# from huggingface_hub import login
# import wandb

# login(token=hf_token)
# wandb.login(key=wandb_key)

In [3]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import Gemma2Model, GemmaTokenizerFast, AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm

from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf

KeyboardInterrupt: 

In [None]:
model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_Sana-7L-MBERT_config.json", torch_dtype=dtype).to(device)
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)

text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype)
# text_encoder = Gemma2Model.from_pretrained(model, subfolder="text_encoder", torch_dtype=dtype).to(device)
# tokenizer = GemmaTokenizerFast.from_pretrained(model, subfolder="tokenizer", torch_dtype=dtype)

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

# Load dataset

In [None]:
ds = load_dataset("g-ronimo/MNIST-latents_dc-ae-f32c32-sana-1.0")

mnist_labels_encoded={i: encode_prompt(str(i), tokenizer, text_encoder) for i in range(10)}

len(mnist_labels_encoded[0]), mnist_labels_encoded[0][0].shape, mnist_labels_encoded[0][1].shape

# Train data loaders and helper functions

In [None]:
def get_sample(split="train", sample_no=None, bs=1):
    assert not (bs>1 and sample_no is not None), "Can't have fixed sample with BS>1" 

    idcs = [random.randint(0, len(ds[split])-1) for _ in range(bs)]
    labels = [ds[split][idx]["label"] for idx in idcs]
    latents = torch.cat([torch.Tensor(ds[split][idx]["latent"]) for idx in idcs])
    prompts_encoded = torch.cat([mnist_labels_encoded[label][0] for label in labels])
    prompts_atnmask = torch.cat([mnist_labels_encoded[label][1] for label in labels])
    
    return labels, latents.to(dtype).to(device), prompts_encoded, prompts_atnmask

In [None]:
def eval_loss(num_samples=100):
    losses = []
    for i in tqdm(range(num_samples)):
        label, latent, prompt_encoded, prompt_atnmask = get_sample("test", i)
        noise = torch.randn(latent.shape).to(dtype).to(device)
        timestep = scheduler.timesteps[random.randint(0, diffuser_timesteps-1)].unsqueeze(0).to(device)
        # timestep = scheduler.timesteps[[random.randint(0, diffuser_timesteps-1) for _ in range(bs)]].to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

In [None]:
def generate(prompt, steps=10, latent_dim=[1, 32, 8, 8], latent_seed=42):
    scheduler.set_timesteps(steps)
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed) if latent_seed else None).to(dtype).to(device)

    for t_idx in tqdm(range(steps)):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)


# Train

In [None]:
log_wandb = True
lr = 1e-4
bs = 128
epochs = 10
diffuser_timesteps = 10

steps_epoch = len(ds["train"])
steps_total = epochs * (steps_epoch // bs)
steps_log = 20
steps_eval = 200

optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(diffuser_timesteps)

transformer=transformer.to(dtype).train()

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Number of parameters: {model_size / 1e6:.2f}M")

In [None]:
if log_wandb: wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_MNIST_LR-{lr}_BS-{bs}_10-TS_MBERT-runpod4090").log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".json"))

t_start, last_step_time = time.time(), time.time()
sample_count, losses = 0, []

for step in range(1, steps_total + 1):
    transformer.train()
    labels, latents, prompts_encoded, prompts_atnmask = get_sample(bs=bs)
    noise = torch.randn_like(latents)
    timesteps = scheduler.timesteps[[random.randint(0, diffuser_timesteps-1) for _ in range(bs)]].to(device)
    latents_noisy = scheduler.scale_noise(latents, timesteps, noise)
    
    noise_pred = transformer(
        latents_noisy, 
        encoder_hidden_states = prompts_encoded, 
        encoder_attention_mask = prompts_atnmask, 
        timestep = timesteps, 
        return_dict=False
    )[0]

    loss = F.mse_loss(noise_pred, noise - latents)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    losses.append(loss.item())
    sample_count = step * bs    
    epoch = sample_count / steps_epoch
    
    if step >0 and step % steps_log == 0:
        loss_train = sum(losses[-steps_log:])/steps_log
        step_time = (time.time() - last_step_time) / steps_log * 1000
        sample_tp = bs * steps_log / (time.time() - last_step_time)
        print(f"step {step}, epoch: {epoch:.4f}, train loss {loss_train:.4f}, {step_time:.2f}ms/step, {sample_tp:.2f}samples/sec")
        if log_wandb: wandb.log({"loss_train": loss_train, "step_time": step_time, "step": step, "epoch": epoch, "sample_tp": sample_tp, "sample_count": sample_count})
        last_step_time = time.time()

    if step >0 and step % steps_eval == 0:
        transformer.eval()
        loss_eval, images_eval = eval_loss(), make_grid([generate(str(p)) for p in range(10)], 2, 5)
        print(f"step {step}, eval loss {loss_eval:.4f}")
        if log_wandb: wandb.log({"loss_eval": loss_eval, "images_eval": wandb.Image(images_eval), "step": step, "epoch": epoch, "sample_count": sample_count})


In [None]:
transformer.push_to_hub(f"g-ronimo/hana-small_MNIST-MODERNBERT-3e")

In [None]:
# !runpodctl remove pod $RUNPOD_POD_ID