In [1]:

import glob
import torch
from torch.utils.data import DataLoader


### Check GPU Status

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. Running a test on the GPU.")
    
    device = torch.device("cuda")
    x = torch.rand(3, 3).to(device)
    print(f"Tensor on GPU: \n{x}")

    y = x * 2
    print(f"Result of operation on GPU: \n{y}")
else:
    print("CUDA is not available. Running on CPU.")

### Read in the Modules

In [3]:
from discriminator import Discriminator
from generator import UNet
from gan_utils import *

### Path to Data

In [None]:
coco_path = "/Users/mikey/.fastai/data/coco_sample/train_sample"
paths = glob.glob(coco_path + "/*.jpg") # Grabbing all the image file names

# Call the function with the desired number of images
train_paths, val_paths = select_images(paths, 1000)
print(f"Training set: {len(train_paths)} images")
print(f"Validation set: {len(val_paths)} images")

### Set Up Dataset

In [5]:
# image size
size = 256
train_ds = ColorizationDataset(size, paths = train_paths, split = "train")
val_ds = ColorizationDataset(size, paths = val_paths, split = "val")

### Setup Dataloader

In [6]:
train_dl = DataLoader(train_ds, batch_size = 4)
val_dl = DataLoader(val_ds, batch_size = 4)

**Check Tensor Size**

In [None]:
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

### Params

In [8]:
# Assuming UNet is already defined as per the code above
generator = UNet()
discriminator = Discriminator()

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# Loss functions
adversarial_loss = nn.BCELoss()  
l1_loss = nn.L1Loss()
l2_loss = nn.MSELoss()
perceptual_loos = PerceptualLoss

# Optimizers
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_l1 = 100

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# training params
epochs = 20
batch_size = 16


### Training

In [None]:

driver = GANDriver(
    generator=generator,
    discriminator=discriminator,
    train_dl=train_dl,
    val_dl=val_dl,
    optimizer_G=optimizer_G,
    optimizer_D=optimizer_D,
    adversarial_loss=adversarial_loss,
    l1_loss=l1_loss,
    lambda_l1=lambda_l1,
    device=device,
    epochs=epochs
)

# Run the GAN training
driver.run()