# Data visualization

## Libraries import

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from sklearn.cluster import KMeans
import contextily as ctx
import pyproj
import geopandas as gpd
import os
import sys
import pickle
import sklearn.metrics
import torch

## Paths creation

In [None]:
notebook_directory = os.path.dirname(os.path.abspath('__file__'))
framework_directory = os.path.abspath(os.path.join(notebook_directory, '..'))

sys.path.append(framework_directory)

print(framework_directory)

## Time–performance characteristic

In [None]:
model = ["MLP", "RFC", "XGB", "TNC", "FTT"] * 2

mlp_training_accuracies = [97.82, 97.88, 97.87, 97.73, 97.82, 97.72, 97.80, 97.64, 97.71, 97.72]
rfc_training_accuracies = [99.26, 99.26, 99.26, 99.26, 99.26, 99.26, 99.26, 99.26, 99.26, 99.26]
xgb_training_accuracies = [99.26, 99.26, 99.23, 99.25, 99.24, 99.26, 99.24, 99.25, 99.26, 99.26]
tnc_training_accuracies = [86.95, 86.23, 86.40, 86.71, 85.94, 86.36, 86.01, 85.90, 86.78, 86.93]
ftt_training_accuracies = [96.14, 96.26, 96.16, 96.14, 96.54, 96.52, 96.00, 96.18, 96.16, 96.00]
accuracy = [np.mean(mlp_training_accuracies), np.mean(rfc_training_accuracies), np.mean(xgb_training_accuracies), np.mean(tnc_training_accuracies), np.mean(ftt_training_accuracies)] + [37.42, 38.24, 36.75, 36.59, 35.71]

mlp_training_times = [9193.57, 9111.05, 9277.34, 9226.49, 9377.13, 9282.78, 9323.88, 8691.53, 9244.66, 9173.71]
rfc_training_times = [2283.03, 2280.12, 2278.89, 2337.44, 2275.51, 2332.83, 2312.09, 2290.51, 2275.04, 2293.37]
xgb_training_times = [53811.31, 54621.94, 53881.44, 54926.00, 54850.32, 54543.94, 54685.12, 54283.54, 54034.83, 54879.43]
tnc_training_times = [7787.45, 7364.54, 6167.79, 10989.60, 7021.01, 6635.65, 5719.83, 7344.03, 8014.12, 10486.35]
ftt_training_times = [11008.77, 11461.27, 11393.96, 11269.92, 11260.76, 10791.19, 11442.52, 11144.11, 11126.66, 10478.45]
time = [np.mean(mlp_training_times), np.mean(rfc_training_times), np.mean(xgb_training_times), np.mean(tnc_training_times), np.mean(ftt_training_times)] + [0.42, 0.62, 28.24, 4.43, 2.30]

standard_deviation = [np.std(mlp_training_accuracies), np.std(rfc_training_accuracies), np.std(xgb_training_accuracies), np.std(tnc_training_accuracies), np.std(ftt_training_accuracies)]
standard_deviation = [((x - min(standard_deviation)) / (max(standard_deviation) - min(standard_deviation))) + 1 for x in standard_deviation]
number_parameters = [2172628, 200, 200, 1315210, 1239372]
number_parameters = [((x - min(number_parameters)) / (max(number_parameters) - min(number_parameters))) + 1 for x in number_parameters]
size = standard_deviation + number_parameters

pipeline = ["Training"] * 5 + ["Prediction"] * 5

df = pd.DataFrame(list(zip(model, time, accuracy, size, pipeline)),
               columns =['Model', 'Time', 'Accuracy', 'Size', 'Pipeline'])

fig = px.scatter(df, x="Time", y="Accuracy", color="Model",
                 size='Size', facet_col='Pipeline',
                 text="Model", labels={"Time": "Time (s)", "Accuracy": "Accuracy (%)"})

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

fig.update_traces(textposition='bottom center',
                  marker=dict(line=dict(width=2,
                                        color='Black')))

fig.update_layout(
    height=500,
    width=2000,
    showlegend=False, 
    font=dict(size=18)
)

fig.update_xaxes(matches=None, )
fig.update_yaxes(matches=None, showticklabels=True)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/time–performance_characteristic.pdf'))

## Accuracy

In [None]:
fig = go.Figure()

epochs = list(range(1, 101))
values = [88.8699, 89.9944, 90.9012, 91.2493, 91.753, 91.9377, 92.0493, 92.2255, 92.2432, 92.4019, 92.3454, 92.5178, 92.5326, 92.8125, 92.7535, 92.8742, 92.8202, 92.8393, 92.8905, 92.7829, 92.9184, 92.9879, 92.9368, 92.9946, 93.0307, 93.0823, 93.106, 93.1555, 93.1434, 93.0361, 93.1622, 93.0456, 93.1882, 93.1678, 93.2963, 93.1519, 93.1102, 93.1803, 93.2212, 93.1467, 93.2695, 95.4257, 95.9578, 96.2402, 96.3902, 96.5103, 96.6018, 96.6814, 96.7414, 96.8154, 96.8526, 96.9051, 96.9651, 96.9938, 97.0473, 97.0858, 97.1216, 97.1505, 97.1843, 97.2018, 97.2446, 97.2592, 97.2745, 97.3146, 97.3346, 97.3554, 97.3842, 97.4067, 97.4106, 97.433, 97.4537, 97.4652, 97.4634, 97.5021, 97.5126, 97.5373, 97.5499, 97.5478, 97.5692, 97.589, 97.5984, 97.6089, 97.6309, 97.6377, 97.6367, 97.6603, 97.6662, 97.6578, 97.6898, 97.6819, 97.702, 97.6992, 97.7158, 97.7294, 97.7484, 97.7416, 97.7403, 97.7713, 97.7668, 97.7726]

