In [1]:
# !pip install transformers accelerate datasets diffusers Pillow==9.4.0 wandb torchmetrics

In [2]:
# from local_secrets import hf_token, wandb_key
# from huggingface_hub import login
# import wandb

# login(token=hf_token)
# wandb.login(key=wandb_key)

In [3]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import AutoModel, AutoTokenizer, set_seed
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm

from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf, pil_clipscore

seed = 42
set_seed(seed)

In [4]:
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_Sana-7L-MBERT_config.json").to(device).to(dtype)
text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype)

model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


# Load dataset

In [5]:
ds = load_dataset("g-ronimo/MNIST-latents_dc-ae-f32c32-sana-1.0")
labels = list(range(10))
labels_encoded={i: encode_prompt(str(i), tokenizer, text_encoder) for i in labels}

len(labels_encoded[0]), labels_encoded[0][0].shape, labels_encoded[0][1].shape

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


(2, torch.Size([1, 300, 768]), torch.Size([1, 300]))

In [6]:
from torch.utils.data import DataLoader

def collate(items):
    labels = [i["label"] for i in items]
    latents = torch.cat([torch.Tensor(i["latent"]) for i in items]).to(dtype).to(device)
    prompts_encoded = torch.cat([labels_encoded[label][0] for label in labels])
    prompts_atnmask = torch.cat([labels_encoded[label][1] for label in labels])

    return labels, latents, prompts_encoded, prompts_atnmask

dataloader = DataLoader(ds["train"], batch_size=2, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
labels, latents, prompts_encoded, prompts_atnmask = next(iter(dataloader))
len(labels), latents.mean(), latents.shape, prompts_encoded.shape, prompts_atnmask.shape

(2,
 tensor(0.4844, device='cuda:0', dtype=torch.bfloat16),
 torch.Size([2, 32, 8, 8]),
 torch.Size([2, 300, 768]),
 torch.Size([2, 300]))

# Helpers for eval and generate

In [7]:
def generate(prompt, num_timesteps=10, latent_dim=[1, 32, 8, 8], latent_seed=42):
    scheduler.set_timesteps(num_timesteps)
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed)).to(dtype).to(device)

    for t_idx in range(num_timesteps):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)

[generate("0")]

[<PIL.Image.Image image mode=RGB size=256x256>]

In [8]:
def eval_loss(data_val, num_samples=10, num_timesteps=10, batch_size=24):
    losses = []
    eval_dataloader = iter(DataLoader(data_val, batch_size=batch_size, shuffle=False, collate_fn=collate))

    for i in tqdm(range(num_samples), "eval_loss"):
        label, latent, prompt_encoded, prompt_atnmask = next(eval_dataloader)
        noise = torch.randn_like(latent)
        timestep = scheduler.timesteps[[random.randint(0, num_timesteps-1) for _ in range(batch_size)]].to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

eval_loss(ds["train"])

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 12.23it/s]


11.5625

In [9]:
def eval_clipscore(seeds=[1,7,42]):
    prompts = [f"handwritten digit {digit}" for digit in range(10)]
    images = [generate(p, latent_seed=seed) for seed in tqdm(seeds, "eval_clipscore") for p in prompts]
    return pil_clipscore(images, prompts*len(seeds))

eval_clipscore()

eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it]
  return self.fget.__get__(instance, owner)()


22.915660858154297

# Train

In [10]:
log_wandb = True

lr = 5e-4
# bs = 128
bs = 256
epochs = 3
diffuser_timesteps = 10
steps_log, steps_eval = 20, 100

data_train, data_val = ds["train"], ds["test"]

