<a href="https://colab.research.google.com/github/diegomrodrigues/generative_models_experiments/blob/main/Full%20VAE%20Training%20on%20CIFAR-10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers huggingface_hub tensorboardX --upgrade --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m54.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.5/417.5 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from huggingface_hub import HfApi, notebook_login
from huggingface_hub import HFSummaryWriter
import os
from transformers import Trainer, TrainingArguments
import io

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 1024
latent_dim = 256
learning_rate = 1e-3
num_epochs = 50

repo_name = "diegomrodrigues/vae-cifar10-experiment"  # Replace with your Hugging Face username

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Wrapper for CIFAR-10 dataset to make it compatible with Hugging Face Trainer
class CIFAR10Wrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        return {"pixel_values": image}

    def __len__(self):
        return len(self.dataset)

train_dataset = CIFAR10Wrapper(datasets.CIFAR10(root='./data', train=True, download=True, transform=transform))

# VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(-1, 128, 4, 4)
        return self.decoder(h)

    def forward(self, pixel_values):
        mu, logvar = self.encode(pixel_values)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Initialize model
model = VAE(latent_dim).to(device)

# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

# Set up Hugging Face SummaryWriter
hf_writer = HFSummaryWriter(repo_id=repo_name)

# Custom Trainer class for VAE
class VAETrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hf_writer = hf_writer

    def compute_loss(self, model, inputs, return_outputs=False):
        pixel_values = inputs['pixel_values']
        recon_batch, mu, logvar = model(pixel_values)
        loss, bce, kld = loss_function(recon_batch, pixel_values, mu, logvar)

        # Log metrics
        self.log_metrics(loss, bce, kld)

        return (loss, (loss, recon_batch)) if return_outputs else loss

    def log_metrics(self, loss, bce, kld):
        if self.state.global_step % 100 == 0:
            self.hf_writer.add_scalar('Loss/total', loss.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/BCE', bce.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/KLD', kld.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/ELBO', -loss.item(), self.state.global_step)

    def log(self, logs):
        super().log(logs)
        if self.state.global_step % 100 == 0:
            self.save_samples(self.state.global_step)

    def save_samples(self, step):
        model.eval()
        with torch.no_grad():
            sample = torch.randn(64, latent_dim).to(device)
            sample = model.decode(sample).cpu()
            img_grid = make_grid(sample, normalize=True, nrow=8)

            # Save image to BytesIO object
            buf = io.BytesIO()
            plt.figure(figsize=(10, 10))
            plt.imshow(np.transpose(img_grid, (1, 2, 0)))
            plt.axis('off')
            plt.title(f'Generated samples - Step {step}')
            plt.savefig(buf, format='png')
            plt.close()
            buf.seek(0)

            # Log image to Hugging Face
            self.hf_writer.add_image('Generated samples', img_grid, step, dataformats='CHW')
            self.hf_writer.add_figure('Generated samples (matplotlib)', plt.gcf(), step)

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    logging_dir='./logs',
    logging_steps=100,
    save_steps=1000,
    learning_rate=learning_rate,
    remove_unused_columns=False,
)

# Initialize trainer
trainer = VAETrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

# Create samples directory
os.makedirs('samples', exist_ok=True)

# Train the model
trainer.train()

print("Training complete!")

# Close Hugging Face SummaryWriter
hf_writer.close()

# Push the model to the Hugging Face Hub
trainer.push_to_hub(commit_message="Upload VAE model")

Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 42215256.21it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Step,Training Loss
100,518288.88
200,366205.12
300,311372.96
400,278783.44
500,263293.58
600,252716.2
700,243776.64
800,237938.54
900,234418.48
1000,232251.92


Training complete!


AttributeError: 'VAE' object has no attribute 'config'

In [8]:
!pip install pytorch-msssim torchmetrics[image] torch-fidelity --quiet

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from huggingface_hub import HfApi, notebook_login
from huggingface_hub import HFSummaryWriter
import os
from transformers import Trainer, TrainingArguments
import io
from pytorch_msssim import ssim
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 128
latent_dim = 256
learning_rate = 1e-3
num_epochs = 50
warmup_steps = 1000
max_grad_norm = 1.0

repo_name = "diegomrodrigues/vae-cifar10-experiment"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class CIFAR10Wrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        return {"pixel_values": image}

    def __len__(self):
        return len(self.dataset)

train_dataset = CIFAR10Wrapper(datasets.CIFAR10(root='./data', train=True, download=True, transform=transform))

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 512),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 2 * 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(-1, 256, 2, 2)
        return self.decoder(h)

    def forward(self, pixel_values):
        mu, logvar = self.encode(pixel_values)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE(latent_dim).to(device)

