In [1]:
# !pip install transformers accelerate datasets git+https://github.com/huggingface/diffusers Pillow==9.4.0 torchmetrics wandb

In [2]:
from local_secrets import hf_token, wandb_key
from huggingface_hub import login
import wandb

login(token=hf_token)
wandb.login(key=wandb_key)

[34m[1mwandb[0m: Currently logged in as: [33mg-ronimo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import torch, torch.nn.functional as F, random, wandb, time
import torchvision.transforms as T
from diffusers import AutoencoderDC, SanaTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import AutoModel, AutoTokenizer, set_seed
from datasets import load_dataset, Dataset, DatasetDict
from tqdm import tqdm

from utils import PIL_to_latent, latent_to_PIL, make_grid, encode_prompt, dcae_scalingf, pil_clipscore

seed = 42
set_seed(seed)

In [4]:
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

transformer = SanaTransformer2DModel.from_config("transformer_Sana-7L-MBERT_config.json").to(device).to(dtype)
text_encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", torch_dtype=dtype)

model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model, subfolder="scheduler")

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


# Load dataset

In [5]:
from utils import fmnist_labels

In [6]:
ds = load_dataset("g-ronimo/FMNIST-latents-64_dc-ae-f32c32-sana-1.0")
labels = fmnist_labels
labels_encoded={k: encode_prompt(labels[k], tokenizer, text_encoder) for k in labels}

len(labels_encoded), len(labels_encoded[0]), labels_encoded[0][0].shape, labels_encoded[0][1].shape

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


(10, 2, torch.Size([1, 300, 768]), torch.Size([1, 300]))

In [7]:
ds.keys()

dict_keys(['train', 'test'])

In [8]:
from torch.utils.data import DataLoader

def collate(items):
    labels = [i["label"] for i in items]
    latents = torch.cat([torch.Tensor(i["latent"]) for i in items]).to(dtype).to(device)
    prompts_encoded = torch.cat([labels_encoded[label][0] for label in labels])
    prompts_atnmask = torch.cat([labels_encoded[label][1] for label in labels])

    return labels, latents, prompts_encoded, prompts_atnmask

dataloader = DataLoader(ds["train"], batch_size=2, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
labels, latents, prompts_encoded, prompts_atnmask = next(iter(dataloader))
len(labels), latents.mean(), latents.shape, prompts_encoded.shape, prompts_atnmask.shape

(2,
 tensor(0.1982, device='cuda:0', dtype=torch.bfloat16),
 torch.Size([2, 32, 2, 2]),
 torch.Size([2, 300, 768]),
 torch.Size([2, 300]))

# Helpers for eval and generate

In [9]:
def generate(prompt, num_timesteps=10, latent_dim=[1, 32, 2, 2], latent_seed=42):
    scheduler.set_timesteps(num_timesteps)
    prompt_encoded, prompt_atnmask = encode_prompt(prompt, tokenizer, text_encoder)
    latents = torch.randn(latent_dim, generator = torch.manual_seed(latent_seed)).to(dtype).to(device)

    for t_idx in range(num_timesteps):
        t = scheduler.timesteps[t_idx].unsqueeze(0).to(device)
        with torch.no_grad():
            noise_pred = transformer(latents, encoder_hidden_states=prompt_encoded, timestep=t, encoder_attention_mask=prompt_atnmask, return_dict=False)[0]
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
    return latent_to_PIL(latents / dcae_scalingf, dcae)

[generate("0")]

[<PIL.Image.Image image mode=RGB size=64x64>]

In [10]:
def eval_loss(data_val, num_samples=10, num_timesteps=10, batch_size=24):
    losses = []
    eval_dataloader = iter(DataLoader(data_val, batch_size=batch_size, shuffle=False, collate_fn=collate))

    for i in tqdm(range(num_samples), "eval_loss"):
        label, latent, prompt_encoded, prompt_atnmask = next(eval_dataloader)
        noise = torch.randn_like(latent)
        timestep = scheduler.timesteps[[random.randint(0, num_timesteps-1) for _ in range(batch_size)]].to(device)
        latent_noisy = scheduler.scale_noise(latent, timestep, noise)
        with torch.no_grad():
            noise_pred = transformer(latent_noisy, encoder_hidden_states = prompt_encoded, encoder_attention_mask = prompt_atnmask, timestep = timestep, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise - latent)
        losses.append(loss.item())
    return sum(losses)/len(losses)

eval_loss(ds["train"])

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 89.53it/s]


11.70625

In [11]:
def eval_clipscore(seeds=[1,7,42]):
    prompts = [fmnist_labels[k] for k in fmnist_labels]
    images = [generate(p, latent_seed=seed, num_timesteps=diffuser_timesteps) for seed in tqdm(seeds, "eval_clipscore") for p in prompts]
    return pil_clipscore(images, prompts*len(seeds))

eval_clipscore()

eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]
  return self.fget.__get__(instance, owner)()


21.874094009399414

# Train

In [12]:
log_wandb = True
lr = 5e-4
# bs = 128
bs = 896
epochs = 400
diffuser_timesteps = 20
steps_log, steps_eval = 20, 200
# steps_log, steps_eval = 2, 10

splits = list(ds.keys())
data_train, data_val = ds[splits[0]], ds[splits[1]]

steps_epoch = len(data_train) // bs

dataloader = DataLoader(data_train, batch_size=bs, shuffle=True, generator = torch.manual_seed(seed), collate_fn=collate)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=lr)
scheduler.set_timesteps(diffuser_timesteps)

model_size = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print(f"Model parameters: {model_size / 1e6:.2f}M")
print(f"{len(splits)} splits: {splits}", [len(ds[s]) for s in splits])
assert len(splits)==2 

Model parameters: 156.41M
2 splits: ['train', 'test'] [60000, 10000]


In [None]:
if log_wandb: 
    if wandb.run is not None: wandb.finish()
    wandb.init(project="Hana", name=f"Z-{model_size / 1e6:.2f}M_FMNIST_LR-{lr}_BS-{bs}_TS-{diffuser_timesteps}_runpod4090").log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb") or path.endswith(".json"))

step = 0
last_step_time = time.time()

for _ in range(epochs):
    for labels, latents, prompts_encoded, prompts_atnmask in dataloader:        
        noise = torch.randn_like(latents)
        timesteps = scheduler.timesteps[torch.randint(diffuser_timesteps,(latents.shape[0],))].to(device)
        latents_noisy = scheduler.scale_noise(latents, timesteps, noise)        
        noise_pred = transformer(latents_noisy, prompts_encoded, timesteps, prompts_atnmask).sample
    
        loss = F.mse_loss(noise_pred, noise - latents)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()        
        
        if step > 0 and step % steps_log == 0:
            loss_train = loss.item()
            step_time = (time.time() - last_step_time) / steps_log * 1000
            sample_tp = bs * steps_log / (time.time() - last_step_time)
            print(f"step {step}, epoch: {step / steps_epoch:.4f}, train loss: {loss_train:.4f}, grad_norm: {grad_norm:.2f}, {step_time:.2f}ms/step, {sample_tp:.2f}samples/sec")
            if log_wandb: wandb.log({"loss_train": loss_train, "grad_norm": grad_norm, "step_time": step_time, "step": step, "sample_tp": sample_tp, "sample_count": step * bs, "epoch": step / steps_epoch})
            last_step_time = time.time()
    
        if step % steps_eval == 0:
            transformer.eval()
            loss_eval, clipscore, images_eval = eval_loss(data_val, num_timesteps=diffuser_timesteps), eval_clipscore(), make_grid([generate(p, num_timesteps=diffuser_timesteps) for p in tqdm([fmnist_labels[k] for k in fmnist_labels], "eval_images")], 2, 5)
            print(f"step {step}, eval loss: {loss_eval:.4f}, clipscore: {clipscore:.2f}")
            if log_wandb: wandb.log({"loss_eval": loss_eval, "clipscore": clipscore, "images_eval": wandb.Image(images_eval), "step": step, "sample_count": step * bs, "epoch": step / steps_epoch})
            else: display(images_eval.resize((300,150)))
            transformer.train()
        step += 1

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


step 0, epoch: 0.0000, train loss: 11.5625, grad_norm: 23.75, 40.99ms/step, 21859.18samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 92.88it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.22it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.04it/s]


