In [None]:
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import gc
import numpy as np
import rasterio
from tqdm.notebook import tqdm
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as v2
from terratorch.models.backbones.prithvi_mae import PrithviViT
sys.path.append('../prithvi/')
from utils import set_seed
from glc_datasets import TrainDataset, TestDataset, read_train_data, read_test_data
from models import ModifiedPrithviResNet18, prithvi_terratorch
torch.multiprocessing.set_sharing_strategy('file_system')
os.listdir(os.environ['LOCAL_SCRATCH'])

In [None]:
prepare_type = "deep"

In [None]:
batch_size = 32
num_workers = 6
pa_presence_threshold = 1
num_classes_total = 11255
landsat_year_len = 18
validation_prop = 0.1
sel_countries = ["France", "Denmark", "Netherlands", "Italy"] #, "Austria"
cov_countries = 1
cov_area, cov_elevation, cov_snow = 1, 1, 1
cov_soil, cov_worldcover, cov_landcover = 1, 1, 1
if prepare_type == "deep":
    cov_snow = cov_landcover = 0
if os.environ["HOSTNAME"] == "gtbase":
    path_save = path_data = "/home/gt/DATA/geolifeclef-2025"
    print("local, using", f"path_data=path_save={path_data}")
else:
    path_data = os.environ['LOCAL_SCRATCH']
    path_save = os.environ['GLC_SCRATCH']
    print("mahti, using", f"path_data={path_data};", f"path_save={path_save}")

mean_landsat = 1*np.array([ 15.0698,   16.0923,    7.9312,   68.9794,   47.9505,   24.8804, 7089.4349, 2830.6658])
std_landsat =  1*np.array([ 11.7218,   10.2417,    9.6499,   18.7112,   13.1681,    9.2436, 3332.3618,   56.7270])
mean_sentinel = 1*np.array([ 624.8547,  684.7646,  456.7674, 2924.1753])
std_sentinel =  1*np.array([ 416.0408,  351.1005,  315.8956,  943.6141])

transform_landsat = v2.Compose([
    v2.Normalize(mean_landsat, std_landsat),
])
transform_sentinel = v2.Compose([
    v2.Normalize(mean_sentinel, std_sentinel),
])

### Train data

In [None]:
if prepare_type == "orig":
    image_mean = True
    sentinel_mask_channel = False
elif prepare_type == "deep":
    image_mean = False
    sentinel_mask_channel = True

In [None]:
cov_flag_list = [cov_area, cov_elevation, cov_countries, cov_soil, cov_worldcover, cov_landcover, cov_snow]
train_combined, train_label_series, cov_columns, cov_norm_coef, num_classes = read_train_data(path_data, cov_flag_list, sel_countries, pa_presence_threshold)
#cov_norm_coef.to_csv(os.path.join(path_data, "hmsc", "train_cov_mean_std.csv"))

In [None]:
set_seed(42)
train_path_sentinel = os.path.join(path_data, "SatelitePatches/PA-train")
train_path_landsat = os.path.join(path_data, "SateliteTimeSeries-Landsat/cubes/PA-train")
train_path_bioclim = os.path.join(path_data, "BioclimTimeSeries/cubes/PA-train")
train_data = train_combined.reset_index(drop=True)
train_label_dict = train_label_series.to_dict()
train_dataset = TrainDataset(train_path_sentinel, train_path_landsat, train_path_bioclim, train_data, cov_columns, train_label_dict, 
                             subset="train", num_classes=num_classes, transform_sentinel=transform_sentinel, transform_landsat=transform_landsat,
                            landsat_year_len=landsat_year_len, image_mean=image_mean, sentinel_mask_channel=sentinel_mask_channel)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
print(train_dataset[0][0].shape, train_dataset[0][1].shape, train_dataset[0][2].shape, train_dataset[0][3].shape, train_dataset[0][4].shape)

In [None]:
if prepare_type == "deep":
    patch_size = [1,16,16]
    n_frame = 1
    n_channel = 5
    embed_dim = 1024
    decoder_depth = 8
    num_heads = 16
    mlp_ratio = 4
    resnet_dim = 1000
    hidden_last_dim = 1000 + 128
    head_dropout = 0.0
    
    prithvi_instance = PrithviViT(
            patch_size=patch_size,
            num_frames=n_frame,
            in_chans=n_channel,
            embed_dim=embed_dim,
            decoder_depth=decoder_depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            head_dropout=head_dropout,
            backbone_input_size=[1,64,64],
            encoder_only=False,
            padding=True,
    )
    prithvi_model = prithvi_terratorch(None, prithvi_instance, [1,64,64])
    
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("DEVICE = CUDA")
    prithvi_model.to(device)
    
    model = ModifiedPrithviResNet18(num_classes, len(cov_columns), resnet_dim, hidden_last_dim, prithvi_model).to(device)
    model.load_state_dict(torch.load(os.path.join(path_data, "hmsc", "0507_135023_weights75.pth"), weights_only=True, map_location=device))
    model.fc_final = nn.Identity()
    model.eval()