def loss_function(recon_x, x, mu, logvar, kl_weight):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + kl_weight * KLD, BCE, KLD

hf_writer = HFSummaryWriter(repo_id=repo_name)

class VAETrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hf_writer = hf_writer
        self.kl_weight = 0.0

    def compute_loss(self, model, inputs, return_outputs=False):
        pixel_values = inputs['pixel_values']
        recon_batch, mu, logvar = model(pixel_values)
        loss, bce, kld = loss_function(recon_batch, pixel_values, mu, logvar, self.kl_weight)

        # KL annealing
        self.kl_weight = min(self.kl_weight + 1 / (10 * len(self.train_dataset)), 1.0)

        # Log metrics
        self.log_metrics(loss, bce, kld, recon_batch, pixel_values)

        return (loss, (loss, recon_batch)) if return_outputs else loss

    def log_metrics(self, loss, bce, kld, recon_batch, pixel_values):
        if self.state.global_step % 100 == 0:
            self.hf_writer.add_scalar('Loss/total', loss.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/BCE', bce.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/KLD', kld.item(), self.state.global_step)
            self.hf_writer.add_scalar('Loss/ELBO', -loss.item(), self.state.global_step)
            self.hf_writer.add_scalar('KL_weight', self.kl_weight, self.state.global_step)

            # Calculate SSIM
            ssim_val = ssim(recon_batch, pixel_values, data_range=2.0, size_average=True)
            self.hf_writer.add_scalar('Metrics/SSIM', ssim_val.item(), self.state.global_step)

    def log(self, logs):
        super().log(logs)
        if self.state.global_step % 1000 == 0:
            self.save_samples(self.state.global_step)

    def save_samples(self, step):
        model.eval()
        with torch.no_grad():
            sample = torch.randn(64, latent_dim).to(device)
            sample = model.decode(sample).cpu()
            img_grid = make_grid(sample, normalize=True, nrow=8)

            buf = io.BytesIO()
            plt.figure(figsize=(10, 10))
            plt.imshow(np.transpose(img_grid, (1, 2, 0)))
            plt.axis('off')
            plt.title(f'Generated samples - Step {step}')
            plt.savefig(buf, format='png')
            plt.close()
            buf.seek(0)

            self.hf_writer.add_image('Generated samples', img_grid, step, dataformats='CHW')
            self.hf_writer.add_figure('Generated samples (matplotlib)', plt.gcf(), step)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    logging_dir='./logs',
    logging_steps=100,
    save_steps=1000,
    learning_rate=learning_rate,
    remove_unused_columns=False,
    fp16=True,
    gradient_accumulation_steps=2,
    warmup_steps=warmup_steps,
    max_grad_norm=max_grad_norm,
)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataset) * num_epochs // (batch_size * training_args.gradient_accumulation_steps)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

trainer = VAETrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    optimizers=(optimizer, scheduler),
)

os.makedirs('samples', exist_ok=True)

trainer.train()

print("Training complete!")

hf_writer.close()

trainer.push_to_hub(commit_message="Upload improved VAE model")

Using device: cuda
Files already downloaded and verified


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss
100,145900.04
200,66875.19
300,40583.895
400,33212.0325
500,27485.3325
600,24392.0875
700,22648.275
800,21270.095
900,20294.495
1000,19484.7513


KeyboardInterrupt: 

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from huggingface_hub import HfApi, notebook_login
from huggingface_hub import HFSummaryWriter
import os
import io
from pytorch_msssim import ssim
import math
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 256
latent_dim = 256
learning_rate = 1e-3
num_epochs = 50
warmup_steps = 1000
max_grad_norm = 1.0

repo_name = "diegomrodrigues/vae-cifar10-experiment"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class CIFAR10Wrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        return {"pixel_values": image}

    def __len__(self):
        return len(self.dataset)

train_dataset = CIFAR10Wrapper(datasets.CIFAR10(root='./data', train=True, download=True, transform=transform))

# Use DataLoader with num_workers and pin_memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 512),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 2 * 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(-1, 256, 2, 2)
        return self.decoder(h)

    def forward(self, pixel_values):
        mu, logvar = self.encode(pixel_values)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE(latent_dim).to(device)

def loss_function(recon_x, x, mu, logvar, kl_weight):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + kl_weight * KLD, BCE, KLD

hf_writer = HFSummaryWriter(repo_id=repo_name)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
scaler = GradScaler()