step 0, eval loss: 10.9812, clipscore: 21.85
step 20, epoch: 0.3030, train loss: 3.4531, grad_norm: 2.75, 952.31ms/step, 940.87samples/sec
step 40, epoch: 0.6061, train loss: 2.6406, grad_norm: 1.59, 535.29ms/step, 1673.87samples/sec
step 60, epoch: 0.9091, train loss: 2.3438, grad_norm: 1.20, 528.65ms/step, 1694.89samples/sec
step 80, epoch: 1.2121, train loss: 2.1562, grad_norm: 1.43, 537.35ms/step, 1667.46samples/sec
step 100, epoch: 1.5152, train loss: 2.0469, grad_norm: 1.20, 536.75ms/step, 1669.31samples/sec
step 120, epoch: 1.8182, train loss: 1.9766, grad_norm: 1.24, 527.02ms/step, 1700.11samples/sec
step 140, epoch: 2.1212, train loss: 1.9531, grad_norm: 0.95, 546.80ms/step, 1638.63samples/sec
step 160, epoch: 2.4242, train loss: 1.9062, grad_norm: 1.10, 530.18ms/step, 1689.98samples/sec
step 180, epoch: 2.7273, train loss: 1.9453, grad_norm: 0.96, 538.43ms/step, 1664.10samples/sec
step 200, epoch: 3.0303, train loss: 1.8516, grad_norm: 0.95, 537.31ms/step, 1667.56samples/sec


eval_loss: 100%|██████████| 10/10 [00:00<00:00, 90.63it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.24it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.31it/s]


step 200, eval loss: 1.7047, clipscore: 23.23
step 220, epoch: 3.3333, train loss: 1.9062, grad_norm: 0.90, 920.18ms/step, 973.73samples/sec
step 240, epoch: 3.6364, train loss: 1.8906, grad_norm: 0.79, 537.61ms/step, 1666.64samples/sec
step 260, epoch: 3.9394, train loss: 1.8203, grad_norm: 0.71, 528.16ms/step, 1696.46samples/sec
step 280, epoch: 4.2424, train loss: 1.8203, grad_norm: 0.75, 537.85ms/step, 1665.89samples/sec
step 300, epoch: 4.5455, train loss: 1.8359, grad_norm: 0.87, 536.95ms/step, 1668.68samples/sec
step 320, epoch: 4.8485, train loss: 1.8125, grad_norm: 0.68, 536.81ms/step, 1669.13samples/sec
step 340, epoch: 5.1515, train loss: 1.8047, grad_norm: 0.68, 527.50ms/step, 1698.59samples/sec
step 360, epoch: 5.4545, train loss: 1.8438, grad_norm: 0.66, 536.22ms/step, 1670.95samples/sec
step 380, epoch: 5.7576, train loss: 1.7891, grad_norm: 0.70, 527.07ms/step, 1699.96samples/sec
step 400, epoch: 6.0606, train loss: 1.8047, grad_norm: 0.67, 535.92ms/step, 1671.88samples

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 98.04it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.79it/s]


step 400, eval loss: 1.6852, clipscore: 23.16
step 420, epoch: 6.3636, train loss: 1.8516, grad_norm: 0.76, 920.56ms/step, 973.32samples/sec
step 440, epoch: 6.6667, train loss: 1.7578, grad_norm: 0.70, 535.95ms/step, 1671.81samples/sec
step 460, epoch: 6.9697, train loss: 1.8047, grad_norm: 0.66, 526.84ms/step, 1700.70samples/sec
step 480, epoch: 7.2727, train loss: 1.7344, grad_norm: 0.62, 535.03ms/step, 1674.68samples/sec
step 500, epoch: 7.5758, train loss: 1.7969, grad_norm: 0.61, 527.34ms/step, 1699.10samples/sec
step 520, epoch: 7.8788, train loss: 1.8047, grad_norm: 0.66, 536.20ms/step, 1671.03samples/sec
step 540, epoch: 8.1818, train loss: 1.8047, grad_norm: 0.61, 544.54ms/step, 1645.43samples/sec
step 560, epoch: 8.4848, train loss: 1.7812, grad_norm: 0.57, 527.26ms/step, 1699.36samples/sec
step 580, epoch: 8.7879, train loss: 1.7578, grad_norm: 0.59, 535.71ms/step, 1672.54samples/sec
step 600, epoch: 9.0909, train loss: 1.7422, grad_norm: 0.63, 535.91ms/step, 1671.91samples

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 95.47it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.76it/s]


