In [1]:
#!/usr/bin/python3

import glob
import yaml
import torch
import logging
import os
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader
from discriminator import *
from generator import UNet
from gan_utils_new import *

import warnings

# Suppress specific warning related to CIE-LAB conversion
warnings.filterwarnings("ignore", message=".*negative Z values that have been clipped to zero.*")


Expected output: (4, 1, 13, 13)
Actual output: torch.Size([4, 1, 11, 11])


In [2]:
# pip install fastai==2.4
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [3]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

In [None]:
# Function to load configuration from YAML file
def load_config(config_path='params.yaml'):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

# Function to select optimizer
def get_optimizer(optimizer_config, model_params):
    opt_type = optimizer_config['type']
    lr = optimizer_config['lr']

    if opt_type == "Adam":
        beta1 = optimizer_config['beta1']
        beta2 = optimizer_config['beta2']
        optimizer = torch.optim.Adam(model_params, lr=lr, betas=(beta1, beta2))

    elif opt_type == "SGD":
        momentum = optimizer_config['momentum']
        optimizer = torch.optim.SGD(model_params, lr=lr, momentum=momentum)

    else:
        raise ValueError(f"Optimizer type '{opt_type}' not recognized. Please choose 'Adam' or 'SGD'.")

    return optimizer

# Main function
def main():
    # Load the configuration from YAML
    load_states = True
    load_gen_weights = False
    config = load_config()
    
    # Set up logging
    logging.basicConfig(filename=f"{config['output']['run_dir']}training.log", level=logging.INFO, format='%(asctime)s %(message)s')

    # Setup device (GPU/CPU)
    if torch.cuda.is_available():
        print(f"Cuda is available!")
        logging.info("CUDA is available. Using GPU.")
        device = torch.device("cuda")
    else:
        logging.info("CUDA is not available. Using CPU.")
        device = torch.device("cpu")
        print(f"Unable to connect to CUDA!!!!")

    # File path from YAML
    coco_path = config['data']['coco_path']
    paths = glob.glob(coco_path + "/*.jpg")  # Grabbing all the image file names

    # Load number of images from config
    num_imgs = config['data']['num_imgs']
    split = config['data']['split']
    train_paths, val_paths = select_images(paths, num_imgs, split)
    logging.info(f"Training set: {len(train_paths)} images")
    logging.info(f"Validation set: {len(val_paths)} images")

    # Image size from YAML
    size = config['data']['image_size']
    train_ds = ColorizationDataset(size, paths=train_paths, split="train")
    val_ds = ColorizationDataset(size, paths=val_paths, split="val")

    # Batch size from YAML
    batch_size = config['training']['batch_size']
    train_dl = DataLoader(train_ds, batch_size=batch_size)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

    # Check Tensor Size
    data = next(iter(train_dl))
    Ls, abs_ = data['L'], data['ab']
    assert Ls.shape == torch.Size([batch_size, 1, size, size]) and abs_.shape == torch.Size([batch_size, 2, size, size])

    # Model parameters
    generator = build_res_unet()
    discriminator = PatchDiscriminator(3)

    # Create the model initializer
    initializer = ModelInitializer(device, init_type=config['model']['init_type'], gain=config['model']['gain'])
        
    # Initialize the models
    if load_states:
        try:
            # Load the model checkpoints
            checkpoint = torch.load("/home/farrell.jo/cGAN_grey_to_color/models/training_runs/pretrained_gen_200_2/model_weights/checkpoint.pth")
            generator.load_state_dict(checkpoint["generator_state_dict"])
            discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
            print("Previous Weights Loaded!!!")
        except FileNotFoundError as e:
            print("Error loading model weights!")
            
    elif load_gen_weights:
        try:
            # Load the model checkpoints
            checkpoint = torch.load("/home/farrell.jo/cGAN_grey_to_color/models/generator_train/Res_full_data_1/gen_weights/checkpoint_epoch_20.pth")
            generator.load_state_dict(checkpoint["model_state_dict"])
            discriminator = initializer.init_model(discriminator)
            print(f"Generator weights laoded successfully!")
        except FileNotFoundError as e:
            print("Error loadinf generator weights!")
            
    else:
        generator = initializer.init_model(generator)
        discriminator = initializer.init_model(discriminator)
        print(f"Models initialized!")

    # Move models to device (GPU/CPU)
    generator.to(device)
    discriminator.to(device)

    # Loss functions from YAML
    adversarial_loss = nn.BCEWithLogitsLoss()
    content_loss = nn.L1Loss()   
    lambda_l1 = config['training']['lambda_l1']

    # Get optimizer from YAML configuration for both generator and discriminator
    optimizer_G = get_optimizer(config['optimizer_G'], generator.parameters())
    optimizer_D = get_optimizer(config['optimizer_D'], discriminator.parameters())

    # Load optimizer state if available in checkpoint
    if load_states:
        if 'optimizer_gen_state_dict' in checkpoint and 'optimizer_disc_state_dict' in checkpoint:
            optimizer_G.load_state_dict(checkpoint['optimizer_gen_state_dict'])
            optimizer_D.load_state_dict(checkpoint['optimizer_disc_state_dict'])
            print("Optimizer states loaded successfully!")

    # Learning rate scheduler
    mode = config['scheduler_G']['mode']
    factor = config['scheduler_G']['factor']
    patience = config['scheduler_G']['patience']
    verbose = config['scheduler_G']['verbose']
    scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode, factor, patience, verbose)

    mode = config['scheduler_D']['mode']
    factor = config['scheduler_D']['factor']
    patience = config['scheduler_D']['patience']
    verbose = config['scheduler_D']['verbose']
    scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, mode, factor, patience, verbose)

    # Number of epochs from YAML
    epochs = config['training']['epochs']

    # Flags for showing and saving images
    show_fig = config['training']['show_fig']
    save_images = config['training']['save_images']

    # Initialize GANDriver with all parameters from YAML
    driver = GANDriver(
        generator=generator,
        discriminator=discriminator,
        train_dl=train_dl,
        val_dl=val_dl,
        optimizer_G=optimizer_G,
        optimizer_D=optimizer_D,
        adversarial_loss=adversarial_loss,
        content_loss=content_loss,
        lambda_l1=lambda_l1,
        device=device,
        epochs=epochs,
        scheduler_D=scheduler_D, 
        scheduler_G=scheduler_G,
        run_dir=config['output']['run_dir'],
        base_dir=config['output']['base_dir']
    )

    # Run the GAN training and save metrics to CSV after each epoch
    train_results = driver.run(show_fig=show_fig, save_images=save_images)

    # Save training results to CSV
    results_df = pd.DataFrame(train_results)
    result_path = f"{config['output']['base_dir']}/{config['output']['run_dir']}/{config['output']['training_results_csv']}"
    results_df.to_csv(result_path, index=False)
    logging.info(f"Training complete. Results saved to {result_path}.")

    
    # Save the dictionary to a YAML file
    yaml_filepath = f"{config['output']['base_dir']}/{config['output']['run_dir']}/config.yml"
    with open(yaml_filepath, 'w') as file:
        yaml.dump(config, file, default_flow_style=False)
    
    logging.info(f"Configuration saved to {yaml_filepath}")

