In [1]:
# !pip install -U transformers accelerate datasets git+https://github.com/huggingface/diffusers Pillow==9.4.0 torchmetrics wandb

In [2]:
# from local_secrets import hf_token, wandb_key
# from huggingface_hub import login
# import wandb

# login(token=hf_token)
# wandb.login(key=wandb_key)

In [3]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from torchvision import transforms
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import AutoModel, AutoTokenizer, set_seed
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm
from torch.utils.data import DataLoader

from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf, pil_clipscore
from utils import cifar10_labels

seed = 42
set_seed(seed)

In [4]:
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_Sana-7L-MBERT_config.json").to(device).to(dtype)
text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype)

model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


# Load dataset

In [5]:
ds = load_dataset("g-ronimo/CIFAR10-64-latents_dc-ae-f32c32-sana-1.0")
ds["train"]

Dataset({
    features: ['label', 'latent'],
    num_rows: 50000
})

In [6]:
labels = cifar10_labels
labels_encoded={k: encode_prompt(labels[k], tokenizer, text_encoder) for k in labels}

len(labels_encoded), len(labels_encoded[0]), labels_encoded[0][0].shape, labels_encoded[0][1].shape

(10, 2, torch.Size([1, 300, 768]), torch.Size([1, 300]))

In [7]:
from torch.utils.data import DataLoader

def collate(items):
    labels = [i["label"] for i in items]
    latents = torch.cat([torch.Tensor(i["latent"]) for i in items]).to(dtype).to(device)
    prompts_encoded = torch.cat([labels_encoded[label][0] for label in labels])
    prompts_atnmask = torch.cat([labels_encoded[label][1] for label in labels])

    return labels, latents, prompts_encoded, prompts_atnmask

dataloader = DataLoader(ds["train"], batch_size=2, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
labels, latents, prompts_encoded, prompts_atnmask = next(iter(dataloader))
len(labels), latents.mean(), latents.shape, prompts_encoded.shape, prompts_atnmask.shape

(2,
 tensor(-0.0366, device='cuda:0', dtype=torch.bfloat16),
 torch.Size([2, 32, 2, 2]),
 torch.Size([2, 300, 768]),
 torch.Size([2, 300]))

# Helpers for eval and generate

In [8]:
scheduler.set_timesteps(100)

In [9]:
def generate(prompt, latent_dim=[1, 32, 2, 2], latent_seed=42):
    scheduler.set_timesteps(scheduler.timesteps.size(0))     # reset step index
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed)).to(dtype).to(device)

    for t_idx in range(scheduler.timesteps.size(0)):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)

[generate("0")]

[<PIL.Image.Image image mode=RGB size=64x64>]

In [10]:
def eval_loss(data_val, num_samples=10, batch_size=24):
    losses = []
    eval_dataloader = iter(DataLoader(data_val, batch_size=batch_size, shuffle=False, collate_fn=collate))
    
    for i in tqdm(range(num_samples), "eval_loss"):
        label, latent, prompt_encoded, prompt_atnmask = next(eval_dataloader)
        noise = torch.randn_like(latent)
        timestep = scheduler.timesteps[torch.randint(scheduler.timesteps.size(0),(latent.shape[0],))].to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

eval_loss(ds["train"])

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 53.96it/s]


6.659375

In [11]:
def eval_clipscore(images):
    prompts = [cifar10_labels[k] for k in cifar10_labels]
    return pil_clipscore(images, prompts)

scheduler.set_timesteps(100)
images = [generate(p) for p in tqdm([cifar10_labels[k] for k in cifar10_labels], "eval_images")]
eval_clipscore(images)

eval_images: 100%|██████████████████████████████| 10/10 [00:08<00:00,  1.15it/s]


21.688941955566406

# Train

In [12]:
log_wandb = True
lr = 5e-4
# bs = 64
bs = 896
epochs = 500
# timesteps_training = 10
# timesteps_generate = 10
timesteps_training = 1000
timesteps_generate = 1000
steps_log, steps_eval = 10, 200
# steps_log, steps_eval = 10, 20

splits = list(ds.keys())
data_train, data_val = ds[splits[0]], ds[splits[1]]

steps_epoch = len(data_train) // bs