def train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch):
    model.train()
    train_loss = 0
    kl_weight = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(progress_bar):
        pixel_values = batch['pixel_values'].to(device)
        optimizer.zero_grad()

        with autocast():
            recon_batch, mu, logvar = model(pixel_values)
            loss, bce, kld = loss_function(recon_batch, pixel_values, mu, logvar, kl_weight)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()
        kl_weight = min(kl_weight + 1 / (10 * len(train_loader)), 1.0)

        if batch_idx % 100 == 0:
            hf_writer.add_scalar('Loss/total', loss.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/BCE', bce.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/KLD', kld.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/ELBO', -loss.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('KL_weight', kl_weight, epoch * len(train_loader) + batch_idx)

            with torch.no_grad():
                ssim_val = ssim(recon_batch.float(), pixel_values.float(), data_range=2.0, size_average=True)
            hf_writer.add_scalar('Metrics/SSIM', ssim_val.item(), epoch * len(train_loader) + batch_idx)

        progress_bar.set_postfix({'loss': loss.item()})

    return train_loss / len(train_loader)

def save_samples(model, epoch):
    model.eval()
    with torch.no_grad():
        sample = torch.randn(64, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        img_grid = make_grid(sample, normalize=True, nrow=8)

        buf = io.BytesIO()
        plt.figure(figsize=(10, 10))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)))
        plt.axis('off')
        plt.title(f'Generated samples - Epoch {epoch+1}')
        plt.savefig(buf, format='png')
        plt.close()
        buf.seek(0)

        hf_writer.add_image('Generated samples', img_grid, epoch, dataformats='CHW')
        hf_writer.add_figure('Generated samples (matplotlib)', plt.gcf(), epoch)

os.makedirs('samples', exist_ok=True)

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch)
    print(f'Epoch: {epoch+1}, Average loss: {train_loss:.4f}')

    if (epoch + 1) % 5 == 0:
        save_samples(model, epoch)

print("Training complete!")

hf_writer.close()

# Save the model
torch.save(model.state_dict(), 'vae_model.pth')

# Push the model to the Hugging Face Hub
api = HfApi()
api.upload_file(
    path_or_fileobj="vae_model.pth",
    path_in_repo="vae_model.pth",
    repo_id=repo_name,
    commit_message=f"Upload VAE model after {num_epochs} epochs"
)

Using device: cuda
Files already downloaded and verified


  scaler = GradScaler()
  with autocast():
Epoch 1/50: 100%|██████████| 196/196 [00:16<00:00, 11.66it/s, loss=2.97e+4]


Epoch: 1, Average loss: 237052.2893


Epoch 2/50: 100%|██████████| 196/196 [00:15<00:00, 12.83it/s, loss=1.95e+4]


Epoch: 2, Average loss: 74693.8434


Epoch 3/50: 100%|██████████| 196/196 [00:15<00:00, 12.67it/s, loss=1.63e+4]


Epoch: 3, Average loss: 53958.9713


Epoch 4/50: 100%|██████████| 196/196 [00:15<00:00, 12.88it/s, loss=1.58e+4]


Epoch: 4, Average loss: 46977.7288


Epoch 5/50: 100%|██████████| 196/196 [00:15<00:00, 12.60it/s, loss=1.43e+4]


Epoch: 5, Average loss: 43244.9694


Epoch 6/50: 100%|██████████| 196/196 [00:15<00:00, 12.86it/s, loss=1.26e+4]


Epoch: 6, Average loss: 40435.9524


Epoch 7/50: 100%|██████████| 196/196 [00:14<00:00, 13.15it/s, loss=1.28e+4]


Epoch: 7, Average loss: 37882.2492


Epoch 8/50: 100%|██████████| 196/196 [00:15<00:00, 12.76it/s, loss=1.21e+4]


Epoch: 8, Average loss: 36229.2162


Epoch 9/50: 100%|██████████| 196/196 [00:15<00:00, 13.03it/s, loss=1.17e+4]


Epoch: 9, Average loss: 34921.7306


Epoch 10/50: 100%|██████████| 196/196 [00:15<00:00, 12.59it/s, loss=1.43e+4]


Epoch: 10, Average loss: 33868.5252


Epoch 11/50: 100%|██████████| 196/196 [00:15<00:00, 12.99it/s, loss=1.29e+4]


Epoch: 11, Average loss: 33129.9906


Epoch 12/50: 100%|██████████| 196/196 [00:15<00:00, 12.67it/s, loss=11472.5]


Epoch: 12, Average loss: 32480.8475


Epoch 13/50: 100%|██████████| 196/196 [00:14<00:00, 13.14it/s, loss=1.07e+4]


