In [None]:
import subprocess

# Function to check and install necessary packages
def setup_environment():
    def check_install(package):
        try:
            __import__(package)  # Try to import the package
            print(f'{package} already installed.')
        except ImportError:
            print(f'{package} not installed. Installing {package}...')
            subprocess.check_call(['pip', 'install', package])  # Install the package if not found

    # List of necessary packages
    packages = ['pylandstats', 'openpyxl', 'rasterio', 'folium', 'pandas', 'numpy', 'geopandas', 'matplotlib',
                'shapely', 'IPython', 'ipywidgets']

    # Check and install each package
    for package in packages:
        check_install(package)

# Check and install necessary packages
setup_environment()

import os
import pandas as pd
import rasterio
import folium
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import box
from matplotlib.patches import Patch
from matplotlib.colors import ListedColormap
from IPython.display import display, HTML
import ipywidgets as widgets
import pylandstats as pls
from google.colab import drive

def check_drive_mount():
    if not os.path.exists('/content/drive/'):
        drive.mount('/content/drive/')
        print('Drive is not mounted. Mounted now.')
    else:
        print('Drive is already mounted.')

# Function to convert hexadecimal color code to RGB tuple
def hex_to_rgb(hex_color):
    # Remove any leading '#' and normalize to lower case
    hex_color = hex_color.lstrip('#').lower()

    # Ensure the hex color is in a valid format
    if not all(c in '0123456789abcdef' for c in hex_color) or len(hex_color) not in (3, 6):
        raise ValueError("Invalid hexadecimal color code")

    # Expand 3-digit to 6-digit hex if necessary (e.g., #abc to #aabbcc)
    if len(hex_color) == 3:
        hex_color = ''.join(c * 2 for c in hex_color)

    # Split into RGB channels
    r = int(hex_color[0:2], 16) / 255.0
    g = int(hex_color[2:4], 16) / 255.0
    b = int(hex_color[4:6], 16) / 255.0

    return (r, g, b)