dataloader = DataLoader(data_train, batch_size=bs, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(timesteps_training)

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Model parameters: {model_size / 1e6:.2f}M")
print(f"{len(splits)} splits: {splits}", [len(ds[s]) for s in splits])
assert len(splits)==2 

Model parameters: 156.41M
2 splits: ['train', 'test'] [50000, 10000]


In [13]:
if log_wandb: 
    if wandb.run is not None: wandb.finish()
    wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_CIFAR10_LR-{lr}_BS-{bs}_TS-{timesteps_training}_my3090").log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".json"))

step = 0
last_step_time = time.time()

for _ in range(epochs):
    for labels, latents, prompts_encoded, prompts_atnmask in dataloader:        
        noise = torch.randn_like(latents)
        timesteps = scheduler.timesteps[torch.randint(timesteps_training,(latents.shape[0],))].to(device)
        latents_noisy = scheduler.scale_noise(latents, timesteps, noise)        
        noise_pred = transformer(latents_noisy, prompts_encoded, timesteps, prompts_atnmask).sample
    
        loss = F.mse_loss(noise_pred, noise - latents)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()        
        
        if step > 0 and step % steps_log == 0:
            loss_train = loss.item()
            step_time = (time.time() - last_step_time) / steps_log * 1000
            sample_tp = bs * steps_log / (time.time() - last_step_time)
            print(f"step {step}, epoch: {step / steps_epoch:.4f}, train loss: {loss_train:.4f}, grad_norm: {grad_norm:.2f}, {step_time:.2f}ms/step, {sample_tp:.2f}samples/sec")
            if log_wandb: wandb.log({"loss_train": loss_train, "grad_norm": grad_norm, "step_time": step_time, "step": step, "sample_tp": sample_tp, "sample_count": step * bs, "epoch": step / steps_epoch})
            last_step_time = time.time()
    
        if step % steps_eval == 0:
            transformer.eval()
            loss_eval = eval_loss(data_val)
            scheduler.set_timesteps(timesteps_generate)
            images_eval = [generate(p) for p in tqdm([cifar10_labels[k] for k in cifar10_labels], "eval_images")]
            clipscore = eval_clipscore(images_eval)
            print(f"step {step}, eval loss: {loss_eval:.4f}, clipscore: {clipscore:.2f}")
            if log_wandb: wandb.log({"loss_eval": loss_eval, "clipscore": clipscore, "images_eval": wandb.Image(make_grid(images_eval, 2, 5)), "step": step, "sample_count": step * bs, "epoch": step / steps_epoch})
            transformer.train()
            scheduler.set_timesteps(timesteps_training)
        step += 1

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mg-ronimo[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 57.58it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.46s/it]


step 0, eval loss: 11.0687, clipscore: 22.85
step 10, epoch: 0.1818, train loss: 5.2188, grad_norm: 3.45, 10063.24ms/step, 89.04samples/sec
step 20, epoch: 0.3636, train loss: 4.1562, grad_norm: 2.94, 1065.73ms/step, 840.74samples/sec
step 30, epoch: 0.5455, train loss: 3.5469, grad_norm: 2.23, 1087.95ms/step, 823.57samples/sec
step 40, epoch: 0.7273, train loss: 3.3281, grad_norm: 2.41, 1036.00ms/step, 864.87samples/sec
step 50, epoch: 0.9091, train loss: 3.3750, grad_norm: 2.23, 1067.88ms/step, 839.04samples/sec
step 60, epoch: 1.0909, train loss: 3.2188, grad_norm: 1.59, 1017.76ms/step, 880.36samples/sec
step 70, epoch: 1.2727, train loss: 3.1250, grad_norm: 1.51, 1034.63ms/step, 866.01samples/sec
step 80, epoch: 1.4545, train loss: 3.1250, grad_norm: 1.41, 1070.67ms/step, 836.86samples/sec
step 90, epoch: 1.6364, train loss: 3.0781, grad_norm: 1.30, 1029.84ms/step, 870.04samples/sec
step 100, epoch: 1.8182, train loss: 3.1094, grad_norm: 1.26, 1070.81ms/step, 836.75samples/sec
step

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.75it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.45s/it]


step 200, eval loss: 3.0406, clipscore: 22.72
step 210, epoch: 3.8182, train loss: 2.8906, grad_norm: 0.99, 9953.49ms/step, 90.02samples/sec
step 220, epoch: 4.0000, train loss: 2.9219, grad_norm: 0.93, 1045.79ms/step, 856.77samples/sec
step 230, epoch: 4.1818, train loss: 2.9844, grad_norm: 0.86, 1052.12ms/step, 851.61samples/sec
step 240, epoch: 4.3636, train loss: 2.9844, grad_norm: 0.86, 1034.02ms/step, 866.52samples/sec
step 250, epoch: 4.5455, train loss: 3.0000, grad_norm: 0.73, 1033.18ms/step, 867.22samples/sec
step 260, epoch: 4.7273, train loss: 2.9844, grad_norm: 0.65, 1065.32ms/step, 841.07samples/sec
step 270, epoch: 4.9091, train loss: 2.9531, grad_norm: 0.89, 1028.46ms/step, 871.21samples/sec
step 280, epoch: 5.0909, train loss: 2.9531, grad_norm: 0.80, 1046.80ms/step, 855.95samples/sec
step 290, epoch: 5.2727, train loss: 2.9062, grad_norm: 0.75, 1065.39ms/step, 841.01samples/sec
step 300, epoch: 5.4545, train loss: 2.9844, grad_norm: 0.77, 1036.11ms/step, 864.77samples

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.01it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:22<00:00,  8.27s/it]


step 400, eval loss: 2.8609, clipscore: 23.07
step 410, epoch: 7.4545, train loss: 2.8281, grad_norm: 0.75, 9611.19ms/step, 93.22samples/sec
step 420, epoch: 7.6364, train loss: 2.8594, grad_norm: 0.65, 1027.16ms/step, 872.31samples/sec
step 430, epoch: 7.8182, train loss: 2.9531, grad_norm: 0.59, 1025.72ms/step, 873.53samples/sec
step 440, epoch: 8.0000, train loss: 2.8594, grad_norm: 0.58, 1063.22ms/step, 842.73samples/sec
step 450, epoch: 8.1818, train loss: 2.9375, grad_norm: 0.69, 1010.48ms/step, 886.71samples/sec
step 460, epoch: 8.3636, train loss: 2.9531, grad_norm: 0.67, 1027.74ms/step, 871.82samples/sec
step 470, epoch: 8.5455, train loss: 2.9062, grad_norm: 0.64, 1060.72ms/step, 844.71samples/sec
step 480, epoch: 8.7273, train loss: 2.9375, grad_norm: 0.57, 1030.86ms/step, 869.17samples/sec
step 490, epoch: 8.9091, train loss: 2.9062, grad_norm: 0.61, 1059.90ms/step, 845.36samples/sec
step 500, epoch: 9.0909, train loss: 2.9531, grad_norm: 0.59, 1030.33ms/step, 869.62samples

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 57.55it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:26<00:00,  8.61s/it]


step 600, eval loss: 2.9562, clipscore: 23.08
step 610, epoch: 11.0909, train loss: 2.8438, grad_norm: 0.64, 10013.09ms/step, 89.48samples/sec
step 620, epoch: 11.2727, train loss: 2.9531, grad_norm: 0.56, 1075.62ms/step, 833.01samples/sec
step 630, epoch: 11.4545, train loss: 2.8906, grad_norm: 0.60, 1058.25ms/step, 846.68samples/sec
step 640, epoch: 11.6364, train loss: 2.9219, grad_norm: 0.59, 1060.51ms/step, 844.87samples/sec
step 650, epoch: 11.8182, train loss: 2.9062, grad_norm: 0.58, 1103.55ms/step, 811.92samples/sec
step 660, epoch: 12.0000, train loss: 2.9375, grad_norm: 0.54, 1061.64ms/step, 843.98samples/sec
step 670, epoch: 12.1818, train loss: 2.8594, grad_norm: 0.56, 1107.52ms/step, 809.02samples/sec
step 680, epoch: 12.3636, train loss: 2.9062, grad_norm: 0.62, 1043.98ms/step, 858.26samples/sec
step 690, epoch: 12.5455, train loss: 2.8594, grad_norm: 0.53, 1073.14ms/step, 834.93samples/sec
step 700, epoch: 12.7273, train loss: 2.8594, grad_norm: 0.47, 1111.69ms/step, 80

