In [17]:
import rasterio
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import colorsys

from geonate import raster, tools

### Define critical functions

In [18]:
# Create change matrix for all periods

def create_change_matrix(input): 
    # Create an empty dataframe
    change_matrix = pd.DataFrame(columns=['source', 'target', 'value',  'type'])

    # Run over each period (pairs of years)
    for x in range(0, len(input)-1):
        before_path = input[x]
        after_path = input[x+1]

        with rasterio.open(before_path) as img_before, rasterio.open(after_path) as img_after:
            if (img_before.count != 1) | (img_after.count != 1):
                raise ValueError('Images must be single band')
            else:
                # Read the data as a 1D array and mask no data values
                data_before = img_before.read(1)
                data_after = img_after.read(1)

                # Check if the metadata is the same
                if ((img_before.crs != img_after.crs) & (img_before.transform != img_after.transform) & (img_before.width != img_after.width) & (img_before.height != img_after.height) & (img_before.res != img_after.res)):        
                    raise ValueError('Images must have the same metadata')
                else:
                    # Adjust no data values
                    valid_mask = (data_after > 0) & (data_before > 0)   
                    data_before = data_before[valid_mask]
                    data_after =   data_after[valid_mask]

                    # stack the data together as pairs [before, after]
                    data_stack = np.stack((data_before, data_after), axis=1)

                    # count the unique pairs in row > axis= 0 is row
                    unique, counts = np.unique(data_stack, axis=0, return_counts=True)

                    # Calculate area in hectares based on image resolution
                    areas = counts * img_before.res[0] * img_before.res[1] / 10000
                    
                    # Create a dataframe of transitions
                    transition = pd.DataFrame(unique, columns=['source', 'target'])
                    transition['value'] = areas
                    transition['type'] = years[x+1] # the year of the end of period
                    
                    # Combine data
                    change_matrix = pd.concat([change_matrix, transition], axis=0).reset_index(drop=True)
    
    return change_matrix


# Create link data for Sankey
def create_sankey_links(transitions, node_indices, offset):
    
    # Function to modify color brightness
    def adjust_brightness(hex_color, factor=1):
        """Adjust the brightness of a hex color."""
        try:
            # 
            rgb = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5))
            h, l, s = colorsys.rgb_to_hls(rgb[0]/255, rgb[1]/255, rgb[2]/255)
            l = min(1, max(0, l * factor))  # Adjust lightness
            r, g, b = colorsys.hls_to_rgb(h, l, s)
            return f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}'
        except Exception as e:
            print(f"Error adjusting brightness for color: {hex_color}")
            raise e
        
    # Function to calculate link colors
    def get_link_color(source, target):
        try:
            base_color = landuse_colors[source]
            if source == target:
                return base_color  # Same-to-same retains the source color
            else:
                # Slightly adjust the brightness for different target types
                return adjust_brightness(base_color, factor= 1.15)
        except KeyError as e:
            print(f"Invalid land use type in source/target: {source}, {target}")
            raise e

    # ********************************************************* #   
    # Main function
    links = {
        'source': [],
        'target': [],
        'value': [],
        'color': []  # Link colors
    }
    for _, row in transitions.iterrows():
        # Ensure the values are valid indices
        from_idx = int(row['source']) - 1
        to_idx = int(row['target']) - 1

        if from_idx not in range(len(nodes)) or to_idx not in range(len(nodes)):
            print(f"Skipping invalid transition: source={row['source']}, target={row['target']}")
            continue

        source = node_indices[nodes[from_idx]] + offset
        target = node_indices[nodes[to_idx]] + offset + len(nodes)
        links['source'].append(source)
        links['target'].append(target)
        links['value'].append(row['value'])
        
        # Debug: Ensure source_label and target_label exist
        try:
            source_label = nodes[from_idx]
            target_label = nodes[to_idx]
        except IndexError:
            print(f"IndexError: from_idx={from_idx}, to_idx={to_idx}")
            raise

        # Assign color to the link based on source and target land use
        links['color'].append(get_link_color(source_label, target_label))
    
    return links

### Example

#### Define and calculate paramters

In [19]:
input = tools.list_files('data/', '*tif')
input.sort()

years = [1990, 2000, 2010, 2020]

# Create change matrix
change_matrix = create_change_matrix(input)

# Calculate transitions for each period
transitions_1990_2000 = change_matrix[change_matrix['type'] == 2000]
transitions_2000_2010 = change_matrix[change_matrix['type'] == 2010]
transitions_2010_2020 = change_matrix[change_matrix['type'] == 2020]

# Define nodes of land use, values, and colors
nodes = ['BAR', 'URBN', 'FOR', 'AGR', 'WAT']
values = np.sort(np.unique(change_matrix['source']))-1
colors = [
    '#ebab34', # Barren
    '#f37f6b', # Urban
    '#4c7300', # Forest
    '#faef7d', # Agriculture
    '#74b3ff' # Water
]

node_indices = {node:value for node, value in zip(nodes, values)}
landuse_colors = {node:color for node, color in zip(nodes, colors)}    

# Create Sankey links
links_1990_2000 = create_sankey_links(transitions_1990_2000, node_indices, 0)
links_2000_2010 = create_sankey_links(transitions_2000_2010, node_indices, len(nodes)*1)
links_2010_2020 = create_sankey_links(transitions_2010_2020, node_indices, len(nodes)*2)

# Combine links
links = {
    'source': links_1990_2000['source'] + links_2000_2010['source'] + links_2010_2020['source'],
    'target': links_1990_2000['target'] + links_2000_2010['target'] + links_2010_2020['target'],
    'value': links_1990_2000['value'] + links_2000_2010['value'] + links_2010_2020['value'],
    'color': links_1990_2000['color'] + links_2000_2010['color'] + links_2010_2020['color']
}

# Define consistent node colors based on land use
node_colors = [landuse_colors[node] for node in nodes] * len(years)  # Repeat colors for each year



The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.



#### Create Sankey plot

In [20]:
# Create Sankey plot
fig = go.Figure(go.Sankey(
    node=dict(
        pad= 20, # Size of vertical space between nodes
        thickness= 30, # Size of node
        line=dict(color="gray", width=1), # Color and outline thickness of outline node
        label=nodes * len(years),  # Repeat labels for each time period
        color=node_colors  # Apply the color scheme
    ),
    link=dict(
        source=links['source'],
        target=links['target'],
        value=links['value'],
        color=links['color']  # Apply link colors
    )
))

# Custom layout
fig.update_layout(
    title_text="Land Use Change Sankey Diagram", 
        font=dict(family="Times New Roman", size=20, color="black"),
        title_font=dict(family="Times New Roman", size=25, color="black"),
        plot_bgcolor="white",  # Optional: Set background to white)
        height= 800,  # Set height of the plot (in pixels)
        width= 1200,  # Set width of the plot (in pixels)
        annotations=[
            dict(
                text="1990", x=0,  y=1.05,  showarrow=False,
                font=dict(family="Times New Roman", size=20, color="black")
            ),
            dict(
                text="2000", x=0.315,  y=1.05,  showarrow=False,
                font=dict(family="Times New Roman", size=20, color="black")
            ),
            # Add annotation for the year 2010
            dict(
                text="2010", x=0.62,  y=1.05,  showarrow=False,
                font=dict(family="Times New Roman", size=20, color="black")
            ),
            # Add annotation for the year 2020
            dict(
                text="2020", x=1.0,  y=1.05, showarrow=False,
                font=dict(family="Times New Roman", size=20, color="black")
            ),
        ]

)
fig.show()

# Export output as html file - interactive image
fig.write_html("Sankey_lulcc.html")
