In [None]:
!pip install transformers accelerate datasets diffusers Pillow==9.4.0 wandb

In [None]:
from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf

In [None]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import Gemma2Model, GemmaTokenizerFast
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm

In [None]:
model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_config.json", torch_dtype=dtype).to(device)
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)
text_encoder = Gemma2Model.from_pretrained(model, subfolder="text_encoder", torch_dtype=dtype).to(device)
tokenizer = GemmaTokenizerFast.from_pretrained(model, subfolder="tokenizer", torch_dtype=dtype)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

# Load dataset

In [None]:
ds = load_dataset("g-ronimo/MNIST-latents_dc-ae-f32c32-sana-1.0")

mnist_labels_encoded={i: encode_prompt(str(i), tokenizer, text_encoder) for i in range(10)}

len(mnist_labels_encoded[0]), mnist_labels_encoded[0][0].shape, mnist_labels_encoded[0][1].shape

# Training "data loader" and helper functions

In [None]:
def get_sample(split="train", sample_no=None):
    d = ds[split][random.randint(0, len(ds[split])-1) if sample_no is None else sample_no]
    label, latent = d["label"], d["latent"]
    prompt_encoded, prompt_atnmask = mnist_labels_encoded[label]
    return label, torch.Tensor(latent).to(device).to(dtype), prompt_encoded, prompt_atnmask

label, latent, prompt_encoded, prompt_atnmask = get_sample("train")
label, (latent.shape, latent.device, latent.dtype), (prompt_encoded.shape, prompt_encoded.device, prompt_encoded.dtype), (prompt_atnmask.shape,prompt_atnmask.device, prompt_atnmask.dtype)

In [None]:
def eval_loss(num_samples=100):
    losses = []
    for i in tqdm(range(num_samples)):
        label, latent, prompt_encoded, prompt_atnmask = get_sample("test", i)
        noise = torch.randn(latent.shape).to(dtype).to(device)
        timestep = scheduler.timesteps[random.randint(0, diffuser_timesteps-1)].unsqueeze(0).to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

In [None]:
def generate(prompt, steps=10, latent_dim=[1, 32, 8, 8], latent_seed=42):
    scheduler.set_timesteps(steps)
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed) if latent_seed else None).to(dtype).to(device)

    for t_idx in tqdm(range(steps)):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)

# Train

In [None]:
log_wandb = True
lr = 1e-4
epochs = 5
steps = len(ds["train"])
steps_log = 20
steps_eval = 200
bs = 1
diffuser_timesteps = 10

optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(diffuser_timesteps)

transformer=transformer.to(dtype).train()

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Number of parameters: {model_size / 1e6:.2f}M")

In [None]:
if log_wandb: wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_MNIST_LR-{lr}_BS-{bs}_10-TS_imageEval-RUNPOD3090").log_code(".")
losses = [] 
t_start, last_step_time = time.time(), time.time()

for epoch in range(epochs):
    for step in range(steps):
        transformer.train()
        label, latent, prompt_encoded, prompt_atnmask = get_sample()
        noise = torch.randn(latent.shape).to(dtype).to(device)
        timestep = scheduler.timesteps[random.randint(0, diffuser_timesteps-1)].unsqueeze(0).to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        
        noise_pred = transformer(
            latent_noisy, 
            encoder_hidden_states = prompt_encoded, 
            encoder_attention_mask = prompt_atnmask, 
            timestep = timestep, 
            return_dict=False
        )[0]
    
        loss = F.mse_loss(noise_pred, noise - latent)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.item())
        
        if step >0 and step % steps_log == 0:
            loss_train = sum(losses[-steps_log:])/steps_log
            step_time = (time.time() - last_step_time) / steps_log * 1000
            print(f"step {step}, train loss {loss_train:.4f}, {step_time:.2f}ms/step")
            last_step_time = time.time()
            if log_wandb: wandb.log({"loss_train": loss_train, "step_time": step_time, "step": step, "epoch": (epoch-1)+(step/steps)})
    
        if step >0 and step % steps_eval == 0:
            transformer.eval()
            loss_eval = eval_loss()
            images_eval = make_grid([generate(str(p)) for p in range(10)], 2, 5)
            print(f"step {step}, eval loss {loss_eval:.4f}")
            if log_wandb: wandb.log({"loss_eval": loss_eval, "images_eval": wandb.Image(images_eval), "step": step, "epoch": (epoch-1)+(step/steps)})


In [None]:
transformer.push_to_hub(f"g-ronimo/hana-small_MNIST-{epochs}e")