eval_images: 100%|██████████████████████████████| 10/10 [01:22<00:00,  8.27s/it]


step 800, eval loss: 2.8406, clipscore: 22.59
step 810, epoch: 14.7273, train loss: 2.7812, grad_norm: 0.52, 9602.72ms/step, 93.31samples/sec
step 820, epoch: 14.9091, train loss: 2.7969, grad_norm: 0.54, 1028.32ms/step, 871.32samples/sec
step 830, epoch: 15.0909, train loss: 2.8750, grad_norm: 0.56, 1060.27ms/step, 845.06samples/sec
step 840, epoch: 15.2727, train loss: 2.8594, grad_norm: 0.48, 1011.95ms/step, 885.42samples/sec
step 850, epoch: 15.4545, train loss: 2.9219, grad_norm: 0.59, 1064.89ms/step, 841.40samples/sec
step 860, epoch: 15.6364, train loss: 2.9062, grad_norm: 0.52, 1032.06ms/step, 868.16samples/sec
step 870, epoch: 15.8182, train loss: 2.8125, grad_norm: 0.57, 1067.62ms/step, 839.25samples/sec
step 880, epoch: 16.0000, train loss: 2.8594, grad_norm: 0.54, 1071.01ms/step, 836.59samples/sec
step 890, epoch: 16.1818, train loss: 2.8438, grad_norm: 0.52, 1032.46ms/step, 867.83samples/sec
step 900, epoch: 16.3636, train loss: 2.8750, grad_norm: 0.54, 1046.21ms/step, 856

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.41it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:25<00:00,  8.52s/it]


step 1000, eval loss: 2.9609, clipscore: 23.70
step 1010, epoch: 18.3636, train loss: 2.8750, grad_norm: 0.49, 9869.08ms/step, 90.79samples/sec
step 1020, epoch: 18.5455, train loss: 2.9219, grad_norm: 0.52, 1058.27ms/step, 846.67samples/sec
step 1030, epoch: 18.7273, train loss: 2.8750, grad_norm: 0.53, 1102.00ms/step, 813.06samples/sec
step 1040, epoch: 18.9091, train loss: 2.8750, grad_norm: 0.50, 1058.24ms/step, 846.68samples/sec
step 1050, epoch: 19.0909, train loss: 2.8438, grad_norm: 0.55, 1100.07ms/step, 814.49samples/sec
step 1060, epoch: 19.2727, train loss: 2.9062, grad_norm: 0.47, 1057.94ms/step, 846.93samples/sec
step 1070, epoch: 19.4545, train loss: 2.9375, grad_norm: 0.57, 1044.77ms/step, 857.60samples/sec
step 1080, epoch: 19.6364, train loss: 2.8594, grad_norm: 0.49, 1109.84ms/step, 807.33samples/sec
step 1090, epoch: 19.8182, train loss: 2.8281, grad_norm: 0.44, 1066.91ms/step, 839.81samples/sec
step 1100, epoch: 20.0000, train loss: 2.8438, grad_norm: 0.50, 1109.94m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.39it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.35s/it]


step 1200, eval loss: 2.8109, clipscore: 22.93
step 1210, epoch: 22.0000, train loss: 2.7812, grad_norm: 0.53, 9804.40ms/step, 91.39samples/sec
step 1220, epoch: 22.1818, train loss: 2.8125, grad_norm: 0.48, 1084.98ms/step, 825.82samples/sec
step 1230, epoch: 22.3636, train loss: 2.8125, grad_norm: 0.47, 1103.29ms/step, 812.11samples/sec
step 1240, epoch: 22.5455, train loss: 2.8750, grad_norm: 0.50, 1033.14ms/step, 867.26samples/sec
step 1250, epoch: 22.7273, train loss: 2.8281, grad_norm: 0.51, 1062.59ms/step, 843.23samples/sec
step 1260, epoch: 22.9091, train loss: 2.9062, grad_norm: 0.57, 1108.42ms/step, 808.36samples/sec
step 1270, epoch: 23.0909, train loss: 2.8125, grad_norm: 0.51, 1033.16ms/step, 867.24samples/sec
step 1280, epoch: 23.2727, train loss: 2.8594, grad_norm: 0.48, 1032.59ms/step, 867.72samples/sec
step 1290, epoch: 23.4545, train loss: 2.8125, grad_norm: 0.47, 1044.52ms/step, 857.81samples/sec
step 1300, epoch: 23.6364, train loss: 2.8438, grad_norm: 0.47, 1031.03m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 57.83it/s]
eval_images:   0%|                                       | 0/10 [00:00<?, ?it/s]

step 1420, epoch: 25.8182, train loss: 2.7344, grad_norm: 0.42, 1040.20ms/step, 861.37samples/sec
step 1430, epoch: 26.0000, train loss: 2.8125, grad_norm: 0.44, 1028.97ms/step, 870.78samples/sec
step 1440, epoch: 26.1818, train loss: 2.7969, grad_norm: 0.48, 1061.23ms/step, 844.30samples/sec
step 1450, epoch: 26.3636, train loss: 2.9062, grad_norm: 0.47, 1028.02ms/step, 871.58samples/sec
step 1460, epoch: 26.5455, train loss: 2.8438, grad_norm: 0.46, 1047.39ms/step, 855.46samples/sec
step 1470, epoch: 26.7273, train loss: 2.7969, grad_norm: 0.52, 1031.83ms/step, 868.36samples/sec
step 1480, epoch: 26.9091, train loss: 2.8125, grad_norm: 0.46, 1044.65ms/step, 857.70samples/sec
step 1490, epoch: 27.0909, train loss: 2.7969, grad_norm: 0.43, 1106.25ms/step, 809.94samples/sec
step 1500, epoch: 27.2727, train loss: 2.8594, grad_norm: 0.47, 1059.94ms/step, 845.33samples/sec
step 1510, epoch: 27.4545, train loss: 2.8594, grad_norm: 0.51, 1105.60ms/step, 810.42samples/sec
step 1520, epoch: 27

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.70it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.35s/it]