if __name__ == "__main__":
    main()

Cuda is available!




Previous Weights Loaded!!!
Optimizer states loaded successfully!



Training Epoch 1/200: 100%|██████████| 500/500 [12:06<00:00,  1.45s/it, D_loss=0.626, G_loss=5.21]
Validation Epoch 1/200: 100%|██████████| 125/125 [01:07<00:00,  1.86it/s, D_loss=0.705, G_loss=15.5]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 2/200: 100%|██████████| 500/500 [07:15<00:00,  1.15it/s, D_loss=0.629, G_loss=5.19]
Validation Epoch 2/200: 100%|██████████| 125/125 [00:50<00:00,  2.45it/s, D_loss=0.707, G_loss=15.5]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 3/200: 100%|██████████| 500/500 [04:52<00:00,  1.71it/s, D_loss=0.628, G_loss=5.24]
Validation Epoch 3/200: 100%|██████████| 125/125 [00:50<00:00,  2.50it/s, D_loss=0.708, G_loss=15.5]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 4/200: 100%|██████████| 500/500 [04:49<00:00,  1.72it/s, D_loss=0.627, G_loss=5.23]
Validation Epoch 4/200: 100%|██████████| 125/125 [00:49<00:00,  2.54it/s, D_loss=0.71, G_loss=15.6] 


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 5/200: 100%|██████████| 500/500 [04:46<00:00,  1.74it/s, D_loss=0.628, G_loss=5.23]
Validation Epoch 5/200: 100%|██████████| 125/125 [00:47<00:00,  2.62it/s, D_loss=0.709, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 6/200: 100%|██████████| 500/500 [04:50<00:00,  1.72it/s, D_loss=0.624, G_loss=5.23]
Validation Epoch 6/200: 100%|██████████| 125/125 [00:47<00:00,  2.65it/s, D_loss=0.711, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 7/200: 100%|██████████| 500/500 [05:16<00:00,  1.58it/s, D_loss=0.623, G_loss=5.28]
Validation Epoch 7/200: 100%|██████████| 125/125 [00:52<00:00,  2.37it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 8/200: 100%|██████████| 500/500 [04:59<00:00,  1.67it/s, D_loss=0.622, G_loss=5.24]
Validation Epoch 8/200: 100%|██████████| 125/125 [00:50<00:00,  2.49it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 9/200: 100%|██████████| 500/500 [04:51<00:00,  1.72it/s, D_loss=0.628, G_loss=5.23]
Validation Epoch 9/200: 100%|██████████| 125/125 [00:51<00:00,  2.42it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 10/200: 100%|██████████| 500/500 [04:45<00:00,  1.75it/s, D_loss=0.627, G_loss=5.24]
Validation Epoch 10/200: 100%|██████████| 125/125 [00:46<00:00,  2.71it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 11/200: 100%|██████████| 500/500 [05:24<00:00,  1.54it/s, D_loss=0.624, G_loss=5.26]
Validation Epoch 11/200: 100%|██████████| 125/125 [00:56<00:00,  2.22it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 12/200: 100%|██████████| 500/500 [05:03<00:00,  1.65it/s, D_loss=0.629, G_loss=5.23]
Validation Epoch 12/200: 100%|██████████| 125/125 [00:51<00:00,  2.41it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 13/200: 100%|██████████| 500/500 [05:16<00:00,  1.58it/s, D_loss=0.631, G_loss=5.25]
Validation Epoch 13/200: 100%|██████████| 125/125 [00:51<00:00,  2.42it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 14/200: 100%|██████████| 500/500 [05:06<00:00,  1.63it/s, D_loss=0.628, G_loss=5.23]
Validation Epoch 14/200: 100%|██████████| 125/125 [00:52<00:00,  2.38it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 15/200: 100%|██████████| 500/500 [04:59<00:00,  1.67it/s, D_loss=0.622, G_loss=5.28]
Validation Epoch 15/200: 100%|██████████| 125/125 [00:45<00:00,  2.74it/s, D_loss=0.713, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 16/200: 100%|██████████| 500/500 [05:57<00:00,  1.40it/s, D_loss=0.629, G_loss=5.24]
Validation Epoch 16/200: 100%|██████████| 125/125 [00:48<00:00,  2.56it/s, D_loss=0.711, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 17/200: 100%|██████████| 500/500 [04:47<00:00,  1.74it/s, D_loss=0.627, G_loss=5.24]
Validation Epoch 17/200: 100%|██████████| 125/125 [00:50<00:00,  2.50it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 18/200: 100%|██████████| 500/500 [04:53<00:00,  1.71it/s, D_loss=0.627, G_loss=5.26]
Validation Epoch 18/200: 100%|██████████| 125/125 [00:48<00:00,  2.59it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 19/200: 100%|██████████| 500/500 [05:16<00:00,  1.58it/s, D_loss=0.624, G_loss=5.23]
Validation Epoch 19/200: 100%|██████████| 125/125 [00:47<00:00,  2.62it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 20/200: 100%|██████████| 500/500 [04:43<00:00,  1.76it/s, D_loss=0.629, G_loss=5.25]
Validation Epoch 20/200: 100%|██████████| 125/125 [00:46<00:00,  2.68it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 21/200: 100%|██████████| 500/500 [05:15<00:00,  1.59it/s, D_loss=0.628, G_loss=5.27]
Validation Epoch 21/200: 100%|██████████| 125/125 [00:53<00:00,  2.34it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 22/200: 100%|██████████| 500/500 [05:32<00:00,  1.50it/s, D_loss=0.627, G_loss=5.22]
Validation Epoch 22/200: 100%|██████████| 125/125 [00:49<00:00,  2.55it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 23/200: 100%|██████████| 500/500 [06:31<00:00,  1.28it/s, D_loss=0.628, G_loss=5.27]
Validation Epoch 23/200: 100%|██████████| 125/125 [00:51<00:00,  2.43it/s, D_loss=0.712, G_loss=15.6]


Model state saved to training_runs/pretrained_gen_200_3/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 24/200: 100%|██████████| 500/500 [09:02<00:00,  1.09s/it, D_loss=0.63, G_loss=5.26] 
Validation Epoch 24/200:   2%|▏         | 3/125 [00:04<03:34,  1.76s/it, D_loss=0.703, G_loss=14.7]