<a href="https://colab.research.google.com/github/hsandaver/hsandaver/blob/main/ML_Fader_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries
!pip install --quiet opencv-python-headless ipywidgets Pillow scipy pandas scikit-learn seaborn

# Enable ipywidgets extension for Google Colab
!jupyter nbextension enable --py widgetsnbextension

# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import ipywidgets as widgets
from PIL import Image
from google.colab import files
import io  # For handling file input/output
from IPython.display import display, clear_output, HTML
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import warnings
import os  # For file operations

# For enhanced visualizations
import seaborn as sns
sns.set(style="whitegrid")

warnings.filterwarnings('ignore')  # Suppress warnings for cleaner output

# ---------------------------- #
# 1. Synthetic Data Generation #
# ---------------------------- #

def create_more_accurate_synthetic_data(num_samples=1000):
    """
    Generates synthetic data simulating the fading of different materials under various environmental conditions.

    Parameters:
        num_samples (int): Number of samples to generate.

    Returns:
        pd.DataFrame: Synthetic dataset.
    """
    np.random.seed(42)

    # Define material-specific fading parameters based on research
    material_fading_types = {
        'Paper': {
            'uv_sensitivity': lambda: np.random.uniform(0.4, 1.0, num_samples),
            'fading_curve': lambda t: 50 * np.exp(-0.05 * t)  # Rapid initial fading followed by slow fading
        },
        'Textiles': {
            'uv_sensitivity': lambda: np.random.uniform(0.3, 0.8, num_samples),
            'fading_curve': lambda t: 30 * np.log1p(t)  # Gradual fading after a threshold
        },
        'Albumen Prints': {
            'uv_sensitivity': lambda: np.random.uniform(0.2, 0.7, num_samples),
            'fading_curve': lambda t: 10 + 0.3 * t  # Sensitive to UV, with pronounced yellowing
        },
        'Silver Gelatin Photographs': {
            'uv_sensitivity': lambda: np.random.uniform(0.1, 0.5, num_samples),
            'humidity_sensitivity': lambda: np.random.uniform(0.1, 0.6, num_samples),  # Humidity increases yellowing
            'fading_curve': lambda t: -20 * np.exp(-0.02 * t) + 30  # Initial darkening, then fading
        }
    }

    # Randomly assign materials to the samples
    materials = np.random.choice(list(material_fading_types.keys()), num_samples)

    # Initialize dataframe
    data = pd.DataFrame({
        'material': materials,
        'uv_exposure': [material_fading_types[m]['uv_sensitivity']()[i] for i, m in enumerate(materials)],
        'lux_hours': np.random.uniform(0, 500, num_samples),
        'humidity': np.random.uniform(0, 100, num_samples),
        'temperature': np.random.uniform(-10, 50, num_samples),
        'manufacture_year': np.random.randint(1600, 2025, num_samples),
        'time_years': np.random.uniform(0, 100, num_samples)
    })

    # Apply material-specific fading curves to generate color shifts
    def compute_delta_L(row):
        material = row['material']
        t = row['time_years']
        fading_curve = material_fading_types[material]['fading_curve'](t)
        if material == 'Silver Gelatin Photographs':
            # Incorporate humidity sensitivity for silver gelatin
            return fading_curve * row['uv_exposure'] * row['humidity'] * material_fading_types[material]['humidity_sensitivity']()[0]
        else:
            return fading_curve * row['uv_exposure']

    data['delta_L'] = data.apply(compute_delta_L, axis=1)
    data['delta_A'] = np.random.uniform(-50, 50, num_samples) * 0.1  # Smaller variation
    data['delta_B'] = np.random.uniform(-50, 50, num_samples) * 0.1  # Smaller variation

    # Clip delta values to realistic ranges
    data['delta_L'] = data['delta_L'].clip(-100, 100)
    data['delta_A'] = data['delta_A'].clip(-100, 100)
    data['delta_B'] = data['delta_B'].clip(-100, 100)

    return data

# Generate synthetic data
lab_data = create_more_accurate_synthetic_data()

# Display the first few rows of the data with enhanced styling
print("### Sample of Synthetic Data:")
display(lab_data.head())

# ----------------------- #
# 2. Data Preprocessing    #
# ----------------------- #