step 1600, eval loss: 2.9516, clipscore: 23.34
step 1610, epoch: 29.2727, train loss: 2.7031, grad_norm: 0.53, 9736.19ms/step, 92.03samples/sec
step 1620, epoch: 29.4545, train loss: 2.7344, grad_norm: 0.42, 1096.03ms/step, 817.50samples/sec
step 1630, epoch: 29.6364, train loss: 2.8281, grad_norm: 0.47, 1037.84ms/step, 863.33samples/sec
step 1640, epoch: 29.8182, train loss: 2.8125, grad_norm: 0.46, 1102.72ms/step, 812.54samples/sec
step 1650, epoch: 30.0000, train loss: 2.8281, grad_norm: 0.53, 1068.16ms/step, 838.83samples/sec
step 1660, epoch: 30.1818, train loss: 2.8125, grad_norm: 0.50, 1040.95ms/step, 860.75samples/sec
step 1670, epoch: 30.3636, train loss: 2.7812, grad_norm: 0.55, 1066.63ms/step, 840.03samples/sec
step 1680, epoch: 30.5455, train loss: 2.7812, grad_norm: 0.48, 1009.51ms/step, 887.56samples/sec
step 1690, epoch: 30.7273, train loss: 2.7500, grad_norm: 0.47, 1040.92ms/step, 860.77samples/sec
step 1700, epoch: 30.9091, train loss: 2.8125, grad_norm: 0.45, 1060.20m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.34it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.47s/it]


step 1800, eval loss: 2.8250, clipscore: 22.84
step 1810, epoch: 32.9091, train loss: 2.6875, grad_norm: 0.55, 9840.91ms/step, 91.05samples/sec
step 1820, epoch: 33.0909, train loss: 2.7031, grad_norm: 0.45, 1060.79ms/step, 844.66samples/sec
step 1830, epoch: 33.2727, train loss: 2.8281, grad_norm: 0.43, 1027.47ms/step, 872.04samples/sec
step 1840, epoch: 33.4545, train loss: 2.7188, grad_norm: 0.45, 1029.03ms/step, 870.72samples/sec
step 1850, epoch: 33.6364, train loss: 2.7969, grad_norm: 0.48, 1041.23ms/step, 860.52samples/sec
step 1860, epoch: 33.8182, train loss: 2.7969, grad_norm: 0.48, 1030.47ms/step, 869.51samples/sec
step 1870, epoch: 34.0000, train loss: 2.7500, grad_norm: 0.57, 1061.84ms/step, 843.82samples/sec
step 1880, epoch: 34.1818, train loss: 2.7969, grad_norm: 0.45, 1030.06ms/step, 869.85samples/sec
step 1890, epoch: 34.3636, train loss: 2.7812, grad_norm: 0.49, 1029.22ms/step, 870.56samples/sec
step 1900, epoch: 34.5455, train loss: 2.8125, grad_norm: 0.45, 1063.89m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 61.08it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:25<00:00,  8.52s/it]


step 2000, eval loss: 2.9594, clipscore: 23.09
step 2010, epoch: 36.5455, train loss: 2.7031, grad_norm: 0.52, 9857.00ms/step, 90.90samples/sec
step 2020, epoch: 36.7273, train loss: 2.8125, grad_norm: 0.45, 1038.49ms/step, 862.79samples/sec
step 2030, epoch: 36.9091, train loss: 2.7500, grad_norm: 0.46, 1026.93ms/step, 872.50samples/sec
step 2040, epoch: 37.0909, train loss: 2.7812, grad_norm: 0.50, 1025.89ms/step, 873.38samples/sec
step 2050, epoch: 37.2727, train loss: 2.7812, grad_norm: 0.49, 1062.21ms/step, 843.52samples/sec
step 2060, epoch: 37.4545, train loss: 2.7812, grad_norm: 0.47, 1026.34ms/step, 873.01samples/sec
step 2070, epoch: 37.6364, train loss: 2.7188, grad_norm: 0.50, 1065.44ms/step, 840.96samples/sec
step 2080, epoch: 37.8182, train loss: 2.7656, grad_norm: 0.50, 1005.18ms/step, 891.39samples/sec
step 2090, epoch: 38.0000, train loss: 2.7344, grad_norm: 0.51, 1048.32ms/step, 854.70samples/sec
step 2100, epoch: 38.1818, train loss: 2.7188, grad_norm: 0.48, 1107.74m

eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.46s/it]


step 2200, eval loss: 2.8578, clipscore: 23.05
step 2210, epoch: 40.1818, train loss: 2.6406, grad_norm: 0.46, 9790.59ms/step, 91.52samples/sec
step 2220, epoch: 40.3636, train loss: 2.6875, grad_norm: 0.47, 1028.89ms/step, 870.84samples/sec
step 2230, epoch: 40.5455, train loss: 2.7500, grad_norm: 0.54, 1064.35ms/step, 841.83samples/sec
step 2240, epoch: 40.7273, train loss: 2.7188, grad_norm: 0.46, 1004.98ms/step, 891.56samples/sec
step 2250, epoch: 40.9091, train loss: 2.7656, grad_norm: 0.52, 1029.50ms/step, 870.33samples/sec
step 2260, epoch: 41.0909, train loss: 2.7500, grad_norm: 0.47, 1066.33ms/step, 840.27samples/sec
step 2270, epoch: 41.2727, train loss: 2.6875, grad_norm: 0.51, 1030.22ms/step, 869.71samples/sec
step 2280, epoch: 41.4545, train loss: 2.7344, grad_norm: 0.50, 1034.17ms/step, 866.39samples/sec
step 2290, epoch: 41.6364, train loss: 2.7031, grad_norm: 0.46, 1030.76ms/step, 869.26samples/sec
step 2300, epoch: 41.8182, train loss: 2.7344, grad_norm: 0.49, 1011.13m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 20.19it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.38s/it]


