In [1]:
# !pip install transformers accelerate datasets git+https://github.com/huggingface/diffusers Pillow==9.4.0 torchmetrics wandb

In [2]:
from local_secrets import hf_token, wandb_key
from huggingface_hub import login
import wandb

login(token=hf_token)
wandb.login(key=wandb_key)

[34m[1mwandb[0m: Currently logged in as: [33mg-ronimo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import AutoModel, AutoTokenizer, set_seed
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm

from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf, pil_clipscore

seed = 42
set_seed(seed)

In [4]:
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_Sana-7L-MBERT_config.json").to(device).to(dtype)
text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype)

model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
Process ForkProcess-3:
Process ForkProcess-2:
Process ForkProcess-28:
Process ForkProcess-15:
Process ForkProcess-9:
Process ForkProcess-30:
Process ForkProcess-1:
Process ForkProcess-20:
Process ForkProcess-5:
Process ForkProcess-17:
Process ForkProcess-31:
Process ForkProcess-22:
Process ForkProcess-29:
Process ForkProcess-32:
Process ForkProcess-21:
Process ForkProcess-25:
Process ForkProcess-19:
Process ForkProcess-18:
Process ForkProcess-13:
Process ForkProcess-7:
Process ForkProcess-23:
Process ForkProcess-27:
Process ForkProcess-26:
Process ForkProcess-24:
Process ForkProcess-4:
Traceback (most recent call last):
Process ForkProcess-8:
Process ForkProcess-6:
Traceback (most recent call last):
Process ForkProcess-14:
Process ForkProcess-16:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most

# Load dataset

In [5]:
from utils import fmnist_labels

ds = load_dataset("g-ronimo/FMNIST-latents-64_dc-ae-f32c32-sana-1.0")
labels = fmnist_labels
labels_encoded={k: encode_prompt(labels[k], tokenizer, text_encoder) for k in labels}

len(labels_encoded), len(labels_encoded[0]), labels_encoded[0][0].shape, labels_encoded[0][1].shape

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


(10, 2, torch.Size([1, 300, 768]), torch.Size([1, 300]))

In [6]:
ds.keys()

dict_keys(['train', 'test'])

In [7]:
from torch.utils.data import DataLoader

def collate(items):
    labels = [i["label"] for i in items]
    latents = torch.cat([torch.Tensor(i["latent"]) for i in items]).to(dtype).to(device)
    prompts_encoded = torch.cat([labels_encoded[label][0] for label in labels])
    prompts_atnmask = torch.cat([labels_encoded[label][1] for label in labels])

    return labels, latents, prompts_encoded, prompts_atnmask

dataloader = DataLoader(ds["train"], batch_size=2, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
labels, latents, prompts_encoded, prompts_atnmask = next(iter(dataloader))
len(labels), latents.mean(), latents.shape, prompts_encoded.shape, prompts_atnmask.shape

(2,
 tensor(0.1982, device='cuda:0', dtype=torch.bfloat16),
 torch.Size([2, 32, 2, 2]),
 torch.Size([2, 300, 768]),
 torch.Size([2, 300]))

# Helpers for eval and generate

In [8]:
scheduler.set_timesteps(100)

In [9]:
def generate(prompt, latent_dim=[1, 32, 2, 2], latent_seed=42):
    scheduler.set_timesteps(scheduler.timesteps.size(0))     # reset step index
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed)).to(dtype).to(device)

    for t_idx in range(scheduler.timesteps.size(0)):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)

[generate("0")]

[<PIL.Image.Image image mode=RGB size=64x64>]

In [10]:
def eval_loss(data_val, num_samples=10, batch_size=24):
    losses = []
    eval_dataloader = iter(DataLoader(data_val, batch_size=batch_size, shuffle=False, collate_fn=collate))
    
    for i in tqdm(range(num_samples), "eval_loss"):
        label, latent, prompt_encoded, prompt_atnmask = next(eval_dataloader)
        noise = torch.randn_like(latent)
        timestep = scheduler.timesteps[torch.randint(scheduler.timesteps.size(0),(latent.shape[0],))].to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

eval_loss(ds["train"])

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 30.52it/s]


11.6875

In [18]:
def eval_clipscore(images):
    prompts = [fmnist_labels[k] for k in fmnist_labels]
    return pil_clipscore(images, prompts)

scheduler.set_timesteps(100)
images = [generate(p) for p in tqdm([fmnist_labels[k] for k in fmnist_labels], "eval_images")]
eval_clipscore(images)


