In [None]:
import os
import sys
import glob
import time
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim
from typing import Tuple, Dict, Optional, List

from IPPy.utilities.data import TrainDataset
from IPPy.utilities.test import *
from IPPy.nn.models import UNet, ResUNet
from IPPy.utilities._utilities import *
from IPPy.utilities.metrics import *
from IPPy.nn.train import *


## Configuration

In [None]:

# Data paths
INPUT_PATH_TRAIN = 'data/input/Canova4'
OUTPUT_PATH_TRAIN = 'data/input/Canova2'
INPUT_PATH_TEST = 'data/target/Canova4_target_67932'
OUTPUT_PATH_TEST = 'data/target/Canova2_target_67932'

# Model parameters
DATA_SHAPE = 512
MIDDLE_CHANNELS = [64, 128, 256, 512, 1024]
FINAL_ACTIVATION = 'sigmoid'

# Training parameters
N_EPOCHS = 100
BATCH_SIZE = 6

# Model and loss function options
MODELS = ['UNet', 'ResUNet']
LOSS_FUNCTIONS = ['MSE', 'L1']

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Output paths
RESULTS_BASE_DIR = 'results'

@staticmethod
def get_output_dir(model_name, loss_name):
    """Generate output directory path for a specific model and loss."""
    return os.path.join(
        RESULTS_BASE_DIR,
        f'output_reconstruction_test_67932/{model_name}',
        f'{N_EPOCHS}epochs_{BATCH_SIZE}batch_{loss_name}Loss'
    )


## Training Setup

In [None]:
def train_model(
    model_name: str,
    loss_fn_name: str,
) -> None:
    """
    Train a model with specified configuration.
    
    Args:
        model_name: Name of the model to train
        loss_fn_name: Name of the loss function to use

    """
    print("\n" + "="*80)
    print(f"TRAINING {model_name} WITH {loss_fn_name} LOSS")
    print("="*80)
    
    # Setup paths
    results_dir = get_output_dir(model_name, loss_fn_name)
    create_directory(results_dir)
    
    weights_path = os.path.join(results_dir, 'model_weights.pth')
    graphic_path = os.path.join(results_dir, 'loss_curve.png')
    loss_file = os.path.join(results_dir, 'loss_values.csv')
    
    # Load dataset
    print(f"\nLoading training data...")
    print(f"  Input path: {INPUT_PATH_TRAIN}")
    print(f"  Output path: {OUTPUT_PATH_TRAIN}")
    print(f"  Files in input: {count_files(INPUT_PATH_TRAIN)}")
    print(f"  Files in output: {count_files(INPUT_PATH_TRAIN)}")
    
    train_data = TrainDataset(
        in_path=INPUT_PATH_TRAIN,
        out_path=INPUT_PATH_TRAIN,
        data_shape=DATA_SHAPE
    )
    print(f"  Dataset size: {len(train_data)}")
    
    # Initialize model
    print(f"\nInitializing {model_name} model...")
    model = get_model(model_name, MIDDLE_CHANNELS,FINAL_ACTIVATION )
    
    # Get loss function
    loss_fn = get_loss_function(loss_fn_name)
    print(f"Loss function: {loss_fn_name}")
    
    # Train
    print(f"\nStarting training...")
    print(f"  Epochs: {N_EPOCHS}")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Device: {DEVICE}")
    
    start_time = time.time()
    train(
        model=model,
        training_data=train_data,
        loss_fn=loss_fn,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        device=DEVICE,
        graphic_path=graphic_path,
        loss_file=loss_file
    )
    
    training_time = time.time() - start_time
    print(f"\nTraining completed in {training_time/60:.2f} minutes")
    
    # Save model
    torch.save(model.state_dict(), weights_path)
    print(f"Model weights saved to: {weights_path}")
    
    clear_gpu_cache()

## Testing and visualization