step 2400, eval loss: 2.9969, clipscore: 22.78
step 2410, epoch: 43.8182, train loss: 2.7656, grad_norm: 0.55, 9705.35ms/step, 92.32samples/sec
step 2420, epoch: 44.0000, train loss: 2.7812, grad_norm: 0.45, 1025.48ms/step, 873.74samples/sec
step 2430, epoch: 44.1818, train loss: 2.7500, grad_norm: 0.47, 1061.21ms/step, 844.32samples/sec
step 2440, epoch: 44.3636, train loss: 2.7344, grad_norm: 0.46, 1029.53ms/step, 870.30samples/sec
step 2450, epoch: 44.5455, train loss: 2.7031, grad_norm: 0.48, 1065.88ms/step, 840.62samples/sec
step 2460, epoch: 44.7273, train loss: 2.7656, grad_norm: 0.48, 1029.51ms/step, 870.32samples/sec
step 2470, epoch: 44.9091, train loss: 2.8125, grad_norm: 0.49, 1011.59ms/step, 885.74samples/sec
step 2480, epoch: 45.0909, train loss: 2.7188, grad_norm: 0.50, 1097.00ms/step, 816.78samples/sec
step 2490, epoch: 45.2727, train loss: 2.7031, grad_norm: 0.46, 1062.22ms/step, 843.52samples/sec
step 2500, epoch: 45.4545, train loss: 2.7031, grad_norm: 0.49, 1063.69m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.50it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.37s/it]


step 2600, eval loss: 2.8484, clipscore: 23.04
step 2610, epoch: 47.4545, train loss: 2.6562, grad_norm: 0.52, 9740.29ms/step, 91.99samples/sec
step 2620, epoch: 47.6364, train loss: 2.6719, grad_norm: 0.43, 1029.29ms/step, 870.51samples/sec
step 2630, epoch: 47.8182, train loss: 2.6875, grad_norm: 0.49, 1028.64ms/step, 871.05samples/sec
step 2640, epoch: 48.0000, train loss: 2.7344, grad_norm: 0.50, 1041.91ms/step, 859.95samples/sec
step 2650, epoch: 48.1818, train loss: 2.6875, grad_norm: 0.48, 1031.57ms/step, 868.58samples/sec
step 2660, epoch: 48.3636, train loss: 2.7344, grad_norm: 0.52, 1067.61ms/step, 839.26samples/sec
step 2670, epoch: 48.5455, train loss: 2.6875, grad_norm: 0.50, 1032.99ms/step, 867.39samples/sec
step 2680, epoch: 48.7273, train loss: 2.7188, grad_norm: 0.46, 1033.67ms/step, 866.81samples/sec
step 2690, epoch: 48.9091, train loss: 2.6719, grad_norm: 0.50, 1047.20ms/step, 855.62samples/sec
step 2700, epoch: 49.0909, train loss: 2.7188, grad_norm: 0.45, 1034.15m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 60.63it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.36s/it]


step 2800, eval loss: 2.8734, clipscore: 23.66
step 2810, epoch: 51.0909, train loss: 2.5625, grad_norm: 0.48, 9818.20ms/step, 91.26samples/sec
step 2820, epoch: 51.2727, train loss: 2.5938, grad_norm: 0.43, 1032.61ms/step, 867.70samples/sec
step 2830, epoch: 51.4545, train loss: 2.7031, grad_norm: 0.50, 1027.78ms/step, 871.78samples/sec
step 2840, epoch: 51.6364, train loss: 2.6719, grad_norm: 0.49, 1029.94ms/step, 869.95samples/sec
step 2850, epoch: 51.8182, train loss: 2.7969, grad_norm: 0.47, 1028.77ms/step, 870.95samples/sec
step 2940, epoch: 53.4545, train loss: 2.6406, grad_norm: 0.54, 1063.02ms/step, 842.88samples/sec
step 2950, epoch: 53.6364, train loss: 2.6250, grad_norm: 0.48, 1033.15ms/step, 867.25samples/sec
step 2960, epoch: 53.8182, train loss: 2.6719, grad_norm: 0.48, 1064.79ms/step, 841.48samples/sec
step 2970, epoch: 54.0000, train loss: 2.5938, grad_norm: 0.49, 1012.10ms/step, 885.28samples/sec
step 2980, epoch: 54.1818, train loss: 2.6562, grad_norm: 0.52, 1033.93m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 60.54it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.44s/it]


step 3000, eval loss: 3.0500, clipscore: 22.93
step 3010, epoch: 54.7273, train loss: 2.5625, grad_norm: 0.47, 9807.15ms/step, 91.36samples/sec
step 3020, epoch: 54.9091, train loss: 2.6094, grad_norm: 0.50, 1062.65ms/step, 843.17samples/sec
step 3030, epoch: 55.0909, train loss: 2.6719, grad_norm: 0.55, 1010.06ms/step, 887.07samples/sec
step 3040, epoch: 55.2727, train loss: 2.6562, grad_norm: 0.49, 1063.04ms/step, 842.87samples/sec
step 3050, epoch: 55.4545, train loss: 2.6875, grad_norm: 0.50, 1032.44ms/step, 867.85samples/sec
step 3060, epoch: 55.6364, train loss: 2.6562, grad_norm: 0.52, 1028.82ms/step, 870.90samples/sec
step 3070, epoch: 55.8182, train loss: 2.6250, grad_norm: 0.59, 1065.85ms/step, 840.64samples/sec
step 3080, epoch: 56.0000, train loss: 2.6562, grad_norm: 0.53, 1012.43ms/step, 885.00samples/sec
step 3090, epoch: 56.1818, train loss: 2.6094, grad_norm: 0.55, 1067.78ms/step, 839.12samples/sec
step 3100, epoch: 56.3636, train loss: 2.6719, grad_norm: 0.51, 1032.61m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.69it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.42s/it]


