In [54]:
# Sys import and directory configure
import sys
import os
import json
import shutil

project_root = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..', '..'))
code_root = os.path.join(project_root, 'code')
sys.path.append(code_root)

# DL related import
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torcheval.metrics import PeakSignalNoiseRatio, StructuralSimilarity
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
from torch.optim import lr_scheduler
from tqdm import tqdm

# Local imports
import importlib
import model
import dataset
import utils
importlib.reload(model)
importlib.reload(dataset)
importlib.reload(utils)
from model import VAE
from dataset import get_tiny_imagenet_datasets
from utils import save_best_model, EarlyStopManager

# Seed
torch.manual_seed(42)
np.random.seed(42)

# Devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.get_device_name(torch.cuda.current_device()))
torch.set_default_device(device)
g = torch.Generator(device=device)
    
# Follow the torch recommendation to use TF32 precision for matmul
torch.set_float32_matmul_precision("high")


NVIDIA GeForce RTX 4070


In [55]:
def train(model, criterion, optimizer, train_set, test_set, batch_size, num_epochs, experiment_name, lr_scheduler, lr_decay_patience=5, early_stopping_patience=10, verbose=False):
    torch.cuda.empty_cache()
    
    psnr = PeakSignalNoiseRatio()
    # ssim = StructuralSimilarity(channel_axis=1)
    # ssim = StructuralSimilarity()

    early_stopper = EarlyStopManager(early_stopping_patience, 'min')

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=g)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, generator=g)

    prev_lr = optimizer.param_groups[0]['lr']
    
    scaler = GradScaler()
    
    if verbose and isinstance(lr_scheduler, torch.optim.lr_scheduler.CyclicLR):
        print(f"Using CyclicLR, update LR ber batch")
    
    # Lists to store metrics
    train_loss, test_loss, train_psnr, test_psnr, train_ssim, test_ssim = [], [], [], [], [], []
    
    for epoch in tqdm(range(num_epochs)):    
        # ========== Training Loop ==========
        model.train()
        
        train_loss_sum, train_psnr_sum, train_ssim_sum = 0, 0, 0
        
        for img, tgt in train_loader:
            img = img.to(device)
            tgt = tgt.to(device)
            
            with autocast(device_type='cuda',dtype=torch.float32):
                out, mu, log_var = model(img)
                loss = criterion(out, tgt, mu, log_var)
            
            # Scales loss, calls backward() to create scaled gradients
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            if isinstance(lr_scheduler, torch.optim.lr_scheduler.CyclicLR):
                lr_scheduler.step()
                
            train_loss_sum += loss.item()
            psnr.update(out, tgt)
            # ssim.update(out, tgt)
        
        train_loss.append(train_loss_sum / len(train_loader))
        train_psnr.append(psnr.compute().item())
        train_ssim.append(0)
        
        psnr.reset()
        # ssim.reset()

        # ========== Testing Loop ==========
        model.eval()
        
        test_loss_sum, test_psnr_sum, test_ssim_sum = 0, 0, 0
        
        with torch.no_grad():
            for img, tgt in test_loader:
                img = img.to(device)
                tgt = tgt.to(device)
                
                out, mu, log_var = model(img)
                loss = criterion(out, tgt, mu, log_var)
                
                test_loss_sum += loss.item()
                psnr.update(out, tgt)
                # ssim.update(out, tgt)
        
        test_loss.append(test_loss_sum / len(test_loader))
        test_psnr.append(psnr.compute().item())
        test_ssim.append(0)
        
        psnr.reset()
        # ssim.reset()
    
        # ========== Update scheduler (Per Epoch) ==========
        if not isinstance(lr_scheduler, torch.optim.lr_scheduler.CyclicLR):
            lr_scheduler.step(train_loss[-1])
            current_lr = optimizer.param_groups[0]['lr']
            if current_lr != prev_lr:
                print(f"Epoch {epoch}: Learning rate changed to {current_lr}")
                prev_lr = current_lr
        
        if verbose:
            print(f"Learning rate {optimizer.param_groups[0]['lr']}")
            
        # Save model parameters
        curr_psnr = test_psnr[-1]
        # curr_ssim = test_ssim[-1]
        save_best_model(curr_psnr, torch.max(torch.tensor(test_psnr)).item(), model, experiment_name, epoch, metric_name="psnr", mode="max")
        # save_best_model(curr_ssim, torch.max(torch.tensor(test_ssim)).item(), model, experiment_name, epoch, metric_name="ssim", mode="max")

        # Early stopping
        if early_stopper.update(test_loss[-1]):
            break
        
        train_text = f"Train [Loss, PSNR, SSIM]: {train_loss[-1]:.4f} {train_psnr[-1]:.4f} {train_ssim[-1]:.4f}"
        test_text = f"Test [Loss, PSNR, SSIM]: {test_loss[-1]:.4f} {test_psnr[-1]:.4f} {test_ssim[-1]:.4f}"
        print(f"{train_text} || {test_text}")
        print('===========================')

    train_history = {
        'train_loss': np.array(train_loss),
        'train_psnr': np.array(train_psnr),
        'train_ssim': np.array(train_ssim),
        
        'test_loss': np.array(test_loss),
        'test_psnr': np.array(test_psnr),
        'test_ssim': np.array(test_ssim),
    }
    
    np.save(f'experiments/{experiment_name}/training_history.npy', train_history)
    
    return train_history