steps_epoch = len(data_train)
steps_total = epochs * (steps_epoch // bs)

dataloader = DataLoader(data_train, batch_size=bs, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(diffuser_timesteps)

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Number of parameters: {model_size / 1e6:.2f}M")

Number of parameters: 156.41M


In [None]:
if log_wandb: wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_MNIST_LR-{lr}_BS-{bs}_10-TS_CLIPSCORE_DATAL").log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".json"))

t_start, last_step_time = time.time(), time.time()
step, losses = 0, []

for _ in range(epochs):
    for batch in dataloader:        
        transformer.train()
        labels, latents, prompts_encoded, prompts_atnmask = batch
        noise = torch.randn_like(latents)
        timesteps = scheduler.timesteps[torch.randint(diffuser_timesteps,(latents.shape[0],))].to(device)
        latents_noisy = scheduler.scale_noise(latents, timesteps, noise)
        
        noise_pred = transformer(
            latents_noisy, 
            encoder_hidden_states = prompts_encoded, 
            encoder_attention_mask = prompts_atnmask, 
            timestep = timesteps, 
            return_dict=False
        )[0]
    
        loss = F.mse_loss(noise_pred, noise - latents)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.item())
        step += 1
        sample_count, epoch = step * bs, step * bs / steps_epoch 
        
        if step % steps_log == 0:
            loss_train = sum(losses)/len(losses)
            step_time = (time.time() - last_step_time) / steps_log * 1000
            sample_tp = bs * steps_log / (time.time() - last_step_time)
            print(f"step {step}, epoch: {epoch:.4f}, train loss: {loss_train:.4f}, grad_norm: {grad_norm:.2f}, {step_time:.2f}ms/step, {sample_tp:.2f}samples/sec")
            if log_wandb: wandb.log({"loss_train": loss_train, "grad_norm": grad_norm, "step_time": step_time, "step": step, "epoch": epoch, "sample_tp": sample_tp, "sample_count": sample_count})
            last_step_time, losses = time.time(), []
    
        if step % steps_eval == 0:
            transformer.eval()
            loss_eval, clipscore, images_eval = eval_loss(data_val), eval_clipscore(), make_grid([generate(str(p)) for p in tqdm(range(10), "images_eval")], 2, 5)
            print(f"step {step}, eval loss: {loss_eval:.4f}, clipscore: {clipscore:.2f}")
            if not log_wandb: display(images_eval.resize((300,150)))
            if log_wandb: wandb.log({"loss_eval": loss_eval, "clipscore": clipscore, "images_eval": wandb.Image(images_eval), "step": step, "epoch": epoch, "sample_count": sample_count})


[34m[1mwandb[0m: Currently logged in as: [33mg-ronimo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


step 20, epoch: 0.0853, train loss: 4.3305, grad_norm: 2.38, 852.66ms/step, 300.24samples/sec
step 40, epoch: 0.1707, train loss: 2.1742, grad_norm: 1.22, 839.10ms/step, 305.09samples/sec
step 60, epoch: 0.2560, train loss: 1.9945, grad_norm: 1.38, 838.31ms/step, 305.37samples/sec
step 80, epoch: 0.3413, train loss: 1.8062, grad_norm: 2.31, 858.82ms/step, 298.08samples/sec
step 100, epoch: 0.4267, train loss: 1.5805, grad_norm: 2.94, 845.06ms/step, 302.94samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 16.18it/s]
eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it]
images_eval: 100%|██████████| 10/10 [00:01<00:00,  6.89it/s]


step 100, eval loss: 1.4625, clipscore: 27.35
step 120, epoch: 0.5120, train loss: 1.4301, grad_norm: 1.47, 1589.31ms/step, 161.08samples/sec
step 140, epoch: 0.5973, train loss: 1.3367, grad_norm: 1.96, 834.96ms/step, 306.60samples/sec
step 160, epoch: 0.6827, train loss: 1.2996, grad_norm: 1.20, 841.12ms/step, 304.35samples/sec
step 180, epoch: 0.7680, train loss: 1.2402, grad_norm: 1.45, 839.64ms/step, 304.89samples/sec
step 200, epoch: 0.8533, train loss: 1.1941, grad_norm: 1.02, 856.41ms/step, 298.92samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 16.45it/s]
eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it]
images_eval: 100%|██████████| 10/10 [00:01<00:00,  6.64it/s]


step 200, eval loss: 1.1379, clipscore: 26.86
step 220, epoch: 0.9387, train loss: 1.1863, grad_norm: 0.87, 1598.79ms/step, 160.12samples/sec
step 240, epoch: 1.0240, train loss: 1.1609, grad_norm: 0.84, 805.90ms/step, 317.66samples/sec
step 260, epoch: 1.1093, train loss: 1.1410, grad_norm: 0.75, 832.39ms/step, 307.55samples/sec
step 280, epoch: 1.1947, train loss: 1.1031, grad_norm: 0.58, 835.87ms/step, 306.27samples/sec
step 300, epoch: 1.2800, train loss: 1.0828, grad_norm: 0.64, 836.32ms/step, 306.10samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 15.89it/s]
eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it]
images_eval: 100%|██████████| 10/10 [00:01<00:00,  6.92it/s]


step 300, eval loss: 1.0465, clipscore: 27.33
step 320, epoch: 1.3653, train loss: 1.0879, grad_norm: 0.75, 1623.53ms/step, 157.68samples/sec
step 340, epoch: 1.4507, train loss: 1.0746, grad_norm: 0.55, 833.69ms/step, 307.07samples/sec
step 360, epoch: 1.5360, train loss: 1.0906, grad_norm: 0.64, 830.97ms/step, 308.07samples/sec
step 380, epoch: 1.6213, train loss: 1.0559, grad_norm: 0.42, 832.45ms/step, 307.53samples/sec
step 400, epoch: 1.7067, train loss: 1.0512, grad_norm: 0.61, 832.06ms/step, 307.67samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 15.62it/s]
eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.44s/it]
images_eval: 100%|██████████| 10/10 [00:01<00:00,  6.88it/s]


step 400, eval loss: 0.9516, clipscore: 27.25
step 420, epoch: 1.7920, train loss: 1.0574, grad_norm: 0.78, 1624.57ms/step, 157.58samples/sec
step 440, epoch: 1.8773, train loss: 1.0428, grad_norm: 0.48, 826.60ms/step, 309.70samples/sec
step 460, epoch: 1.9627, train loss: 1.0523, grad_norm: 0.60, 825.55ms/step, 310.10samples/sec
step 480, epoch: 2.0480, train loss: 1.0201, grad_norm: 0.50, 802.15ms/step, 319.14samples/sec
step 500, epoch: 2.1333, train loss: 1.0355, grad_norm: 0.41, 830.77ms/step, 308.15samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 15.93it/s]
eval_clipscore: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it]
images_eval: 100%|██████████| 10/10 [00:01<00:00,  7.02it/s]


step 500, eval loss: 1.0203, clipscore: 27.75
step 520, epoch: 2.2187, train loss: 1.0270, grad_norm: 0.60, 1578.87ms/step, 162.14samples/sec
step 540, epoch: 2.3040, train loss: 1.0078, grad_norm: 0.53, 824.60ms/step, 310.45samples/sec
step 560, epoch: 2.3893, train loss: 1.0297, grad_norm: 0.54, 832.96ms/step, 307.34samples/sec


In [None]:
transformer.push_to_hub(f"g-ronimo/hana-small_alpha7")

In [None]:
!runpodctl remove pod $RUNPOD_POD_ID