step 3200, eval loss: 2.8891, clipscore: 23.25
step 3210, epoch: 58.3636, train loss: 2.5156, grad_norm: 0.62, 9759.93ms/step, 91.80samples/sec
step 3220, epoch: 58.5455, train loss: 2.5469, grad_norm: 0.61, 1059.53ms/step, 845.66samples/sec
step 3230, epoch: 58.7273, train loss: 2.6719, grad_norm: 0.54, 1029.54ms/step, 870.29samples/sec
step 3240, epoch: 58.9091, train loss: 2.5781, grad_norm: 0.51, 1059.94ms/step, 845.33samples/sec
step 3250, epoch: 59.0909, train loss: 2.6406, grad_norm: 0.55, 1006.90ms/step, 889.86samples/sec
step 3260, epoch: 59.2727, train loss: 2.6250, grad_norm: 0.59, 1029.35ms/step, 870.45samples/sec
step 3270, epoch: 59.4545, train loss: 2.5781, grad_norm: 0.60, 1063.36ms/step, 842.61samples/sec
step 3280, epoch: 59.6364, train loss: 2.6250, grad_norm: 0.48, 1029.71ms/step, 870.15samples/sec
step 3290, epoch: 59.8182, train loss: 2.6250, grad_norm: 0.54, 1027.51ms/step, 872.01samples/sec
step 3300, epoch: 60.0000, train loss: 2.6719, grad_norm: 0.50, 1062.59m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 60.33it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.39s/it]


step 3400, eval loss: 3.0734, clipscore: 23.19
step 3410, epoch: 62.0000, train loss: 2.5156, grad_norm: 0.54, 9723.10ms/step, 92.15samples/sec
step 3420, epoch: 62.1818, train loss: 2.6406, grad_norm: 0.50, 1079.78ms/step, 829.80samples/sec
step 3430, epoch: 62.3636, train loss: 2.5781, grad_norm: 0.58, 1057.96ms/step, 846.91samples/sec
step 3440, epoch: 62.5455, train loss: 2.6094, grad_norm: 0.57, 1058.95ms/step, 846.12samples/sec
step 3450, epoch: 62.7273, train loss: 2.6094, grad_norm: 0.54, 1101.51ms/step, 813.43samples/sec
step 3460, epoch: 62.9091, train loss: 2.5938, grad_norm: 0.52, 1058.16ms/step, 846.75samples/sec
step 3470, epoch: 63.0909, train loss: 2.5625, grad_norm: 0.59, 1101.50ms/step, 813.44samples/sec
step 3480, epoch: 63.2727, train loss: 2.5938, grad_norm: 0.60, 1038.38ms/step, 862.89samples/sec
step 3490, epoch: 63.4545, train loss: 2.5625, grad_norm: 0.54, 1060.13ms/step, 845.18samples/sec
step 3500, epoch: 63.6364, train loss: 2.5625, grad_norm: 0.61, 1104.24m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 58.16it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:21<00:00,  8.19s/it]


step 3600, eval loss: 2.9516, clipscore: 23.24
step 3610, epoch: 65.6364, train loss: 2.4688, grad_norm: 0.55, 9498.78ms/step, 94.33samples/sec
step 3620, epoch: 65.8182, train loss: 2.5156, grad_norm: 0.54, 1026.78ms/step, 872.63samples/sec
step 3630, epoch: 66.0000, train loss: 2.5938, grad_norm: 0.60, 1063.12ms/step, 842.80samples/sec
step 3640, epoch: 66.1818, train loss: 2.5625, grad_norm: 0.57, 1006.46ms/step, 890.25samples/sec
step 3650, epoch: 66.3636, train loss: 2.5938, grad_norm: 0.63, 1063.71ms/step, 842.34samples/sec
step 3660, epoch: 66.5455, train loss: 2.5625, grad_norm: 0.61, 1027.87ms/step, 871.71samples/sec
step 3670, epoch: 66.7273, train loss: 2.5000, grad_norm: 0.53, 1026.29ms/step, 873.05samples/sec
step 3680, epoch: 66.9091, train loss: 2.5781, grad_norm: 0.62, 1063.67ms/step, 842.37samples/sec
step 3690, epoch: 67.0909, train loss: 2.5469, grad_norm: 0.54, 1028.56ms/step, 871.12samples/sec
step 3700, epoch: 67.2727, train loss: 2.5469, grad_norm: 0.56, 1044.00m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 61.88it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.44s/it]


step 3800, eval loss: 3.1250, clipscore: 22.48
step 3810, epoch: 69.2727, train loss: 2.6094, grad_norm: 0.59, 9823.44ms/step, 91.21samples/sec
step 3820, epoch: 69.4545, train loss: 2.6250, grad_norm: 0.58, 1059.91ms/step, 845.35samples/sec
step 3830, epoch: 69.6364, train loss: 2.5781, grad_norm: 0.57, 1081.61ms/step, 828.40samples/sec
step 3840, epoch: 69.8182, train loss: 2.5625, grad_norm: 0.59, 1025.92ms/step, 873.36samples/sec
step 3850, epoch: 70.0000, train loss: 2.5625, grad_norm: 0.62, 1028.02ms/step, 871.58samples/sec
step 3860, epoch: 70.1818, train loss: 2.5781, grad_norm: 0.56, 1058.75ms/step, 846.28samples/sec
step 3870, epoch: 70.3636, train loss: 2.6406, grad_norm: 0.54, 1024.24ms/step, 874.80samples/sec
step 3880, epoch: 70.5455, train loss: 2.5625, grad_norm: 0.57, 1064.62ms/step, 841.61samples/sec
step 3890, epoch: 70.7273, train loss: 2.5469, grad_norm: 0.53, 1030.97ms/step, 869.09samples/sec
step 3900, epoch: 70.9091, train loss: 2.5312, grad_norm: 0.54, 1029.12m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.62it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.38s/it]