In [None]:
surv_list, cov_list, y_list = [], [], []
with torch.no_grad():
    for sentinel, landsat, data_cov, lonlat, label, survey in tqdm(train_loader):
        if prepare_type == "orig":
            cov = torch.concat([sentinel, landsat.reshape([landsat.shape[0], -1]), data_cov, lonlat], dim=1).numpy()
        elif prepare_type == "deep":
            sentinel = sentinel.to(device)[:,:,None,:,:]
            landsat, data_cov, lonlat = landsat.to(device), data_cov.to(device), lonlat.to(device)
            features = model(sentinel, landsat, data_cov, lonlat)
            cov = torch.concat([features, lonlat], dim=1).cpu().numpy()
        cov_list.append(cov)
        y_list.append(label.numpy())
        surv_list.append(survey.numpy())

In [None]:
if prepare_type == "orig":
    cols = [f"sentinel{i}" for i in range(sentinel.shape[-1])] + [f"landsatbio{i}{j}" for i in range(landsat.shape[-2]) for j in range(landsat.shape[-1])] + cov_columns + ["lon","lat"]
elif prepare_type == "deep":
    cols = [f"f{i:04}" for i in range(features.shape[-1])] + ["lon","lat"]

cov = pd.DataFrame(np.concatenate(cov_list), columns=cols)
display(cov)
Y = pd.DataFrame(np.concatenate(y_list)).astype(int)
os.makedirs(os.path.join(path_data, "hmsc"), exist_ok=True)
# if prepare_type == "orig":
#     cov.to_csv(os.path.join(path_data, "hmsc", "train_cov.csv"), index_label=False)
# elif prepare_type == "deep":
#     cov.to_csv(os.path.join(path_data, "hmsc", "train_deepfeatures.csv"), index_label=False)
#Y.to_csv(os.path.join(path_data, "hmsc", "train_Y.csv"), index_label=False)
surv = pd.DataFrame({"surveyId":np.concatenate(surv_list)})

### Test data

In [None]:
test_path_sentinel = os.path.join(path_data, "SatelitePatches/PA-test")
test_path_landsat = os.path.join(path_data, "SateliteTimeSeries-Landsat/cubes/PA-test")
test_path_bioclim = os.path.join(path_data, "BioclimTimeSeries/cubes/PA-test")
test_combined = read_test_data(path_data, cov_columns, cov_norm_coef, sel_countries)
test_combined.reset_index(drop=True, inplace=True)
test_dataset = TestDataset(test_path_sentinel, test_path_landsat, test_path_bioclim, test_combined, cov_columns, subset="test", 
                           transform_sentinel=transform_sentinel, transform_landsat=transform_landsat, image_mean=image_mean, sentinel_mask_channel=sentinel_mask_channel)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
print(test_dataset[0][0].shape, test_dataset[0][1].shape, test_dataset[0][2].shape, test_dataset[0][3].shape)

In [None]:
surv_list, cov_list = [], []
with torch.no_grad():
    for sentinel, landsat, data_cov, lonlat, survey in tqdm(test_loader):
        if prepare_type == "orig":
            cov = torch.concat([sentinel, landsat.reshape([landsat.shape[0], -1]), data_cov, lonlat], dim=1).numpy()
        elif prepare_type == "deep":
            sentinel = sentinel.to(device)[:,:,None,:,:]
            landsat, data_cov, lonlat = landsat.to(device), data_cov.to(device), lonlat.to(device)
            features = model(sentinel, landsat, data_cov, lonlat)
            cov = torch.concat([features, lonlat], dim=1).cpu().numpy()
        cov_list.append(cov)
        surv_list.append(survey.numpy())

In [None]:
cov = pd.DataFrame(np.concatenate(cov_list), columns=cols)
display(cov)
os.makedirs(os.path.join(path_data, "hmsc"), exist_ok=True)
# if prepare_type == "orig":
#     cov.to_csv(os.path.join(path_data, "hmsc", "test_cov.csv"), index_label=False)
# elif prepare_type == "deep":
#     cov.to_csv(os.path.join(path_data, "hmsc", "test_deepfeatures.csv"), index_label=False)

### PO data

