In [1]:
#!/usr/bin/python3
import glob
import yaml
import torch
import logging
import os
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader
from discriminator import *
from generator import UNet
from gan_utils_new import *

import warnings

# Suppress specific warning related to CIE-LAB conversion
warnings.filterwarnings("ignore", message=".*negative Z values that have been clipped to zero.*")

Expected output: (4, 1, 13, 13)
Actual output: torch.Size([4, 1, 11, 11])


In [2]:
# pip install fastai==2.4
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [3]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

In [None]:
# Function to load configuration from YAML file
def load_config(config_path='params.yaml'):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

# Function to select optimizer
def get_optimizer(optimizer_config, model_params):
    opt_type = optimizer_config['type']
    lr = optimizer_config['lr']

    if opt_type == "Adam":
        beta1 = optimizer_config['beta1']
        beta2 = optimizer_config['beta2']
        optimizer = torch.optim.Adam(model_params, lr=lr, betas=(beta1, beta2))

    elif opt_type == "SGD":
        momentum = optimizer_config['momentum']
        optimizer = torch.optim.SGD(model_params, lr=lr, momentum=momentum)

    else:
        raise ValueError(f"Optimizer type '{opt_type}' not recognized. Please choose 'Adam' or 'SGD'.")

    return optimizer

# Main function
def main():
    # Load the configuration from YAML
    load_states = False
    load_gen_weights = False
    config = load_config("params_2.yaml")
    
    # Set up logging
    logging.basicConfig(filename=f"{config['output']['run_dir']}training.log", level=logging.INFO, format='%(asctime)s %(message)s')

    # Setup device (GPU/CPU)
    if torch.cuda.is_available():
        print(f"Cuda is available!")
        logging.info("CUDA is available. Using GPU.")
        device = torch.device("cuda")
    else:
        logging.info("CUDA is not available. Using CPU.")
        device = torch.device("cpu")
        print(f"Unable to connect to CUDA!!!!")

    # File path from YAML
    coco_path = config['data']['coco_path']
    paths = glob.glob(coco_path + "/*.jpg")  # Grabbing all the image file names

    # Load number of images from config
    num_imgs = config['data']['num_imgs']
    split = config['data']['split']
    train_paths, val_paths = select_images(paths, num_imgs, split)
    logging.info(f"Training set: {len(train_paths)} images")
    logging.info(f"Validation set: {len(val_paths)} images")

    # Image size from YAML
    size = config['data']['image_size']
    train_ds = ColorizationDataset(size, paths=train_paths, split="train")
    val_ds = ColorizationDataset(size, paths=val_paths, split="val")

    # Batch size from YAML
    batch_size = config['training']['batch_size']
    train_dl = DataLoader(train_ds, batch_size=batch_size)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

    # Check Tensor Size
    data = next(iter(train_dl))
    Ls, abs_ = data['L'], data['ab']
    assert Ls.shape == torch.Size([batch_size, 1, size, size]) and abs_.shape == torch.Size([batch_size, 2, size, size])

    # Model parameters
    generator = Unet()
    discriminator = PatchDiscriminator(3)

    # Create the model initializer
    initializer = ModelInitializer(device, init_type=config['model']['init_type'], gain=config['model']['gain'])
        
    # Initialize the models
    if load_states:
        try:
            # Load the model checkpoints
            checkpoint = torch.load("/home/farrell.jo/cGAN_grey_to_color/models/training_runs/pretrained_gen_200_with_entropy_loss/model_weights/checkpoint.pth")
            generator.load_state_dict(checkpoint["generator_state_dict"])
            discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
            print("Previous Weights Loaded!!!")
        except FileNotFoundError as e:
            print("Error loading model weights!")
            
    elif load_gen_weights:
        try:
            # Load the model checkpoints
            checkpoint = torch.load("/home/farrell.jo/cGAN_grey_to_color/models/generator_train/Res_full_data_3/gen_weights/checkpoint_epoch_201.pth")
            generator.load_state_dict(checkpoint["model_state_dict"])
            discriminator = initializer.init_model(discriminator)
            print(f"Generator weights laoded successfully!")
        except FileNotFoundError as e:
            print("Error loadinf generator weights!")
            
    else:
        generator = initializer.init_model(generator)
        discriminator = initializer.init_model(discriminator)
        print(f"Models initialized!")

    # Move models to device (GPU/CPU)
    generator.to(device)
    discriminator.to(device)

    # Loss functions from YAML
    adversarial_loss = nn.BCEWithLogitsLoss()
    content_loss = nn.L1Loss()   
    lambda_l1 = config['training']['lambda_l1']

    # Get optimizer from YAML configuration for both generator and discriminator
    optimizer_G = get_optimizer(config['optimizer_G'], generator.parameters())
    optimizer_D = get_optimizer(config['optimizer_D'], discriminator.parameters())

    # Load optimizer state if available in checkpoint
    if load_states:
        if 'optimizer_gen_state_dict' in checkpoint and 'optimizer_disc_state_dict' in checkpoint:
            optimizer_G.load_state_dict(checkpoint['optimizer_gen_state_dict'])
            optimizer_D.load_state_dict(checkpoint['optimizer_disc_state_dict'])
            print("Optimizer states loaded successfully!")

    # Learning rate scheduler
    mode = config['scheduler_G']['mode']
    factor = config['scheduler_G']['factor']
    patience = config['scheduler_G']['patience']
    verbose = config['scheduler_G']['verbose']
    scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode, factor, patience, verbose)

    mode = config['scheduler_D']['mode']
    factor = config['scheduler_D']['factor']
    patience = config['scheduler_D']['patience']
    verbose = config['scheduler_D']['verbose']
    scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, mode, factor, patience, verbose)

    # Number of epochs from YAML
    epochs = config['training']['epochs']

    # Flags for showing and saving images
    show_fig = config['training']['show_fig']
    save_images = config['training']['save_images']

    # Initialize GANDriver with all parameters from YAML
    driver = GANDriver(
        generator=generator,
        discriminator=discriminator,
        train_dl=train_dl,
        val_dl=val_dl,
        optimizer_G=optimizer_G,
        optimizer_D=optimizer_D,
        adversarial_loss=adversarial_loss,
        content_loss=content_loss,
        lambda_l1=lambda_l1,
        device=device,
        epochs=epochs,
        scheduler_D=scheduler_D, 
        scheduler_G=scheduler_G,
        run_dir=config['output']['run_dir'],
        base_dir=config['output']['base_dir']
    )

    # Run the GAN training and save metrics to CSV after each epoch
    train_results = driver.run(show_fig=show_fig, save_images=save_images)

    # Save training results to CSV
    results_df = pd.DataFrame(train_results)
    result_path = f"{config['output']['base_dir']}/{config['output']['run_dir']}/{config['output']['training_results_csv']}"
    results_df.to_csv(result_path, index=False)
    logging.info(f"Training complete. Results saved to {result_path}.")

    
    # Save the dictionary to a YAML file
    yaml_filepath = f"{config['output']['base_dir']}/{config['output']['run_dir']}/config.yml"
    with open(yaml_filepath, 'w') as file:
        yaml.dump(config, file, default_flow_style=False)
    
    logging.info(f"Configuration saved to {yaml_filepath}")