step 4000, eval loss: 2.9547, clipscore: 23.23
step 4010, epoch: 72.9091, train loss: 2.4688, grad_norm: 0.53, 9736.25ms/step, 92.03samples/sec
step 4020, epoch: 73.0909, train loss: 2.5156, grad_norm: 0.55, 1039.46ms/step, 861.99samples/sec
step 4030, epoch: 73.2727, train loss: 2.5469, grad_norm: 0.54, 1110.11ms/step, 807.13samples/sec
step 4130, epoch: 75.0909, train loss: 2.5469, grad_norm: 0.57, 1029.49ms/step, 870.33samples/sec
step 4140, epoch: 75.2727, train loss: 2.5469, grad_norm: 0.66, 1062.36ms/step, 843.40samples/sec
step 4150, epoch: 75.4545, train loss: 2.4844, grad_norm: 0.56, 1013.82ms/step, 883.79samples/sec
step 4160, epoch: 75.6364, train loss: 2.5000, grad_norm: 0.54, 1085.86ms/step, 825.15samples/sec
step 4170, epoch: 75.8182, train loss: 2.4375, grad_norm: 0.58, 1060.38ms/step, 844.98samples/sec
step 4180, epoch: 76.0000, train loss: 2.5312, grad_norm: 0.67, 1039.16ms/step, 862.24samples/sec
step 4190, epoch: 76.1818, train loss: 2.5781, grad_norm: 0.59, 1061.76m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.00it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.35s/it]


step 4200, eval loss: 2.9719, clipscore: 23.71
step 4210, epoch: 76.5455, train loss: 2.3750, grad_norm: 0.58, 9795.07ms/step, 91.47samples/sec
step 4220, epoch: 76.7273, train loss: 2.4375, grad_norm: 0.52, 1050.62ms/step, 852.83samples/sec
step 4230, epoch: 76.9091, train loss: 2.5312, grad_norm: 0.56, 1053.14ms/step, 850.79samples/sec
step 4240, epoch: 77.0909, train loss: 2.5312, grad_norm: 0.56, 1100.19ms/step, 814.41samples/sec
step 4250, epoch: 77.2727, train loss: 2.6250, grad_norm: 0.57, 1056.43ms/step, 848.14samples/sec
step 4260, epoch: 77.4545, train loss: 2.5156, grad_norm: 0.53, 1082.10ms/step, 828.02samples/sec
step 4270, epoch: 77.6364, train loss: 2.5156, grad_norm: 0.57, 1061.93ms/step, 843.74samples/sec
step 4280, epoch: 77.8182, train loss: 2.5156, grad_norm: 0.54, 1057.90ms/step, 846.96samples/sec
step 4290, epoch: 78.0000, train loss: 2.5312, grad_norm: 0.53, 1066.67ms/step, 840.00samples/sec
step 4300, epoch: 78.1818, train loss: 2.5781, grad_norm: 0.55, 1055.58m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.81it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.32s/it]


step 4400, eval loss: 3.1844, clipscore: 23.54
step 4410, epoch: 80.1818, train loss: 2.3750, grad_norm: 0.59, 9744.73ms/step, 91.95samples/sec
step 4420, epoch: 80.3636, train loss: 2.4062, grad_norm: 0.56, 1087.10ms/step, 824.21samples/sec
step 4430, epoch: 80.5455, train loss: 2.5000, grad_norm: 0.63, 1034.11ms/step, 866.45samples/sec
step 4440, epoch: 80.7273, train loss: 2.4688, grad_norm: 0.58, 1102.99ms/step, 812.34samples/sec
step 4450, epoch: 80.9091, train loss: 2.4688, grad_norm: 0.64, 1055.81ms/step, 848.64samples/sec
step 4460, epoch: 81.0909, train loss: 2.4531, grad_norm: 0.64, 1057.11ms/step, 847.59samples/sec
step 4470, epoch: 81.2727, train loss: 2.4531, grad_norm: 0.60, 1098.51ms/step, 815.65samples/sec
step 4480, epoch: 81.4545, train loss: 2.4688, grad_norm: 0.62, 1035.96ms/step, 864.90samples/sec
step 4490, epoch: 81.6364, train loss: 2.4219, grad_norm: 0.58, 1101.38ms/step, 813.53samples/sec
step 4500, epoch: 81.8182, train loss: 2.4844, grad_norm: 0.58, 1059.56m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 60.00it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:23<00:00,  8.30s/it]


step 4600, eval loss: 3.0375, clipscore: 23.41
step 4610, epoch: 83.8182, train loss: 2.3281, grad_norm: 0.70, 9686.75ms/step, 92.50samples/sec
step 4620, epoch: 84.0000, train loss: 2.3594, grad_norm: 0.68, 1088.39ms/step, 823.23samples/sec
step 4630, epoch: 84.1818, train loss: 2.4688, grad_norm: 0.62, 1026.74ms/step, 872.67samples/sec
step 4640, epoch: 84.3636, train loss: 2.4062, grad_norm: 0.59, 1061.44ms/step, 844.13samples/sec
step 4650, epoch: 84.5455, train loss: 2.4375, grad_norm: 0.63, 1005.96ms/step, 890.69samples/sec
step 4660, epoch: 84.7273, train loss: 2.4219, grad_norm: 0.63, 1029.76ms/step, 870.11samples/sec
step 4670, epoch: 84.9091, train loss: 2.3750, grad_norm: 0.72, 1061.77ms/step, 843.87samples/sec
step 4680, epoch: 85.0909, train loss: 2.4219, grad_norm: 0.64, 1029.17ms/step, 870.61samples/sec
step 4690, epoch: 85.2727, train loss: 2.4531, grad_norm: 0.61, 1025.60ms/step, 873.64samples/sec
step 4700, epoch: 85.4545, train loss: 2.5000, grad_norm: 0.61, 1060.67m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 60.50it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.44s/it]