eval_images:   0%|          | 0/10 [00:00<?, ?it/s][A
eval_images:  10%|█         | 1/10 [00:00<00:05,  1.60it/s][A
eval_images:  20%|██        | 2/10 [00:01<00:04,  1.65it/s][A
eval_images:  30%|███       | 3/10 [00:01<00:04,  1.66it/s][A
eval_images:  40%|████      | 4/10 [00:02<00:03,  1.67it/s][A
eval_images:  50%|█████     | 5/10 [00:03<00:02,  1.67it/s][A
eval_images:  60%|██████    | 6/10 [00:03<00:02,  1.67it/s][A
eval_images:  70%|███████   | 7/10 [00:04<00:01,  1.67it/s][A
eval_images:  80%|████████  | 8/10 [00:04<00:01,  1.68it/s][A
eval_images:  90%|█████████ | 9/10 [00:05<00:00,  1.68it/s][A
eval_images: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s][A


22.945858001708984

# Train

In [21]:
log_wandb = True
lr = 5e-4
# bs = 128
bs = 896
epochs = 400
timesteps_training = 1000
timesteps_generate = 1000
steps_log, steps_eval = 20, 200
# steps_log, steps_eval = 2, 10

splits = list(ds.keys())
data_train, data_val = ds[splits[0]], ds[splits[1]]

steps_epoch = len(data_train) // bs

dataloader = DataLoader(data_train, batch_size=bs, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(timesteps_training)

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Model parameters: {model_size / 1e6:.2f}M")
print(f"{len(splits)} splits: {splits}", [len(ds[s]) for s in splits])
assert len(splits)==2 

Model parameters: 156.41M
2 splits: ['train', 'test'] [60000, 10000]


In [22]:
if log_wandb: 
    if wandb.run is not None: wandb.finish()
    wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_FMNIST_LR-{lr}_BS-{bs}_TS-{timesteps_training}_runpod4090").log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".json"))

step = 0
last_step_time = time.time()

for _ in range(epochs):
    for labels, latents, prompts_encoded, prompts_atnmask in dataloader:        
        noise = torch.randn_like(latents)
        timesteps = scheduler.timesteps[torch.randint(timesteps_training,(latents.shape[0],))].to(device)
        latents_noisy = scheduler.scale_noise(latents, timesteps, noise)        
        noise_pred = transformer(latents_noisy, prompts_encoded, timesteps, prompts_atnmask).sample
    
        loss = F.mse_loss(noise_pred, noise - latents)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()        
        
        if step > 0 and step % steps_log == 0:
            loss_train = loss.item()
            step_time = (time.time() - last_step_time) / steps_log * 1000
            sample_tp = bs * steps_log / (time.time() - last_step_time)
            print(f"step {step}, epoch: {step / steps_epoch:.4f}, train loss: {loss_train:.4f}, grad_norm: {grad_norm:.2f}, {step_time:.2f}ms/step, {sample_tp:.2f}samples/sec")
            if log_wandb: wandb.log({"loss_train": loss_train, "grad_norm": grad_norm, "step_time": step_time, "step": step, "sample_tp": sample_tp, "sample_count": step * bs, "epoch": step / steps_epoch})
            last_step_time = time.time()
    
        if step % steps_eval == 0:
            transformer.eval()
            loss_eval = eval_loss(data_val)
            scheduler.set_timesteps(timesteps_generate)
            images_eval = [generate(p) for p in tqdm([fmnist_labels[k] for k in fmnist_labels], "eval_images")]
            clipscore = eval_clipscore(images_eval)
            print(f"step {step}, eval loss: {loss_eval:.4f}, clipscore: {clipscore:.2f}")
            if log_wandb: wandb.log({"loss_eval": loss_eval, "clipscore": clipscore, "images_eval": wandb.Image(make_grid(images_eval, 2, 5)), "step": step, "sample_count": step * bs, "epoch": step / steps_epoch})
            transformer.train()
            scheduler.set_timesteps(timesteps_training)
        step += 1

  lambda data: self._console_raw_callback("stderr", data),
eval_images:  20%|██        | 2/10 [04:04<16:18, 122.31s/it]
eval_loss: 100%|██████████| 10/10 [00:00<00:00, 40.42it/s]
eval_images: 100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


step 0, eval loss: 2.9172, clipscore: 23.46
step 20, epoch: 0.3030, train loss: 1.7656, grad_norm: 0.68, 974.31ms/step, 919.63samples/sec


KeyboardInterrupt: 

In [None]:
transformer.push_to_hub(f"g-ronimo/hana-small_alpha8_TS-{timesteps_training}_{epochs}e")

In [None]:
!runpodctl remove pod $RUNPOD_POD_ID