def preprocess_data(data):
    """
    Preprocesses the synthetic data by encoding categorical variables, scaling features, and splitting into train and validation sets.

    Parameters:
        data (pd.DataFrame): The synthetic dataset.

    Returns:
        tuple: Scaled training and validation features and targets, scaler object, and feature columns.
    """
    # Drop any potential NaN values
    data = data.dropna()

    # Encode categorical variables using one-hot encoding
    data_encoded = pd.get_dummies(data, columns=['material'], drop_first=True)

    # Features and targets
    feature_cols = ['uv_exposure', 'lux_hours', 'humidity', 'temperature', 'manufacture_year', 'time_years']
    feature_cols += [col for col in data_encoded.columns if col.startswith('material_')]
    features = data_encoded[feature_cols]
    targets = data_encoded[['delta_L', 'delta_A', 'delta_B']]

    # Feature scaling
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(features)

    # Split the data
    X_train, X_val, y_train, y_val = train_test_split(X_scaled, targets, test_size=0.2, random_state=42)

    return X_train, X_val, y_train, y_val, scaler, feature_cols

# Preprocess data
X_train, X_val, y_train, y_val, scaler, feature_cols = preprocess_data(lab_data)

# ---------------------------- #
# 3. Model Training & Selection #
# ---------------------------- #

def train_model(X_train, y_train, model_type='RandomForest'):
    """
    Trains a multi-output regression model.

    Parameters:
        X_train (np.array): Training features.
        y_train (pd.DataFrame): Training targets.
        model_type (str): Type of regression model ('LinearRegression' or 'RandomForest').

    Returns:
        MultiOutputRegressor: Trained regression model.
    """
    if model_type == 'LinearRegression':
        base_model = LinearRegression()
    elif model_type == 'RandomForest':
        base_model = RandomForestRegressor(n_estimators=100, random_state=42)
    else:
        raise ValueError("Unsupported model type. Choose 'LinearRegression' or 'RandomForest'.")

    model = MultiOutputRegressor(base_model)
    model.fit(X_train, y_train)
    return model

# Choose model type
model_type = 'RandomForest'  # Options: 'LinearRegression', 'RandomForest'

# Train the model
model = train_model(X_train, y_train, model_type=model_type)

# ------------------------- #
# 4. Model Evaluation       #
# ------------------------- #

def evaluate_model(model, X_val, y_val):
    """
    Evaluates the regression model using RMSE.

    Parameters:
        model (MultiOutputRegressor): The trained regression model.
        X_val (np.array): Validation features.
        y_val (pd.DataFrame): Validation targets.

    Returns:
        None
    """
    predictions = model.predict(X_val)
    rmse = np.sqrt(mean_squared_error(y_val, predictions, multioutput='raw_values'))
    print("### Calibration Model RMSE (Validation Data):")
    print(f"- **L channel:** {rmse[0]:.2f}")
    print(f"- **A channel:** {rmse[1]:.2f}")
    print(f"- **B channel:** {rmse[2]:.2f}")

    # Visualize RMSE
    channels = ['L', 'A', 'B']
    plt.figure(figsize=(6,4))
    sns.barplot(x=channels, y=rmse, palette='viridis')
    plt.title('RMSE per LAB Channel')
    plt.ylabel('RMSE')
    plt.xlabel('Channel')
    plt.show()

print("\n### Model Evaluation:")
evaluate_model(model, X_val, y_val)

# ------------------------------------ #
# 5. Enhanced Interactive Interface    #
# ------------------------------------ #

# Define material-specific environmental factor ranges
material_env_ranges = {
    'Paper': {
        'uv_exposure': (0.4, 1.0),
        'lux_hours': (100, 500),
        'humidity': (20, 80),
        'temperature': (10, 40),
        'time_years': (0, 50)
    },
    'Textiles': {
        'uv_exposure': (0.3, 0.8),
        'lux_hours': (150, 450),
        'humidity': (30, 90),
        'temperature': (15, 35),
        'time_years': (0, 70)
    },
    'Albumen Prints': {
        'uv_exposure': (0.2, 0.7),
        'lux_hours': (80, 400),
        'humidity': (25, 85),
        'temperature': (5, 45),
        'time_years': (0, 90)
    },
    'Silver Gelatin Photographs': {
        'uv_exposure': (0.1, 0.5),
        'lux_hours': (50, 300),
        'humidity': (10, 70),
        'temperature': (0, 50),
        'time_years': (0, 100)
    }
}