In [None]:
if os.environ["HOSTNAME"] == "gtbase":
    raise(Exception("Do not run PO data on local machine!"))

In [None]:
po_metadata_orig = pd.read_csv(os.path.join(path_data, "GLC25_P0_metadata_train.csv"))

In [None]:
po_metadata = po_metadata_orig 
po_metadata = po_metadata.loc[:, ["lat", "lon", "surveyId"]].drop_duplicates().set_index("surveyId", drop=True).sort_index()
po_countries = pd.read_csv(os.path.join(path_data, "po_with_countries.csv"), index_col=0)
po_metadata = po_metadata.join(po_countries.loc[~po_countries.index.duplicated(keep='first'), "name"])
po_metadata = po_metadata.rename({"name": "country"}, axis=1)
country_columns = ["con"+country[:3] for country in sel_countries] + ["conOther"]
for country, col in zip(sel_countries, country_columns[:-1]):
    po_metadata[col] = po_metadata["country"] == country
po_metadata[country_columns[-1]] = ~po_metadata["country"].isin(sel_countries)
po_worldcover = pd.read_csv(os.path.join(path_data, "worldcover", "po_train_survey_points_with_worldcover.csv"), index_col=0)
po_combined = po_metadata.reset_index().merge(po_worldcover.loc[:,["lat","lon","class"]], on=["lat","lon"]).set_index("surveyId")
po_combined["areaLog"] = 0.0

po_elevation = pd.read_csv(os.path.join(path_data, "EnvironmentalValues", "Elevation", "GLC25-PO-train-elevation.csv"), index_col=0)
po_combined = po_combined.join(po_elevation)

po_soil = pd.read_csv(os.path.join(path_data, "EnvironmentalValues", "SoilGrids", "GLC25-PO-train-soilgrids.csv"), index_col=0)
for column in po_soil.columns: po_soil[column].fillna((po_soil[column].mean()), inplace=True)
po_combined = po_combined.join(po_soil)

po_wcdummy = pd.get_dummies(po_combined["class"], prefix="wc")
po_wcdummy.drop(columns=["wc_70", "wc_100"], inplace=True)
po_combined = po_combined.join(po_wcdummy)

po_landcover = pd.read_csv(os.path.join(path_data, "EnvironmentalValues", "LandCover", "GLC25-PO-train-landcover.csv"), index_col=0)
landcover_col_ind=[0,2,3,5,8,11,12]
po_landcover = po_landcover.iloc[:, landcover_col_ind]
po_combined = po_combined.join(po_landcover)

po_snow = pd.read_csv(os.path.join(path_data, "EnvironmentalValues", "chelsa_snow", "po_train_snowcover_chelsa_scd.csv"), index_col=0).sort_index()
po_snow = po_snow.rename({"scd_1": "scd"}, axis=1)
po_snow = po_snow[~po_snow.index.duplicated(keep='first')]
po_combined = po_combined.join(po_snow)

po_combined.loc[:,cov_columns] = (po_combined.loc[:,cov_columns] - cov_norm_coef.loc["mean"]) / cov_norm_coef.loc["std"]
po_combined.loc[:,cov_columns].isna().sum()
po_combined.reset_index(drop=False, inplace=True)

In [None]:
po_path_sentinel = os.path.join(os.environ['LOCAL_SCRATCH'], "SatelitePatches/po/output/TIFF_64")
po_path_landsat = os.path.join(os.environ['LOCAL_SCRATCH'], "SateliteTimeSeries-Landsat/cubes_landsat")
po_path_bioclim = os.path.join(os.environ['LOCAL_SCRATCH'], "BioclimTimeSeries/cubes_bioclim")
po_dataset = TestDataset(po_path_sentinel, po_path_landsat, po_path_bioclim, po_combined, cov_columns, subset="po", 
                           transform_sentinel=transform_sentinel, transform_landsat=transform_landsat, image_mean=image_mean, sentinel_mask_channel=sentinel_mask_channel)
