In [None]:
from google.colab import drive
import sys

PROJECT_ROOT = '/content/drive/MyDrive/commit_test_folder/EECE491-01-Capstone-Design'

drive.mount('/content/drive')
sys.path.append(PROJECT_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import random
import matplotlib.pyplot as plt

from src.utils.data_utils import get_dataloaders, prepare_dataset
from src.utils.viz_utils import render_tensor
from src.utils.pls_utils import *
from src.models.face_autoencoder import FaceAutoencoder

In [None]:
DRIVE_ARCHIVE_PATH = "/content/drive/MyDrive/datasets/cropped_celeba.tar"
LOCAL_ARCHIVE_PATH = "/content/cropped_celeba.tar"
EXTRACT_PATH = "/content/celeba_dataset"

LOCAL_DATA_DIR = prepare_dataset(DRIVE_ARCHIVE_PATH, LOCAL_ARCHIVE_PATH, EXTRACT_PATH)


Starting data setup...
Data directory /content/celeba_dataset/content/cropped_celeba already exists. Skipping copy/untar.
Data setup finished in 0.00 seconds.
Successfully found data at: /content/celeba_dataset/content/cropped_celeba


In [None]:
DATA_ROOT = LOCAL_DATA_DIR
BATCH_SIZE = 256
IMAGE_SIZE = 128
RANDOM_SEED = 42

# Get dataloaders
train_loader, val_loader, test_loader = get_dataloaders(
    root_dir=DATA_ROOT,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    random_seed=RANDOM_SEED
)


Loading dataset from: /content/celeba_dataset/content/cropped_celeba
Searching for '*.jpg' files in: /content/celeba_dataset/content/cropped_celeba
Successfully found 199509 images.
Successfully loaded 199509 total images.
Splitting dataset into:
  Train: 159607 images
  Validation: 19950 images
  Test: 19952 images

DataLoaders created successfully.


In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content/drive/MyDrive/models"
VISUALIZE_SNR_DB = 5.0
NUM_IMAGES_TO_SHOW = 4

MODEL_CONFIGS = [
    (512, "_512.pth", "MSE"),
    (512, "_512_SSIM.pth", "SSIM"),
    (512, "_512_SSIM_Augmentation.pth", "SSIM Augmentation")
]

try:
    val_wrapper = val_loader.dataset
    val_subset = val_wrapper.subset

    val_img_number = len(val_wrapper)

    IMAGE_INDICES_TO_SHOW_LIST = random.sample(range(val_img_number), k=NUM_IMAGES_TO_SHOW)
    print(f"Selected validation indices: {IMAGE_INDICES_TO_SHOW_LIST}")

    image_list = []
    for idx in IMAGE_INDICES_TO_SHOW_LIST:
        image_tensor, _ = val_wrapper[idx]
        image_list.append(image_tensor)

    sample_images = torch.stack(image_list).to(device)
    print(f"Successfully loaded {len(sample_images)} specific images from VAL_DATASET.")

except NameError:
    print("[ERROR] val_loader is not defined. Please run the Dataloader setup cell first.")
except AttributeError as e:
    print(f"[ERROR] Attribute error: {e}")
except Exception as e:
    print(f"An error occurred: {e}")

reconstruction_results = {}

for latent_dim, suffix, display_name in MODEL_CONFIGS:
    model_name = f"face_autoencoder{suffix}"
    MODEL_PATH = os.path.join(SAVE_DIR, model_name)
    print(f"\n--- Processing Model: {model_name} ---")

    model = FaceAutoencoder(latent_dim=latent_dim).to(device)
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.eval()
        print("Model loaded successfully.")
    except FileNotFoundError:
        print(f"[ERROR] Model file not found at {MODEL_PATH}. Skipping.")
        reconstruction_results[display_name] = {'pristine': None, 'bob': None, 'eve': None}
        continue

    with torch.no_grad():
        latent_original = model.encode(sample_images)

        # Bob과 Eve의 latent 텐서를 모두 받음
        latent_bob, latent_eve = an_pls_channel(
            latent_original,
            snr_db=VISUALIZE_SNR_DB
            # (alpha=0.5, Nt=4, Ne=1 등 기본값 사용)
        )

        # 4개의 복원 이미지 생성
        recon_pristine = model.decode(latent_original) # 1. 원본 (노이즈 X)
        recon_bob = model.decode(latent_bob)           # 2. Bob (보안 채널)
        recon_eve = model.decode(latent_eve)           # 3. Eve (도청 채널)

        reconstruction_results[display_name] = {
            'pristine': recon_pristine,
            'bob': recon_bob,
            'eve': recon_eve
        }

# -----------------------------------------------
# 6. Plotting (Bob, Eve 모두 표시)
# -----------------------------------------------
num_cols_per_image = 1 + len(MODEL_CONFIGS) * 3
num_rows = NUM_IMAGES_TO_SHOW

fig, axes = plt.subplots(num_rows, num_cols_per_image, figsize=(num_cols_per_image * 2.5, num_rows * 2.5))
fig.suptitle(f"Model Comparison (PLS-AN Channel) at SNR_DB = {VISUALIZE_SNR_DB}", fontsize=16)

# --- 각 열의 헤더(제목) 정의 ---
column_titles = ["Original"]
for _, _, display_name in MODEL_CONFIGS:
    column_titles.append(f"{display_name} (Pristine)")
    column_titles.append(f"{display_name} (Bob)")
    column_titles.append(f"{display_name} (Eve)")

for row_idx in range(NUM_IMAGES_TO_SHOW):
    # (인덱스 번호 표시)
    current_val_index = IMAGE_INDICES_TO_SHOW_LIST[row_idx]
    original_title = f"Original (Idx: {current_val_index})"
    title_to_show = original_title if row_idx == 0 else None
    if row_idx > 0:
        axes[row_idx, 0].set_ylabel(f"Val Idx: {current_val_index}", fontsize=10, labelpad=20)

    render_tensor(axes[row_idx, 0], sample_images[row_idx],
           title=title_to_show)

    col_offset = 1
    for _, _, display_name in MODEL_CONFIGS:
        recons = reconstruction_results.get(display_name, {'pristine': None, 'bob': None, 'eve': None})

        # 1. Pristine 복원
        if recons['pristine'] is not None:
            render_tensor(axes[row_idx, col_offset], recons['pristine'][row_idx],
                   title=column_titles[col_offset] if row_idx == 0 else None)
        else:
            axes[row_idx, col_offset].set_title("N/A", fontsize=10) if row_idx == 0 else None
            axes[row_idx, col_offset].axis('off')
        col_offset += 1

        # 2. Bob 복원
        if recons['bob'] is not None:
            render_tensor(axes[row_idx, col_offset], recons['bob'][row_idx],
                   title=column_titles[col_offset] if row_idx == 0 else None)
        else:
            axes[row_idx, col_offset].set_title("N/A", fontsize=10) if row_idx == 0 else None
            axes[row_idx, col_offset].axis('off')
        col_offset += 1

        # 3. Eve 복원
        if recons['eve'] is not None:
            render_tensor(axes[row_idx, col_offset], recons['eve'][row_idx],
                   title=column_titles[col_offset] if row_idx == 0 else None)
        else:
            axes[row_idx, col_offset].set_title("N/A", fontsize=10) if row_idx == 0 else None
            axes[row_idx, col_offset].axis('off')
        col_offset += 1

plt.tight_layout(rect=[0, 0.03, 1, 0.96])
plt.show()

Output hidden; open in https://colab.research.google.com to view.