# Initialize widgets with enhanced styling
material_dropdown = widgets.Dropdown(
    options=list(material_env_ranges.keys()),
    value='Paper',
    description='**Material:**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

uv_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.01, description='**UV Exposure**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
lux_slider = widgets.FloatSlider(
    value=250, min=0, max=500, step=10, description='**Lux Hours**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
humidity_slider = widgets.FloatSlider(
    value=50, min=0, max=100, step=1, description='**Humidity (%)**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
temp_slider = widgets.FloatSlider(
    value=20, min=-10, max=50, step=1, description='**Temperature (°C)**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
time_slider = widgets.FloatSlider(
    value=10, min=0, max=100, step=1, description='**Years of Aging**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
manufacture_year_slider = widgets.IntSlider(
    value=1850, min=1600, max=2024, step=1, description='**Year of Manufacture**',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

# Function to update slider ranges based on selected material
def update_sliders(change):
    material = change['new']
    ranges = material_env_ranges[material]
    uv_slider.min, uv_slider.max = ranges['uv_exposure']
    uv_slider.value = np.clip(uv_slider.value, ranges['uv_exposure'][0], ranges['uv_exposure'][1])

    lux_slider.min, lux_slider.max = ranges['lux_hours']
    lux_slider.value = np.clip(lux_slider.value, ranges['lux_hours'][0], ranges['lux_hours'][1])

    humidity_slider.min, humidity_slider.max = ranges['humidity']
    humidity_slider.value = np.clip(humidity_slider.value, ranges['humidity'][0], ranges['humidity'][1])

    temp_slider.min, temp_slider.max = ranges['temperature']
    temp_slider.value = np.clip(temp_slider.value, ranges['temperature'][0], ranges['temperature'][1])

    time_slider.min, time_slider.max = ranges['time_years']
    time_slider.value = np.clip(time_slider.value, ranges['time_years'][0], ranges['time_years'][1])

# Attach the update function to material dropdown
material_dropdown.observe(update_sliders, names='value')

# Arrange sliders in a Grid layout for better organization
slider_grid = widgets.GridspecLayout(6, 2, height='auto')
slider_grid[0, 0] = material_dropdown
slider_grid[1, 0] = uv_slider
slider_grid[2, 0] = lux_slider
slider_grid[3, 0] = humidity_slider
slider_grid[4, 0] = temp_slider
slider_grid[5, 0] = time_slider
slider_grid[1, 1] = manufacture_year_slider

# Style adjustments
slider_grid.layout.width = '100%'
slider_grid.layout.margin = '10px'

# Display sliders
display(HTML("<h2>### Environmental Factors for Fading Simulation</h2>"))
display(slider_grid)

# ------------------------------------ #
# 6. Image Upload and Processing Setup #
# ------------------------------------ #

def upload_image():
    """
    Allows the user to upload an image and converts it to the LAB color space.

    Returns:
        tuple: Original image in RGB and LAB color spaces.
    """
    try:
        uploaded = files.upload()
        if not uploaded:
            print("**No file uploaded. Please upload an image to proceed.**")
            return None, None
        for filename in uploaded.keys():
            image = Image.open(io.BytesIO(uploaded[filename])).convert('RGB')
            # Display uploaded image with styling
            display(HTML("<h3>Original Image:</h3>"))
            display(image)
            image_np = np.array(image)
            image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
            image_lab = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2LAB)
            return image_np, image_lab
    except Exception as e:
        print(f"**An error occurred while uploading the image: {e}**")
        return None, None

# Call the function to upload an image
print("\n### Image Upload:")
print("**Please upload an image to apply the fading effect.**")
original_image_rgb, original_image_lab = upload_image()

# ------------------------------------ #
# 7. Fading Function with Enhancements #
# ------------------------------------ #

def fade_image(model, scaler, feature_cols, material, uv_exposure, lux_hours, humidity, temperature, manufacture_year, time_years, per_pixel_variation=10):
    """
    Applies the calibrated fading effect to the uploaded image based on environmental factors.

    Parameters:
        model (MultiOutputRegressor): Trained regression model.
        scaler (StandardScaler): Scaler used for feature scaling.
        feature_cols (list): List of feature column names.
        material (str): Selected material type.
        uv_exposure (float): UV exposure value.
        lux_hours (float): Lux hours value.
        humidity (float): Humidity percentage.
        temperature (float): Temperature in °C.
        manufacture_year (int): Year of manufacture.
        time_years (float): Years of aging.
        per_pixel_variation (int): Standard deviation for per-pixel variation.

    Returns:
        tuple: Faded image in LAB color space, delta_L_map, delta_A_map, delta_B_map
    """
    # Prepare input features
    input_features = {
        'uv_exposure': uv_exposure,
        'lux_hours': lux_hours,
        'humidity': humidity,
        'temperature': temperature,
        'manufacture_year': manufacture_year,
        'time_years': time_years
    }

    # One-hot encode material
    for mat in material_env_ranges.keys():
        if mat != material:
            input_features[f'material_{mat}'] = 0
        else:
            input_features[f'material_{mat}'] = 1

    # Create DataFrame in the order of feature_cols
    input_df = pd.DataFrame([input_features])[feature_cols]

    # Scale features
    input_scaled = scaler.transform(input_df)

    # Predict color shifts
    delta = model.predict(input_scaled)[0]

    # Introduce per-pixel variation by adding Gaussian noise
    delta_L_map = delta[0] + np.random.normal(0, per_pixel_variation, (original_image_lab.shape[0], original_image_lab.shape[1]))
    delta_A_map = delta[1] + np.random.normal(0, per_pixel_variation, (original_image_lab.shape[0], original_image_lab.shape[1]))
    delta_B_map = delta[2] + np.random.normal(0, per_pixel_variation, (original_image_lab.shape[0], original_image_lab.shape[1]))

    # Clip delta values to realistic ranges
    delta_L_map = np.clip(delta_L_map, 0, 100)  # Only consider positive fading in L
    delta_A_map = np.clip(delta_A_map, -100, 100)
    delta_B_map = np.clip(delta_B_map, -100, 100)

    # Apply color shifts
    faded_lab = original_image_lab.copy().astype(np.float32)
    faded_lab[:, :, 0] = np.clip(faded_lab[:, :, 0] - delta_L_map, 0, 255)  # L channel
    faded_lab[:, :, 1] = np.clip(faded_lab[:, :, 1] + delta_A_map, 0, 255)  # A channel
    faded_lab[:, :, 2] = np.clip(faded_lab[:, :, 2] + delta_B_map, 0, 255)  # B channel

    return faded_lab.astype(np.uint8), delta_L_map, delta_A_map, delta_B_map

# --------------------------------------- #
# 8. Apply Fading, Heat Maps, Histograms, and Percentage Function  #
# --------------------------------------- #

output = widgets.Output()

def on_apply_fading_clicked(b):
    """
    Event handler for the "Apply Fading" button. Applies the fading effect, displays the result, heat maps, histograms, and percentage of fading.
    """
    with output:
        clear_output()
        if original_image_lab is None:
            print("**Please upload an image first.**")
            return

        material = material_dropdown.value
        uv_exposure = uv_slider.value
        lux_hours = lux_slider.value
        humidity = humidity_slider.value
        temperature = temp_slider.value
        manufacture_year = manufacture_year_slider.value
        time_years = time_slider.value

        try:
            # Apply fading with per-pixel variation
            faded_image_lab, delta_L_map, delta_A_map, delta_B_map = fade_image(
                model, scaler, feature_cols, material,
                uv_exposure, lux_hours, humidity,
                temperature, manufacture_year, time_years
            )

            # Convert faded LAB image back to RGB
            faded_image_bgr = cv2.cvtColor(faded_image_lab, cv2.COLOR_LAB2BGR)
            faded_image_rgb = cv2.cvtColor(faded_image_bgr, cv2.COLOR_BGR2RGB)

            # Calculate Percentage of Fading
            # Percentage of fading is defined as the average of positive delta_L_map
            percentage_fading = np.mean(delta_L_map)
            percentage_fading = np.clip(percentage_fading, 0, 100)  # Ensure it doesn't exceed 100%

            # Display original and faded images side by side
            display(HTML(f"<h3>Faded Image ({time_years} Years):</h3>"))
            fig, ax = plt.subplots(1, 2, figsize=(18, 9))

            # Original Image
            ax[0].imshow(original_image_rgb)
            ax[0].set_title("Original Image")
            ax[0].axis('off')

            # Faded Image without Heatmaps
            ax[1].imshow(faded_image_rgb)
            ax[1].set_title(f"Faded Image ({time_years} Years)")
            ax[1].axis('off')

            plt.tight_layout()
            plt.show()

            # Display Percentage of Fading
            display(HTML(f"<h3>### Percentage of Fading Applied:</h3>"))
            display(widgets.HTML(f"<h4><b>{percentage_fading:.2f}%</b> of the image has been faded.</h4>"))

            # Display Faded Image with Heatmaps
            display(HTML(f"<h3>Faded Image with Heatmaps ({time_years} Years):</h3>"))
            fig, ax = plt.subplots(1, 1, figsize=(9, 9))
            ax.imshow(faded_image_rgb)
            ax.set_title(f"Faded Image with Heatmaps ({time_years} Years)")
            ax.axis('off')

            # Overlay Heatmaps on Faded Image
            # Delta L Heatmap
            sns.heatmap(delta_L_map, ax=ax, cmap='viridis', alpha=0.5, cbar=False)
            # Delta A Heatmap
            sns.heatmap(delta_A_map, ax=ax, cmap='coolwarm', alpha=0.3, cbar=False)
            # Delta B Heatmap
            sns.heatmap(delta_B_map, ax=ax, cmap='coolwarm', alpha=0.3, cbar=False)

            plt.tight_layout()
            plt.show()

            # Generate and display comparative histograms
            display(HTML("<h3>### Histograms of Color Shifts:</h3>"))
            fig, axs = plt.subplots(1, 3, figsize=(18, 6))

            # Convert original and faded images to LAB for histogram comparison
            original_lab = original_image_lab.astype(np.float32)
            faded_lab = faded_image_lab.astype(np.float32)

            # L Channel Histogram
            axs[0].hist(original_lab[:, :, 0].flatten(), bins=50, alpha=0.5, label='Original', color='green')
            axs[0].hist(faded_lab[:, :, 0].flatten(), bins=50, alpha=0.5, label='Faded', color='green')
            axs[0].set_title('L Channel Distribution')
            axs[0].set_xlabel('L Value')
            axs[0].set_ylabel('Frequency')
            axs[0].legend()

            # A Channel Histogram
            axs[1].hist(original_lab[:, :, 1].flatten(), bins=50, alpha=0.5, label='Original', color='red')
            axs[1].hist(faded_lab[:, :, 1].flatten(), bins=50, alpha=0.5, label='Faded', color='red')
            axs[1].set_title('A Channel Distribution')
            axs[1].set_xlabel('A Value')
            axs[1].set_ylabel('Frequency')
            axs[1].legend()

            # B Channel Histogram
            axs[2].hist(original_lab[:, :, 2].flatten(), bins=50, alpha=0.5, label='Original', color='blue')
            axs[2].hist(faded_lab[:, :, 2].flatten(), bins=50, alpha=0.5, label='Faded', color='blue')
            axs[2].set_title('B Channel Distribution')
            axs[2].set_xlabel('B Value')
            axs[2].set_ylabel('Frequency')
            axs[2].legend()

            plt.tight_layout()
            plt.show()

            # Provide option to download the faded image
            faded_pil = Image.fromarray(faded_image_rgb)
            buf = io.BytesIO()
            faded_pil.save(buf, format='PNG')
            buf.seek(0)
            # Define a unique filename for each download
            filename = f"faded_image_{int(time_years)}_years.png"
            with open(filename, 'wb') as f:
                f.write(buf.getvalue())

            download_button = widgets.Button(
                description="📥 Download Faded Image",
                button_style='info',
                tooltip='Download the faded image as PNG',
                icon='download'
            )

            def on_download_clicked(b):
                try:
                    files.download(filename)
                except FileNotFoundError:
                    print("**File not found. Please try applying the fading effect again.**")

            download_button.on_click(on_download_clicked)
            display(download_button)

        except Exception as e:
            print(f"**Error during fading: {e}**")

# Enhanced Apply Button with styling
apply_button = widgets.Button(
    description="🎨 Apply Fading",
    button_style='success',
    tooltip='Click to apply fading effect',
    icon='paint-brush',
    layout=widgets.Layout(width='50%')
)
apply_button.on_click(on_apply_fading_clicked)

# Display Apply Button and Output
display(widgets.HBox([apply_button]))
display(output)