step 4800, eval loss: 3.2422, clipscore: 23.51
step 4810, epoch: 87.4545, train loss: 2.3125, grad_norm: 0.66, 9875.58ms/step, 90.73samples/sec
step 4820, epoch: 87.6364, train loss: 2.4531, grad_norm: 0.64, 1037.81ms/step, 863.35samples/sec
step 4830, epoch: 87.8182, train loss: 2.3750, grad_norm: 0.61, 1026.32ms/step, 873.02samples/sec
step 4840, epoch: 88.0000, train loss: 2.3906, grad_norm: 0.65, 1023.93ms/step, 875.06samples/sec
step 4850, epoch: 88.1818, train loss: 2.3750, grad_norm: 0.59, 1058.32ms/step, 846.62samples/sec
step 4860, epoch: 88.3636, train loss: 2.4062, grad_norm: 0.67, 1026.34ms/step, 873.00samples/sec
step 4870, epoch: 88.5455, train loss: 2.3750, grad_norm: 0.60, 1061.94ms/step, 843.73samples/sec
step 4880, epoch: 88.7273, train loss: 2.4062, grad_norm: 0.66, 1004.62ms/step, 891.88samples/sec
step 4890, epoch: 88.9091, train loss: 2.3594, grad_norm: 0.61, 1026.88ms/step, 872.55samples/sec
step 4900, epoch: 89.0909, train loss: 2.3750, grad_norm: 0.62, 1059.80m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.46it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.49s/it]


step 5000, eval loss: 3.0688, clipscore: 23.24
step 5010, epoch: 91.0909, train loss: 2.2969, grad_norm: 0.63, 9876.78ms/step, 90.72samples/sec
step 5020, epoch: 91.2727, train loss: 2.3438, grad_norm: 0.62, 1023.37ms/step, 875.54samples/sec
step 5030, epoch: 91.4545, train loss: 2.4219, grad_norm: 0.68, 1058.35ms/step, 846.60samples/sec
step 5040, epoch: 91.6364, train loss: 2.3750, grad_norm: 0.66, 1006.60ms/step, 890.13samples/sec
step 5050, epoch: 91.8182, train loss: 2.4062, grad_norm: 0.64, 1062.32ms/step, 843.44samples/sec
step 5060, epoch: 92.0000, train loss: 2.3906, grad_norm: 0.60, 1053.26ms/step, 850.69samples/sec
step 5070, epoch: 92.1818, train loss: 2.3281, grad_norm: 0.61, 1060.74ms/step, 844.69samples/sec
step 5080, epoch: 92.3636, train loss: 2.4062, grad_norm: 0.66, 1062.18ms/step, 843.55samples/sec
step 5090, epoch: 92.5455, train loss: 2.3750, grad_norm: 0.63, 1027.46ms/step, 872.06samples/sec
step 5100, epoch: 92.7273, train loss: 2.3438, grad_norm: 0.60, 1040.39m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 61.31it/s]
eval_images: 100%|██████████████████████████████| 10/10 [01:24<00:00,  8.47s/it]


step 5200, eval loss: 3.2828, clipscore: 23.43
step 5210, epoch: 94.7273, train loss: 2.4062, grad_norm: 0.72, 9968.25ms/step, 89.89samples/sec
step 5220, epoch: 94.9091, train loss: 2.4375, grad_norm: 0.66, 1054.36ms/step, 849.80samples/sec
step 5230, epoch: 95.0909, train loss: 2.3906, grad_norm: 0.64, 1100.11ms/step, 814.46samples/sec
step 5240, epoch: 95.2727, train loss: 2.3906, grad_norm: 0.70, 1054.12ms/step, 850.00samples/sec
step 5250, epoch: 95.4545, train loss: 2.3750, grad_norm: 0.64, 1099.27ms/step, 815.09samples/sec
step 5260, epoch: 95.6364, train loss: 2.3750, grad_norm: 0.63, 1057.50ms/step, 847.28samples/sec
step 5270, epoch: 95.8182, train loss: 2.4531, grad_norm: 0.62, 1038.00ms/step, 863.20samples/sec
step 5280, epoch: 96.0000, train loss: 2.3750, grad_norm: 0.62, 1098.24ms/step, 815.85samples/sec
step 5290, epoch: 96.1818, train loss: 2.3594, grad_norm: 0.62, 1059.06ms/step, 846.03samples/sec
step 5300, epoch: 96.3636, train loss: 2.3438, grad_norm: 0.57, 1100.48m

eval_loss: 100%|████████████████████████████████| 10/10 [00:00<00:00, 59.70it/s]
eval_images:  80%|████████████████████████▊      | 8/10 [01:07<00:16,  8.38s/it]

step 5500, epoch: 100.0000, train loss: 2.3281, grad_norm: 0.60, 1026.71ms/step, 872.69samples/sec
step 5510, epoch: 100.1818, train loss: 2.2344, grad_norm: 0.63, 1063.01ms/step, 842.89samples/sec
step 5520, epoch: 100.3636, train loss: 2.3281, grad_norm: 0.66, 1028.23ms/step, 871.40samples/sec
step 5530, epoch: 100.5455, train loss: 2.3438, grad_norm: 0.66, 1026.66ms/step, 872.73samples/sec
step 5540, epoch: 100.7273, train loss: 2.3594, grad_norm: 0.70, 1060.18ms/step, 845.14samples/sec
step 5550, epoch: 100.9091, train loss: 2.3281, grad_norm: 0.63, 1010.83ms/step, 886.40samples/sec
step 5560, epoch: 101.0909, train loss: 2.3125, grad_norm: 0.64, 1062.99ms/step, 842.91samples/sec
step 5570, epoch: 101.2727, train loss: 2.2500, grad_norm: 0.62, 1028.57ms/step, 871.11samples/sec
step 5580, epoch: 101.4545, train loss: 2.3281, grad_norm: 0.67, 1026.52ms/step, 872.85samples/sec
step 5590, epoch: 101.6364, train loss: 2.3906, grad_norm: 0.65, 1061.20ms/step, 844.33samples/sec


In [14]:
transformer.push_to_hub(f"g-ronimo/hana-alpha11_cifar10_TS-{timesteps_training}_{epochs}e")

diffusion_pytorch_model.safetensors:   0%|          | 0.00/313M [00:00<?, ?B/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [15]:
# !runpodctl remove pod $RUNPOD_POD_ID