Epoch: 13, Average loss: 31864.9152


Epoch 14/50: 100%|██████████| 196/196 [00:15<00:00, 12.76it/s, loss=1.09e+4]


Epoch: 14, Average loss: 31349.5100


Epoch 15/50: 100%|██████████| 196/196 [00:15<00:00, 12.36it/s, loss=1.03e+4]


Epoch: 15, Average loss: 30835.5512


Epoch 16/50: 100%|██████████| 196/196 [00:15<00:00, 12.41it/s, loss=1.05e+4]


Epoch: 16, Average loss: 30366.6921


Epoch 17/50: 100%|██████████| 196/196 [00:15<00:00, 12.67it/s, loss=1.06e+4]


Epoch: 17, Average loss: 29944.0989


Epoch 18/50: 100%|██████████| 196/196 [00:15<00:00, 12.62it/s, loss=1.03e+4]


Epoch: 18, Average loss: 29699.3932


Epoch 19/50: 100%|██████████| 196/196 [00:15<00:00, 12.92it/s, loss=9.92e+3]


Epoch: 19, Average loss: 29369.1174


Epoch 20/50: 100%|██████████| 196/196 [00:15<00:00, 12.67it/s, loss=9.14e+3]


Epoch: 20, Average loss: 28935.8821


Epoch 21/50: 100%|██████████| 196/196 [00:15<00:00, 12.86it/s, loss=1.02e+4]


Epoch: 21, Average loss: 28577.7518


Epoch 22/50: 100%|██████████| 196/196 [00:15<00:00, 12.79it/s, loss=9.85e+3]


Epoch: 22, Average loss: 28279.8664


Epoch 23/50: 100%|██████████| 196/196 [00:15<00:00, 12.89it/s, loss=9.39e+3]


Epoch: 23, Average loss: 28023.1245


Epoch 24/50: 100%|██████████| 196/196 [00:15<00:00, 12.79it/s, loss=1.09e+4]


Epoch: 24, Average loss: 27677.0682


Epoch 25/50: 100%|██████████| 196/196 [00:15<00:00, 12.90it/s, loss=9.87e+3]


Epoch: 25, Average loss: 27507.1102


Epoch 26/50: 100%|██████████| 196/196 [00:15<00:00, 12.77it/s, loss=9.36e+3]


Epoch: 26, Average loss: 27239.8372


Epoch 27/50: 100%|██████████| 196/196 [00:15<00:00, 12.88it/s, loss=9.55e+3]


Epoch: 27, Average loss: 27017.3638


Epoch 28/50: 100%|██████████| 196/196 [00:15<00:00, 12.77it/s, loss=9.82e+3]


Epoch: 28, Average loss: 26796.5733


Epoch 29/50: 100%|██████████| 196/196 [00:15<00:00, 12.82it/s, loss=8.89e+3]


Epoch: 29, Average loss: 26649.4190


Epoch 30/50: 100%|██████████| 196/196 [00:15<00:00, 12.87it/s, loss=9.77e+3]


Epoch: 30, Average loss: 26405.4328


Epoch 31/50: 100%|██████████| 196/196 [00:15<00:00, 12.79it/s, loss=8.9e+3]


Epoch: 31, Average loss: 26323.7616


Epoch 32/50: 100%|██████████| 196/196 [00:15<00:00, 12.84it/s, loss=9.07e+3]


Epoch: 32, Average loss: 26094.1561


Epoch 33/50: 100%|██████████| 196/196 [00:15<00:00, 13.05it/s, loss=1.02e+4]


Epoch: 33, Average loss: 25954.3329


Epoch 34/50: 100%|██████████| 196/196 [00:15<00:00, 12.76it/s, loss=8.85e+3]


Epoch: 34, Average loss: 25774.4378


Epoch 35/50: 100%|██████████| 196/196 [00:16<00:00, 12.05it/s, loss=9.7e+3]


Epoch: 35, Average loss: 25648.4119


Epoch 36/50: 100%|██████████| 196/196 [00:16<00:00, 12.23it/s, loss=9.6e+3]


Epoch: 36, Average loss: 25554.9504


Epoch 37/50: 100%|██████████| 196/196 [00:16<00:00, 12.24it/s, loss=9.3e+3]


Epoch: 37, Average loss: 25417.0569


Epoch 38/50: 100%|██████████| 196/196 [00:16<00:00, 12.10it/s, loss=9.57e+3]


Epoch: 38, Average loss: 25335.4056


