In [1]:
from google.colab import drive
import sys

PROJECT_ROOT = '/content/drive/MyDrive/commit_test_folder/EECE491-01-Capstone-Design'

drive.mount('/content/drive')
sys.path.append(PROJECT_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install piqa



In [3]:
import os
import time
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from piqa.ssim import SSIM

from src.utils.data_utils import get_dataloaders, prepare_dataset
from src.utils.channels import awgn_channel
from src.utils.viz_utils import plot_loss
from src.models.face_autoencoder import FaceAutoencoder


This cell prepares the Colab environment by copying and extracting the dataset from Google Drive to the fast local SSD.

In [4]:
DRIVE_ARCHIVE_PATH = "/content/drive/MyDrive/datasets/cropped_celeba.tar"
LOCAL_ARCHIVE_PATH = "/content/cropped_celeba.tar"
EXTRACT_PATH = "/content/celeba_dataset"

LOCAL_DATA_DIR = prepare_dataset(DRIVE_ARCHIVE_PATH, LOCAL_ARCHIVE_PATH, EXTRACT_PATH)


Starting data setup...
Data directory /content/celeba_dataset/content/cropped_celeba already exists. Skipping copy/untar.
Data setup finished in 0.00 seconds.
Successfully found data at: /content/celeba_dataset/content/cropped_celeba


In [5]:
DATA_ROOT = LOCAL_DATA_DIR
BATCH_SIZE = 256
IMAGE_SIZE = 128
RANDOM_SEED = 42

# Get dataloaders
train_loader, val_loader, test_loader = get_dataloaders(
    root_dir=DATA_ROOT,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    random_seed=RANDOM_SEED
)


Loading dataset from: /content/celeba_dataset/content/cropped_celeba
Searching for '*.jpg' files in: /content/celeba_dataset/content/cropped_celeba
Successfully found 199509 images.
Successfully loaded 199509 total images.
Splitting dataset into:
  Train: 159607 images
  Validation: 19950 images
  Test: 19952 images

DataLoaders created successfully.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = FaceAutoencoder(latent_dim=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = SSIM(n_channels=3).to(device) # n_channels = 3: RGB image

# Training parameters
num_epochs = 50
MIN_SNR_DB = 0.0
MAX_SNR_DB = 20.0

# Save directory
SAVE_DIR = "/content/drive/MyDrive/models"
MODEL_PATH = os.path.join(SAVE_DIR, "face_autoencoder_512_SSIM_Augmentation.pth")
os.makedirs(SAVE_DIR, exist_ok=True)

best_val_loss = float('inf') # initial value is infinity
train_loss_history = []
val_loss_history = []

print("Start training...")
SNR_POINTS_FOR_VAL = [0.0, 5.0, 10.0, 15.0, 20.0]
NUM_VAL_POINTS = len(SNR_POINTS_FOR_VAL)

for epoch in range(num_epochs):

    # Traning Phase
    model.train()
    total_train_loss = 0

    for images, _ in train_loader:
        images = images.to(device)

        latent_vector = model.encode(images)

        current_snr_db = random.uniform(MIN_SNR_DB, MAX_SNR_DB)
        noisy_vector = awgn_channel(latent_vector, snr_db=current_snr_db)

        reconstructed_images = model.decode(noisy_vector)

        # Rescailing [-1, 1] into [0, 1]
        recon_rescaled = (reconstructed_images + 1.0) / 2.0
        images_rescaled = (images + 1.0) / 2.0

        ssim_value = criterion(recon_rescaled, images_rescaled)
        loss = 1.0 - ssim_value

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation Phase
    model.eval()
    total_combined_loss = 0
    with torch.no_grad():
        for val_images, _ in val_loader:
            val_images = val_images.to(device)
            for fixed_snr_db in SNR_POINTS_FOR_VAL:
                latent_vector = model.encode(val_images)
                noisy_vector = awgn_channel(latent_vector, snr_db=fixed_snr_db)

                reconstructed_images = model.decode(noisy_vector)

                # Rescailing [-1, 1] into [0, 1]
                reconstructed_images_rescaled = (reconstructed_images + 1.0) / 2.0
                val_images_rescaled = (val_images + 1.0) / 2.0

                val_ssim_value = criterion(reconstructed_images_rescaled, val_images_rescaled)
                val_loss = 1.0 - val_ssim_value
                total_combined_loss += val_loss.item()

    avg_val_loss = total_combined_loss / (len(val_loader) * NUM_VAL_POINTS)

    train_loss_history.append(avg_train_loss)
    val_loss_history.append(avg_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")

    # Save best model (criterion: validation loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        print(f" -> New best validation loss! Saving model")
        torch.save(model.state_dict(), MODEL_PATH)

print("--- Training finished. ---")
print(f"Best validation loss achieved: {best_val_loss:.6f}")
print(f"Best model saved to {MODEL_PATH}")

Using device: cuda
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 115MB/s]


Start training...
Epoch [1/50], Train Loss: 0.543080, Val Loss: 0.563743
 -> New best validation loss! Saving model
Epoch [2/50], Train Loss: 0.450088, Val Loss: 0.422934
 -> New best validation loss! Saving model
Epoch [3/50], Train Loss: 0.345119, Val Loss: 0.382675
 -> New best validation loss! Saving model
Epoch [4/50], Train Loss: 0.305554, Val Loss: 0.352284
 -> New best validation loss! Saving model
Epoch [5/50], Train Loss: 0.279394, Val Loss: 0.338721
 -> New best validation loss! Saving model
Epoch [6/50], Train Loss: 0.263946, Val Loss: 0.330226
 -> New best validation loss! Saving model
Epoch [7/50], Train Loss: 0.252755, Val Loss: 0.316607
 -> New best validation loss! Saving model
Epoch [8/50], Train Loss: 0.243325, Val Loss: 0.310479
 -> New best validation loss! Saving model
Epoch [9/50], Train Loss: 0.236628, Val Loss: 0.334966
Epoch [10/50], Train Loss: 0.233136, Val Loss: 0.295813
 -> New best validation loss! Saving model
Epoch [11/50], Train Loss: 0.227956, Val Los

In [None]:
plot_loss(num_epochs, train_loss_history, val_loss_history)

NameError: name 'plt' is not defined