step 600, eval loss: 1.5914, clipscore: 22.96
step 620, epoch: 9.3939, train loss: 1.7656, grad_norm: 0.70, 916.64ms/step, 977.48samples/sec
step 640, epoch: 9.6970, train loss: 1.7188, grad_norm: 0.62, 544.93ms/step, 1644.25samples/sec
step 660, epoch: 10.0000, train loss: 1.7578, grad_norm: 0.59, 526.89ms/step, 1700.54samples/sec
step 680, epoch: 10.3030, train loss: 1.7812, grad_norm: 0.62, 545.87ms/step, 1641.41samples/sec
step 700, epoch: 10.6061, train loss: 1.7266, grad_norm: 0.62, 535.81ms/step, 1672.23samples/sec
step 720, epoch: 10.9091, train loss: 1.7344, grad_norm: 0.60, 527.05ms/step, 1700.03samples/sec
step 740, epoch: 11.2121, train loss: 1.7266, grad_norm: 0.58, 527.63ms/step, 1698.16samples/sec
step 760, epoch: 11.5152, train loss: 1.7656, grad_norm: 0.58, 536.39ms/step, 1670.41samples/sec
step 780, epoch: 11.8182, train loss: 1.7266, grad_norm: 0.55, 527.35ms/step, 1699.05samples/sec
step 800, epoch: 12.1212, train loss: 1.7734, grad_norm: 0.59, 536.23ms/step, 1670.9

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.06it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.76it/s]


step 800, eval loss: 1.6617, clipscore: 23.43
step 820, epoch: 12.4242, train loss: 1.7188, grad_norm: 0.52, 926.33ms/step, 967.26samples/sec
step 840, epoch: 12.7273, train loss: 1.6797, grad_norm: 0.56, 526.95ms/step, 1700.35samples/sec
step 860, epoch: 13.0303, train loss: 1.8125, grad_norm: 0.54, 527.08ms/step, 1699.94samples/sec
step 880, epoch: 13.3333, train loss: 1.7344, grad_norm: 0.58, 546.12ms/step, 1640.66samples/sec
step 900, epoch: 13.6364, train loss: 1.6719, grad_norm: 0.54, 527.18ms/step, 1699.60samples/sec
step 920, epoch: 13.9394, train loss: 1.7266, grad_norm: 0.55, 536.23ms/step, 1670.92samples/sec
step 940, epoch: 14.2424, train loss: 1.7812, grad_norm: 0.51, 537.70ms/step, 1666.35samples/sec
step 960, epoch: 14.5455, train loss: 1.7500, grad_norm: 0.53, 535.66ms/step, 1672.69samples/sec
step 980, epoch: 14.8485, train loss: 1.7188, grad_norm: 0.57, 527.12ms/step, 1699.80samples/sec
step 1000, epoch: 15.1515, train loss: 1.6953, grad_norm: 0.50, 527.26ms/step, 169

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.12it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.69it/s]


step 1000, eval loss: 1.5945, clipscore: 23.38
step 1020, epoch: 15.4545, train loss: 1.7500, grad_norm: 0.54, 915.46ms/step, 978.75samples/sec
step 1040, epoch: 15.7576, train loss: 1.6484, grad_norm: 0.47, 526.67ms/step, 1701.24samples/sec
step 1060, epoch: 16.0606, train loss: 1.7422, grad_norm: 0.54, 527.22ms/step, 1699.49samples/sec
step 1080, epoch: 16.3636, train loss: 1.6562, grad_norm: 0.52, 537.23ms/step, 1667.81samples/sec
step 1100, epoch: 16.6667, train loss: 1.7188, grad_norm: 0.49, 527.20ms/step, 1699.55samples/sec
step 1120, epoch: 16.9697, train loss: 1.7266, grad_norm: 0.51, 536.49ms/step, 1670.13samples/sec
step 1140, epoch: 17.2727, train loss: 1.7344, grad_norm: 0.52, 527.84ms/step, 1697.48samples/sec
step 1160, epoch: 17.5758, train loss: 1.7109, grad_norm: 0.50, 545.22ms/step, 1643.39samples/sec
step 1180, epoch: 17.8788, train loss: 1.7188, grad_norm: 0.54, 536.70ms/step, 1669.48samples/sec
step 1200, epoch: 18.1818, train loss: 1.7031, grad_norm: 0.58, 536.59ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.89it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.76it/s]


step 1200, eval loss: 1.6789, clipscore: 23.55
step 1220, epoch: 18.4848, train loss: 1.7031, grad_norm: 0.58, 912.32ms/step, 982.11samples/sec
step 1240, epoch: 18.7879, train loss: 1.6406, grad_norm: 0.50, 535.76ms/step, 1672.38samples/sec
step 1260, epoch: 19.0909, train loss: 1.7188, grad_norm: 0.51, 535.87ms/step, 1672.05samples/sec
step 1280, epoch: 19.3939, train loss: 1.7188, grad_norm: 0.44, 527.47ms/step, 1698.69samples/sec
step 1300, epoch: 19.6970, train loss: 1.6719, grad_norm: 0.54, 536.38ms/step, 1670.44samples/sec
step 1320, epoch: 20.0000, train loss: 1.7266, grad_norm: 0.54, 527.31ms/step, 1699.20samples/sec
step 1340, epoch: 20.3030, train loss: 1.7188, grad_norm: 0.51, 537.24ms/step, 1667.80samples/sec
step 1360, epoch: 20.6061, train loss: 1.7188, grad_norm: 0.51, 527.23ms/step, 1699.46samples/sec
step 1380, epoch: 20.9091, train loss: 1.6875, grad_norm: 0.52, 536.31ms/step, 1670.66samples/sec
step 1400, epoch: 21.2121, train loss: 1.6875, grad_norm: 0.52, 536.83ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.86it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.40it/s]


