In [None]:
import subprocess

# Function to check and install necessary packages
def setup_environment(packages=None):
    if packages is None:
        packages = [
            'pylandstats', 'openpyxl', 'rasterio', 'folium', 'pandas',
            'numpy', 'geopandas', 'matplotlib', 'shapely', 'IPython', 'ipywidgets'
        ]

    # Function to check if a package is installed and install it if not
    def check_install(package):
        try:
            __import__(package)
            print(f'{package} already installed.')
        except ImportError:
            print(f'{package} not installed. Installing {package}...')
            try:
                subprocess.check_call(['pip', 'install', package])
            except subprocess.CalledProcessError as e:
                print(f"Failed to install {package}: {e}")

    # Iterate over the packages list and install if necessary
    for package in packages:
        check_install(package)

setup_environment()

# Import required libraries
import os
import pandas as pd
import folium
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import box
from matplotlib.patches import Patch
from matplotlib.colors import ListedColormap
from IPython.display import display
import ipywidgets as widgets
import pylandstats as pls
from google.colab import drive

# Function to check if Google Drive is mounted (specific to Google Colab)
def check_drive_mount():
    if not os.path.exists('/content/drive/'):
        drive.mount('/content/drive/')
        print('Drive is not mounted. Mounted now.')
    else:
        print('Drive is already mounted.')

# Function to convert hexadecimal color code to RGB tuple
def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip('#').lower()
    if not all(c in '0123456789abcdef' for c in hex_color) or len(hex_color) not in (3, 6):
        raise ValueError("Invalid hexadecimal color code")
    if len(hex_color) == 3:
        hex_color = ''.join(c * 2 for c in hex_color)
    r = int(hex_color[0:2], 16) / 255.0
    g = int(hex_color[2:4], 16) / 255.0
    b = int(hex_color[4:6], 16) / 255.0
    return (r, g, b)

# Function to read color legend file (supports CSV and XLSX)
def read_color_legend(color_file):
    if color_file.lower().endswith('.xlsx'):
        print(f"Reading Excel color legend file: {color_file}")
        return pd.read_excel(color_file)
    elif color_file.lower().endswith('.csv'):
        print(f"Reading CSV color legend file: {color_file}")
        return pd.read_csv(color_file, sep=';')
    else:
        raise ValueError(f"Unsupported file format: {color_file}. Only .xlsx and .csv are supported.")

# Function to plot a pie chart by metric
def plot_pie_chart_by_metric(metric):
    global fig, ax, metrics_df, color_df, current_metric, class_name_df
    if metric is None:
        return

    print(f"Plotting metric: {metric}")

    values = metrics_df[metric].tolist()
    categories = metrics_df['class_val'].tolist()
    color_map = {row['Class_ID']: row['Color'] for _, row in color_df.iterrows()}
    colors = [color_map.get(cat, '#CCCCCC') for cat in categories]
    total = sum(values)
    percentages = [value / total * 100 for value in values]

    fig, ax = plt.subplots(figsize=(14, 10))
    wedges, _ = ax.pie(
        values, labels=None, colors=colors, startangle=140,
        wedgeprops={'linewidth': 3.0, 'edgecolor': 'white'}
    )
    plt.title(f'Pie Chart for {metric}', fontsize=16, pad=20)
    plt.axis('equal')

    legend_labels = [f'{class_name_df[class_name_df["Class_ID"] == cat]["Class_Name"].values[0]}: {pct:.1f}%' for cat, pct in zip(categories, percentages)]
    plt.legend(wedges, legend_labels, title="Categories", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1), fontsize=12)
    plt.show()

    current_metric[0] = metric  # Update the current metric

# Function to save the current figure
def save_current_figure(b):
    global fig, current_metric, output_folder
    if current_metric[0] is not None and fig is not None:
        filename = f"Pie_Chart_for_{current_metric[0].replace(' ', '_').replace(':', '')}.png"
        filepath = os.path.join(output_folder, filename)
        fig.savefig(filepath, bbox_inches='tight')
        print(f"Figure saved as '{filename}' in the directory: {output_folder}")
    else:
        print("No metric selected to save or figure not generated.")

