In [None]:

##################################################### Packages ###################################################################
import os
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch.utils.data import DataLoader


from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### Training class

In [None]:


class TrainViT:
    """
    Class to pretrain the generator network
    """
    def __init__(self, 
                 image_size = 224, 
                 batch_size = 32, 
                 epochs = 100, 
                 lr = 0.0002, 
                 beta1 = 0.5, 
                 beta2 = 0.999, 
                 weight_decay=0, 
                 loss = nn.L1Loss(), 
                 run = "training_run", 
                 start_epoch = 0):
        """
        Initializes PreTrainGenerator classes with all default values.
        See methods to perform sets.
        """
        self.image_size = image_size
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.weight_decay = weight_decay
        self.loss = loss
        self.run = run
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = None
        self.optimizer = None
        self.scheduler = None

        self.start_epoch = start_epoch
        self.epochs = epochs
        self.avg_loss = 0

        self.train_ds = None
        self.val_ds = None
        self.train_dl = None
        self.val_dl = None

        self.train_loss_generator = []
        self.val_loss_generator = []

        self.val_paths = None
        self.train_paths = None

    def set_train_and_val_paths(self, data_dir:str, num_images:int) -> None:
        """
        Implement me
        """
        self.train_paths, self.val_paths = select_images(data_dir, num_images)
        
    def set_model(self, model:callable = None) -> None:
        """
        Set the generator model and optimizer, default is to use a U-Net with a ResNet18 backbone
        """
        pass

    def load_state(self, path_to_checkpoint:str) -> None:
        """
        Loads a previous model state
        """
        try:
            checkpoint = torch.load(path_to_checkpoint)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_epoch = checkpoint['epoch']
            self.avg_loss = checkpoint['loss']
            print(f"Model state loaded successfully!")
        except FileNotFoundError as e:
            print("Error loading generator weights!")
        return

    def set_data_loaders(self, perform_checks:bool = True) -> None:
        """
        Set up the dataloaders
        """
        self.train_ds = IsdDataSet(self.batch_size, paths = self.train_paths, split = "train")
        self.val_ds = IsdDataSet(self.batch_size, paths = self.val_paths, split = "val")
        self.train_dl = DataLoader(self.train_ds, batch_size = self.batch_size)
        self.val_dl = DataLoader(self.val_ds, batch_size = self.batch_size)

        if perform_checks:
            data = next(iter(self.train_dl))
            Ls, abs_ = data['L'], data['ab']
            assert Ls.shape == torch.Size([self.batch_size, 1, self.batch_size, self.batch_size])
            assert abs_.shape == torch.Size([self.batch_size, 2, self.batch_size, self.batch_size])
            print(Ls.shape, abs_.shape)
            print(len(self.train_dl), len(self.val_dl))

        return

    def train_loop(self, epoch) -> None:
        """
        Performs the train loop tracking train loss
        """
        epoch_train_loss = 0
        num_batches = 0

        # Train Loop
        pbar = tqdm(self.train_dl, desc=f"Training Epoch {self.start_epoch}/{self.start_epoch + self.epochs}")
        for i, data in enumerate(pbar):
            L, abs_ = data["L"], data["ab"]
            L, abs_ = L.to(self.device), abs_.to(self.device)
    
            # Train the generator
            self.model.train()
            self.optimizer.zero_grad()
            generated_abs = self.model(L)
    
            LOSS = self.loss(generated_abs, abs_) 
            LOSS.backward()
            self.optimizer.step()
    
            # Accumulate losses
            epoch_train_loss += LOSS.item()
            num_batches += 1
    
            # Update progress bar with current loss values
            pbar.set_postfix(G_loss=LOSS.item())
    
        # Average losses for the epoch
        avg_train_loss = epoch_train_loss / num_batches
        self.train_loss_generator.append(avg_train_loss)
        print(f"The average loss for epoch: {epoch} - {avg_train_loss}")
    
        self.scheduler.step(avg_train_loss)

    def val_loop(self, epoch) -> None:
        """
        Performs the val loop tracking val loss
        """
        with torch.no_grad():
            num_batches = 0
            epoch_val_loss = 0
            self.model.eval()
            
            pbar = tqdm(self.val_dl, desc=f"Validation Epoch {self.start_epoch}/{self.start_epoch + self.epochs}")
            for i, data in enumerate(pbar):
                L, abs_ = data["L"], data["ab"]
                L, abs_ = L.to(self.device), abs_.to(self.device)
    
                 # Evaluate the generator
                generated_abs = self.model(L)
                LOSS = self.loss(generated_abs, abs_) 
        
                # Accumulate losses
                epoch_val_loss += LOSS.item()
                num_batches += 1
        
                # Update progress bar with current loss values
                pbar.set_postfix(G_loss=LOSS.item())
    
                # Create the directory to save iamges in
                image_save_dir = f"{str(Path.cwd())}/training_runs/{self.run}/val_images/"
                os.makedirs(image_save_dir, exist_ok=True)
                image_save_path = image_save_dir + f"epoch_{epoch}.png"
                
                if epoch % 10 == 0:
                    self.plot_batch(L, generated_abs, abs_, show = False, save_path = image_save_path)
        
        # Calculate average validation loss
        avg_val_loss = epoch_val_loss / num_batches
        self.val_loss_generator.append(avg_val_loss)
        print(f"Avg Validation Loss: {avg_val_loss}")
        return

    def plot_losses(self, epoch) -> None:
        """
        Generates and saves loss versus epoch plot
        """
        # Create fig
        figs_save_dir = f"{str(Path.cwd())}/training_runs/{self.run}/loss_figs/"
        os.makedirs(figs_save_dir, exist_ok=True)
        figs_save_path = figs_save_dir + f"epoch_{epoch}.png"
                
        # Ensure the directory exists
        epoch_range = range(self.start_epoch, epoch + 1)
        plt.plot(epoch_range, self.train_loss_generator, c = "b", label = "Train Loss")
        plt.plot(epoch_range, self.val_loss_generator, c = "r", label = "Val Loss")
        plt.legend()
        plt.tight_layout()
        plt.savefig(figs_save_path)
        plt.close

    def save_model_state(self, epoch:int) -> None:
        """
        Saves the current model state
        """
        # Path to model weights location 
        state_save_dir = f"/home/farrell.jo/cGAN_grey_to_color/models/generator_train/{self.run}/gen_weights/"
    
        # Ensure the directory exists
        os.makedirs(state_save_dir, exist_ok=True)
    
        # Update path with file name
        state_save_path = os.path.join(state_save_dir, f'checkpoint_epoch_{epoch}.pth')
    
        # Save the model weights
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': self.train_loss_generator[-1],
        }, state_save_path)
        print(f"Model state saved to: {state_save_path}")
        
    def train_model(self) -> None:
        """
        Trains the model
        """
        for epoch in range(self.start_epoch, self.start_epoch + self.epochs):
            self.train_loop(epoch)
            self.val_loop(epoch)
            if epoch % 10 == 1:
                self.plot_losses()
                self.save_model_state()
        
    def set_optimizer(self, model_params) -> None:
        """
        Method to set up the optimizer
        """
        self.optimizer = torch.optim.Adam(model_params, lr=self.lr, betas=(self.beta1, self.beta2), weight_decay=self.weight_decay)

    def set_scheduler(self) -> None:
        """
        Method to set up the scheduler
        """
        raise NotImplementedError

