In [None]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from scipy.spatial import cKDTree
from matplotlib import pyplot as plt
sys.path.append('../prithvi/')
from glc_datasets import TrainDataset, TestDataset, read_train_data, read_test_data

if os.uname()[1] == "gtbase":
    path_save = path_data = "/home/gt/DATA/geolifeclef-2025"
    print("local, using", f"path_data=path_save={path_data}")
else:
    path_data = os.environ['LOCAL_SCRATCH']
    path_save = os.environ['GLC_SCRATCH']
    print("mahti, using", f"path_data={path_data};", f"path_save={path_save}")
    os.listdir(os.environ['LOCAL_SCRATCH'])

In [None]:
data_po = pd.read_csv(os.path.join(path_data, "PO_points_with_cell_id_net_90km.csv"))

In [None]:
data_po = data_po.dropna(subset="index_right")
data_po.speciesId = data_po.speciesId.astype(int)
data_po.cell_id = data_po.cell_id.astype(int)
#data_po.loc[:,["surveyId","lon","lat","cell_id","speciesId"]]

In [None]:
metadata_po = data_po.groupby("surveyId").agg({"lon": "first", "lat": "first", "cell_id": "first", "speciesId": lambda values : ",".join(values.astype(str))})
metadata_po 

In [None]:
data_worldcover = pd.read_csv(os.path.join(path_data, "worldcover", "po_train_survey_points_with_worldcover.csv"), index_col=0)
data_worldcover

In [None]:
meta = metadata_po.reset_index().merge(data_worldcover.loc[:,["lon","lat","class"]], on=["lon","lat"])
meta.set_index("surveyId", drop=True, inplace=True)
meta

In [None]:
tmp = meta.groupby(["cell_id","class"]).agg({"speciesId": lambda values: ",".join(values.astype(str))})
pd.concat([tmp.speciesId.apply(lambda x: len(x.split(","))), tmp.speciesId.apply(lambda x: len(set(x.split(","))))], axis=1)

In [None]:
cov_flag_list = [1, 0, 0, 0, 0, 0]
pa_presence_threshold = 1
sel_countries = ["France", "Denmark", "Netherlands", "Italy"]
train_combined, train_label_series, sp_categories, cov_columns, cov_norm_coef, num_classes = read_train_data(path_data, cov_flag_list, sel_countries, pa_presence_threshold)

In [None]:
pd.Categorical([int(x) for x in tmp.speciesId.iloc[0].split(",")], categories=sp_categories)

In [None]:
sp = tmp.speciesId.apply(lambda val: np.sort(pd.Categorical([int(x) for x in set(val.split(","))], categories=sp_categories).codes))
sp.loc[(3456,)]

In [None]:
x_data_list = []
for i in range(7):
    x_data_list.append(pd.read_feather(os.path.join(path_data, "hmsc", "po", f"po_cov{i:03d}.feather")))
x_data = pd.concat(x_data_list)
x_data.set_index("index", drop=True, inplace=True)

In [None]:
meta_cov = meta.loc[:,["lon","lat","cell_id","class","speciesId",]].join(x_data.loc[:,x_data.columns[:-2]])
meta_cov

In [None]:
agg_dict = {"speciesId": lambda values: ",".join(values.astype(str))}
for colname in ["lon","lat"] + list(x_data.columns[:-2]):
    agg_dict[colname] = "mean"
print(agg_dict)
df_po = meta_cov.groupby(["cell_id","class"]).agg(agg_dict)
df_po

In [None]:
po_X = pd.concat([df_po.speciesId.apply(lambda val: len(val.split(","))), df_po.drop(columns="speciesId")], axis=1)
po_X.rename(columns={"speciesId": "obs"}, inplace=True)
po_X = po_X.reset_index()
display(po_X)
po_X.to_csv(os.path.join(path_data, "hmsc", "po_X.csv"), index=False)

In [None]:
def splist_to_vector(val):
    species_ids = np.sort(pd.Categorical([int(x) for x in set(val.split(","))], categories=sp_categories).codes)
    species_ids = species_ids[species_ids>=0].astype(int)
    vec = torch.zeros(num_classes, dtype=int).scatter(0, torch.tensor(species_ids), torch.ones(len(species_ids), dtype=int))
    return vec.numpy()

df_Y = df_po.speciesId.apply(splist_to_vector)
po_Y = pd.DataFrame(np.stack(df_Y.values), index=df_Y.index)
display(po_Y)
po_Y.to_csv(os.path.join(path_data, "hmsc", "po_Y.csv"), index=False)

In [None]:
n_vec = [100, 200, 400]
colnames = [f"clusters{n}" for n in n_vec]
train_clusters = pd.DataFrame(index=train_combined.index, columns=colnames)
po_clusters = pd.DataFrame(index=po_X.index, columns=colnames)
for k, n in enumerate(n_vec):
    centroids = pd.read_csv(os.path.join(path_data, "hmsc", "centroids_po_pa", f"centroids_k{n}.csv"))
    btree = cKDTree(centroids.loc[:,["lon","lat"]])
    train_clusters.iloc[:,k] = btree.query(train_combined.loc[:,["lon","lat"]], k=1)[1]
    po_clusters.iloc[:,k] = btree.query(po_X.loc[:,["lon","lat"]], k=1)[1]
train_clusters.to_csv(os.path.join(path_data, "hmsc", "centroids_po_pa", "train_clusters.csv"), index=False)
po_clusters.to_csv(os.path.join(path_data, "hmsc", "centroids_po_pa", "po_clusters.csv"), index=False)

In [None]:
k = 0
n = n_vec[k]
ind = [0, 111, 7557, 753, 8745]
centroids = pd.read_csv(os.path.join(path_data, "hmsc", "centroids_po_pa", f"centroids_k{n}.csv"))
fig, ax = plt.subplots(ncols=2, figsize=[18,6])
ax[0].scatter(centroids.lon, centroids.lat, s=1, color="black")
for i, row in centroids.iterrows():
    ax[0].annotate(f"{i:03d}", (row.lon, row.lat), color="black")
ax[0].scatter(train_combined.lon.iloc[ind], train_combined.lat.iloc[ind], s=10, color="red")
for i in ind:
    ax[0].annotate(f"{train_clusters.iloc[i,k]:03d}", tuple(train_combined.iloc[i][["lon", "lat"]]), color="red")

ax[1].scatter(centroids.lon, centroids.lat, s=1, color="black")
for i, row in centroids.iterrows():
    ax[1].annotate(f"{i:03d}", (row.lon, row.lat), color="black")
ax[1].scatter(po_X.lon.iloc[ind], po_X.lat.iloc[ind], s=10, color="blue")
for i in ind:
    ax[1].annotate(f"{po_clusters.iloc[i,k]:03d}", tuple(po_X.iloc[i][["lon", "lat"]]), color="blue")

plt.show()

In [None]:
file_list = os.listdir(os.path.join(path_data, "hmsc", "po"))
file_list.sort()
file_list

In [None]:
from importlib import reload
sys.path.append('../prithvi/')
import glc_datasets
reload(glc_datasets)
from glc_datasets import TrainDataset, TestDataset, read_train_data, read_test_data