Epoch 39/50: 100%|██████████| 196/196 [00:16<00:00, 12.14it/s, loss=9.33e+3]


Epoch: 39, Average loss: 25216.7052


Epoch 40/50: 100%|██████████| 196/196 [00:15<00:00, 12.26it/s, loss=9.29e+3]


Epoch: 40, Average loss: 25034.9291


Epoch 41/50: 100%|██████████| 196/196 [00:15<00:00, 12.51it/s, loss=8.91e+3]


Epoch: 41, Average loss: 24974.2792


Epoch 42/50: 100%|██████████| 196/196 [00:15<00:00, 12.57it/s, loss=1.08e+4]


Epoch: 42, Average loss: 24896.7857


Epoch 43/50:  31%|███       | 61/196 [00:04<00:09, 13.74it/s, loss=2.42e+4]


KeyboardInterrupt: 

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from huggingface_hub import HfApi, notebook_login
from huggingface_hub import HFSummaryWriter
import os
import io
from pytorch_msssim import ssim
import math
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 256
latent_dim = 256
learning_rate = 1e-3
num_epochs = 50
warmup_steps = 1000
max_grad_norm = 1.0

repo_name = "diegomrodrigues/vae-cifar10-experiment"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class CIFAR10Wrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        return {"pixel_values": image}

    def __len__(self):
        return len(self.dataset)

train_dataset = CIFAR10Wrapper(datasets.CIFAR10(root='./data', train=True, download=True, transform=transform))
test_dataset = CIFAR10Wrapper(datasets.CIFAR10(root='./data', train=False, download=True, transform=transform))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 2 * 2, 512),
            nn.LeakyReLU(0.2)
        )
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 2 * 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(-1, 256, 2, 2)
        return self.decoder(h)

    def forward(self, pixel_values):
        mu, logvar = self.encode(pixel_values)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE(latent_dim).to(device)

def loss_function(recon_x, x, mu, logvar, kl_weight):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + kl_weight * KLD, BCE, KLD

hf_writer = HFSummaryWriter(repo_id=repo_name)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
scaler = GradScaler()

def train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch):
    model.train()
    train_loss = 0
    kl_weight = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, batch in enumerate(progress_bar):
        pixel_values = batch['pixel_values'].to(device)
        optimizer.zero_grad()

        with autocast():
            recon_batch, mu, logvar = model(pixel_values)
            loss, bce, kld = loss_function(recon_batch, pixel_values, mu, logvar, kl_weight)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        train_loss += loss.item()
        kl_weight = min(kl_weight + 1 / (10 * len(train_loader)), 1.0)

        if batch_idx % 100 == 0:
            hf_writer.add_scalar('Loss/total', loss.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/BCE', bce.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/KLD', kld.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('Loss/ELBO', -loss.item(), epoch * len(train_loader) + batch_idx)
            hf_writer.add_scalar('KL_weight', kl_weight, epoch * len(train_loader) + batch_idx)

            with torch.no_grad():
                ssim_val = ssim(recon_batch.float(), pixel_values.float(), data_range=2.0, size_average=True)
            hf_writer.add_scalar('Metrics/SSIM', ssim_val.item(), epoch * len(train_loader) + batch_idx)

        progress_bar.set_postfix({'loss': loss.item()})

    return train_loss / len(train_loader)

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    ssim_sum = 0

    with torch.no_grad():
        for batch in test_loader:
            pixel_values = batch['pixel_values'].to(device)
            recon_batch, mu, logvar = model(pixel_values)
            loss, _, _ = loss_function(recon_batch, pixel_values, mu, logvar, 1.0)
            test_loss += loss.item()

            ssim_val = ssim(recon_batch.float(), pixel_values.float(), data_range=2.0, size_average=True)
            ssim_sum += ssim_val.item()

    test_loss /= len(test_loader)
    ssim_avg = ssim_sum / len(test_loader)

    return test_loss, ssim_avg

def save_samples(model, epoch):
    model.eval()
    with torch.no_grad():
        sample = torch.randn(64, latent_dim).to(device)
        sample = model.decode(sample).cpu()
        img_grid = make_grid(sample, normalize=True, nrow=8)

        buf = io.BytesIO()
        plt.figure(figsize=(10, 10))
        plt.imshow(np.transpose(img_grid, (1, 2, 0)))
        plt.axis('off')
        plt.title(f'Generated samples - Epoch {epoch+1}')
        plt.savefig(buf, format='png')
        plt.close()
        buf.seek(0)

        hf_writer.add_image('Generated samples', img_grid, epoch, dataformats='CHW')
        hf_writer.add_figure('Generated samples (matplotlib)', plt.gcf(), epoch)

os.makedirs('samples', exist_ok=True)

best_ssim = 0
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, epoch)
    test_loss, ssim_avg = evaluate(model, test_loader)

    print(f'Epoch: {epoch+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, SSIM: {ssim_avg:.4f}')

    hf_writer.add_scalar('Epoch/Train_Loss', train_loss, epoch)
    hf_writer.add_scalar('Epoch/Test_Loss', test_loss, epoch)
    hf_writer.add_scalar('Epoch/SSIM', ssim_avg, epoch)

    if ssim_avg > best_ssim:
        best_ssim = ssim_avg
        torch.save(model.state_dict(), 'best_vae_model.pth')

    if (epoch + 1) % 5 == 0:
        save_samples(model, epoch)