# Function to calculate and save landscape metrics and plot the raster interactively
def calculate_save_metrics_and_plot(input_raster, output_folder, output_base_name, color_data, nodata_value=0):
    global fig, ax, metrics_df, color_df, current_metric, class_name_df
    current_metric = [None]  # Initialize current_metric as a mutable list
    class_name_df = pd.DataFrame(columns=['Class_ID', 'Class_Name'])  # Initialize DataFrame to store class names
    try:
        print("Loading raster data and calculating metrics...")
        ls = pls.Landscape(input_raster, nodata=nodata_value)
        metrics_df = ls.compute_class_metrics_df()
        print(f"Metrics DataFrame head:\n{metrics_df.head()}")
        metrics_df.to_csv(os.path.join(output_folder, f'{output_base_name}_metrics_output.csv'), sep='\t', encoding='utf-8')
        print(f"Metrics calculated and saved to {os.path.join(output_folder, f'{output_base_name}_metrics_output.csv')}")

        print("Reading color legend file...")
        color_df = read_color_legend(color_data)
        print(f"Color legend DataFrame head:\n{color_df.head()}")
        color_map = {row['Class_ID']: row['Color'] for _, row in color_df.iterrows()}

        print("Loading raster data using rasterio...")
        with rasterio.open(input_raster) as src:
            bounds = src.bounds
            raster_array = src.read(1)
            transform = src.transform
        print(f"Raster data shape: {raster_array.shape}")

        print("Creating custom colormap...")
        unique_classes = np.unique(raster_array)
        colormap = np.zeros((256, 4), dtype=int)
        for class_id, color in color_map.items():
            hex_color = color.lstrip('#')
            rgb = tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
            colormap[class_id] = (*rgb, 255)
        colormap = ListedColormap([tuple(c / 255 for c in colormap[i]) for i in range(256)])
        print(f"Unique classes in raster: {unique_classes}")

        print("Plotting raster with GeoPandas...")
        polys = []
        used_colors = set()

        xmin, ymin, xmax, ymax = bounds
        pixel_width = (xmax - xmin) / raster_array.shape[1]
        pixel_height = (ymax - ymin) / raster_array.shape[0]

        for i in range(raster_array.shape[0]):
            for j in range(raster_array.shape[1]):
                pixel_value = raster_array[i, j]
                if pixel_value in color_df['Class_ID'].values:
                    color_row = color_df[color_df['Class_ID'] == pixel_value]
                    hex_color = color_row.iloc[0]['Color']
                    r, g, b = hex_to_rgb(hex_color)

                    xleft = xmin + j * pixel_width
                    xright = xmin + (j + 1) * pixel_width
                    ybottom = ymin + (raster_array.shape[0] - i - 1) * pixel_height
                    ytop = ymin + (raster_array.shape[0] - i) * pixel_height

                    poly = box(xleft, ybottom, xright, ytop)
                    polys.append({'geometry': poly, 'value': pixel_value, 'color': (r, g, b)})

                    used_colors.add((hex_color, pixel_value))

        gdf = gpd.GeoDataFrame(polys, crs=src.crs)
        print(f"Geopandas DataFrame head:\n{gdf.head()}")

        fig, ax = plt.subplots(1, 1, figsize=(12, 12))
        gdf.plot(ax=ax, facecolor=gdf['color'], edgecolor='none')
        ax.set_title(f"{output_base_name}")

        legend_title = input('Enter title for the legend:\n').strip()
        legend_elements = []
        for hex_color, pixel_value in used_colors:
            rgb_color = hex_to_rgb(hex_color)
            class_name = input(f"Enter name for Class {pixel_value}:\n").strip()
            class_name_df = pd.concat([class_name_df, pd.DataFrame({'Class_ID': [pixel_value], 'Class_Name': [class_name]})], ignore_index=True)  # Store user input
            legend_elements.append(Patch(facecolor=rgb_color, edgecolor='black', label=f'{class_name}'))

        ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left', title=legend_title, frameon=True)
        raster_plot_output = os.path.join(output_folder, f'{output_base_name}_raster_plot.png')
        plt.savefig(raster_plot_output, dpi=300)
        plt.tight_layout()
        plt.show()
        print(f"Raster plot saved as {raster_plot_output}")

        metrics_csv_path = os.path.join(output_folder, f'{output_base_name}_metrics_output.csv')
        metrics_df = pd.read_csv(metrics_csv_path, sep='\t')
        valid_columns = metrics_df.dropna(axis=1).columns[1:]

        # Dropdown widget
        dropdown = widgets.Dropdown(
            options=metrics_df.columns[1:],
            description= '',
            disabled=False,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )

        # Save button widget
        save_button = widgets.Button(
            description="Save",
            button_style='success',
            layout=widgets.Layout(width='100px')
        )

        # Associate save function with button
        save_button.on_click(save_current_figure)

        # Output widget for the plot
        out = widgets.interactive_output(plot_pie_chart_by_metric, {'metric': dropdown})
        display(widgets.VBox([dropdown, save_button, out]))

        print("Creating folium map...")
        m = folium.Map(
            location=[(bounds.top + bounds.bottom) / 2, (bounds.left + bounds.right) / 2],
            zoom_start=10, width='100%', height='100%'
        )
        folium.raster_layers.ImageOverlay(
            image=raster_array,
            bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
            colormap=lambda x: colormap.colors[x],
            opacity=0.6
        ).add_to(m)
        folium.LayerControl().add_to(m)
        html_output = os.path.join(output_folder, f'{output_base_name}_map_output.html')
        m.save(html_output)
        print(f'Map saved to {html_output}')

    except FileNotFoundError as fe:
        print(f"File not found error: {fe}")
        print(f"Please ensure the file path is correct and the file exists.")
    except ValueError as ve:
        print(f"Value error: {ve}")
        print(f"Please check the file type or content.")
    except IOError as ioe:
        print(f"IOError: Unable to read color data from {color_data}.")
        print(f"Please check the file format and contents of {color_data}.")
    except Exception as e:
        print(f"Error during processing: {e}")
        print("An unexpected error occurred. Please check your inputs and try again.")

# Function to check and create output folder if it doesn't exist
def check_create_output_folder(output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print(f"Created output folder: {output_folder}")

# Main execution starts here
if __name__ == "__main__":

    check_drive_mount()

    output_folder = input('Enter the path to the output folder where results will be saved:\n')
    check_create_output_folder(output_folder)

    input_raster = input('Enter the path to the input raster (.tif) file:\n')
    color_data = input('Enter the path to the color legend data (.xlsx or .csv) file:\n')
    output_base_name = input('Enter the base name for output files (without extension):\n')

    if not os.path.isfile(input_raster) or not input_raster.lower().endswith('.tif'):
        print(f"Error: No .tif file found at the specified path or incorrect file type: {input_raster}")
        exit()

    if not os.path.isfile(color_data) or not (color_data.lower().endswith('.xlsx') or color_data.lower().endswith('.csv')):
        print(f"Error: No .xlsx or .csv file found at the specified path or incorrect file type: {color_data}")
        exit()

    print(f'Input raster: {input_raster}')
    print(f'Color Data: {color_data}')
    print(f'Output folder: {output_folder}')
    print(f'Output base name: {output_base_name}')

    calculate_save_metrics_and_plot(input_raster, output_folder, output_base_name, color_data)