if __name__ == "__main__":
    main()

Cuda is available!




Models initialized!



Training Epoch 1/200: 100%|██████████| 125/125 [03:52<00:00,  1.86s/it, D_loss=0.345, G_loss=18.9]
Validation Epoch 1/200: 100%|██████████| 32/32 [00:47<00:00,  1.50s/it, D_loss=0.845, G_loss=18.6]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 2/200: 100%|██████████| 125/125 [03:25<00:00,  1.64s/it, D_loss=0.407, G_loss=20.5]
Validation Epoch 2/200: 100%|██████████| 32/32 [00:42<00:00,  1.31s/it, D_loss=0.812, G_loss=23.9]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 3/200: 100%|██████████| 125/125 [03:25<00:00,  1.65s/it, D_loss=0.233, G_loss=21.5]
Validation Epoch 3/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.848, G_loss=19.9]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 4/200: 100%|██████████| 125/125 [03:26<00:00,  1.65s/it, D_loss=0.301, G_loss=21]  
Validation Epoch 4/200: 100%|██████████| 32/32 [00:41<00:00,  1.29s/it, D_loss=0.772, G_loss=19.2]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 5/200: 100%|██████████| 125/125 [03:35<00:00,  1.73s/it, D_loss=0.295, G_loss=20.8]
Validation Epoch 5/200: 100%|██████████| 32/32 [00:47<00:00,  1.48s/it, D_loss=0.733, G_loss=20.1]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 6/200: 100%|██████████| 125/125 [03:42<00:00,  1.78s/it, D_loss=0.378, G_loss=20.2]
Validation Epoch 6/200: 100%|██████████| 32/32 [00:42<00:00,  1.32s/it, D_loss=0.735, G_loss=19.9]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 7/200: 100%|██████████| 125/125 [03:27<00:00,  1.66s/it, D_loss=0.327, G_loss=20]  
Validation Epoch 7/200: 100%|██████████| 32/32 [00:42<00:00,  1.34s/it, D_loss=0.684, G_loss=21.6]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 8/200: 100%|██████████| 125/125 [03:26<00:00,  1.66s/it, D_loss=0.293, G_loss=19.3]
Validation Epoch 8/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.694, G_loss=20.7]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 9/200: 100%|██████████| 125/125 [03:27<00:00,  1.66s/it, D_loss=0.347, G_loss=18.8]
Validation Epoch 9/200: 100%|██████████| 32/32 [00:42<00:00,  1.32s/it, D_loss=0.655, G_loss=18.3]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 10/200: 100%|██████████| 125/125 [03:31<00:00,  1.69s/it, D_loss=0.368, G_loss=18.8]
Validation Epoch 10/200: 100%|██████████| 32/32 [00:42<00:00,  1.34s/it, D_loss=0.641, G_loss=18.4]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 11/200: 100%|██████████| 125/125 [03:40<00:00,  1.76s/it, D_loss=0.386, G_loss=18.7]
Validation Epoch 11/200: 100%|██████████| 32/32 [00:44<00:00,  1.40s/it, D_loss=0.642, G_loss=18.8]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 12/200: 100%|██████████| 125/125 [03:34<00:00,  1.71s/it, D_loss=0.42, G_loss=18.4] 
Validation Epoch 12/200: 100%|██████████| 32/32 [00:44<00:00,  1.38s/it, D_loss=0.643, G_loss=19.2]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 13/200: 100%|██████████| 125/125 [03:41<00:00,  1.77s/it, D_loss=0.386, G_loss=18]  
Validation Epoch 13/200: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it, D_loss=0.638, G_loss=19.5]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 14/200: 100%|██████████| 125/125 [03:47<00:00,  1.82s/it, D_loss=0.366, G_loss=17.6]
Validation Epoch 14/200: 100%|██████████| 32/32 [00:43<00:00,  1.36s/it, D_loss=0.641, G_loss=19.7]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 15/200: 100%|██████████| 125/125 [03:32<00:00,  1.70s/it, D_loss=0.37, G_loss=17.3] 
Validation Epoch 15/200: 100%|██████████| 32/32 [00:43<00:00,  1.36s/it, D_loss=0.64, G_loss=19]   


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 16/200: 100%|██████████| 125/125 [03:29<00:00,  1.67s/it, D_loss=0.382, G_loss=16.9]
Validation Epoch 16/200: 100%|██████████| 32/32 [00:40<00:00,  1.28s/it, D_loss=0.63, G_loss=19.3] 


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 17/200: 100%|██████████| 125/125 [03:28<00:00,  1.66s/it, D_loss=0.368, G_loss=16.9]
Validation Epoch 17/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.635, G_loss=19.2]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 18/200: 100%|██████████| 125/125 [03:31<00:00,  1.69s/it, D_loss=0.365, G_loss=16.4]
Validation Epoch 18/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.634, G_loss=18.6]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 19/200: 100%|██████████| 125/125 [03:31<00:00,  1.69s/it, D_loss=0.361, G_loss=15.9]
Validation Epoch 19/200: 100%|██████████| 32/32 [00:44<00:00,  1.40s/it, D_loss=0.634, G_loss=19.1]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 20/200: 100%|██████████| 125/125 [03:39<00:00,  1.76s/it, D_loss=0.375, G_loss=15.8]
Validation Epoch 20/200: 100%|██████████| 32/32 [00:42<00:00,  1.32s/it, D_loss=0.638, G_loss=19.3]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 21/200: 100%|██████████| 125/125 [03:33<00:00,  1.71s/it, D_loss=0.364, G_loss=15.8]
Validation Epoch 21/200: 100%|██████████| 32/32 [00:44<00:00,  1.38s/it, D_loss=0.636, G_loss=19.6]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 22/200: 100%|██████████| 125/125 [03:26<00:00,  1.65s/it, D_loss=0.352, G_loss=15.8]
Validation Epoch 22/200: 100%|██████████| 32/32 [00:42<00:00,  1.32s/it, D_loss=0.636, G_loss=19.2]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 23/200: 100%|██████████| 125/125 [03:26<00:00,  1.65s/it, D_loss=0.351, G_loss=15.8]
Validation Epoch 23/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.636, G_loss=19.6]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 24/200: 100%|██████████| 125/125 [03:30<00:00,  1.68s/it, D_loss=0.357, G_loss=15.8]
Validation Epoch 24/200: 100%|██████████| 32/32 [00:42<00:00,  1.33s/it, D_loss=0.639, G_loss=19.1]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 25/200: 100%|██████████| 125/125 [03:29<00:00,  1.68s/it, D_loss=0.347, G_loss=16]  
Validation Epoch 25/200: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it, D_loss=0.634, G_loss=19.3]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 26/200: 100%|██████████| 125/125 [03:30<00:00,  1.68s/it, D_loss=0.355, G_loss=15.6]
Validation Epoch 26/200: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it, D_loss=0.632, G_loss=19.1]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 27/200: 100%|██████████| 125/125 [03:35<00:00,  1.73s/it, D_loss=0.346, G_loss=15.6]
Validation Epoch 27/200: 100%|██████████| 32/32 [00:43<00:00,  1.35s/it, D_loss=0.634, G_loss=19.2]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 28/200: 100%|██████████| 125/125 [03:40<00:00,  1.76s/it, D_loss=0.356, G_loss=15.6]
Validation Epoch 28/200: 100%|██████████| 32/32 [00:42<00:00,  1.32s/it, D_loss=0.633, G_loss=19.1]


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 29/200: 100%|██████████| 125/125 [03:25<00:00,  1.64s/it, D_loss=0.353, G_loss=15.4]
Validation Epoch 29/200: 100%|██████████| 32/32 [00:41<00:00,  1.29s/it, D_loss=0.629, G_loss=19]  


Model state saved to training_runs/base_model/model_weights/checkpoint.pth

Training complete and model weights saved.



Training Epoch 30/200:  99%|█████████▉| 124/125 [03:25<00:01,  1.62s/it, D_loss=0.504, G_loss=15.1]