fig.add_trace(go.Scatter(
    x=epochs,
    y=values
))

fig.add_vrect(x0=epochs[0], x1=epochs[41],
    label=dict(text="LR: 0.001", textposition="top center", font=dict(size=15, family="Times New Roman"),),
    fillcolor="#1f77b4", opacity=0.25, line_width=0)

fig.add_vrect(x0=epochs[41], x1=epochs[-1],
    label=dict(text="LR: 0.0001", textposition="top center", font=dict(size=15, family="Times New Roman"),),
    fillcolor="#ff7f0e", opacity=0.25, line_width=0)

fig.add_annotation(x=epochs[-1], y=values[-1],
            text="Early stopping did not occur",
            showarrow=True,
            arrowsize=3, 
            arrowwidth=1,
            arrowhead=1,
            ax=-50,
            ay=50)

fig.add_vline(x=epochs[41], line_width=3, line_dash="dash", line_color="black")

fig.update_layout(
    height=500,
    width=1000,
    showlegend=False,
    xaxis_title='Epochs',
    yaxis_title='Accuracy (%)'
)

fig.update_xaxes(range=[epochs[0], epochs[-1]])

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/accuracy.pdf'))

## Species per plot

In [None]:
eva_species = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_species.csv'))
eva_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'))

counter_less_10 = 0
counter_11_30 = 0
counter_more_31 = 0

counter_31_40 = 0
counter_41_50 = 0
counter_51_60 = 0
counter_more_61 = 0

count_species = eva_species.groupby('PlotObservationID').count()
for plot_id in eva_header['PlotObservationID'].to_list():
    if count_species.loc[plot_id]['Matched concept'] <= 10:
        counter_less_10 += 1
    elif count_species.loc[plot_id]['Matched concept'] >= 11 and count_species.loc[plot_id]['Matched concept'] <= 30:
        counter_11_30 +=1
    else:
        counter_more_31 += 1
        if count_species.loc[plot_id]['Matched concept'] >= 31 and count_species.loc[plot_id]['Matched concept'] <= 40:
            counter_31_40 +=1
        elif count_species.loc[plot_id]['Matched concept'] >= 41 and count_species.loc[plot_id]['Matched concept'] <= 50:
            counter_41_50 += 1
        elif count_species.loc[plot_id]['Matched concept'] >= 51 and count_species.loc[plot_id]['Matched concept'] <= 60:
            counter_51_60 += 1  
        else:
            counter_more_61 += 1

counter_31_40 = round(counter_31_40 / counter_more_31, 2)  
counter_41_50 = round(counter_41_50 / counter_more_31, 2)
counter_51_60 = round(counter_51_60 / counter_more_31, 2)
counter_more_61 = round(counter_more_61 / counter_more_31, 2)