In [None]:
def test_model(
    model_name: str,
    loss_fn_name: str, 
) -> Dict[str, float]:
    """
    Test a trained model and save results.
    
    Args:
        model_name: Name of the model to test
        loss_fn_name: Name of the loss function used during training
    
    Returns:
        Dictionary containing average metrics
    """
    print("\n" + "="*80)
    print(f"TESTING {model_name} WITH {loss_fn_name} LOSS")
    print("="*80)
    
    # Setup paths
    results_dir = get_output_dir(model_name, loss_fn_name)
    weights_path = os.path.join(results_dir, 'model_weights.pth')
    output_dir = os.path.join(results_dir, 'single_images')
    comparison_dir = os.path.join(results_dir, 'comparison')
    metrics_file = os.path.join(results_dir, 'metrics.txt')
    
    create_directory(output_dir)
    create_directory(comparison_dir)
    
    # Load model
    print(f"\nLoading {model_name} model...")
    model = get_model(model_name).to(DEVICE)
    
    if os.path.exists(weights_path):
        model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
        print(f"Weights loaded from: {weights_path}")
    else:
        print(f"WARNING: Weights file not found at {weights_path}")
        print("Continuing with randomly initialized model.")
    
    model.eval()
    
    # Load test dataset
    print(f"\nLoading test data...")
    print(f"  Input path: {INPUT_PATH_TEST}")
    print(f"  Output path: {OUTPUT_PATH_TEST}")
    
    test_dataset = TrainDataset(
        in_path=INPUT_PATH_TEST,
        out_path=OUTPUT_PATH_TEST,
        data_shape=DATA_SHAPE
    )
    print(f"  Test dataset size: {len(test_dataset)}")
    
    if len(test_dataset) == 0:
        print("ERROR: No test images found!")
        return {}
    
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Initialize metrics accumulators
    metrics_sum = {
        're_input_target': 0, 'psnr_input_target': 0,
        'ssim_input_target': 0, 'rmse_input_target': 0,
        're_pred_target': 0, 'psnr_pred_target': 0,
        'ssim_pred_target': 0, 'rmse_pred_target': 0
    }
    
    # Process images
    print(f"\nProcessing test images...")
    with torch.no_grad():
        for idx, (x, y) in enumerate(test_loader):
            # Get prediction
            x_device = x.to(DEVICE)
            x_pred = model(x_device).detach().cpu()
            
            # Calculate metrics
            metrics = calculate_metrics(x, y, x_pred)
            for key, value in metrics.items():
                metrics_sum[key] += value
            
            # Save individual prediction
            save_image(x_pred, os.path.join(output_dir, f'predicted_{idx}.png'))
            
            # Save comparison plot
            save_comparison_plot(
                x, y, x_pred,
                os.path.join(comparison_dir, f'comparison_{idx}.png'),
                idx
            )
            
            if (idx + 1) % 10 == 0:
                print(f"  Processed {idx + 1}/{len(test_dataset)} images")
    
    # Calculate averages
    n_images = len(test_dataset)
    metrics_avg = {key: value / n_images for key, value in metrics_sum.items()}
    
    # Save metrics to file
    with open(metrics_file, 'w') as f:
        f.write(f"Testing Results for {model_name} with {loss_fn_name} Loss\n")
        f.write("="*80 + "\n\n")
        f.write(f"Number of test images: {n_images}\n\n")
        f.write("Average Metrics:\n")
        f.write("-"*80 + "\n")
        f.write("\nInput-Target Comparison:\n")
        f.write(f"  RE:   {metrics_avg['re_input_target']:.4f}\n")
        f.write(f"  PSNR: {metrics_avg['psnr_input_target']:.4f}\n")
        f.write(f"  SSIM: {metrics_avg['ssim_input_target']:.4f}\n")
        f.write(f"  RMSE: {metrics_avg['rmse_input_target']:.4f}\n")
        f.write("\nPrediction-Target Comparison:\n")
        f.write(f"  RE:   {metrics_avg['re_pred_target']:.4f}\n")
        f.write(f"  PSNR: {metrics_avg['psnr_pred_target']:.4f}\n")
        f.write(f"  SSIM: {metrics_avg['ssim_pred_target']:.4f}\n")
        f.write(f"  RMSE: {metrics_avg['rmse_pred_target']:.4f}\n")
    
    print(f"\nMetrics saved to: {metrics_file}")
    print("\nAverage Metrics:")
    print(f"  Prediction PSNR: {metrics_avg['psnr_pred_target']:.4f}")
    print(f"  Prediction SSIM: {metrics_avg['ssim_pred_target']:.4f}")
    
    return metrics_avg

In [None]:
def main():
    """Main execution function."""
    print("\n" + "="*80)
    print("NEURAL NETWORK TRAINING AND TESTING PIPELINE")
    print("="*80)
    
    
    # Display configuration
    print("\nConfiguration:")
    print(f"  Data shape: {DATA_SHAPE}")
    print(f"  Epochs: {N_EPOCHS}")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Device: {DEVICE}")
    print(f"  Models: {', '.join(MODELS)}")
    print(f"  Loss functions: {', '.join(LOSS_FUNCTIONS)}")
    
    # Train all models
    for model_name in MODELS:
        for loss_fn_name in LOSS_FUNCTIONS:
            train_model(model_name, loss_fn_name)
    
    # Test all models
    for model_name in MODELS:
        for loss_fn_name in LOSS_FUNCTIONS:
            test_model(model_name, loss_fn_name)
    
    # Create visualizations
    plot_loss_comparison(LOSS_FUNCTIONS, MODELS, get_output_dir)
    create_metrics_comparison_table(LOSS_FUNCTIONS, MODELS, get_output_dir)
    
    print("\n" + "="*80)
    print("PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*80)


if __name__ == "__main__":
    main()