In [56]:
train_set, test_set = get_tiny_imagenet_datasets(normalize='MinMax', reconstruction=True, debug=True)

min = np.inf
max = -np.inf
mean = 0

for img in train_set:
    print(np.min(img), np.max(img), np.mean(img))



Tiny ImageNet dataset already downloaded and extracted.
0.007843138 1.0 0.5258923
0.0 1.0 0.23986417
0.0 1.0 0.5375686
0.0 1.0 0.45978987
0.0 1.0 0.20645969
0.0 1.0 0.46610785
0.0 1.0 0.5116141
0.0 1.0 0.15009224
0.0 1.0 0.35220397
0.0 1.0 0.53236926
0.0 1.0 0.42268243
0.0 1.0 0.53440565
0.0 1.0 0.40184078
0.0 1.0 0.3150506
0.0 1.0 0.36460254
0.011764706 1.0 0.44155562
0.0 1.0 0.563916
0.0 0.972549 0.10138825
0.0 1.0 0.17823191
0.0 1.0 0.4744552
0.0 1.0 0.38870254
0.0 1.0 0.34743825
0.0 0.8901961 0.08651674
0.0 1.0 0.36481953
0.0 1.0 0.4082555
0.0 1.0 0.4631236
0.0 1.0 0.5345869
0.050980393 0.8862745 0.50261825
0.0 1.0 0.4097586
0.0 1.0 0.47310495
0.0 1.0 0.3092713
0.0 1.0 0.1951236
0.0 1.0 0.4099619
0.0 1.0 0.28381076
0.0 1.0 0.48263666
0.0 0.83137256 0.53799057
0.0 1.0 0.3551611
0.015686275 0.88235295 0.4436814
0.0 1.0 0.2597181
0.0 1.0 0.22006486
0.0 1.0 0.3913485
0.0 1.0 0.2771022
0.0 1.0 0.24301696
0.0 1.0 0.47187662
0.0 1.0 0.33712885
0.0 1.0 0.3580346
0.0 1.0 0.34128848
0.0 0.82

In [57]:
DEBUG_MODE = True

train_set, test_set = get_tiny_imagenet_datasets(normalize=True, reconstruction=True, debug=DEBUG_MODE)

# ==================== Tune Here ====================
model = VAE(latent_dim=32).to(device)
epochs = 10 if DEBUG_MODE else 300
batch_size = 4096
initial_lr = 1e-3
lr_decay = 5
early_stopping = 20
lr_scheduler_name = "ReduceLROnPlateau"
experiment_name = 'VAE'
experiment_description = "latent_dim=32"
# ==================== Tune Here ====================

