In [1]:
!pip install kagglehub


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import kagglehub
jessicali9530_celeba_dataset_path = kagglehub.dataset_download('jessicali9530/celeba-dataset')

Downloading from https://www.kaggle.com/api/v1/datasets/download/jessicali9530/celeba-dataset?dataset_version_number=2...


100% 1.33G/1.33G [01:49<00:00, 13.1MB/s]

Extracting files...





In [19]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.utils as vutils
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms

In [6]:
os.listdir(jessicali9530_celeba_dataset_path)

['img_align_celeba',
 'list_landmarks_align_celeba.csv',
 'list_bbox_celeba.csv',
 'list_eval_partition.csv',
 'list_attr_celeba.csv']

In [7]:
DATA_ROOT = jessicali9530_celeba_dataset_path
IMG_DIR   = os.path.join(DATA_ROOT, "img_align_celeba/img_align_celeba")
CSV_PATH  = os.path.join(DATA_ROOT, "list_attr_celeba.csv")

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
BATCH_SIZE = 256
LATENT_DIM = 256
NUM_EPOCHS = 500
WARMUP_EPOCHS = 5
IMG_SIZE = 64

ATTRS = ["Eyeglasses", "Smiling", "Mustache"]
NUM_ATTRS = 3

In [None]:
class CelebADataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform
        self.df.replace(-1, 0, inplace=True)

    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        w, h = image.size
        left = (w - 178) // 2
        top = (h - 178) // 2
        image = image.crop((left, top, left + 178, top + 178))
        if self.transform:
            image = self.transform(image)
        attrs = self.df.iloc[idx][ATTRS].values.astype(np.float32)
        return image, torch.from_numpy(attrs)

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = CelebADataset(CSV_PATH, IMG_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=4, pin_memory=True, drop_last=True)