# Function to calculate and save landscape metrics and plot the raster interactively
def calculate_save_metrics_and_plot(input_raster, output_folder, output_base_name, color_excel, nodata_value=0):
    try:
        # Load raster data and calculate metrics using pylandstats
        print("Loading raster data and calculating metrics...")
        ls = pls.Landscape(input_raster, nodata=nodata_value)
        metrics_df = ls.compute_class_metrics_df().to_csv(os.path.join(output_folder, f'{output_base_name}_metrics_output.csv'), sep='\t', encoding='utf-8')
        print(f"Metrics calculated and saved to {os.path.join(output_folder, f'{output_base_name}_metrics_output.csv')}")

        # Read color legend Excel file using pandas
        print("Reading color legend Excel file...")
        color_df = pd.read_excel(color_excel)
        color_map = {row['Class_ID']: row['Color'] for _, row in color_df.iterrows()}

        # Load raster data using rasterio
        print("Loading raster data using rasterio...")
        with rasterio.open(input_raster) as src:
            bounds = src.bounds
            raster_array = src.read(1)
            transform = src.transform

        # Create a custom colormap based on colors from the legend
        print("Creating custom colormap...")
        unique_classes = np.unique(raster_array)
        colormap = np.zeros((256, 4), dtype=int)
        for class_id, color in color_map.items():
            hex_color = color.lstrip('#')
            rgb = tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
            colormap[class_id] = (*rgb, 255)  # Add alpha channel

        colormap = ListedColormap([tuple(c / 255 for c in colormap[i]) for i in range(256)])

        # Plot raster using GeoPandas with the defined colormap
        print("Plotting raster with GeoPandas...")
        polys = []
        used_colors = set()  # To track used colors

        xmin, ymin, xmax, ymax = bounds
        pixel_width = (xmax - xmin) / raster_array.shape[1]
        pixel_height = (ymax - ymin) / raster_array.shape[0]

        for i in range(raster_array.shape[0]):
            for j in range(raster_array.shape[1]):
                pixel_value = raster_array[i, j]
                if pixel_value in color_df['Class_ID'].values:
                    color_row = color_df[color_df['Class_ID'] == pixel_value]
                    hex_color = color_row.iloc[0]['Color']
                    r, g, b = hex_to_rgb(hex_color)

                    xleft = xmin + j * pixel_width
                    xright = xmin + (j + 1) * pixel_width
                    ybottom = ymin + (raster_array.shape[0] - i) * pixel_height
                    ytop = ymin + (raster_array.shape[0] - i - 1) * pixel_height

                    poly = box(xleft, ybottom, xright, ytop)
                    polys.append({'geometry': poly, 'value': pixel_value, 'color': (r, g, b)})

                    used_colors.add((hex_color, pixel_value))  # Add used color and pixel value to set

        gdf = gpd.GeoDataFrame(polys, crs=src.crs)

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        pc = gdf.plot(ax=ax, facecolor=gdf['color'], edgecolor='none')

        # Set title
        ax.set_title(f"Raster Plot: {output_base_name}")

        # Prompt for custom legend title
        legend_title = input('Enter title for the legend:\n').strip()

        # Create legend with used colors
        legend_elements = []
        for hex_color, pixel_value in used_colors:
            rgb_color = hex_to_rgb(hex_color)
            class_name = input(f"Enter name for Class {pixel_value}:\n").strip()
            legend_elements.append(Patch(facecolor=rgb_color, edgecolor='black', label=f'{class_name}'))

        ax.legend(handles=legend_elements, loc='lower right', title=legend_title, frameon=True)

        # Save the raster plot as PNG with the specified title
        print(f"Plot Raster: {output_base_name}")
        raster_plot_output = os.path.join(output_folder, f'{output_base_name}_raster_plot.png')
        print("Saving raster plot as PNG...")
        plt.savefig(raster_plot_output, dpi=300)
        print(f"Raster plot saved as {raster_plot_output}")

        # Display the interactive map title and separator line
        print("\n=============================================================================================================================================================================")
        print("=============================================================================================================================================================================")
        display(HTML(f"<div style='font-size:30px; font-weight:bold;'>Raster Map: {output_base_name}</div>"))
        plt.tight_layout()
        plt.show()
        print("=============================================================================================================================================================================")
        print("=============================================================================================================================================================================\n")

        # Construct the path to the metrics CSV file
        metrics_csv_path = os.path.join(output_folder, f'{output_base_name}_metrics_output.csv')

        # Load the CSV file into a pandas DataFrame
        metrics_df = pd.read_csv(metrics_csv_path, sep='\t')

        # Drop NA Values
        valid_columns = metrics_df.dropna(axis=1)

        # Create dropdown for interactive plot
        dropdown = widgets.Dropdown(
            options=valid_columns.columns[1:],
            description='Metric:',
            disabled=False,
        )

        # Function to plot pie chart by selected metric
        def plot_pie_chart_by_metric(metric):
            values = metrics_df[metric].tolist()
            categories = metrics_df['class_val'].tolist()

            color_df = pd.read_excel(color_excel)
            color_map = {row['Class_ID']: row['Color'] for _, row in color_df.iterrows()}

            # Create a list of colors based on Class_ID
            colors = [color_map.get(cat, '#CCCCCC') for cat in categories]

            plt.figure(figsize=(8, 8))
            plt.pie(values, labels=categories, colors=colors, autopct='%1.1f%%', startangle=140)
            plt.title(f'Pie Chart for {metric}')
            plt.axis('equal')
            plt.show()

        # Interactive call to plot pie chart by selected metric
        out = widgets.interactive_output(plot_pie_chart_by_metric, {'metric': dropdown})
        display(HTML(f"<div style='font-size:30px; font-weight:bold;'>Displaying the Pie chart: {output_base_name}</div>"))
        display(dropdown, out)
        print("=============================================================================================================================================================================")
        print("=============================================================================================================================================================================\n")

        print("Creating folium map...")
        m = folium.Map(
            location=[(bounds.top + bounds.bottom) / 2, (bounds.left + bounds.right) / 2],
            zoom_start=10,
            width='50%',
            height='50%'
        )

        # Add raster data overlay to the map
        print("Adding raster data overlay to the map...")
        folium.raster_layers.ImageOverlay(
            image=raster_array,
            bounds=[[bounds.bottom, bounds.left], [bounds.top, bounds.right]],
            colormap=lambda x: colormap.colors[x],
            opacity=0.6
        ).add_to(m)

        # Add zoom control to the map
        print("Adding zoom control to the map...")
        folium.LayerControl().add_to(m)

        print("Map creation complete.")

        # Save the interactive map as HTML
        html_output = os.path.join(output_folder, f'{output_base_name}_map_output.html')
        print("Saving the map as HTML...")
        m.save(html_output)
        print(f'Map saved to {html_output}')

        print("\n=============================================================================================================================================================================")
        print("=============================================================================================================================================================================")
        display(HTML(f"<div style='font-size:30px; font-weight:bold;'>Displaying the interactive map: {output_base_name}</div>"))
        display(m)

    except FileNotFoundError as fe:
        print(f"File not found error: {fe}")
        print(f"Please ensure the file path is correct and the file exists.")
    except ValueError as ve:
        print(f"Value error: {ve}")
        print(f"Please check the file type or content.")
    except IOError as ioe:
        print(f"IOError: Unable to read color data from {color_excel}.")
        print(f"Please check the file format and contents of {color_excel}.")
    except Exception as e:
        print(f"Error during processing: {e}")
        print("An unexpected error occurred. Please check your inputs and try again.")

# Function to check and create output folder if it doesn't exist
def check_create_output_folder(output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print(f"Created output folder: {output_folder}")

# Main execution starts here
if __name__ == "__main__":

    # Check drive
    check_drive_mount()

    # Prompt the user to input the output folder path
    output_folder = input('Enter the path to the output folder where results will be saved:\n')
    check_create_output_folder(output_folder)

    # Prompt the user to input the file paths and output base name
    input_raster = input('Enter the path to the input raster (.tif) file:\n')
    color_excel = input('Enter the path to the color legend Excel (.xlsx) file:\n')
    output_base_name = input('Enter the base name for output files (without extension):\n')

    # Check if the provided files exist and have correct extensions
    if not os.path.isfile(input_raster) or not input_raster.lower().endswith('.tif'):
        print(f"Error: No .tif file found at the specified path or incorrect file type: {input_raster}")
        exit()

    if not os.path.isfile(color_excel) or not color_excel.lower().endswith('.xlsx'):
        print(f"Error: No .xlsx file found at the specified path or incorrect file type: {color_excel}")
        exit()

    print(f'Input raster: {input_raster}')
    print(f'Color Excel: {color_excel}')
    print(f'Output folder: {output_folder}')
    print(f'Output base name: {output_base_name}')

    # Call the function to perform calculations, save metrics, and plot raster
    calculate_save_metrics_and_plot(input_raster, output_folder, output_base_name, color_excel)