step 1400, eval loss: 1.4891, clipscore: 23.71
step 1420, epoch: 21.5152, train loss: 1.6641, grad_norm: 0.51, 918.30ms/step, 975.72samples/sec
step 1440, epoch: 21.8182, train loss: 1.6172, grad_norm: 0.46, 536.09ms/step, 1671.37samples/sec
step 1460, epoch: 22.1212, train loss: 1.6797, grad_norm: 0.58, 536.84ms/step, 1669.03samples/sec
step 1480, epoch: 22.4242, train loss: 1.6719, grad_norm: 0.53, 527.38ms/step, 1698.96samples/sec
step 1640, epoch: 24.8485, train loss: 1.6328, grad_norm: 0.47, 535.94ms/step, 1671.82samples/sec
step 1660, epoch: 25.1515, train loss: 1.6484, grad_norm: 0.55, 527.19ms/step, 1699.58samples/sec
step 1680, epoch: 25.4545, train loss: 1.6797, grad_norm: 0.44, 544.54ms/step, 1645.44samples/sec
step 1700, epoch: 25.7576, train loss: 1.6484, grad_norm: 0.48, 526.82ms/step, 1700.76samples/sec
step 1720, epoch: 26.0606, train loss: 1.6953, grad_norm: 0.47, 527.19ms/step, 1699.58samples/sec
step 1740, epoch: 26.3636, train loss: 1.6719, grad_norm: 0.55, 536.38ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.40it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.75it/s]


step 1800, eval loss: 1.5711, clipscore: 23.35
step 1820, epoch: 27.5758, train loss: 1.6797, grad_norm: 0.48, 915.90ms/step, 978.28samples/sec
step 1840, epoch: 27.8788, train loss: 1.6328, grad_norm: 0.50, 526.70ms/step, 1701.17samples/sec
step 1860, epoch: 28.1818, train loss: 1.6797, grad_norm: 0.48, 535.40ms/step, 1673.53samples/sec
step 1880, epoch: 28.4848, train loss: 1.6641, grad_norm: 0.49, 536.86ms/step, 1668.97samples/sec
step 1900, epoch: 28.7879, train loss: 1.6562, grad_norm: 0.44, 527.21ms/step, 1699.50samples/sec
step 1920, epoch: 29.0909, train loss: 1.6406, grad_norm: 0.43, 536.20ms/step, 1671.02samples/sec
step 1940, epoch: 29.3939, train loss: 1.6875, grad_norm: 0.47, 527.00ms/step, 1700.18samples/sec
step 1960, epoch: 29.6970, train loss: 1.6406, grad_norm: 0.45, 544.10ms/step, 1646.76samples/sec
step 1980, epoch: 30.0000, train loss: 1.6484, grad_norm: 0.45, 536.48ms/step, 1670.14samples/sec
step 2000, epoch: 30.3030, train loss: 1.6719, grad_norm: 0.43, 536.60ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.66it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.49it/s]


step 2000, eval loss: 1.6031, clipscore: 23.45
step 2020, epoch: 30.6061, train loss: 1.6562, grad_norm: 0.48, 908.76ms/step, 985.96samples/sec
step 2040, epoch: 30.9091, train loss: 1.6406, grad_norm: 0.41, 544.55ms/step, 1645.39samples/sec
step 2060, epoch: 31.2121, train loss: 1.6562, grad_norm: 0.46, 527.18ms/step, 1699.60samples/sec
step 2080, epoch: 31.5152, train loss: 1.6016, grad_norm: 0.50, 535.37ms/step, 1673.62samples/sec
step 2100, epoch: 31.8182, train loss: 1.5938, grad_norm: 0.42, 535.71ms/step, 1672.56samples/sec
step 2120, epoch: 32.1212, train loss: 1.6641, grad_norm: 0.46, 527.17ms/step, 1699.65samples/sec
step 2140, epoch: 32.4242, train loss: 1.6875, grad_norm: 0.47, 527.18ms/step, 1699.61samples/sec
step 2160, epoch: 32.7273, train loss: 1.6172, grad_norm: 0.44, 535.62ms/step, 1672.83samples/sec
step 2180, epoch: 33.0303, train loss: 1.6875, grad_norm: 0.46, 526.63ms/step, 1701.39samples/sec
step 2200, epoch: 33.3333, train loss: 1.6875, grad_norm: 0.48, 535.70ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 96.99it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.76it/s]


step 2200, eval loss: 1.5406, clipscore: 23.09
step 2220, epoch: 33.6364, train loss: 1.6797, grad_norm: 0.47, 914.32ms/step, 979.96samples/sec
step 2240, epoch: 33.9394, train loss: 1.6328, grad_norm: 0.43, 535.35ms/step, 1673.67samples/sec
step 2260, epoch: 34.2424, train loss: 1.6797, grad_norm: 0.46, 526.87ms/step, 1700.60samples/sec
step 2280, epoch: 34.5455, train loss: 1.6562, grad_norm: 0.48, 526.20ms/step, 1702.76samples/sec
step 2300, epoch: 34.8485, train loss: 1.6328, grad_norm: 0.46, 536.17ms/step, 1671.11samples/sec
step 2320, epoch: 35.1515, train loss: 1.6406, grad_norm: 0.49, 527.15ms/step, 1699.72samples/sec
step 2340, epoch: 35.4545, train loss: 1.6797, grad_norm: 0.43, 536.51ms/step, 1670.05samples/sec
step 2360, epoch: 35.7576, train loss: 1.5938, grad_norm: 0.44, 534.55ms/step, 1676.17samples/sec
step 2380, epoch: 36.0606, train loss: 1.6953, grad_norm: 0.44, 527.00ms/step, 1700.18samples/sec
step 2400, epoch: 36.3636, train loss: 1.7266, grad_norm: 0.48, 535.76ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.60it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.63it/s]


step 2400, eval loss: 1.6180, clipscore: 23.30
step 2420, epoch: 36.6667, train loss: 1.6172, grad_norm: 0.50, 907.83ms/step, 986.97samples/sec
step 2440, epoch: 36.9697, train loss: 1.6016, grad_norm: 0.42, 526.73ms/step, 1701.07samples/sec
step 2460, epoch: 37.2727, train loss: 1.6406, grad_norm: 0.44, 527.09ms/step, 1699.91samples/sec
step 2480, epoch: 37.5758, train loss: 1.5859, grad_norm: 0.44, 527.93ms/step, 1697.20samples/sec
step 2500, epoch: 37.8788, train loss: 1.5859, grad_norm: 0.42, 544.85ms/step, 1644.49samples/sec
step 2520, epoch: 38.1818, train loss: 1.6172, grad_norm: 0.49, 535.92ms/step, 1671.90samples/sec
step 2540, epoch: 38.4848, train loss: 1.6562, grad_norm: 0.54, 536.83ms/step, 1669.07samples/sec
step 2560, epoch: 38.7879, train loss: 1.6406, grad_norm: 0.45, 527.78ms/step, 1697.66samples/sec
step 2580, epoch: 39.0909, train loss: 1.6562, grad_norm: 0.46, 536.29ms/step, 1670.74samples/sec
step 2600, epoch: 39.3939, train loss: 1.6250, grad_norm: 0.43, 527.02ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.10it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.67it/s]