In [None]:
class CVAE(nn.Module):
    def __init__(self):
        super().__init__()
        c = 128
        self.enc = nn.Sequential(
            nn.Conv2d(3 + 3, c,    4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(c,    c*2,  4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(c*2,  c*4,  4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(c*4,  c*8,  4, 2, 1), nn.LeakyReLU(0.2),
        )
        self.fc_mu = nn.Linear(1024*4*4, LATENT_DIM)
        self.fc_logvar = nn.Linear(1024*4*4, LATENT_DIM)
        self.fc_dec = nn.Linear(LATENT_DIM + 3, 1024*4*4)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(1024, c*4, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(c*4,  c*2, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(c*2,  c,   4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(c,    3,   4, 2, 1),
            nn.Tanh()
        )

    def encode(self, x, c):
        c = c.view(-1, 3, 1, 1).repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        h = self.enc(x).view(x.size(0), -1)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        z = torch.cat([z, c], dim=1)
        h = self.fc_dec(z).view(-1, 1024, 4, 4)
        return self.dec(h)

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar

model = CVAE().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4, weight_decay=1e-5)

In [None]:
class SafeWarmupCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs]
        progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
        cos = 0.5 * (1.0 + math.cos(math.pi * progress))
        return [1e-5 + (6e-4 - 1e-5) * cos for _ in self.base_lrs]

scheduler = SafeWarmupCosine(optimizer,
                             warmup_steps=WARMUP_EPOCHS * len(dataloader),
                             total_steps=NUM_EPOCHS * len(dataloader))

def loss_function(recon, x, mu, logvar, beta=4.0):
    MSE = F.mse_loss(recon, x, reduction='sum') / x.size(0)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return MSE + beta * KLD, MSE, KLD

In [None]:
model.train()
fixed_z = torch.randn(64, LATENT_DIM, device=DEVICE)

for epoch in range(1, NUM_EPOCHS + 1):
    total_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for imgs, attrs in pbar:
        imgs, attrs = imgs.to(DEVICE), attrs.to(DEVICE)

        optimizer.zero_grad()
        recon, mu, logvar = model(imgs, attrs)
        loss, mse, kld = loss_function(recon, imgs, mu, logvar)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        pbar.set_postfix({
            'loss': f'{loss.item():.1f}',
            'mse': f'{mse.item():.1f}',
            'lr': f'{scheduler.get_last_lr()[0]:.1e}'
        })

    avg_loss = total_loss / len(dataloader)
    print(f"\n=== Epoch {epoch} | Avg Loss: {avg_loss:.2f} | LR: {scheduler.get_last_lr()[0]:.1e} ===\n")

    if epoch % 25 == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            def gen(attr, name):
                c = torch.tensor([attr] * 64, device=DEVICE)
                gen = model.decode(fixed_z, c)
                vutils.save_image(gen, f"samples_v3/epoch{epoch}_{name}.png",
                                  nrow=8, normalize=True, value_range=(-1,1))

            gen([0,0,0], "neutral")
            gen([1,0,0], "eyeglasses_only")
            gen([0,1,0], "smiling_only")
            gen([0,0,1], "mustache_only")
            gen([1,1,1], "all_three")

        model.train()
        print(f"Samples saved samples_3attr/epoch{epoch}_*.png")

torch.save(model.state_dict(), "cvae_eyeglasses_smiling_mustache.pth")

Epoch 1/500: 100% 791/791 [00:29<00:00, 27.20it/s, loss=942.0, mse=717.8, lr=1.2e-04]  



=== Epoch 1 | Avg Loss: 1723.62 | LR: 1.2e-04 ===

Samples saved samples_3attr/epoch1_*.png


Epoch 2/500: 100% 791/791 [00:29<00:00, 27.07it/s, loss=812.3, mse=580.1, lr=2.4e-04]



=== Epoch 2 | Avg Loss: 848.25 | LR: 2.4e-04 ===



Epoch 3/500: 100% 791/791 [00:28<00:00, 27.64it/s, loss=788.1, mse=542.2, lr=3.6e-04]



=== Epoch 3 | Avg Loss: 783.03 | LR: 3.6e-04 ===



Epoch 4/500: 100% 791/791 [00:29<00:00, 26.87it/s, loss=752.5, mse=500.9, lr=4.8e-04]



=== Epoch 4 | Avg Loss: 753.15 | LR: 4.8e-04 ===



Epoch 5/500: 100% 791/791 [00:28<00:00, 28.03it/s, loss=750.3, mse=488.6, lr=6.0e-04]



=== Epoch 5 | Avg Loss: 738.15 | LR: 6.0e-04 ===



Epoch 6/500: 100% 791/791 [00:29<00:00, 26.86it/s, loss=712.7, mse=463.6, lr=6.0e-04]



=== Epoch 6 | Avg Loss: 724.84 | LR: 6.0e-04 ===



Epoch 7/500: 100% 791/791 [00:28<00:00, 27.48it/s, loss=721.6, mse=468.7, lr=6.0e-04]



=== Epoch 7 | Avg Loss: 716.43 | LR: 6.0e-04 ===



Epoch 8/500: 100% 791/791 [00:28<00:00, 27.39it/s, loss=693.4, mse=444.9, lr=6.0e-04]



=== Epoch 8 | Avg Loss: 708.03 | LR: 6.0e-04 ===



Epoch 9/500: 100% 791/791 [00:29<00:00, 27.15it/s, loss=705.8, mse=457.6, lr=6.0e-04]



=== Epoch 9 | Avg Loss: 702.86 | LR: 6.0e-04 ===



Epoch 10/500: 100% 791/791 [00:29<00:00, 27.07it/s, loss=689.1, mse=433.3, lr=6.0e-04]



=== Epoch 10 | Avg Loss: 699.94 | LR: 6.0e-04 ===



Epoch 11/500: 100% 791/791 [00:29<00:00, 26.89it/s, loss=689.2, mse=439.5, lr=6.0e-04]



=== Epoch 11 | Avg Loss: 698.06 | LR: 6.0e-04 ===



Epoch 12/500: 100% 791/791 [00:28<00:00, 27.32it/s, loss=702.0, mse=454.7, lr=6.0e-04]



=== Epoch 12 | Avg Loss: 694.42 | LR: 6.0e-04 ===



Epoch 13/500: 100% 791/791 [00:28<00:00, 27.86it/s, loss=717.9, mse=476.1, lr=6.0e-04]



=== Epoch 13 | Avg Loss: 693.69 | LR: 6.0e-04 ===



Epoch 14/500: 100% 791/791 [00:29<00:00, 27.11it/s, loss=663.1, mse=415.1, lr=6.0e-04]



=== Epoch 14 | Avg Loss: 691.49 | LR: 6.0e-04 ===



Epoch 15/500: 100% 791/791 [00:29<00:00, 27.19it/s, loss=705.4, mse=449.7, lr=6.0e-04]



=== Epoch 15 | Avg Loss: 689.30 | LR: 6.0e-04 ===



Epoch 16/500: 100% 791/791 [00:31<00:00, 25.41it/s, loss=683.0, mse=431.9, lr=6.0e-04]



=== Epoch 16 | Avg Loss: 687.97 | LR: 6.0e-04 ===



Epoch 17/500: 100% 791/791 [00:28<00:00, 27.35it/s, loss=696.9, mse=440.2, lr=6.0e-04]



=== Epoch 17 | Avg Loss: 685.79 | LR: 6.0e-04 ===



Epoch 18/500: 100% 791/791 [00:29<00:00, 27.27it/s, loss=701.2, mse=449.2, lr=6.0e-04]



=== Epoch 18 | Avg Loss: 685.49 | LR: 6.0e-04 ===



Epoch 22/500: 100% 791/791 [00:28<00:00, 27.36it/s, loss=692.3, mse=449.9, lr=6.0e-04]



=== Epoch 22 | Avg Loss: 680.90 | LR: 6.0e-04 ===



Epoch 23/500: 100% 791/791 [00:28<00:00, 27.64it/s, loss=691.0, mse=438.1, lr=6.0e-04]



=== Epoch 23 | Avg Loss: 679.43 | LR: 6.0e-04 ===



Epoch 24/500: 100% 791/791 [00:29<00:00, 27.02it/s, loss=660.4, mse=406.4, lr=6.0e-04]



=== Epoch 24 | Avg Loss: 678.22 | LR: 6.0e-04 ===



Epoch 25/500:  76% 602/791 [00:22<00:07, 26.83it/s, loss=678.0, mse=432.9, lr=6.0e-04]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 78/500: 100% 791/791 [00:29<00:00, 27.19it/s, loss=643.0, mse=398.2, lr=5.7e-04]



=== Epoch 78 | Avg Loss: 651.58 | LR: 5.7e-04 ===



Epoch 79/500: 100% 791/791 [00:29<00:00, 26.77it/s, loss=661.2, mse=413.7, lr=5.7e-04]



=== Epoch 79 | Avg Loss: 651.58 | LR: 5.7e-04 ===



Epoch 80/500: 100% 791/791 [00:28<00:00, 27.82it/s, loss=650.8, mse=406.0, lr=5.7e-04]



=== Epoch 80 | Avg Loss: 651.28 | LR: 5.7e-04 ===



Epoch 81/500: 100% 791/791 [00:29<00:00, 27.23it/s, loss=647.8, mse=399.8, lr=5.7e-04]



=== Epoch 81 | Avg Loss: 651.52 | LR: 5.7e-04 ===



Epoch 82/500: 100% 791/791 [00:28<00:00, 27.32it/s, loss=652.4, mse=402.7, lr=5.7e-04]



=== Epoch 82 | Avg Loss: 651.05 | LR: 5.7e-04 ===



Epoch 83/500: 100% 791/791 [00:28<00:00, 27.56it/s, loss=657.7, mse=415.9, lr=5.6e-04]



=== Epoch 83 | Avg Loss: 651.90 | LR: 5.6e-04 ===



Epoch 84/500: 100% 791/791 [00:29<00:00, 26.89it/s, loss=654.0, mse=413.2, lr=5.6e-04]



=== Epoch 84 | Avg Loss: 651.15 | LR: 5.6e-04 ===



Epoch 85/500: 100% 791/791 [00:28<00:00, 27.72it/s, loss=648.6, mse=403.7, lr=5.6e-04]



=== Epoch 85 | Avg Loss: 650.00 | LR: 5.6e-04 ===



Epoch 86/500:  90% 713/791 [00:25<00:02, 29.68it/s, loss=653.7, mse=404.5, lr=5.6e-04]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 217/500: 100% 791/791 [00:29<00:00, 27.11it/s, loss=618.9, mse=376.7, lr=3.7e-04]



=== Epoch 217 | Avg Loss: 629.35 | LR: 3.7e-04 ===



Epoch 218/500: 100% 791/791 [00:28<00:00, 27.78it/s, loss=635.9, mse=393.2, lr=3.7e-04]



=== Epoch 218 | Avg Loss: 629.78 | LR: 3.7e-04 ===



Epoch 219/500: 100% 791/791 [00:29<00:00, 26.70it/s, loss=620.2, mse=380.3, lr=3.7e-04]



=== Epoch 219 | Avg Loss: 629.52 | LR: 3.7e-04 ===



Epoch 220/500:  40% 319/791 [00:12<00:16, 28.78it/s, loss=615.5, mse=374.2, lr=3.7e-04]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 223/500: 100% 791/791 [00:30<00:00, 25.69it/s, loss=646.4, mse=399.9, lr=3.6e-04]



=== Epoch 223 | Avg Loss: 629.00 | LR: 3.6e-04 ===



Epoch 224/500: 100% 791/791 [00:29<00:00, 27.13it/s, loss=629.1, mse=384.5, lr=3.6e-04]



=== Epoch 224 | Avg Loss: 628.49 | LR: 3.6e-04 ===



Epoch 225/500: 100% 791/791 [00:29<00:00, 26.91it/s, loss=608.8, mse=370.2, lr=3.6e-04]



=== Epoch 225 | Avg Loss: 628.54 | LR: 3.6e-04 ===

Samples saved samples_3attr/epoch225_*.png


Epoch 226/500: 100% 791/791 [00:28<00:00, 27.50it/s, loss=652.3, mse=404.0, lr=3.5e-04]



=== Epoch 226 | Avg Loss: 628.45 | LR: 3.5e-04 ===



Epoch 227/500: 100% 791/791 [00:29<00:00, 27.22it/s, loss=621.4, mse=385.2, lr=3.5e-04]



=== Epoch 227 | Avg Loss: 628.24 | LR: 3.5e-04 ===



Epoch 228/500: 100% 791/791 [00:28<00:00, 27.77it/s, loss=628.2, mse=384.9, lr=3.5e-04]



=== Epoch 228 | Avg Loss: 628.24 | LR: 3.5e-04 ===



Epoch 229/500: 100% 791/791 [00:28<00:00, 27.33it/s, loss=639.1, mse=389.9, lr=3.5e-04]



=== Epoch 229 | Avg Loss: 628.55 | LR: 3.5e-04 ===



Epoch 230/500: 100% 791/791 [00:28<00:00, 27.88it/s, loss=617.4, mse=378.4, lr=3.5e-04]



=== Epoch 230 | Avg Loss: 627.84 | LR: 3.5e-04 ===



Epoch 231/500: 100% 791/791 [00:28<00:00, 27.68it/s, loss=632.7, mse=384.9, lr=3.5e-04]



=== Epoch 231 | Avg Loss: 627.41 | LR: 3.5e-04 ===



Epoch 232/500: 100% 791/791 [00:28<00:00, 27.30it/s, loss=636.7, mse=394.7, lr=3.4e-04]



=== Epoch 232 | Avg Loss: 627.65 | LR: 3.4e-04 ===



Epoch 233/500: 100% 791/791 [00:29<00:00, 27.04it/s, loss=632.9, mse=385.0, lr=3.4e-04]



=== Epoch 233 | Avg Loss: 627.22 | LR: 3.4e-04 ===



Epoch 234/500: 100% 791/791 [00:29<00:00, 27.11it/s, loss=637.3, mse=392.2, lr=3.4e-04]



=== Epoch 234 | Avg Loss: 626.99 | LR: 3.4e-04 ===



Epoch 235/500: 100% 791/791 [00:28<00:00, 27.41it/s, loss=629.6, mse=386.0, lr=3.4e-04]



=== Epoch 235 | Avg Loss: 626.92 | LR: 3.4e-04 ===



Epoch 236/500: 100% 791/791 [00:29<00:00, 27.15it/s, loss=608.0, mse=375.4, lr=3.4e-04]



=== Epoch 236 | Avg Loss: 626.95 | LR: 3.4e-04 ===



Epoch 237/500: 100% 791/791 [00:29<00:00, 27.06it/s, loss=648.4, mse=395.3, lr=3.3e-04]



=== Epoch 237 | Avg Loss: 626.83 | LR: 3.3e-04 ===



Epoch 238/500: 100% 791/791 [00:28<00:00, 27.80it/s, loss=627.0, mse=386.1, lr=3.3e-04]



=== Epoch 238 | Avg Loss: 627.42 | LR: 3.3e-04 ===



Epoch 239/500: 100% 791/791 [00:28<00:00, 27.50it/s, loss=636.8, mse=384.5, lr=3.3e-04]



=== Epoch 239 | Avg Loss: 626.50 | LR: 3.3e-04 ===



Epoch 240/500: 100% 791/791 [00:28<00:00, 27.35it/s, loss=628.3, mse=390.2, lr=3.3e-04]



=== Epoch 240 | Avg Loss: 626.55 | LR: 3.3e-04 ===



Epoch 241/500: 100% 791/791 [00:28<00:00, 27.58it/s, loss=624.3, mse=379.0, lr=3.3e-04]



=== Epoch 241 | Avg Loss: 626.55 | LR: 3.3e-04 ===



Epoch 242/500: 100% 791/791 [00:29<00:00, 27.03it/s, loss=607.3, mse=368.6, lr=3.2e-04]



=== Epoch 242 | Avg Loss: 625.72 | LR: 3.2e-04 ===



Epoch 243/500: 100% 791/791 [00:29<00:00, 27.22it/s, loss=634.3, mse=389.5, lr=3.2e-04]



=== Epoch 243 | Avg Loss: 625.81 | LR: 3.2e-04 ===



Epoch 244/500: 100% 791/791 [00:29<00:00, 26.93it/s, loss=615.8, mse=380.3, lr=3.2e-04]



=== Epoch 244 | Avg Loss: 625.64 | LR: 3.2e-04 ===



Epoch 245/500: 100% 791/791 [00:28<00:00, 27.69it/s, loss=620.0, mse=370.9, lr=3.2e-04]



=== Epoch 245 | Avg Loss: 625.67 | LR: 3.2e-04 ===



Epoch 246/500: 100% 791/791 [00:29<00:00, 26.62it/s, loss=613.0, mse=366.9, lr=3.2e-04]



=== Epoch 246 | Avg Loss: 625.75 | LR: 3.2e-04 ===



Epoch 247/500: 100% 791/791 [00:29<00:00, 27.09it/s, loss=633.2, mse=382.9, lr=3.2e-04]



=== Epoch 247 | Avg Loss: 625.31 | LR: 3.2e-04 ===



Epoch 248/500: 100% 791/791 [00:29<00:00, 26.77it/s, loss=635.6, mse=391.2, lr=3.1e-04]



=== Epoch 248 | Avg Loss: 625.23 | LR: 3.1e-04 ===



Epoch 249/500: 100% 791/791 [00:29<00:00, 26.84it/s, loss=634.4, mse=384.6, lr=3.1e-04]



=== Epoch 249 | Avg Loss: 624.85 | LR: 3.1e-04 ===



Epoch 250/500: 100% 791/791 [00:28<00:00, 27.34it/s, loss=629.2, mse=381.8, lr=3.1e-04]



=== Epoch 250 | Avg Loss: 624.85 | LR: 3.1e-04 ===

Samples saved samples_3attr/epoch250_*.png


Epoch 251/500:  54% 430/791 [00:16<00:12, 29.54it/s, loss=616.3, mse=378.5, lr=3.1e-04]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 389/500: 100% 791/791 [00:27<00:00, 28.52it/s, loss=627.7, mse=377.9, lr=8.0e-05]



=== Epoch 389 | Avg Loss: 608.55 | LR: 8.0e-05 ===



Epoch 390/500: 100% 791/791 [00:28<00:00, 27.82it/s, loss=602.4, mse=359.5, lr=7.9e-05]



=== Epoch 390 | Avg Loss: 608.38 | LR: 7.9e-05 ===



Epoch 391/500: 100% 791/791 [00:27<00:00, 28.42it/s, loss=609.5, mse=370.3, lr=7.8e-05]



=== Epoch 391 | Avg Loss: 608.34 | LR: 7.8e-05 ===



Epoch 392/500:  13% 101/791 [00:04<00:25, 26.69it/s, loss=595.4, mse=354.9, lr=7.8e-05]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