print("Training complete!")

hf_writer.close()

# Save the final model
torch.save(model.state_dict(), 'final_vae_model.pth')

# Push the models to the Hugging Face Hub
api = HfApi()
api.upload_file(
    path_or_fileobj="best_vae_model.pth",
    path_in_repo="best_vae_model.pth",
    repo_id=repo_name,
    commit_message=f"Upload best VAE model (SSIM: {best_ssim:.4f})"
)
api.upload_file(
    path_or_fileobj="final_vae_model.pth",
    path_in_repo="final_vae_model.pth",
    repo_id=repo_name,
    commit_message=f"Upload final VAE model after {num_epochs} epochs"
)

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


  scaler = GradScaler()
  with autocast():
Epoch 1/50: 100%|██████████| 196/196 [00:20<00:00,  9.54it/s, loss=2.53e+4]


Epoch: 1, Train loss: 208892.1422, Test loss: 110205.3153, SSIM: 0.1691


Epoch 2/50: 100%|██████████| 196/196 [00:17<00:00, 11.47it/s, loss=1.91e+4]


Epoch: 2, Train loss: 65264.6842, Test loss: 97657.4616, SSIM: 0.3244


Epoch 3/50: 100%|██████████| 196/196 [00:15<00:00, 12.43it/s, loss=1.51e+4]


Epoch: 3, Train loss: 48583.6182, Test loss: 92422.0516, SSIM: 0.3876


Epoch 4/50: 100%|██████████| 196/196 [00:15<00:00, 12.51it/s, loss=1.31e+4]


Epoch: 4, Train loss: 42488.2415, Test loss: 87274.1672, SSIM: 0.4380


Epoch 5/50: 100%|██████████| 196/196 [00:15<00:00, 12.42it/s, loss=1.35e+4]


Epoch: 5, Train loss: 38184.9651, Test loss: 82771.1915, SSIM: 0.4580


Epoch 6/50: 100%|██████████| 196/196 [00:15<00:00, 12.68it/s, loss=1.16e+4]


Epoch: 6, Train loss: 35488.7022, Test loss: 79972.7266, SSIM: 0.4857


Epoch 7/50: 100%|██████████| 196/196 [00:15<00:00, 12.44it/s, loss=1.11e+4]


Epoch: 7, Train loss: 33283.9070, Test loss: 76920.8605, SSIM: 0.5120


Epoch 8/50: 100%|██████████| 196/196 [00:15<00:00, 12.72it/s, loss=1.2e+4]


Epoch: 8, Train loss: 31932.1242, Test loss: 75324.1613, SSIM: 0.5296


Epoch 9/50: 100%|██████████| 196/196 [00:15<00:00, 12.49it/s, loss=1.12e+4]


Epoch: 9, Train loss: 30746.0305, Test loss: 73753.6191, SSIM: 0.5263


Epoch 10/50: 100%|██████████| 196/196 [00:15<00:00, 12.58it/s, loss=1.05e+4]


Epoch: 10, Train loss: 29943.6700, Test loss: 72658.5635, SSIM: 0.5476


Epoch 11/50: 100%|██████████| 196/196 [00:15<00:00, 12.69it/s, loss=1.02e+4]


Epoch: 11, Train loss: 29120.5237, Test loss: 72937.5181, SSIM: 0.5593


Epoch 12/50: 100%|██████████| 196/196 [00:15<00:00, 12.51it/s, loss=9.89e+3]


Epoch: 12, Train loss: 28382.5634, Test loss: 72781.7167, SSIM: 0.5684


Epoch 13/50: 100%|██████████| 196/196 [00:16<00:00, 12.19it/s, loss=9.74e+3]


Epoch: 13, Train loss: 27767.7339, Test loss: 71200.8816, SSIM: 0.5822