step 2600, eval loss: 1.6500, clipscore: 23.37
step 2620, epoch: 39.6970, train loss: 1.6328, grad_norm: 0.44, 913.12ms/step, 981.25samples/sec
step 2640, epoch: 40.0000, train loss: 1.5859, grad_norm: 0.46, 536.20ms/step, 1671.02samples/sec
step 2660, epoch: 40.3030, train loss: 1.6406, grad_norm: 0.41, 527.55ms/step, 1698.42samples/sec
step 2680, epoch: 40.6061, train loss: 1.6641, grad_norm: 0.44, 537.49ms/step, 1667.01samples/sec
step 2700, epoch: 40.9091, train loss: 1.6094, grad_norm: 0.45, 536.12ms/step, 1671.26samples/sec
step 2720, epoch: 41.2121, train loss: 1.6172, grad_norm: 0.42, 527.06ms/step, 1700.00samples/sec
step 2740, epoch: 41.5152, train loss: 1.6562, grad_norm: 0.45, 526.90ms/step, 1700.50samples/sec
step 2760, epoch: 41.8182, train loss: 1.6094, grad_norm: 0.40, 545.73ms/step, 1641.83samples/sec
step 2780, epoch: 42.1212, train loss: 1.6484, grad_norm: 0.45, 526.84ms/step, 1700.70samples/sec
step 2800, epoch: 42.4242, train loss: 1.6562, grad_norm: 0.43, 526.62ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.33it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.37it/s]


step 2800, eval loss: 1.6203, clipscore: 23.34
step 2820, epoch: 42.7273, train loss: 1.6328, grad_norm: 0.43, 926.69ms/step, 966.88samples/sec
step 2840, epoch: 43.0303, train loss: 1.5859, grad_norm: 0.49, 544.85ms/step, 1644.48samples/sec
step 2860, epoch: 43.3333, train loss: 1.5938, grad_norm: 0.46, 536.30ms/step, 1670.69samples/sec
step 2880, epoch: 43.6364, train loss: 1.6172, grad_norm: 0.47, 535.85ms/step, 1672.10samples/sec
step 2900, epoch: 43.9394, train loss: 1.5859, grad_norm: 0.44, 526.80ms/step, 1700.83samples/sec
step 2920, epoch: 44.2424, train loss: 1.6172, grad_norm: 0.47, 536.53ms/step, 1669.99samples/sec
step 2940, epoch: 44.5455, train loss: 1.6484, grad_norm: 0.45, 527.05ms/step, 1700.03samples/sec
step 2960, epoch: 44.8485, train loss: 1.6484, grad_norm: 0.44, 544.32ms/step, 1646.08samples/sec
step 2980, epoch: 45.1515, train loss: 1.6328, grad_norm: 0.47, 535.90ms/step, 1671.95samples/sec
step 3000, epoch: 45.4545, train loss: 1.6250, grad_norm: 0.45, 526.92ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.40it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.73it/s]


step 3000, eval loss: 1.5953, clipscore: 23.54
step 3020, epoch: 45.7576, train loss: 1.6250, grad_norm: 0.48, 888.94ms/step, 1007.94samples/sec
step 3040, epoch: 46.0606, train loss: 1.6328, grad_norm: 0.48, 535.99ms/step, 1671.67samples/sec
step 3060, epoch: 46.3636, train loss: 1.6250, grad_norm: 0.45, 536.41ms/step, 1670.35samples/sec
step 3080, epoch: 46.6667, train loss: 1.6094, grad_norm: 0.41, 526.95ms/step, 1700.36samples/sec
step 3100, epoch: 46.9697, train loss: 1.6328, grad_norm: 0.41, 535.21ms/step, 1674.12samples/sec
step 3120, epoch: 47.2727, train loss: 1.6172, grad_norm: 0.46, 526.67ms/step, 1701.25samples/sec
step 3140, epoch: 47.5758, train loss: 1.6328, grad_norm: 0.44, 536.23ms/step, 1670.91samples/sec
step 3160, epoch: 47.8788, train loss: 1.6250, grad_norm: 0.45, 527.61ms/step, 1698.24samples/sec
step 3180, epoch: 48.1818, train loss: 1.6250, grad_norm: 0.43, 536.29ms/step, 1670.75samples/sec
step 3200, epoch: 48.4848, train loss: 1.6094, grad_norm: 0.41, 527.08m

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.33it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.36it/s]


step 3200, eval loss: 1.6766, clipscore: 23.45
step 3220, epoch: 48.7879, train loss: 1.6094, grad_norm: 0.45, 982.15ms/step, 912.28samples/sec
step 3240, epoch: 49.0909, train loss: 1.6016, grad_norm: 0.48, 536.03ms/step, 1671.54samples/sec
step 3260, epoch: 49.3939, train loss: 1.6406, grad_norm: 0.39, 526.91ms/step, 1700.47samples/sec
step 3280, epoch: 49.6970, train loss: 1.6250, grad_norm: 0.44, 526.52ms/step, 1701.73samples/sec
step 3300, epoch: 50.0000, train loss: 1.5625, grad_norm: 0.39, 542.50ms/step, 1651.60samples/sec
step 3320, epoch: 50.3030, train loss: 1.6328, grad_norm: 0.45, 529.40ms/step, 1692.48samples/sec
step 3340, epoch: 50.6061, train loss: 1.6094, grad_norm: 0.40, 537.97ms/step, 1665.51samples/sec
step 3360, epoch: 50.9091, train loss: 1.6484, grad_norm: 0.44, 537.62ms/step, 1666.61samples/sec
step 3380, epoch: 51.2121, train loss: 1.6328, grad_norm: 0.42, 537.28ms/step, 1667.66samples/sec
step 3400, epoch: 51.5152, train loss: 1.6328, grad_norm: 0.39, 536.92ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.43it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.27it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.05it/s]


