In [None]:
# train.py
import os
import multiprocessing
import torch
from torch.utils import data
from pix2pix_model import Generator, Discriminator
from engine import train_GAN
from helper_functions import visualize_dataset, PairedImageDatasetLoader, save_model, save_history, plot_history
from inference_functions import ssim_psnr_results, plot_result_distribution, predict_few_samples

# ----- Config -----
TRAIN_DIR = "/Users/debanjan_5402/Desktop/MyCodes/GAN/maps/train"
VAL_DIR   = "/Users/debanjan_5402/Desktop/MyCodes/GAN/maps/val"
RESIZE = 256
SPLIT_WIDTH = 600
BATCH_SIZE = 1
LR = 2e-4
BETAS = (0.5, 0.999)
EPOCHS = 150
SAVE_DIR = "./checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

# Use spawn (safe cross-platform)
if __name__ == "__main__":
    multiprocessing.set_start_method("spawn", force=True)
    ctx = multiprocessing.get_context("spawn")

    # Device
    DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    print("Device:", DEVICE)

    # Workers â€” tune for your machine. os.cpu_count() is a starting point.
    NUM_WORKERS = os.cpu_count()
    print("CPU cores:", NUM_WORKERS)

    # Create datasets
    train_dataset = PairedImageDatasetLoader(TRAIN_DIR, RESIZE, SPLIT_WIDTH)
    val_dataset   = PairedImageDatasetLoader(VAL_DIR, RESIZE, SPLIT_WIDTH)

    # DataLoader performance flags
    # - multiprocessing_context=ctx ensures child processes use spawn
    # - persistent_workers=True keeps workers alive between epochs (fast)
    # - prefetch_factor controls how many samples each worker preloads (default 2)
    train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
                                   multiprocessing_context=ctx, persistent_workers=True, pin_memory=False, prefetch_factor=4)
    
    val_loader = data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, 
                                 multiprocessing_context=ctx, persistent_workers=True, pin_memory=False, prefetch_factor=4)

    # Optional quick visual check (small number, will use workers)
    visualize_dataset(train_loader, num_samples=16, cols=8, subplot_width=3, title="Training Data")
    visualize_dataset(val_loader, num_samples=16, cols=8, subplot_width=3, title="Validation Data")

    # Models, losses, optimizers
    generator_model = Generator(in_channels=3, out_channels=3, num_filter=32).to(DEVICE)
    discriminator_model = Discriminator(in_channels=3, num_filter=32).to(DEVICE)

    loss_fn_adversarial = torch.nn.BCELoss()
    loss_fn_reconstruction = torch.nn.L1Loss()

    generator_optimizer = torch.optim.Adam(generator_model.parameters(), lr=LR, betas=BETAS)
    discriminator_optimizer = torch.optim.Adam(discriminator_model.parameters(), lr=LR, betas=BETAS)

    # Training run
    history = train_GAN(generator=generator_model, discriminator=discriminator_model,
                        train_dataloader=train_loader, val_dataloader=val_loader,
                        loss_fn_adv=loss_fn_adversarial, loss_fn_recon=loss_fn_reconstruction,
                        optimizer_G=generator_optimizer, optimizer_D=discriminator_optimizer,
                        lambda_gen=100, epochs=EPOCHS, device=DEVICE)
    
    save_model(model=generator_model, optimizer=generator_optimizer, epoch=EPOCHS, save_path=os.path.join(SAVE_DIR, "generator_model.pth"))
    save_model(model=discriminator_model, optimizer=discriminator_optimizer, epoch=EPOCHS, save_path=os.path.join(SAVE_DIR, "discriminator_model.pth"))

    save_history(history=history, save_path=os.path.join(SAVE_DIR, "history.json"))

    plot_history(history=history)

    results = ssim_psnr_results(model=generator_model, dataloader=val_loader, device=DEVICE)

    plot_result_distribution(results=results)

    predict_few_samples(dataloader=val_loader, model=generator_model, num_samples=6, device=DEVICE, title="Few predictions")