In [None]:
# Import the necessary libraries
import torch
import random
import warnings
import numpy as np
from train import train_cgan
warnings.filterwarnings("ignore")
from dataset import ColorizationDataset
from torch.utils.data import Subset, DataLoader
from models import UNetGenerator, PatchDiscriminator

In [None]:
# Set the random seeds for reproducibility
SEED = 27
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# Initialize the data paths
l_path = "Data/gray_scale.npy"
ab_paths = ["Data/ab/ab1.npy", "Data/ab/ab2.npy", "Data/ab/ab3.npy"]

# Load the L (lightness) channel as a memory-mapped array (read on demand from disk)
l_data = np.load(l_path, mmap_mode="r")     # Shape: (25000, 224, 224)

# Load and concatenate the ab (red–green, blue–yellow) channels as memory-mapped arrays
ab_data = [np.load(path, mmap_mode="r") for path in ab_paths]
ab_data = np.concatenate(ab_data, axis=0)   # Shape: (25000, 224, 224, 2)

In [None]:
# TODO Create the transforms for data augmentation on the training set; will require creating a simple one ourselves
train_transform = None
test_transform = None

# Create the datasets
full_train_ds = ColorizationDataset(l_data, ab_data, transform=train_transform)
full_test_ds  = ColorizationDataset(l_data, ab_data, transform=test_transform)

In [None]:
# Split the datasets into training (80%) and test (20%) sets
N = l_data.shape[0]
idxs = np.random.permutation(N) 
train_size = int(0.8 * N)
train_idxs = idxs[:train_size]
test_idxs = idxs[train_size:]
train_dataset = Subset(full_train_ds, train_idxs)
test_dataset  = Subset(full_test_ds,  test_idxs)

In [None]:
# Create the dataloaders for each of the two datasets
BATCH_SIZE = 32
NUM_WORKERS = 2
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
# Determine which device to use (GPU if available, else CPU)
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
# Create the generator and discriminator models
generator = UNetGenerator()
discriminator = PatchDiscriminator()

# Train the models 
train_cgan(generator, discriminator, train_loader, device=device, epochs=1)

In [None]:
print(generator)