print("This works!!")
if __name__ == "__main__":
    pass


### Param YAML 

In [None]:
# Configuration for training
checkpoint_path: null  # Path to checkpoint file if loading a previous state
load_previous_state: false  # Whether to load a previous state
data_dir: "path/to/data"  # Directory containing .tif images
num_images: 10000  # Number of images to use for training
size: 256  # Image size
batch_size: 32  # Batch size for training
epochs: 101  # Number of epochs for training
lr: 0.0002  # Learning rate
beta1: 0.5  # Beta1 for Adam optimizer
beta2: 0.999  # Beta2 for Adam optimizer
run: "test_run_01"  # Run name for logging and saving checkpoints
start_epoch: 0  # Starting epoch

pretrained:
  checkpoint_path: null  # Path to checkpoint file if loading a previous state
  load_model_state: false  # Whether to load the model state
  load_optim_states: false  # Whether to load optimizer states

### Train driver

In [None]:
import glob
import yaml

def load_optimizer_states(optimizer, config):
    """
    Loads the optimizer states for generator and discriminator from the checkpoint
    if they exist and if the config specifies to do so.

    Args:
        optimizer (torch.optim.Optimizer): Optimizer
        checkpoint (dict): The loaded checkpoint containing model and optimizer states.
        config (dict): Configuration dictionary specifying whether to load optimizer states.

    Returns:
        None
    """
    if config['pretrained']['checkpoint_path']:
        checkpoint_path = config['pretrained']['checkpoint_path']
        checkpoint = torch.load(checkpoint_path)
        
    if config['pretrained']['load_optim_states']:
        try:
            if 'optimizer_state_dict' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            else:
                print("Optimizer state dictionaries not found in checkpoint.")
        except Exception as e:
            print(f"Error loading optimizer states: {e}")

    return optimizer


# Load parameters from YAML
with open("config.yaml", "r") as file:
    params = yaml.safe_load(file)

# Extract parameters
checkpoint_path = params.get("checkpoint_path")
load_previous_state = params.get("load_previous_state", False)
data_dir = params["data_dir"]
paths = glob.glob(data_dir + "/*.tif")
num_images = params["num_images"]
size = params["size"]
batch_size = params["batch_size"]
epochs = params["epochs"]
lr = params["lr"]
beta1 = params["beta1"]
beta2 = params["beta2"]
l1_loss = nn.L1Loss()
run = params["run"]
start_epoch = params["start_epoch"]

# Select model
model = None  # Define or load your model here if needed
model_params = model.parameters()

# Train model
vit_trainer = TrainViT(size, batch_size, epochs, lr, beta1, beta2, l1_loss, run, start_epoch)
vit_trainer.set_train_and_val_paths(paths, num_images)
vit_trainer.set_data_loaders()
vit_trainer.set_model(model=model)
vit_trainer.set_optimizer(model_params=model_params)
if load_previous_state:
    vit_trainer.load_state(checkpoint_path)
vit_trainer.train_model()

FileNotFoundError: [Errno 2] No such file or directory: 'config.yaml'