step 3400, eval loss: 1.6297, clipscore: 23.59
step 3420, epoch: 51.8182, train loss: 1.6250, grad_norm: 0.43, 923.97ms/step, 969.73samples/sec
step 3440, epoch: 52.1212, train loss: 1.5625, grad_norm: 0.45, 527.51ms/step, 1698.55samples/sec
step 3460, epoch: 52.4242, train loss: 1.6250, grad_norm: 0.46, 527.32ms/step, 1699.15samples/sec
step 3480, epoch: 52.7273, train loss: 1.6172, grad_norm: 0.42, 527.39ms/step, 1698.93samples/sec
step 3500, epoch: 53.0303, train loss: 1.5703, grad_norm: 0.41, 544.33ms/step, 1646.06samples/sec
step 3520, epoch: 53.3333, train loss: 1.6172, grad_norm: 0.42, 536.29ms/step, 1670.74samples/sec
step 3540, epoch: 53.6364, train loss: 1.6094, grad_norm: 0.44, 527.19ms/step, 1699.58samples/sec
step 3560, epoch: 53.9394, train loss: 1.6172, grad_norm: 0.46, 534.66ms/step, 1675.82samples/sec
step 3580, epoch: 54.2424, train loss: 1.6406, grad_norm: 0.42, 536.24ms/step, 1670.91samples/sec
step 3600, epoch: 54.5455, train loss: 1.6328, grad_norm: 0.43, 527.39ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.03it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.75it/s]


step 3600, eval loss: 1.5922, clipscore: 23.38
step 3620, epoch: 54.8485, train loss: 1.6094, grad_norm: 0.44, 913.31ms/step, 981.04samples/sec
step 3640, epoch: 55.1515, train loss: 1.5625, grad_norm: 0.47, 529.04ms/step, 1693.64samples/sec
step 3660, epoch: 55.4545, train loss: 1.6172, grad_norm: 0.43, 527.63ms/step, 1698.15samples/sec
step 3680, epoch: 55.7576, train loss: 1.6328, grad_norm: 0.44, 538.09ms/step, 1665.14samples/sec
step 3700, epoch: 56.0606, train loss: 1.6094, grad_norm: 0.44, 539.09ms/step, 1662.05samples/sec
step 3720, epoch: 56.3636, train loss: 1.5859, grad_norm: 0.46, 530.86ms/step, 1687.84samples/sec
step 3740, epoch: 56.6667, train loss: 1.5938, grad_norm: 0.43, 538.89ms/step, 1662.69samples/sec
step 3760, epoch: 56.9697, train loss: 1.6719, grad_norm: 0.42, 548.27ms/step, 1634.24samples/sec
step 3780, epoch: 57.2727, train loss: 1.6016, grad_norm: 0.43, 548.82ms/step, 1632.60samples/sec
step 3800, epoch: 57.5758, train loss: 1.6250, grad_norm: 0.43, 554.47ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 95.35it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.48it/s]


step 3800, eval loss: 1.5219, clipscore: 23.41
step 3820, epoch: 57.8788, train loss: 1.6406, grad_norm: 0.43, 931.40ms/step, 961.99samples/sec
step 3840, epoch: 58.1818, train loss: 1.5781, grad_norm: 0.44, 539.56ms/step, 1660.60samples/sec
step 3860, epoch: 58.4848, train loss: 1.6016, grad_norm: 0.42, 546.06ms/step, 1640.84samples/sec
step 3880, epoch: 58.7879, train loss: 1.6250, grad_norm: 0.86, 535.27ms/step, 1673.92samples/sec
step 3900, epoch: 59.0909, train loss: 1.5625, grad_norm: 0.46, 534.98ms/step, 1674.84samples/sec
step 3920, epoch: 59.3939, train loss: 1.5938, grad_norm: 0.42, 557.08ms/step, 1608.38samples/sec
step 3940, epoch: 59.6970, train loss: 1.6328, grad_norm: 0.45, 535.71ms/step, 1672.55samples/sec
step 3960, epoch: 60.0000, train loss: 1.5859, grad_norm: 0.46, 543.91ms/step, 1647.33samples/sec
step 3980, epoch: 60.3030, train loss: 1.6016, grad_norm: 0.45, 545.03ms/step, 1643.94samples/sec
step 4000, epoch: 60.6061, train loss: 1.6016, grad_norm: 0.45, 540.44ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 96.83it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.27it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 11.46it/s]


step 4000, eval loss: 1.5938, clipscore: 23.41
step 4020, epoch: 60.9091, train loss: 1.5938, grad_norm: 0.40, 943.05ms/step, 950.11samples/sec
step 4040, epoch: 61.2121, train loss: 1.5547, grad_norm: 0.49, 559.26ms/step, 1602.12samples/sec
step 4060, epoch: 61.5152, train loss: 1.5469, grad_norm: 0.40, 533.22ms/step, 1680.35samples/sec
step 4080, epoch: 61.8182, train loss: 1.6250, grad_norm: 0.42, 541.25ms/step, 1655.43samples/sec
step 4100, epoch: 62.1212, train loss: 1.5312, grad_norm: 0.53, 543.77ms/step, 1647.75samples/sec
step 4120, epoch: 62.4242, train loss: 1.5547, grad_norm: 0.44, 535.44ms/step, 1673.39samples/sec
step 4140, epoch: 62.7273, train loss: 1.6406, grad_norm: 0.40, 540.21ms/step, 1658.61samples/sec
step 4160, epoch: 63.0303, train loss: 1.5859, grad_norm: 0.40, 545.06ms/step, 1643.85samples/sec
step 4180, epoch: 63.3333, train loss: 1.6172, grad_norm: 0.43, 545.67ms/step, 1642.01samples/sec
step 4200, epoch: 63.6364, train loss: 1.6094, grad_norm: 0.38, 537.37ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 95.04it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.24it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 11.86it/s]


