In [1]:
# Libraries
import glob
import torch
from torch.utils.data import DataLoader

# Modules
from discriminator import Discriminator
from generator import UNet
from gan_utils import *

### Check GPU Status

In [2]:
if torch.cuda.is_available():
    print("CUDA is available. Running a test on the GPU.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Running on CPU.")
    device = torch.device("cpu")

CUDA is not available. Running on CPU.


### Params

In [3]:
# Assuming UNet is already defined as per the code above
generator = UNet()
discriminator = Discriminator()

# Initialize the models
initializer = ModelInitializer(device, init_type='norm', gain=0.2)
generator = initializer.init_model(generator)
discriminator = initializer.init_model(discriminator)

# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# Loss functions
adversarial_loss = nn.BCEWithLogitsLoss()  
l1_loss = nn.L1Loss()
l2_loss = nn.MSELoss()
perceptual_loss = PerceptualLoss(layer=8)

# Optimizers
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
lambda_l1 = 100
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Training params
epochs = 10
batch_size = 4

# Data path
coco_path = "/Users/mikey/.fastai/data/coco_sample/train_sample"
paths = glob.glob(coco_path + "/*.jpg") # Grabbing all the image file names
# Number of images
num_imgs = 100
# Train ratio
split = 0.8
# Image size
size = 256



Model initialized with norm initialization
Model initialized with norm initialization




### Path to Data

In [4]:
# Call the function with the desired number of images
train_paths, val_paths = select_images(paths, num_imgs, split)
print(f"Training set: {len(train_paths)} images")
print(f"Validation set: {len(val_paths)} images")

Training set: 80 images
Validation set: 20 images


### Set Up Dataset

In [5]:
train_ds = ColorizationDataset(size, paths = train_paths, split = "train")
val_ds = ColorizationDataset(size, paths = val_paths, split = "val")

### Setup Dataloader

In [6]:
train_dl = DataLoader(train_ds, batch_size)
val_dl = DataLoader(val_ds, batch_size)

**Check Tensor Size**

In [7]:
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
assert Ls.shape == torch.Size([batch_size, 1, 256, 256]) and abs_.shape == torch.Size([batch_size, 2, 256, 256])
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

torch.Size([4, 1, 256, 256]) torch.Size([4, 2, 256, 256])
20 5


### Training

In [8]:
driver = GANDriver(
    generator=generator,
    discriminator=discriminator,
    train_dl=train_dl,
    val_dl=val_dl,
    optimizer_G=optimizer_G,
    optimizer_D=optimizer_D,
    adversarial_loss=adversarial_loss,
    content_loss=perceptual_loss,
    lambda_l1=lambda_l1,
    device=device,
    epochs=epochs
)

# Run the GAN training
driver.run()




Training Epoch 1/10:   0%|          | 0/20 [00:00<?, ?it/s]


RuntimeError: Given transposed=1, weight of size [512, 256, 2, 2], expected input[4, 1024, 16, 16] to have 512 channels, but got 1024 channels instead