print(model.count_parameters())
torch.cuda.empty_cache()

criterion = model.loss_function
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-4)
# optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=1e-4)

if lr_scheduler_name == "CyclicLR":
    iterations_per_epoch = len(train_set) / batch_size
    epochs_per_cycle = 4
    step_size_up = int(iterations_per_epoch * epochs_per_cycle)
    scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-2, step_size_up=step_size_up, mode='triangular2')
elif lr_scheduler_name == "ReduceLROnPlateau":
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=lr_decay, factor=0.5)
elif lr_scheduler_name == "None":
    scheduler = lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)  # Returns 1.0 multiplier, effectively doing nothing
else:
    raise Exception(f"Invalid learning rate scheduler: {lr_scheduler_name}")
        
if DEBUG_MODE:   # Overwrite debug folder
    experiment_name = "DEBUG"
    if not os.path.exists(f'experiments/DEBUG'):
        os.makedirs(f'experiments/DEBUG')
else:
    if os.path.exists(f'experiments/{experiment_name}'):
        user_input = input(f"Experiment '{experiment_name}' already exists. Overwrite? (y/n): ")
        if user_input.lower() != 'y':
            raise Exception(f"Experiment {experiment_name} already exists. Please use a different name.")
        else:
            shutil.rmtree(f'experiments/{experiment_name}')
            os.makedirs(f'experiments/{experiment_name}')
    else:
        os.makedirs(f'experiments/{experiment_name}')
            
hyperparameters = {
    "epochs": epochs,
    "batch_size": batch_size,
    "initial_lr": initial_lr,
    "lr_scheduler": lr_scheduler_name,
    "lr_decay_patient": lr_decay if lr_scheduler_name == "ReduceLROnPlateau" else None,
    "early_stopping_patient": early_stopping,
    "description": experiment_description
}

with open('hyperparameters.json', 'w') as f:
    json.dump(hyperparameters, open(f'experiments/{experiment_name}/hyperparameters.json', 'w'), indent=4)

# Train the model
train_hist = train(
    model,
    criterion,
    optimizer,
    train_set,
    test_set,
    batch_size,
    epochs,
    experiment_name,
    scheduler,
    lr_decay,
    early_stopping,
    verbose=False
)

print('resulting metrics:')
for experiment_name in train_hist:
    print(experiment_name)

Tiny ImageNet dataset already downloaded and extracted.
4348547


  0%|          | 0/10 [00:00<?, ?it/s]

MSE: 977.333251953125, KLD: 0.006952584721148014
MSE: 960.7467041015625, KLD: 0.018803998827934265
MSE: 973.0464477539062, KLD: 0.002269275486469269


 10%|█         | 1/10 [00:07<01:05,  7.29s/it]

MSE: 957.0118408203125, KLD: 0.003352272091433406
Train [Loss, PSNR, SSIM]: 970.3848 11.0282 0.0000 || Test [Loss, PSNR, SSIM]: 957.0152 11.0856 0.0000
MSE: 955.38916015625, KLD: 0.003351835533976555
MSE: 948.962646484375, KLD: 0.0035424595698714256
MSE: 942.9411010742188, KLD: 0.005006738938391209


 20%|██        | 2/10 [00:14<00:56,  7.00s/it]

MSE: 939.8788452148438, KLD: 0.02564064972102642
Train [Loss, PSNR, SSIM]: 949.1016 11.1153 0.0000 || Test [Loss, PSNR, SSIM]: 939.9045 11.1641 0.0000
MSE: 944.9606323242188, KLD: 0.025604424998164177
MSE: 927.9654541015625, KLD: 0.05689358711242676
MSE: 930.5839233398438, KLD: 0.23394589126110077


 30%|███       | 3/10 [00:20<00:48,  6.90s/it]