step 4200, eval loss: 1.5766, clipscore: 23.40
step 4220, epoch: 63.9394, train loss: 1.6641, grad_norm: 0.42, 958.42ms/step, 934.87samples/sec
step 4240, epoch: 64.2424, train loss: 1.5547, grad_norm: 0.43, 538.69ms/step, 1663.31samples/sec
step 4260, epoch: 64.5455, train loss: 1.5703, grad_norm: 0.43, 528.33ms/step, 1695.90samples/sec
step 4280, epoch: 64.8485, train loss: 1.5859, grad_norm: 0.44, 527.58ms/step, 1698.32samples/sec
step 4300, epoch: 65.1515, train loss: 1.5703, grad_norm: 0.47, 545.99ms/step, 1641.06samples/sec
step 4320, epoch: 65.4545, train loss: 1.5703, grad_norm: 0.44, 537.03ms/step, 1668.45samples/sec
step 4340, epoch: 65.7576, train loss: 1.6172, grad_norm: 0.44, 527.21ms/step, 1699.53samples/sec
step 4360, epoch: 66.0606, train loss: 1.5859, grad_norm: 0.39, 535.08ms/step, 1674.52samples/sec
step 4380, epoch: 66.3636, train loss: 1.5625, grad_norm: 0.39, 536.37ms/step, 1670.48samples/sec
step 4400, epoch: 66.6667, train loss: 1.5938, grad_norm: 0.41, 528.29ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.30it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.35it/s]


step 4400, eval loss: 1.6203, clipscore: 23.24
step 4420, epoch: 66.9697, train loss: 1.6328, grad_norm: 0.43, 922.07ms/step, 971.72samples/sec
step 4440, epoch: 67.2727, train loss: 1.5469, grad_norm: 0.46, 535.77ms/step, 1672.37samples/sec
step 4460, epoch: 67.5758, train loss: 1.6094, grad_norm: 0.46, 526.74ms/step, 1701.04samples/sec
step 4480, epoch: 67.8788, train loss: 1.6250, grad_norm: 0.40, 537.05ms/step, 1668.37samples/sec
step 4500, epoch: 68.1818, train loss: 1.5781, grad_norm: 0.44, 534.85ms/step, 1675.23samples/sec
step 4520, epoch: 68.4848, train loss: 1.6016, grad_norm: 0.45, 528.97ms/step, 1693.84samples/sec
step 4540, epoch: 68.7879, train loss: 1.6094, grad_norm: 0.42, 535.94ms/step, 1671.82samples/sec
step 4560, epoch: 69.0909, train loss: 1.5781, grad_norm: 0.41, 544.67ms/step, 1645.05samples/sec
step 4580, epoch: 69.3939, train loss: 1.6094, grad_norm: 0.44, 527.42ms/step, 1698.82samples/sec
step 4600, epoch: 69.6970, train loss: 1.5781, grad_norm: 0.43, 536.37ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 98.30it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.03it/s]


step 4600, eval loss: 1.5703, clipscore: 23.37
step 4620, epoch: 70.0000, train loss: 1.6094, grad_norm: 0.47, 914.30ms/step, 979.99samples/sec
step 4640, epoch: 70.3030, train loss: 1.5547, grad_norm: 0.39, 536.68ms/step, 1669.53samples/sec
step 4660, epoch: 70.6061, train loss: 1.5703, grad_norm: 0.44, 536.20ms/step, 1671.02samples/sec
step 4680, epoch: 70.9091, train loss: 1.5703, grad_norm: 0.38, 527.22ms/step, 1699.48samples/sec
step 4700, epoch: 71.2121, train loss: 1.5625, grad_norm: 0.43, 537.14ms/step, 1668.08samples/sec
step 4720, epoch: 71.5152, train loss: 1.5859, grad_norm: 0.42, 536.84ms/step, 1669.04samples/sec
step 4740, epoch: 71.8182, train loss: 1.6016, grad_norm: 0.42, 527.15ms/step, 1699.69samples/sec
step 4760, epoch: 72.1212, train loss: 1.5625, grad_norm: 0.42, 528.56ms/step, 1695.18samples/sec
step 4780, epoch: 72.4242, train loss: 1.6016, grad_norm: 0.43, 536.61ms/step, 1669.73samples/sec
step 4800, epoch: 72.7273, train loss: 1.5859, grad_norm: 0.44, 526.95ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.38it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.79it/s]


step 4800, eval loss: 1.5227, clipscore: 23.32
step 4820, epoch: 73.0303, train loss: 1.6406, grad_norm: 0.46, 906.81ms/step, 988.08samples/sec
step 4840, epoch: 73.3333, train loss: 1.5781, grad_norm: 0.42, 535.44ms/step, 1673.40samples/sec
step 4860, epoch: 73.6364, train loss: 1.5547, grad_norm: 0.40, 536.45ms/step, 1670.24samples/sec
step 4880, epoch: 73.9394, train loss: 1.5703, grad_norm: 0.38, 527.35ms/step, 1699.07samples/sec
step 4900, epoch: 74.2424, train loss: 1.5781, grad_norm: 0.42, 545.99ms/step, 1641.04samples/sec
step 4920, epoch: 74.5455, train loss: 1.5859, grad_norm: 0.44, 526.62ms/step, 1701.42samples/sec
step 4940, epoch: 74.8485, train loss: 1.5703, grad_norm: 0.41, 535.50ms/step, 1673.21samples/sec
step 4960, epoch: 75.1515, train loss: 1.5547, grad_norm: 0.42, 537.57ms/step, 1666.76samples/sec
step 4980, epoch: 75.4545, train loss: 1.5625, grad_norm: 0.42, 535.79ms/step, 1672.31samples/sec
step 5000, epoch: 75.7576, train loss: 1.6172, grad_norm: 0.41, 527.19ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 97.32it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.39it/s]


step 5000, eval loss: 1.6313, clipscore: 23.25
step 5020, epoch: 76.0606, train loss: 1.6172, grad_norm: 0.44, 907.72ms/step, 987.09samples/sec
step 5040, epoch: 76.3636, train loss: 1.4844, grad_norm: 0.43, 526.53ms/step, 1701.71samples/sec
step 5060, epoch: 76.6667, train loss: 1.5703, grad_norm: 0.41, 536.34ms/step, 1670.57samples/sec
step 5080, epoch: 76.9697, train loss: 1.6094, grad_norm: 0.41, 527.21ms/step, 1699.52samples/sec
step 5100, epoch: 77.2727, train loss: 1.5625, grad_norm: 0.41, 537.19ms/step, 1667.94samples/sec
step 5120, epoch: 77.5758, train loss: 1.5781, grad_norm: 0.40, 527.47ms/step, 1698.67samples/sec
step 5140, epoch: 77.8788, train loss: 1.6094, grad_norm: 0.42, 537.42ms/step, 1667.24samples/sec
step 5160, epoch: 78.1818, train loss: 1.5781, grad_norm: 0.40, 528.09ms/step, 1696.68samples/sec
step 5180, epoch: 78.4848, train loss: 1.5859, grad_norm: 0.47, 547.40ms/step, 1636.81samples/sec
step 5200, epoch: 78.7879, train loss: 1.5859, grad_norm: 0.44, 536.98ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 95.78it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.25it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.83it/s]


