In [1]:
!pip install torch torchvision tensorboard tqdm matplotlib numpy



In [2]:
from utils.train import train_model
from data.datasets import get_cifar10_dataloaders
from models.RealNVP import RealNVP, RealNVPLoss
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from typing import Callable, Optional, Dict
from enum import IntEnum

batch_size = 64
num_epochs = 20
learning_rate = 0.001
num_blocks = 8
num_scales = 2
in_channels = 3
mid_channels = 64
model_name = 'RealNVP_CIFAR10'
device = 'cuda'

print(f'Using device: {device}')

    # Data loading and preprocessing
train_loader, val_loader, test_loader, denorm_params = get_cifar10_dataloaders(batch_size=batch_size)

    # Model, loss function, optimizer
model = RealNVP(num_scales=num_scales, in_channels=in_channels, mid_channels=mid_channels, num_blocks=num_blocks).to(device)
criterion = RealNVPLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # TensorBoard writer
writer = SummaryWriter(log_dir=f'logs/{model_name}')

    # Directory to save generated images
image_save_dir = f'images/{model_name}'
os.makedirs(image_save_dir, exist_ok=True)

    # Denormalization parameters (if any). For CIFAR-10 with ToTensor(), images are in [0,1]
denorm_params = None  # Modify if normalization was applied

    # Train the model using the original training pipeline
train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        num_epochs=num_epochs,
        device=device,
        model_name=model_name,
        logger=writer,
        save_best=True,
        denorm_params=denorm_params,
        model_update_fn=None,  # Define if any model-specific updates are needed
        generate_images_flag=True
    )



Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:01<00:00, 88.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


  WeightNorm.apply(module, name, dim)
Training Epoch 1/20: 100%|██████████| 782/782 [02:13<00:00,  5.85it/s, Loss=9.41e+3]


Epoch [1/20], Training Loss: 11196.250477, Validation Loss: 320317.555209
Best model saved with validation loss: 320317.555209
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_1.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_1.png


Training Epoch 2/20: 100%|██████████| 782/782 [02:11<00:00,  5.94it/s, Loss=9.81e+3]


Epoch [2/20], Training Loss: 9523.634831, Validation Loss: 16357.942047
Best model saved with validation loss: 16357.942047
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_2.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_2.png


Training Epoch 3/20: 100%|██████████| 782/782 [02:14<00:00,  5.81it/s, Loss=9.44e+3]


Epoch [3/20], Training Loss: 9245.084893, Validation Loss: 88228.515428
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_3.png


Training Epoch 4/20: 100%|██████████| 782/782 [02:18<00:00,  5.65it/s, Loss=9.21e+3]


Epoch [4/20], Training Loss: 8976.499253, Validation Loss: 61281.382659
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_4.png


Training Epoch 5/20: 100%|██████████| 782/782 [02:13<00:00,  5.88it/s, Loss=1.14e+4]


Epoch [5/20], Training Loss: 9250.583839, Validation Loss: 11169.557625
Best model saved with validation loss: 11169.557625
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_5.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_5.png


Training Epoch 6/20: 100%|██████████| 782/782 [02:13<00:00,  5.85it/s, Loss=9.1e+3]


Epoch [6/20], Training Loss: 9430.764499, Validation Loss: 21813513.937444
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_6.png


Training Epoch 7/20: 100%|██████████| 782/782 [02:13<00:00,  5.86it/s, Loss=9.59e+3]


Epoch [7/20], Training Loss: 9108.500514, Validation Loss: 417609.332850
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_7.png


Training Epoch 8/20: 100%|██████████| 782/782 [02:12<00:00,  5.90it/s, Loss=8.6e+3]


Epoch [8/20], Training Loss: 8924.986573, Validation Loss: 134834.523248
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_8.png


Training Epoch 9/20: 100%|██████████| 782/782 [02:12<00:00,  5.90it/s, Loss=1.03e+4]


Epoch [9/20], Training Loss: 8694.000083, Validation Loss: 8869.712297
Best model saved with validation loss: 8869.712297
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_9.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_9.png


Training Epoch 10/20: 100%|██████████| 782/782 [02:12<00:00,  5.92it/s, Loss=8.91e+3]


Epoch [10/20], Training Loss: 8708.070231, Validation Loss: 172671.712578
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_10.png


Training Epoch 11/20: 100%|██████████| 782/782 [02:13<00:00,  5.88it/s, Loss=9.71e+3]


Epoch [11/20], Training Loss: 8483.931152, Validation Loss: 101339.472872
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_11.png


Training Epoch 12/20: 100%|██████████| 782/782 [02:11<00:00,  5.93it/s, Loss=9.33e+3]


Epoch [12/20], Training Loss: 8594.412249, Validation Loss: 3650861407.853700
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_12.png


Training Epoch 13/20: 100%|██████████| 782/782 [02:12<00:00,  5.89it/s, Loss=8.68e+3]


Epoch [13/20], Training Loss: 8385.314159, Validation Loss: 56705.065300
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_13.png


Training Epoch 14/20: 100%|██████████| 782/782 [02:11<00:00,  5.93it/s, Loss=8.68e+3]


Epoch [14/20], Training Loss: 8310.035315, Validation Loss: 16573.405445
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_14.png


Training Epoch 15/20: 100%|██████████| 782/782 [02:13<00:00,  5.85it/s, Loss=8.48e+3]


Epoch [15/20], Training Loss: 8592.746955, Validation Loss: 8477.790119
Best model saved with validation loss: 8477.790119
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_15.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_15.png


Training Epoch 16/20: 100%|██████████| 782/782 [02:13<00:00,  5.84it/s, Loss=9.48e+3]


Epoch [16/20], Training Loss: 8409.794547, Validation Loss: 8511.496597
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_16.png


Training Epoch 17/20: 100%|██████████| 782/782 [02:13<00:00,  5.88it/s, Loss=7.9e+3]


Epoch [17/20], Training Loss: 8215.337149, Validation Loss: 11251.482625
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_17.png


Training Epoch 18/20: 100%|██████████| 782/782 [02:13<00:00,  5.86it/s, Loss=9.02e+3]


Epoch [18/20], Training Loss: 8155.633393, Validation Loss: 8131.734456
Best model saved with validation loss: 8131.734456
Reconstructed images saved to images/RealNVP_CIFAR10/reconstructions_epoch_18.png
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_18.png


Training Epoch 19/20: 100%|██████████| 782/782 [02:13<00:00,  5.87it/s, Loss=9.83e+3]


Epoch [19/20], Training Loss: 8327.916289, Validation Loss: 8809.931844
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_19.png


Training Epoch 20/20: 100%|██████████| 782/782 [02:16<00:00,  5.71it/s, Loss=8.14e+3]


Epoch [20/20], Training Loss: 8173.797073, Validation Loss: 13156.217434
Generated images saved to images/RealNVP_CIFAR10/generated_epoch_20.png
