In [None]:
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
import rasterio
import random
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support

import torch
from torch.utils.data import Dataset, DataLoader
sys.path.append('../prithvi/')
from utils import set_seed, f1_score

In [None]:
path_data = "/home/gt/DATA/geolifeclef-2025"
nc = 409
ns = 2519
nu = 100
nf = 20
n_samples = 50
thin = 4000
if nu == 0: nf = 0
model_type_string = f"nc{nc:04d}_ns{ns:04d}_np{nu:04d}_nf{nf:02d}"
pred_filename = f"pred_{model_type_string}_sam{n_samples:04d}_thin{thin:04d}.csv"
test_pred = pd.read_csv(os.path.join(path_data, "hmsc", "pred", pred_filename))

In [None]:
pa_presence_threshold = 1
train_path_sentinel = os.path.join(path_data, "SatelitePatches/PA-train")
train_path_landsat = os.path.join(path_data, "SateliteTimeSeries-Landsat/cubes/PA-train")
train_path_bioclim = os.path.join(path_data, "BioclimTimeSeries/cubes/PA-train")
train_metadata = pd.read_csv(os.path.join(path_data, "GLC25_PA_metadata_train.csv"))
train_metadata = train_metadata.dropna(subset="speciesId").reset_index(drop=True)
train_metadata['speciesId'] = train_metadata['speciesId'].astype(int)
train_metadata["speciesIdOrig"] = train_metadata['speciesId']
tmp = train_metadata["speciesId"].value_counts() >= pa_presence_threshold
train_metadata.loc[~train_metadata["speciesId"].isin(tmp[tmp].index), "speciesId"] = -1
sp_categorical = train_metadata["speciesId"].astype("category").values
num_classes = len(sp_categorical.categories)
train_metadata['speciesId'] = sp_categorical.codes
test_metadata = pd.read_csv(os.path.join(path_data, "GLC25_PA_metadata_test.csv")).set_index("surveyId", drop=False).sort_index()

In [None]:
batch_size = 64
set_seed(42)

class PredDataset(Dataset):
    def __init__(self, pred):
        self.pred = pred

    def __len__(self):
        return self.pred.shape[0]

    def __getitem__(self, idx):
        return self.pred.iloc[idx].values

test_loader = DataLoader(PredDataset(test_pred), batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")

In [None]:
with torch.no_grad():
    top_indices = []
    for outputs in tqdm(test_loader, total=len(test_loader),  desc="prediction"):
        outputs = torch.logit(outputs.to(device))
        top_batch_list_orig = f1_score(outputs, None, device=device)
        top_batch_list_proc = [np.sort(sp_categorical.categories[pred.cpu().numpy()]) for pred in top_batch_list_orig]
        top_indices += top_batch_list_proc

In [None]:
outputs

In [None]:
data_concatenated = [' '.join(map(str, row)) for row in top_indices]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
res = pd.DataFrame({'surveyId': test_metadata.surveyId.values, 'predictions': data_concatenated})

In [None]:
os.makedirs(os.path.join(path_data, "hmsc", "result"), exist_ok=True)
timestamp = datetime.now().strftime('%m%d_%H%M%S')
res_filename = f"{timestamp}_res_{model_type_string}_sam{n_samples:04d}_thin{thin:04d}.csv"
res.to_csv(os.path.join(path_data, "hmsc", "result", res_filename), index=False)

In [None]:
print(res.predictions.apply(len).min(), res.predictions.apply(len).max())