step 5200, eval loss: 1.5430, clipscore: 23.36
step 5220, epoch: 79.0909, train loss: 1.6250, grad_norm: 0.43, 915.12ms/step, 979.10samples/sec
step 5240, epoch: 79.3939, train loss: 1.5391, grad_norm: 0.46, 535.89ms/step, 1671.97samples/sec
step 5260, epoch: 79.6970, train loss: 1.5703, grad_norm: 0.37, 526.90ms/step, 1700.50samples/sec
step 5280, epoch: 80.0000, train loss: 1.5938, grad_norm: 0.40, 536.38ms/step, 1670.46samples/sec
step 5300, epoch: 80.3030, train loss: 1.5156, grad_norm: 0.45, 527.92ms/step, 1697.24samples/sec
step 5320, epoch: 80.6061, train loss: 1.5547, grad_norm: 0.40, 536.53ms/step, 1669.99samples/sec
step 5340, epoch: 80.9091, train loss: 1.6016, grad_norm: 0.43, 527.18ms/step, 1699.60samples/sec
step 5360, epoch: 81.2121, train loss: 1.5938, grad_norm: 0.40, 537.35ms/step, 1667.44samples/sec
step 5380, epoch: 81.5152, train loss: 1.5469, grad_norm: 0.48, 527.04ms/step, 1700.07samples/sec
step 5400, epoch: 81.8182, train loss: 1.6094, grad_norm: 0.42, 536.43ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 96.75it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.76it/s]


step 5400, eval loss: 1.6539, clipscore: 23.32
step 5420, epoch: 82.1212, train loss: 1.5625, grad_norm: 0.39, 913.73ms/step, 980.60samples/sec
step 5440, epoch: 82.4242, train loss: 1.5469, grad_norm: 0.42, 528.09ms/step, 1696.68samples/sec
step 5460, epoch: 82.7273, train loss: 1.5391, grad_norm: 0.39, 527.24ms/step, 1699.42samples/sec
step 5480, epoch: 83.0303, train loss: 1.5938, grad_norm: 0.50, 536.66ms/step, 1669.60samples/sec
step 5500, epoch: 83.3333, train loss: 1.5000, grad_norm: 0.50, 528.03ms/step, 1696.86samples/sec
step 5520, epoch: 83.6364, train loss: 1.5469, grad_norm: 0.43, 536.97ms/step, 1668.62samples/sec
step 5540, epoch: 83.9394, train loss: 1.5781, grad_norm: 0.40, 527.42ms/step, 1698.84samples/sec
step 5560, epoch: 84.2424, train loss: 1.5938, grad_norm: 0.42, 535.95ms/step, 1671.78samples/sec
step 5580, epoch: 84.5455, train loss: 1.5547, grad_norm: 0.41, 536.18ms/step, 1671.09samples/sec
step 5600, epoch: 84.8485, train loss: 1.5312, grad_norm: 0.45, 527.40ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 95.76it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 12.10it/s]


step 5600, eval loss: 1.6898, clipscore: 23.27
step 5620, epoch: 85.1515, train loss: 1.6016, grad_norm: 0.40, 916.58ms/step, 977.55samples/sec
step 5640, epoch: 85.4545, train loss: 1.5000, grad_norm: 0.39, 537.80ms/step, 1666.05samples/sec
step 5660, epoch: 85.7576, train loss: 1.5469, grad_norm: 0.42, 536.76ms/step, 1669.28samples/sec
step 5680, epoch: 86.0606, train loss: 1.5547, grad_norm: 0.41, 537.42ms/step, 1667.23samples/sec
step 5700, epoch: 86.3636, train loss: 1.5391, grad_norm: 0.40, 535.83ms/step, 1672.18samples/sec
step 5720, epoch: 86.6667, train loss: 1.5391, grad_norm: 0.43, 536.99ms/step, 1668.56samples/sec
step 5740, epoch: 86.9697, train loss: 1.5859, grad_norm: 0.40, 527.25ms/step, 1699.39samples/sec
step 5760, epoch: 87.2727, train loss: 1.6016, grad_norm: 0.44, 527.81ms/step, 1697.59samples/sec
step 5780, epoch: 87.5758, train loss: 1.5469, grad_norm: 0.42, 536.07ms/step, 1671.41samples/sec
step 5800, epoch: 87.8788, train loss: 1.5703, grad_norm: 0.42, 527.59ms

eval_loss: 100%|██████████| 10/10 [00:00<00:00, 96.61it/s]
eval_clipscore: 100%|██████████| 3/3 [00:02<00:00,  1.27it/s]
eval_images: 100%|██████████| 10/10 [00:00<00:00, 11.87it/s]


step 5800, eval loss: 1.6156, clipscore: 23.13
step 5820, epoch: 88.1818, train loss: 1.5234, grad_norm: 0.41, 912.58ms/step, 981.83samples/sec
step 5840, epoch: 88.4848, train loss: 1.4844, grad_norm: 0.46, 535.88ms/step, 1672.01samples/sec
step 5860, epoch: 88.7879, train loss: 1.5469, grad_norm: 0.41, 527.48ms/step, 1698.63samples/sec
step 5880, epoch: 89.0909, train loss: 1.5312, grad_norm: 0.42, 527.50ms/step, 1698.58samples/sec
step 5900, epoch: 89.3939, train loss: 1.5078, grad_norm: 0.43, 545.01ms/step, 1643.99samples/sec
step 5920, epoch: 89.6970, train loss: 1.4922, grad_norm: 0.44, 527.68ms/step, 1698.00samples/sec


In [None]:
transformer.push_to_hub(f"g-ronimo/hana-small_alpha8-400e")

In [None]:
# !runpodctl remove pod $RUNPOD_POD_ID