po_loader = DataLoader(po_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
print(po_dataset[0][0].shape, po_dataset[0][1].shape, po_dataset[0][2].shape, po_dataset[0][3].shape)

In [None]:
surv_list, cov_list = [], []
df_count = 0
counter = 0
os.makedirs(os.path.join(path_save, "hmsc", "po"), exist_ok=True)
with torch.no_grad():
    for sentinel, landsat, data_cov, lonlat, survey in tqdm(po_loader):
        if os.path.isfile(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather")):
            counter += batch_size
            if(counter >= 100000):
                df_count += 1
                counter = 0
            continue
            
        if prepare_type == "orig":
            cov = torch.concat([sentinel, landsat.reshape([landsat.shape[0], -1]), data_cov, lonlat], dim=1).numpy()
        elif prepare_type == "deep":
            sentinel = sentinel.to(device)[:,:,None,:,:]
            landsat, data_cov, lonlat = landsat.to(device), data_cov.to(device), lonlat.to(device)
            features = model(sentinel, landsat, data_cov, lonlat)
            cov = torch.concat([features, lonlat], dim=1).cpu().numpy()
        cov_list.append(cov)
        surv_list.append(survey.numpy())
        if(len(cov_list)*batch_size >= 100000):
            cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list)).reset_index(drop=False)
            if prepare_type == "orig":
                cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_cov{df_count:03d}.feather"))
            elif prepare_type == "deep":
                cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather"))
            surv_list, cov_list = [], []
            df_count += 1
            gc.collect()

if len(cov_list) > 0:
    cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list)).reset_index(drop=False)
    if prepare_type == "orig":
        cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_cov{df_count:03d}.feather"))
    elif prepare_type == "deep":
        cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather"))

In [None]:
surv_list, cov_list = [], []
df_count = 0
counter = 0
os.makedirs(os.path.join(path_save, "hmsc", "po"), exist_ok=True)
with torch.no_grad():
    for sentinel, landsat, data_cov, lonlat, survey in tqdm(po_loader):
        if os.path.isfile(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather")):
            counter += batch_size
            if(counter >= 100000):
                df_count += 1
                counter = 0
            continue
            
        if prepare_type == "orig":
            cov = torch.concat([sentinel, landsat.reshape([landsat.shape[0], -1]), data_cov, lonlat], dim=1).numpy()
        elif prepare_type == "deep":
            sentinel = sentinel.to(device)[:,:,None,:,:]
            landsat, data_cov, lonlat = landsat.to(device), data_cov.to(device), lonlat.to(device)
            features = model(sentinel, landsat, data_cov, lonlat)
            cov = torch.concat([features, lonlat], dim=1).cpu().numpy()
        cov_list.append(cov)
        surv_list.append(survey.numpy())
        if(len(cov_list)*batch_size >= 100000):
            cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list)).reset_index(drop=False)
            if prepare_type == "orig":
                cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_cov{df_count:03d}.feather"))
            elif prepare_type == "deep":
                cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather"))
            surv_list, cov_list = [], []
            df_count += 1
            gc.collect()

if len(cov_list) > 0:
    cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list)).reset_index(drop=False)
    if prepare_type == "orig":
        cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_cov{df_count:03d}.feather"))
    elif prepare_type == "deep":
        cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_deepfeatures{df_count:03d}.feather"))

In [None]:
cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list)).reset_index(drop=False)
cov.to_feather(os.path.join(path_save, "hmsc", "po", f"po_cov{df_count:03d}.feather"))

In [None]:
cols = [f"sentinel{i}" for i in range(sentinel.shape[-1])] + [f"landsatbio{i}{j}" for i in range(landsat.shape[-2]) for j in range(landsat.shape[-1])] + cov_columns + ["lon","lat"]
cov = pd.DataFrame(np.concatenate(cov_list), columns=cols, index=np.concatenate(surv_list))
os.makedirs(os.path.join(path_save, "hmsc"), exist_ok=True)
cov.to_csv(os.path.join(path_save, "hmsc", "po_cov.csv"), index_label=True)

In [None]:
po_metadata = po_metadata.join(po_soil)
po_metadata

In [None]:
po_metadata = po_metadata_orig 
po_metadata = po_metadata.loc[:, ["lat", "lon", "surveyId"]].drop_duplicates().set_index("surveyId", drop=True).sort_index()
po_metadata = po_metadata.join(po_countries.loc[~po_countries.index.duplicated(keep='first'), "name"])
po_metadata = po_metadata.rename({"name": "country"}, axis=1)
for country, col in zip(sel_countries, country_columns[:-1]):
    po_metadata[col] = po_metadata["country"] == country
po_metadata[country_columns[-1]] = ~po_metadata["country"].isin(sel_countries)
po_worldcover = pd.read_csv(os.path.join(path_data, "worldcover", "po_train_survey_points_with_worldcover.csv"), index_col=0)
po_metadata

In [None]:
po_metadata = po_metadata.reset_index().merge(po_worldcover.loc[:,["lat","lon","class"]], on=["lat","lon"]).set_index("surveyId")
po_metadata

In [None]:
po_countries.loc[po_countries.name.isna(), "name"]

In [None]:
from importlib import reload
import glc_datasets
reload(glc_datasets)
from glc_datasets import TrainDataset, TestDataset, read_train_data, read_test_data