Epoch 14/50: 100%|██████████| 196/196 [00:16<00:00, 11.86it/s, loss=1.01e+4]


Epoch: 14, Train loss: 27210.9630, Test loss: 71599.8082, SSIM: 0.5836


Epoch 15/50: 100%|██████████| 196/196 [00:16<00:00, 12.22it/s, loss=9.28e+3]


Epoch: 15, Train loss: 26884.6750, Test loss: 70247.0707, SSIM: 0.5863


Epoch 16/50: 100%|██████████| 196/196 [00:15<00:00, 12.31it/s, loss=8.8e+3]


Epoch: 16, Train loss: 26388.8745, Test loss: 69577.8418, SSIM: 0.5956


Epoch 17/50: 100%|██████████| 196/196 [00:16<00:00, 12.08it/s, loss=9.47e+3]


Epoch: 17, Train loss: 26052.3059, Test loss: 70125.4628, SSIM: 0.6040


Epoch 18/50: 100%|██████████| 196/196 [00:15<00:00, 12.42it/s, loss=8.85e+3]


Epoch: 18, Train loss: 25641.2128, Test loss: 69425.6457, SSIM: 0.6078


Epoch 19/50: 100%|██████████| 196/196 [00:16<00:00, 11.76it/s, loss=8.43e+3]


Epoch: 19, Train loss: 25333.5794, Test loss: 69110.2100, SSIM: 0.6113


Epoch 20/50: 100%|██████████| 196/196 [00:16<00:00, 11.81it/s, loss=9.95e+3]


Epoch: 20, Train loss: 25079.9877, Test loss: 69677.1392, SSIM: 0.6119


Epoch 21/50: 100%|██████████| 196/196 [00:16<00:00, 12.03it/s, loss=9.22e+3]


Epoch: 21, Train loss: 24820.7493, Test loss: 69233.6344, SSIM: 0.6154


Epoch 22/50: 100%|██████████| 196/196 [00:16<00:00, 12.07it/s, loss=8.32e+3]


Epoch: 22, Train loss: 24591.2612, Test loss: 68801.5620, SSIM: 0.6205


Epoch 23/50: 100%|██████████| 196/196 [00:15<00:00, 12.25it/s, loss=9.1e+3]


Epoch: 23, Train loss: 24340.7906, Test loss: 68743.1661, SSIM: 0.6220


Epoch 24/50: 100%|██████████| 196/196 [00:16<00:00, 12.18it/s, loss=9.16e+3]


Epoch: 24, Train loss: 24215.7332, Test loss: 68695.2406, SSIM: 0.6303


Epoch 25/50: 100%|██████████| 196/196 [00:16<00:00, 11.96it/s, loss=9.48e+3]


Epoch: 25, Train loss: 23958.2109, Test loss: 69333.0205, SSIM: 0.6287


Epoch 26/50: 100%|██████████| 196/196 [00:15<00:00, 12.53it/s, loss=8.53e+3]


Epoch: 26, Train loss: 23790.5694, Test loss: 69043.6849, SSIM: 0.6365


Epoch 27/50: 100%|██████████| 196/196 [00:15<00:00, 12.25it/s, loss=8.8e+3]


Epoch: 27, Train loss: 23541.2709, Test loss: 67993.8880, SSIM: 0.6334


Epoch 28/50: 100%|██████████| 196/196 [00:15<00:00, 12.26it/s, loss=8.21e+3]


Epoch: 28, Train loss: 23412.5023, Test loss: 68344.8258, SSIM: 0.6370


Epoch 29/50: 100%|██████████| 196/196 [00:16<00:00, 12.18it/s, loss=8.54e+3]


Epoch: 29, Train loss: 23245.4453, Test loss: 68300.6609, SSIM: 0.6412


Epoch 30/50: 100%|██████████| 196/196 [00:16<00:00, 11.85it/s, loss=8.77e+3]


Epoch: 30, Train loss: 23067.4808, Test loss: 67598.4343, SSIM: 0.6429


Epoch 31/50: 100%|██████████| 196/196 [00:16<00:00, 12.06it/s, loss=8.77e+3]


Epoch: 31, Train loss: 22922.2498, Test loss: 68173.3482, SSIM: 0.6432


Epoch 32/50: 100%|██████████| 196/196 [00:15<00:00, 12.27it/s, loss=8.26e+3]


Epoch: 32, Train loss: 22779.0818, Test loss: 68761.6873, SSIM: 0.6470


Epoch 33/50: 100%|██████████| 196/196 [00:15<00:00, 12.33it/s, loss=8.74e+3]