MSE: 922.3190307617188, KLD: 1.517554521560669
Train [Loss, PSNR, SSIM]: 934.6088 11.1848 0.0000 || Test [Loss, PSNR, SSIM]: 923.8366 11.2460 0.0000
MSE: 917.2058715820312, KLD: 1.5158416032791138
MSE: 915.3544921875, KLD: 1.9301165342330933
MSE: 885.607177734375, KLD: 6.210729122161865


 40%|████      | 4/10 [00:27<00:41,  6.87s/it]

MSE: 880.265869140625, KLD: 3.606468439102173
Train [Loss, PSNR, SSIM]: 909.2748 11.3009 0.0000 || Test [Loss, PSNR, SSIM]: 883.8723 11.4487 0.0000
MSE: 882.3873291015625, KLD: 3.3364691734313965
MSE: 875.5623779296875, KLD: 11.520893096923828
MSE: 886.8939208984375, KLD: 3.9098308086395264


 50%|█████     | 5/10 [00:34<00:34,  6.84s/it]

MSE: 885.1023559570312, KLD: 4.527029037475586
Train [Loss, PSNR, SSIM]: 887.8703 11.4480 0.0000 || Test [Loss, PSNR, SSIM]: 889.6294 11.4249 0.0000
MSE: 892.724365234375, KLD: 4.002034664154053
MSE: 865.71044921875, KLD: 4.957139492034912
MSE: 880.1057739257812, KLD: 10.88840103149414


 60%|██████    | 6/10 [00:41<00:27,  6.91s/it]

MSE: 847.966796875, KLD: 6.324524879455566
Train [Loss, PSNR, SSIM]: 886.1294 11.4531 0.0000 || Test [Loss, PSNR, SSIM]: 854.2913 11.6110 0.0000
MSE: 851.4790649414062, KLD: 5.516298770904541
MSE: 860.6264038085938, KLD: 4.22671365737915
MSE: 869.4114990234375, KLD: 4.170715808868408


 70%|███████   | 7/10 [00:48<00:20,  6.93s/it]

MSE: 845.3602905273438, KLD: 4.943011283874512
Train [Loss, PSNR, SSIM]: 865.1436 11.5576 0.0000 || Test [Loss, PSNR, SSIM]: 850.3033 11.6244 0.0000
MSE: 856.2481689453125, KLD: 4.624711036682129
MSE: 847.17822265625, KLD: 6.9860029220581055
MSE: 837.6652221679688, KLD: 7.114307403564453


 80%|████████  | 8/10 [00:55<00:13,  6.92s/it]

MSE: 834.6582641601562, KLD: 5.8092217445373535
Train [Loss, PSNR, SSIM]: 853.2722 11.6048 0.0000 || Test [Loss, PSNR, SSIM]: 840.4675 11.6797 0.0000
MSE: 831.2393188476562, KLD: 5.334207057952881
MSE: 846.9814453125, KLD: 5.24751091003418
MSE: 849.353271484375, KLD: 5.784879207611084


 90%|█████████ | 9/10 [01:02<00:06,  6.91s/it]

MSE: 818.134765625, KLD: 8.365280151367188
Train [Loss, PSNR, SSIM]: 847.9802 11.6470 0.0000 || Test [Loss, PSNR, SSIM]: 826.5001 11.7666 0.0000
MSE: 819.69677734375, KLD: 8.21243667602539
MSE: 828.0662841796875, KLD: 10.26241683959961
MSE: 820.2732543945312, KLD: 9.029970169067383


100%|██████████| 10/10 [01:09<00:00,  6.93s/it]

MSE: 810.324951171875, KLD: 9.422391891479492
Train [Loss, PSNR, SSIM]: 831.8470 11.7396 0.0000 || Test [Loss, PSNR, SSIM]: 819.7473 11.8082 0.0000
resulting metrics:
train_loss
train_psnr
train_ssim
test_loss
test_psnr
test_ssim