counter_less_10 = round(counter_less_10 / len(eva_header), 2)
counter_11_30 = round(counter_11_30 / len(eva_header), 2)
counter_more_31 = round(counter_more_31 / len(eva_header), 2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
fig.subplots_adjust(wspace=0)

ax1.set_title('All plots')
overall_ratios = [counter_more_31, counter_11_30, counter_less_10]
labels = ['More than 31', 'Between 11 and 30', 'Less than 10']
explode = [0.1, 0, 0]
angle = -180 * overall_ratios[0]
wedges, *_ = ax1.pie(overall_ratios, autopct='%1.1f%%', startangle=angle,
                     labels=labels, explode=explode)

age_ratios = [counter_31_40, counter_41_50, counter_51_60, counter_more_61]
age_labels = ['Under 40', '41-50', '51-60', 'Over 61']
bottom = 1
width = .2

for j, (height, label) in enumerate(reversed([*zip(age_ratios, age_labels)])):
    bottom -= height
    bc = ax2.bar(0, height, width, bottom=bottom, color='C0', label=label,
                 alpha=0.1 + 0.25 * j)
    ax2.bar_label(bc, labels=[f"{height:.0%}"], label_type='center')

ax2.set_title('Manifold plots')
ax2.legend()
ax2.axis('off')
ax2.set_xlim(- 2.5 * width, 2.5 * width)

theta1, theta2 = wedges[0].theta1, wedges[0].theta2
center, r = wedges[0].center, wedges[0].r
bar_height = sum(age_ratios)

x = r * np.cos(np.pi / 180 * theta2) + center[0]
y = r * np.sin(np.pi / 180 * theta2) + center[1]
con = ConnectionPatch(xyA=(-width / 2, bar_height), coordsA=ax2.transData,
                      xyB=(x, y), coordsB=ax1.transData)
con.set_color([0, 0, 0])
con.set_linewidth(4)
ax2.add_artist(con)

x = r * np.cos(np.pi / 180 * theta1) + center[0]
y = r * np.sin(np.pi / 180 * theta1) + center[1]
con = ConnectionPatch(xyA=(-width / 2, 0), coordsA=ax2.transData,
                      xyB=(x, y), coordsB=ax1.transData)
con.set_color([0, 0, 0])
ax2.add_artist(con)
con.set_linewidth(4)

plt.savefig(os.path.join(framework_directory, 'Images/species_per_plot.pdf'))

plt.show()

## Plots by habitat group

In [None]:
eva_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'), usecols=['Expert System'])

eva_header['Habitat type'] = eva_header['Expert System'].apply(lambda x: x[:3] if x.startswith('MA2') else x[0])
eva_header = eva_header.groupby(['Habitat type', 'Expert System']).size().reset_index(name='Count')
eva_header['Habitat type count'] = eva_header.groupby('Habitat type')['Count'].transform('sum')
eva_header = eva_header.sort_values(by=['Habitat type count', 'Count'], ascending=[False, False])

colors = []
color_scales = [px.colors.sequential.Reds, px.colors.sequential.Blues, px.colors.sequential.Greens, px.colors.sequential.Oranges,
                px.colors.sequential.Purples, px.colors.sequential.Greys, px.colors.sequential.YlOrBr, px.colors.sequential.RdPu]

number_of_habitats = eva_header['Habitat type'].value_counts(sort=False).values

for i in range(len(number_of_habitats)):
    color_scale = color_scales[i][2:] * 10
    colors.extend(color_scale[:number_of_habitats[i]])

fig = px.bar(eva_header, x='Habitat type', y='Count', color='Expert System',
             color_discrete_sequence=colors)

fig.update_layout(
    height=500,
    width=1000,
    showlegend=False
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/habitat_type.pdf'))

## Observation of species

In [None]:
eva_species = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_species.csv'), usecols=['Matched concept'])

eva_species = eva_species.groupby('Matched concept').size().reset_index(name='Count')
eva_species = eva_species.sort_values(by='Count', ascending=False)

fig = px.line(eva_species, x="Matched concept", y="Count", log_y=True)

fig.update_layout(
    height=500,
    width=1000,
    xaxis_title='Species',
    xaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 1000
    )
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/species.pdf'))

## Ablation study

In [None]:
x = [1, 2, 3, 4, 5, 6, 7]
x_labels = ["All", "Species + Location", "Species + Environmental", "Location + Environmental", "Species", "Location", "Environmental"]
x_rev = x[::-1]

# Line 1
y1 = [88.7437, 88.0000, 88.7580, 21.4009, 87.8343, 12.7983, 19.8927]
std_1 = [0.3514, 0.3693, 0.3668, 0.8639, 0.3358, 1.0368, 0.7909]
y1_upper = [x + y for x, y in zip(y1, std_1)]
y1_lower = [x - y for x, y in zip(y1, std_1)]
y1_lower = y1_lower[::-1]

# Line 2
y2 = [79.3921, 79.6920, 79.4754, 24.3493, 79.0983, 22.9658, 15.1615]
std_2 = [0.4478, 0.4765, 0.4571, 0.7567, 0.4757, 0.6759, 0.6180]
y2_upper = [x + y for x, y in zip(y2, std_2)]
y2_lower = [x - y for x, y in zip(y2, std_2)]
y2_lower = y2_lower[::-1]

# Line 3
y3 = [86.8039, 86.6270, 86.8484, 26.5687, 86.0186, 24.0786, 19.4963]
std_3 = [0.4524, 0.4254, 0.4598, 0.9531, 0.4686, 0.6943, 0.9863]
y3_upper = [x + y for x, y in zip(y3, std_3)]
y3_lower = [x - y for x, y in zip(y3, std_3)]
y3_lower = y3_lower[::-1]

# Line 4
y4 = [80.2178, 79.3296, 80.2858, 21.4163, 79.2267, 16.7509, 19.5843]
std_4 = [0.5378, 1.0547, 0.5652, 1.2249, 0.5257, 0.8521, 1.0226]
y4_upper = [x + y for x, y in zip(y4, std_4)]
y4_lower = [x - y for x, y in zip(y4, std_4)]
y4_lower = y4_lower[::-1]

# Line 5
y5 = [86.9774, 86.1416, 87.0251, 18.2139, 86.0481, 13.3969, 17.8147]
std_5 = [0.3791, 0.3617, 0.3547, 0.7813, 0.3298, 1.4059, 0.8566]
y5_upper = [x + y for x, y in zip(y5, std_5)]
y5_lower = [x - y for x, y in zip(y5, std_5)]
y5_lower = y5_lower[::-1]


fig = go.Figure()

fig.add_trace(go.Scatter(
    x=x+x_rev,
    y=y1_upper+y1_lower,
    fill='toself',
    fillcolor='rgba(31, 119, 180, 0.2)',
    line_color='rgba(255,255,255,0)',
))
fig.add_trace(go.Scatter(
    x=x+x_rev,
    y=y2_upper+y2_lower,
    fill='toself',
    fillcolor='rgba(255, 127, 14, 0.2)',
    line_color='rgba(255,255,255,0)',
))
fig.add_trace(go.Scatter(
    x=x+x_rev,
    y=y3_upper+y3_lower,
    fill='toself',
    fillcolor='rgba(44, 160, 44, 0.2)',
    line_color='rgba(255,255,255,0)',
))
fig.add_trace(go.Scatter(
    x=x+x_rev,
    y=y4_upper+y4_lower,
    fill='toself',
    fillcolor='rgba(214, 39, 40, 0.2)',
    line_color='rgba(255,255,255,0)',
))
fig.add_trace(go.Scatter(
    x=x+x_rev,
    y=y5_upper+y5_lower,
    fill='toself',
    fillcolor='rgba(148, 103, 189, 0.2)',
    line_color='rgba(255,255,255,0)',
))
fig.add_trace(go.Scatter(
    x=x, y=y1,
    line_color='rgb(31, 119, 180)',
))
fig.add_trace(go.Scatter(
    x=x, y=y2,
    line_color='rgb(255, 127, 14)',
))
fig.add_trace(go.Scatter(
    x=x, y=y3,
    line_color='rgb(44, 160, 44)',
))
fig.add_trace(go.Scatter(
    x=x, y=y4,
    line_color='rgb(214, 39, 40)',
))
fig.add_trace(go.Scatter(
    x=x, y=y5,
    line_color='rgb(148, 103, 189)',
))

fig.update_layout(
    xaxis=dict(
        tickmode='array',
        tickvals=x,
        ticktext=x_labels,
    ),
    width=1000,
    height=1000,
    showlegend=False
)

fig.update_xaxes(range=[0.5, 7.5], title_text='Used features')
fig.update_yaxes(title_text='Accuracy (%)')

fig.add_annotation(
    x=0.75, y=y1[0] + 1, text='MLP',
    showarrow=False, font=dict(color='rgb(31, 119, 180)', size=20)
)
fig.add_annotation(
    x=0.75, y=y2[0] - 1, text='RFC',
    showarrow=False, font=dict(color='rgb(255, 127, 14)', size=20)
)
fig.add_annotation(
    x=0.75, y=y3[0] - 1, text='XGB',
    showarrow=False, font=dict(color='rgb(44, 160, 44)', size=20)
)
fig.add_annotation(
    x=0.75, y=y4[0] + 1, text='TNC',
    showarrow=False, font=dict(color='rgb(214, 39, 40)', size=20)
)
fig.add_annotation(
    x=0.75, y=y5[0] + 1, text='FTT',
    showarrow=False, font=dict(color='rgb(148, 103, 189)', size=20)
)

fig.update_traces(mode='lines')

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/ablation_study.pdf'))

## Distribution EVA

In [None]:
eva_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'))

data = eva_header[['Latitude', 'Longitude']]
kmeans = KMeans(n_clusters=50, n_init=10)
kmeans.fit(data)
eva_header['Cluster'] = kmeans.labels_

grouped_eva_header = eva_header.groupby('Cluster').agg({'Latitude': 'mean', 'Longitude': 'mean', 'Cluster': 'size'})
grouped_eva_header.columns = ['Latitude', 'Longitude', 'Size']
grouped_eva_header.reset_index(inplace=True)
grouped_eva_header.sort_values('Cluster', inplace=True)

transformer = pyproj.Transformer.from_crs("epsg:4326", "epsg:3857", always_xy=True)

grouped_eva_header['Longitude'], grouped_eva_header['Latitude'] = transformer.transform(
    grouped_eva_header['Longitude'].values,
    grouped_eva_header['Latitude'].values
)

fig, ax = plt.subplots(figsize=(20, 10))

scatter = ax.scatter(
    grouped_eva_header['Longitude'],
    grouped_eva_header['Latitude'],
    s=grouped_eva_header['Size'] / 10,
    alpha=0.7,
    c='lime',
)

margin = 0.05
x_min, x_max = grouped_eva_header['Longitude'].min(), grouped_eva_header['Longitude'].max()
x_range = x_max - x_min
ax.set_xlim(x_min - margin * x_range, x_max + margin * x_range)

y_min, y_max = grouped_eva_header['Latitude'].min(), grouped_eva_header['Latitude'].max()
y_range = x_range + 2 * margin * x_range
ax.set_ylim(((y_min + y_max) / 2) - (y_range / 4) - 1000000, ((y_min + y_max) / 2) + (y_range / 4) - 1000000)

ax.axis('off')

scatter.set_clip_on(True)

for x, y, size in zip(grouped_eva_header['Longitude'], grouped_eva_header['Latitude'], grouped_eva_header['Size']):
    text = plt.text(x, y, str(size), ha='center', va='center', fontsize=6)
    text.set_clip_on(True)

ctx.add_basemap(plt.gca(), crs='EPSG:3857', source=ctx.providers.Esri.WorldImagery, attribution='')

plt.savefig(os.path.join(framework_directory, 'Images/distribution_eva.pdf'), dpi=300, bbox_inches='tight')

plt.show()

## Distribution NPMS

In [None]:
test_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/test_header.csv'))

transformer = pyproj.Transformer.from_crs("epsg:4326", "epsg:3857", always_xy=True)

test_header['Longitude'], test_header['Latitude'] = transformer.transform(
    test_header['Longitude'].values,
    test_header['Latitude'].values
)

fig, ax = plt.subplots(figsize=(10, 10))

scatter = ax.scatter(
    test_header['Longitude'], 
    test_header['Latitude'], 
    c='red',               
    s=3
)

margin = 0.01  # Adjust the margin as needed
y_min, y_max = test_header['Latitude'].min(), test_header['Latitude'].max()
y_range = y_max - y_min
ax.set_ylim(y_min - margin * y_range, y_max + margin * y_range)

x_min, x_max = test_header['Longitude'].min(), test_header['Longitude'].max()
x_range = y_range + 2 * margin * y_range  # Adjust the calculation to include the margin
ax.set_xlim(((x_min + x_max) / 2) - (x_range / 2), ((x_min + x_max) / 2) + (x_range / 2))

ax.axis('off')

scatter.set_clip_on(True)

ctx.add_basemap(ax, crs='EPSG:3857', source=ctx.providers.OpenStreetMap.Mapnik, attribution='')

plt.savefig(os.path.join(framework_directory, 'Images/distribution_npms.pdf'), dpi=300, bbox_inches='tight')  # Adjust the filename and dpi as needed

plt.show()

## Split assignment

In [None]:
eva_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'))

fig, ax = plt.subplots(figsize=(20, 10))

colors = px.colors.qualitative.Plotly
legend_labels = []

for fold, group in eva_header.groupby('Fold'):
    scatter = ax.scatter(
        group['Longitude'],
        group['Latitude'],
        c=colors[fold],
        s=10,
        linewidths=0.1,
        edgecolors='black',
        label=f'Fold {fold}'
    )
    legend_labels.append(f'Fold {fold}')

coordinates_montpellier = (3.876716, 43.610769)
x_min = coordinates_montpellier[0] - 0.5
x_max = coordinates_montpellier[0] + 0.5
y_min = coordinates_montpellier[1] - 0.25
y_max = coordinates_montpellier[1] + 0.25

ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

ax.axis('off')

for x in np.arange(0, 80, 0.0897222222224867):
    ax.axvline(x, color='black', linestyle='--', linewidth=0.5)

for y in np.arange(0, 80, 0.0897222222224867):
    ax.axhline(y, color='black', linestyle='--', linewidth=0.5)

ctx.add_basemap(ax, crs='EPSG:4326', source=ctx.providers.OpenStreetMap.Mapnik, attribution='')


legend = ax.legend(
    title='Folds',
    labels=legend_labels,
    loc='upper right',
    fontsize='medium',
    frameon=True
)
legend.get_frame().set_edgecolor('black')

plt.savefig(os.path.join(framework_directory, 'Images/split_assignment.pdf'), dpi=300, bbox_inches='tight')

plt.show()

## Threatened plots

In [None]:
habitats = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'), usecols=['Expert System'])
red_list = pd.read_excel(os.path.join(framework_directory, 'Datasets/red_list_habitats.xlsx'), sheet_name='Terrestrial cross to EUNIS 2021', usecols=['Overall category EU28+', 'EUNIS 2019/2021 code'])

replacement_dict = {
    'Least concern': 'Least Concern',
    'Near threatened': 'Near Threatened',
    'Critically EndangeredR': 'Critically Endangered'
}
category_priority = ['Data Deficient', 'Least Concern', 'Near Threatened', 'Vulnerable', 'Endangered', 'Critically Endangered']
category_priority_map = {category: priority for priority, category in enumerate(category_priority)}

red_list = red_list.dropna()
red_list = red_list.rename(columns={'Overall category EU28+':'Category', 'EUNIS 2019/2021 code': 'EUNIS'})
red_list['Category'] = red_list['Category'].replace(replacement_dict)
red_list = red_list.drop_duplicates()
red_list['Category'] = red_list['Category'].astype(pd.CategoricalDtype(categories=category_priority, ordered=True))
red_list = red_list.sort_values(by=['EUNIS', 'Category'], ascending=[True, False])
red_list = red_list.drop_duplicates(subset='EUNIS', keep='first')
red_list.reset_index(drop=True, inplace=True)

eunis_to_category = red_list.set_index('EUNIS')['Category'].to_dict()
habitats['Category'] = habitats['Expert System'].apply(lambda x: eunis_to_category.get(x, 'Data Deficient'))

habitats['Group'] = habitats['Category'].apply(lambda x: 'Incomplete Category' if x == 'Data Deficient' else 
                                               'Stable Category' if x in ['Least Concern', 'Near Threatened'] else
                                               'Threatened Category' if x in ['Vulnerable', 'Endangered', 'Critically Endangered'] else
                                               'Unknown Category')

habitats_counts = habitats.groupby(['Group', 'Category']).size().reset_index(name='Count')
habitats_counts['Parent'] = ' '

total_counts = habitats_counts['Count'].sum()
habitats_counts['Percentage'] = (habitats_counts['Count'] / total_counts) * 100

fig = px.sunburst(habitats_counts, path=['Parent', 'Group', 'Category'], values='Count', color='Group',
                 color_discrete_map={'(?)':'white', 'Stable Category':'#00CC96', 'Incomplete Category':'#636EFA', 'Threatened Category':'#EF553B'}
                 )

fig.update_traces(textinfo="label+percent parent")

fig.update_layout(
    height=500,
    width=1000
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/threatened_plots.pdf'))

## Date range

In [None]:
eva_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/eva_header.csv'))
original_header = pd.read_csv(os.path.join(framework_directory, 'Datasets/EVA/eva_header.csv'), delimiter='\t', usecols=['PlotObservationID', 'Date of recording'])

original_header['Year'] = original_header['Date of recording'].fillna('0')
original_header = original_header.drop(['Date of recording'], axis=1)
original_header['Year'] = original_header['Year'].str.replace(':', '', regex=True).apply(lambda x: int(x[-4:]))

eva_header = eva_header.merge(original_header[['PlotObservationID', 'Year']], on='PlotObservationID', how='left')

count_eva_years = eva_header['Year'].value_counts().to_frame(name="Occurrences").rename(index={0: '?'})

count_eva_years = count_eva_years.drop(['?'])

list_of_years = count_eva_years.index.to_list()

fig = px.bar(
    count_eva_years,
    y='Occurrences',
    x=count_eva_years.index,
    text='Occurrences',
    color='Occurrences',
    labels={
        "Occurrences": "Number of plots",
        "index": "Year"
    }
)

fig.update_traces(texttemplate='%{text:.2s}', textposition='outside')
fig.update_layout(uniformtext_minsize=8, uniformtext_mode='hide')
fig.update_layout(height=500, width=1000)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/date_range.pdf'))

## Performance by region

In [None]:
united_kingdom_regions_columns = ['rgn_name', 'geometry']

united_kingdom_regions = gpd.read_file(os.path.join(framework_directory, 'Datasets/united_kingdom_regions.shp'))
united_kingdom_regions = united_kingdom_regions[united_kingdom_regions_columns]
united_kingdom_regions = united_kingdom_regions.rename(columns={'rgn_name': 'region'})
united_kingdom_regions['region'] = united_kingdom_regions['region'].apply(lambda x: ''.join(char for char in x if char not in ["'", "[", "]"]))
united_kingdom_regions = united_kingdom_regions.sort_values(by='region').reset_index(drop=True)
united_kingdom_regions['accuracy (%)'] = [35.98, 41.40, 87.69, 32.88, 31.79, 20.92, 33.33, 43.44, 38.77, 19.25, 47.91, 40.45]
united_kingdom_regions.index = united_kingdom_regions['region']
united_kingdom_regions = united_kingdom_regions.drop(['region'], axis=1)
united_kingdom_regions = united_kingdom_regions.to_crs(crs=3857)

fig, ax = plt.subplots(figsize=(10, 10))

united_kingdom_regions.plot(column="accuracy (%)", cmap="viridis", linewidth=0.8, ax=ax, legend=True)

ax.axis('off')

ctx.add_basemap(ax, crs='EPSG:3857', source=ctx.providers.OpenStreetMap.Mapnik, attribution='')

cax = plt.gcf().get_axes()[1]
cax.set_title('Accuracy (%)', fontsize=12)

plt.savefig(os.path.join(framework_directory, 'Images/performance_by_region.pdf'), dpi=300, bbox_inches='tight')

plt.show()

## Confusion matrix

In [None]:
with open(os.path.join(framework_directory, 'Experiments/predictions.pkl'), 'rb') as file:
    predictions = pickle.load(file)
target_values = np.load(os.path.join(framework_directory, 'Data/target_values.npy'))
split_assignments = np.load(os.path.join(framework_directory, 'Data/split_assignments.npy'))
with open(os.path.join(framework_directory, 'Data/le_header.pkl'), 'rb') as f:
    le_header = pickle.load(f)

target_values = np.concatenate([target_values[split_assignments == i] for i in range(10)])
target_values = le_header.inverse_transform(target_values)
target_values = np.asarray([target_value[:-1] for target_value in target_values])

predictions = np.concatenate(predictions)
predictions = le_header.inverse_transform(predictions)
predictions = np.asarray([prediction[:-1] for prediction in predictions])

confusion_matrix = sklearn.metrics.confusion_matrix(target_values, predictions, normalize="true")

fig, ax = plt.subplots(figsize=(12,12))

disp = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=[habitat for habitat in np.unique(np.asarray([habitat[:-1] for habitat in le_header.classes_]))])

disp.plot(include_values=False, colorbar=False, ax=ax)

ax.set_xticklabels(disp.display_labels, rotation=45, fontsize=8)
ax.set_yticklabels(disp.display_labels, rotation=45, fontsize=8)

plt.savefig(os.path.join(framework_directory, 'Images/confusion_matrix.pdf'))

plt.show()

## Top-20 features

In [None]:
attributions = torch.load(os.path.join(framework_directory, 'Experiments/attributions.pt'))
with open(os.path.join(framework_directory, 'Data/le_species.pkl'), 'rb') as f:
    le_species = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_country.pkl'), 'rb') as f:
    ohe_country = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_ecoregion.pkl'), 'rb') as f:
    ohe_ecoregion = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_dune.pkl'), 'rb') as f:
    ohe_dune = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_coast.pkl'), 'rb') as f:
    ohe_coast = pickle.load(f)

feature_names = le_species.classes_.tolist() + ["Longitude"] + ["Latitude"] + ["Altitude"] + ohe_country.categories_[0].tolist() + [f"Ecoregion {ecoregion}" for ecoregion in ohe_ecoregion.categories_[0].tolist()] + ohe_dune.categories_[0].tolist() + ohe_coast.categories_[0].tolist()

features_attributions = torch.mean(attributions, dim=0)

k = 20

all_species = torch.sum(features_attributions[:len(le_species.classes_)])
all_location = torch.sum(features_attributions[len(le_species.classes_): len(le_species.classes_) + 2])
all_environmental = torch.sum(features_attributions[len(le_species.classes_) + 2:])

sorted_indices = torch.argsort(features_attributions, descending=True)
top_k_indices = sorted_indices[:k]
top_k_features = [feature_names[i] for i in top_k_indices]
top_k_scores = features_attributions[top_k_indices]

trace = go.Bar(x=["Species", "Environment", "Location"] + top_k_features,
               y=torch.cat((all_species.unsqueeze(0), all_environmental.unsqueeze(0), all_location.unsqueeze(0), top_k_scores), dim=0),
               marker=dict(color=['red'] * 3 + ["blue"] * 20)
)

layout = go.Layout(
    xaxis=dict(title='Feature'),
    yaxis=dict(title='Mean attribution score'),
    height=500,
    width=1000
)

fig = go.Figure(data=[trace], layout=layout)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/top_20_features.pdf'))

## Importance by criteria

In [None]:
attributions = torch.load(os.path.join(framework_directory, 'Experiments/attributions.pt'))
with open(os.path.join(framework_directory, 'Data/le_species.pkl'), 'rb') as f:
    le_species = pickle.load(f)
    
features_attributions = torch.mean(attributions, dim=0)

features_attributions_species = features_attributions[:len(le_species.classes_)]

features_attributions_location = features_attributions[len(le_species.classes_): len(le_species.classes_) + 2]

features_attributions_environment = features_attributions[len(le_species.classes_) + 2:]

labels = ['Species','Location', 'Environment']
values = [torch.sum(features_attributions_species), torch.sum(features_attributions_location), torch.sum(features_attributions_environment)]

fig = go.Figure(data=[go.Pie(labels=labels, values=values, textinfo='label+percent',
                             insidetextorientation='horizontal', hole=.3, #pull=[0.2, 0, 0]
                            )])

fig.update_layout(
    height=500,
    width=500,
    showlegend=False
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/criteria_importance.pdf'))

## Criteria importance per habitat

In [None]:
attributions = torch.load(os.path.join(framework_directory, 'Experiments/attributions.pt'))
with open(os.path.join(framework_directory, 'Data/le_header.pkl'), 'rb') as f:
    le_header = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/le_species.pkl'), 'rb') as f:
    le_species = pickle.load(f)

df = pd.DataFrame(columns=['Habitat type', 'Attribution score', 'Feature type'])

nbr_habitats = np.zeros(8, dtype=int)
habitat_types = ['MA2', 'N', 'Q', 'R', 'S', 'T', 'U', 'V']

for code in le_header.classes_:
    for i, habitat in enumerate(habitat_types):
        if code.startswith(habitat):
            nbr_habitats[i] += 1
            break

start_idx = 0

for i in range(len(nbr_habitats)):
    end_idx = start_idx + nbr_habitats[i]
    summed_features = attributions[start_idx:end_idx]
    start_idx = end_idx

    summed_internal = torch.sum(summed_features[:, :len(le_species.classes_)], dim=1)
    internal_df = pd.DataFrame({'Habitat type': habitat_types[i],
                           'Attribution score': summed_internal.numpy(),
                           'Feature type': 'Internal'})
    
    summed_external = torch.sum(summed_features[:, len(le_species.classes_):], dim=1)
    external_df = pd.DataFrame({'Habitat type': habitat_types[i],
                            'Attribution score': summed_external.numpy(),
                            'Feature type': 'External'})

    df = pd.concat([df, internal_df, external_df], ignore_index=True)

pointpos_internal = [-0.4, -0.5, -0.5, -0.7, -0.9, -0.6, -0.5, -0.4]
pointpos_external = [0.4, 0.4, 0.5, 1, 0.8, 0.9, 0.6, 0.4]

show_legend = [True] + [False] * (len(pd.unique(df['Habitat type'])) - 1)

fig = go.Figure()

for i in range(len(pd.unique(df['Habitat type']))):
    fig.add_trace(go.Violin(x=df['Habitat type'][(df['Feature type'] == 'Internal') &
                                        (df['Habitat type'] == pd.unique(df['Habitat type'])[i])],
                            y=df['Attribution score'][(df['Feature type'] == 'Internal')&
                                               (df['Habitat type'] == pd.unique(df['Habitat type'])[i])],
                            legendgroup='I', scalegroup='I', name='I',
                            side='negative',
                            pointpos=pointpos_internal[i],
                            line_color='lightseagreen',
                            showlegend=show_legend[i])
             )
    fig.add_trace(go.Violin(x=df['Habitat type'][(df['Feature type'] == 'External') &
                                        (df['Habitat type'] == pd.unique(df['Habitat type'])[i])],
                            y=df['Attribution score'][(df['Feature type'] == 'External')&
                                               (df['Habitat type'] == pd.unique(df['Habitat type'])[i])],
                            legendgroup='E', scalegroup='E', name='E',
                            side='positive',
                            pointpos=pointpos_external[i],
                            line_color='mediumpurple',
                            showlegend=show_legend[i])
             )

fig.update_traces(meanline_visible=True,
                  points='all',
                  jitter=0.05,
                  scalemode='count')

fig.update_layout(
    #title_text="Attribution score distribution<br><i>scaled by number of habitats per habitat type",
    violingap=0, violingroupgap=0, violinmode='overlay',
    width=2000,
    height=500, 
    showlegend=False
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/criteria_importance_per_habitat.pdf'))

## Importance by type

In [None]:
attributions = torch.load(os.path.join(framework_directory, 'Experiments/attributions.pt'))
with open(os.path.join(framework_directory, 'Data/le_species.pkl'), 'rb') as f:
    le_species = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_country.pkl'), 'rb') as f:
    ohe_country = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_ecoregion.pkl'), 'rb') as f:
    ohe_ecoregion = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_dune.pkl'), 'rb') as f:
    ohe_dune = pickle.load(f)
with open(os.path.join(framework_directory, 'Data/ohe_coast.pkl'), 'rb') as f:
    ohe_coast = pickle.load(f)
arborescent_species = np.load(os.path.join(framework_directory, 'Datasets/arborescent_species.npy'))

arborescent_species = np.where(np.isin(le_species.classes_, arborescent_species), 1, 0)
arborescent_indices = np.where(arborescent_species == 1)[0]
herbaceous_indices = np.where(arborescent_species == 0)[0]

index = ['MA2', 'N', 'Q', 'R', 'S', 'T', 'U', 'V']

scores = [[], [], [], [], [], [], [], []]

start_idx = 0

for i in range(len(nbr_habitats)):
    end_idx = start_idx + nbr_habitats[i]
    summed_features = attributions[start_idx:end_idx]
    start_idx = end_idx
    
    summed_herbaceous = torch.sum(summed_features[:, herbaceous_indices], dim=1)
    mean_herbaceous = torch.mean(summed_herbaceous)
    summed_arborescent = torch.sum(summed_features[:, arborescent_indices], dim=1)
    mean_arborescent = torch.mean(summed_arborescent)
    summed_location = torch.sum(summed_features[:, len(le_species.classes_):len(le_species.classes_) + 2], dim=1)
    mean_location = torch.mean(summed_location)
    summed_altitude = torch.sum(summed_features[:, len(le_species.classes_) + 2:len(le_species.classes_) + 3], dim=1)
    mean_altitude = torch.mean(summed_altitude)
    summed_country = torch.sum(summed_features[:, len(le_species.classes_) + 3:len(le_species.classes_) + 3 + len(ohe_country.categories_[0])], dim=1)
    mean_country = torch.mean(summed_country)
    summed_ecoregion = torch.sum(summed_features[:, len(le_species.classes_) + 3 + len(ohe_country.categories_[0]):len(le_species.classes_) + 3 + len(ohe_country.categories_[0]) + len(ohe_ecoregion.categories_[0])], dim=1)
    mean_ecoregion = torch.mean(summed_ecoregion)
    summed_dune = torch.sum(summed_features[:, len(le_species.classes_) + 3 + len(ohe_country.categories_[0]) + len(ohe_ecoregion.categories_[0]):len(le_species.classes_) + 3 + len(ohe_country.categories_[0]) + len(ohe_ecoregion.categories_[0]) + len(ohe_dune.categories_[0])], dim=1)
    mean_dune = torch.mean(summed_dune)
    summed_coast = torch.sum(summed_features[:, len(le_species.classes_) + 3 + len(ohe_country.categories_[0]) + len(ohe_ecoregion.categories_[0]) + len(ohe_dune.categories_[0]):], dim=1)
    mean_coast = torch.mean(summed_coast)
    
    scores[i].extend([mean_herbaceous, mean_arborescent, mean_location, mean_altitude, mean_country, mean_ecoregion, mean_dune, mean_coast])

df = pd.concat(
    [
        pd.DataFrame(
            [score[0:2] for score in scores],
            index=index,
            columns=["Herbaceous", "Arborescent"]
        ),
        pd.DataFrame(
            [score[2:] for score in scores],
            index=index,
            columns=["Location", "Altitude", "Country", "Ecoregion", "Dune", "Coast"]
        ),
    ],
    axis=1,
    keys=["Internal", "External"]
)

fig = go.Figure(
    layout=go.Layout(
        height=500,
        width=1000,
        barmode="relative",
        yaxis_showticklabels=False,
        yaxis_showgrid=False,
        yaxis_range=[0, df.groupby(axis=1, level=0).sum().max().max() * 1.5],
        yaxis2=go.layout.YAxis(
            visible=False,
            matches="y",
            overlaying="y",
            anchor="x",
        ),
        font=dict(size=20),
        legend_x=0,
        legend_y=1,
        legend_orientation="h",
        hovermode="x",
        margin=dict(b=0,t=10,l=0,r=10)
    )
)

colors = {
    "Internal": {
        "Herbaceous": px.colors.sequential.Greens[3],
        "Arborescent": px.colors.sequential.Greens[7]
    },
    "External": {
        "Location": px.colors.sequential.Reds[3],
        "Altitude": px.colors.sequential.Reds[4],
        "Country": px.colors.sequential.Reds[5],
        "Ecoregion": px.colors.sequential.Reds[6],
        "Dune": px.colors.sequential.Reds[7],
        "Coast": px.colors.sequential.Reds[8]
    }
}

# Add the traces
for i, t in enumerate(colors):
    for j, col in enumerate(df[t].columns):
        if (df[t][col] == 0).all():
            continue
        fig.add_bar(
            x=df.index,
            y=df[t][col],
            yaxis=f"y{i + 1}",
            offsetgroup=str(i),
            offset=(i - 1) * 1/3,
            width=1/3,
            legendgroup=t,
            legendgrouptitle_text=t,
            name=col,
            marker_color=colors[t][col],
            marker_line=dict(width=2, color="#333"),
            hovertemplate="%{y}<extra></extra>"
        )

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/internal_external_per_habitat.pdf'))

## Importance by rank

In [None]:
ranks = torch.load(os.path.join(framework_directory, 'Experiments/ranks.pt'))
with open(os.path.join(framework_directory, 'Data/le_species.pkl'), 'rb') as f:
    le_species = pickle.load(f)

x = np.arange(1, len(le_species.classes_) + 1)[:50]
y = ranks[:50].cpu()

fig = go.Figure(data=go.Scatter(x=x, y=y))

fig.update_layout(xaxis_title="Species dominance rank", yaxis_title="Mean attribution")
fig.update_xaxes(range=[0, 51])

fig.update_layout(
    height=500,
    width=1000
)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/importance_by_rank.pdf'))

## Feature ablation

In [None]:
ablations = torch.load(os.path.join(framework_directory, 'Experiments/ablations.pt'))
with open(os.path.join(framework_directory, 'Data/le_header.pkl'), 'rb') as f:
    le_header = pickle.load(f)

data = {
    "Habitat": [habitat[:-2] for habitat in le_header.classes_.tolist() * 2],  # Repeat each class twice
    "Criteria": ["Internal"] * 228 + ["External"] * 228,
    "Ablations": np.concatenate((ablations[:, 0], ablations[:, 1]))  # Concatenate the values from ablations
}

df = pd.DataFrame(data)

fig = px.box(df, x='Habitat', y='Ablations', color='Criteria', height=500, width=1000)


fig.update_layout(showlegend=False)

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/feature_ablation_per_habitat.pdf'))

## Micro average

In [None]:
folds=[f"Fold {i}" for i in range(10)]

fig = go.Figure(data=[
    go.Bar(name='Top-3', x=folds, y=[98.63, 98.59, 98.50, 98.54, 98.51, 98.52, 98.58, 98.56, 98.42, 98.64], marker_color='red'),
    go.Bar(name='Top-1', x=folds, y=[89.06, 89.10, 88.13, 88.88, 89.08, 88.80, 88.91, 88.80, 88.22, 88.47], marker_color='blue')
])

fig.update_layout(barmode='overlay',
                  xaxis_title='Folds',
                  yaxis_title='Accuracy (%)',
                  showlegend=False,
                  height=500,
                  width=1000)

fig.update_yaxes(range=[50, 100])

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/micro_average.pdf'))

## Macro average

In [None]:
folds=[f"Fold {i}" for i in range(10)]

fig = go.Figure(data=[
    go.Bar(name='Top-3', x=folds, y=[91.89, 93.62, 90.91, 90.39, 88.94, 93.49, 91.57, 89.02, 88.56, 89.71], marker_color='red'),
    go.Bar(name='Top-1', x=folds, y=[74.08, 77.23, 73.64, 72.59, 72.82, 77.10, 74.06, 73.00, 73.62, 71.58], marker_color='blue')
])

fig.update_layout(barmode='overlay',
                  xaxis_title='Folds',
                  yaxis_title='Accuracy (%)',
                  showlegend=False,
                  height=500,
                  width=1000)

fig.update_yaxes(range=[50, 100])

fig.show()

fig.write_image(os.path.join(framework_directory, 'Images/macro_average.pdf'))