Epoch: 33, Train loss: 22661.8198, Test loss: 67973.7948, SSIM: 0.6521


Epoch 34/50: 100%|██████████| 196/196 [00:16<00:00, 12.13it/s, loss=8.35e+3]


Epoch: 34, Train loss: 22526.9342, Test loss: 67804.2335, SSIM: 0.6536


Epoch 35/50: 100%|██████████| 196/196 [00:16<00:00, 12.00it/s, loss=7.6e+3]


Epoch: 35, Train loss: 22452.0948, Test loss: 67468.7375, SSIM: 0.6557


Epoch 36/50: 100%|██████████| 196/196 [00:16<00:00, 12.25it/s, loss=8.05e+3]


Epoch: 36, Train loss: 22288.4453, Test loss: 68193.4418, SSIM: 0.6514


Epoch 37/50: 100%|██████████| 196/196 [00:15<00:00, 12.25it/s, loss=8.57e+3]


Epoch: 37, Train loss: 22192.7940, Test loss: 67725.7502, SSIM: 0.6566


Epoch 38/50: 100%|██████████| 196/196 [00:16<00:00, 11.92it/s, loss=8.44e+3]


Epoch: 38, Train loss: 22171.8646, Test loss: 67955.8831, SSIM: 0.6569


Epoch 39/50: 100%|██████████| 196/196 [00:16<00:00, 12.05it/s, loss=8.18e+3]


Epoch: 39, Train loss: 22056.9778, Test loss: 67609.4201, SSIM: 0.6613


Epoch 40/50: 100%|██████████| 196/196 [00:16<00:00, 12.15it/s, loss=7.94e+3]


Epoch: 40, Train loss: 21946.2968, Test loss: 67798.1544, SSIM: 0.6635


Epoch 41/50: 100%|██████████| 196/196 [00:16<00:00, 12.18it/s, loss=8.09e+3]


Epoch: 41, Train loss: 21917.5504, Test loss: 68006.4895, SSIM: 0.6651


Epoch 42/50: 100%|██████████| 196/196 [00:15<00:00, 12.61it/s, loss=8.62e+3]


Epoch: 42, Train loss: 21835.1846, Test loss: 67808.3132, SSIM: 0.6662


Epoch 43/50: 100%|██████████| 196/196 [00:15<00:00, 12.92it/s, loss=7.44e+3]


Epoch: 43, Train loss: 21767.6850, Test loss: 68056.9548, SSIM: 0.6655


Epoch 44/50: 100%|██████████| 196/196 [00:15<00:00, 12.32it/s, loss=7.71e+3]


Epoch: 44, Train loss: 21747.4370, Test loss: 68200.0296, SSIM: 0.6676


Epoch 45/50: 100%|██████████| 196/196 [00:16<00:00, 12.07it/s, loss=8.25e+3]


Epoch: 45, Train loss: 21732.5660, Test loss: 68548.3939, SSIM: 0.6695


Epoch 46/50: 100%|██████████| 196/196 [00:15<00:00, 12.61it/s, loss=8.75e+3]


Epoch: 46, Train loss: 21674.7728, Test loss: 69014.6037, SSIM: 0.6713


Epoch 47/50: 100%|██████████| 196/196 [00:15<00:00, 12.28it/s, loss=7.82e+3]


Epoch: 47, Train loss: 21702.0868, Test loss: 70882.9418, SSIM: 0.6742


Epoch 48/50: 100%|██████████| 196/196 [00:16<00:00, 12.20it/s, loss=1e+4]


Epoch: 48, Train loss: 21679.6187, Test loss: 74035.5319, SSIM: 0.6790


Epoch 49/50: 100%|██████████| 196/196 [00:16<00:00, 12.05it/s, loss=7.94e+3]


Epoch: 49, Train loss: 21716.9324, Test loss: 75984.0952, SSIM: 0.6837


Epoch 50/50: 100%|██████████| 196/196 [00:16<00:00, 12.06it/s, loss=8.26e+3]


Epoch: 50, Train loss: 21627.8107, Test loss: 76373.3919, SSIM: 0.6832
Training complete!


best_vae_model.pth:   0%|          | 0.00/9.75M [00:00<?, ?B/s]

final_vae_model.pth:   0%|          | 0.00/9.75M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/diegomrodrigues/vae-cifar10-experiment/commit/437caa1c467be29d584ceaaee0c9c8326bd8c394', commit_message='Upload final VAE model after 50 epochs', commit_description='', oid='437caa1c467be29d584ceaaee0c9c8326bd8c394', pr_url=None, pr_revision=None, pr_num=None)