### Creating maps from the trained models

In [None]:
from collections import Counter, defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import pandas as pd
import torchvision
import os

from dataset import SINR_DS
from models import *
from utils import DefaultParams
from embedders import get_embedder

In [None]:
Data_FOLDER = "glc23_data/"
# dataset_file = pd.read_csv(Data_FOLDER + 'Pot_10_to_1000.csv', sep=";", header='infer', low_memory=False)
dataset_file = pd.read_csv(
    Data_FOLDER + "Pot_10_to_1000_nofrance.csv",
    sep=";",
    header="infer",
    low_memory=False,
)

In [None]:
# Listing all checkpoints in a folder
CP_PATH = ""
cps = os.listdir(CP_PATH)
for cp in cps:
    if "lf" in cp:
        try:
            print(cp)
        except:
            pass

In [None]:
# Code-Snippet to match GLC23 occurrence with gbif occurrence to retrieve the original species
print(dataset_file[dataset_file["speciesId"] == 265].iloc[0])
import pygbif

# Replace key with key from previously printed snippet
pygbif.occurrences.get(key=3951621754)

In [None]:
def get_model(name):
    if "sat" in name:
        sinr = False
        PREDICTORS = "loc_env_sent2"
    elif "env" in name:
        sinr = True
        PREDICTORS = "loc_env"
    else:
        sinr = True
        PREDICTORS = "loc"
    bioclim_path = Data_FOLDER + "bioclim+elev/bioclim_elevation_scaled_europe.npy"
    dataset = SINR_DS(
        dataset_file,
        PREDICTORS,
        sent_data_path=Data_FOLDER + "SatelliteImages/",
        bioclim_path=bioclim_path,
    )

    default_params = DefaultParams(sinr)
    default_params.dataset.predictors = PREDICTORS

    if "sat" in name:
        default_params.model = name.split(" ")[0]
        model = SAT_SINR(default_params, dataset, get_embedder(default_params))
    else:
        model = SINR(default_params, dataset)

    path = CP_PATH + name

    state_dict = torch.load(path)["state_dict"]
    model.load_state_dict(state_dict, strict=True)
    return model.eval(), sinr, dataset

In [None]:
RES_LON = 502
RES_LAT = 408

c = Counter(dataset_file["speciesId"].to_numpy())

max_lon = 34.55792
min_lon = -10.53904
max_lat = 71.18392
min_lat = 34.56858

In [None]:
def get_preds(model, sinr, dataset):
    if sinr:
        locs = []
        model.net.to("cpu")
        for i in tqdm(range(RES_LON)):
            # i is lon
            # j is lat
            for j in range(RES_LAT):
                lon = i / RES_LON
                lat = j / RES_LAT
                lon = lon * (max_lon - min_lon) + min_lon
                lat = lat * (max_lat - min_lat) + min_lat
                locs.append(dataset.encode(lon, lat))
        locs = torch.stack(locs)
        preds = model(locs).sigmoid()
    else:
        preds = []
        model.net.to("cuda")
        for i in tqdm(range(RES_LON)):
            # i is lon
            # j is lat
            for j in range(RES_LAT):
                lon = i / RES_LON
                lat = j / RES_LAT
                lon = lon * (max_lon - min_lon) + min_lon
                lat = lat * (max_lat - min_lat) + min_lat
                loc = dataset.encode(lon, lat)
                pos = str(lat) + "," + str(lon)
                # Requires downloading and cropping fitting Sentinel-2 images from the Ecodatacube
                rgb_path = "sentinel_2 2021 Europe/rgb/" + pos + ".jpeg"
                nir_path = "sentinel_2 2021 Europe/nir/" + pos + ".jpeg"
                try:
                    rgb = Image.open(rgb_path)
                    nir = Image.open(nir_path)
                    to_tensor = torchvision.transforms.PILToTensor()
                    sent2 = torch.concat([to_tensor(rgb), to_tensor(nir)], dim=0) / 255
                except:
                    sent2 = torch.zeros(4, 128, 128)
                if sent2.shape != torch.Size([4, 128, 128]):
                    sent2 = torch.zeros(4, 128, 128)
                with torch.no_grad():
                    preds.append(
                        model.net((loc.to("cuda"), sent2.to("cuda")), no_sent2=False)
                        .detach()
                        .to("cpu")
                    )
        preds = torch.stack(preds).sigmoid()
    return preds

In [None]:
def print_and_save_res(preds, name, NOCCS=False, FRANCE_ONLY=False, STD=False):

    try:
        os.mkdir("./visuals/" + name)
    except:
        pass
    NUM_SAMPLES = len(to_map)

    for sid in range(NUM_SAMPLES):

        vmin = 0
        if STD:
            vmax = 0.5
        else:
            vmax = 1
        occs = dataset_file.query("speciesId == " + str(to_map[sid]))
        assert len(occs) == num_samples[sid]
        lon_occs = occs["lon"].to_numpy()
        lat_occs = occs["lat"].to_numpy()

        # Ocean mask can be downloaded the original SINR repo
        mask = np.load(
            os.path.join(Data_FOLDER + "sinr_data/data/masks", "ocean_mask_hr.npy")
        )
        lon_res = mask.shape[1] / 360
        lat_res = mask.shape[0] / 180
        north = int((90 - max_lat) * lat_res)
        south = int((90 - min_lat) * lat_res)
        west = int((180 + min_lon) * lon_res)
        east = int((180 + max_lon) * lon_res)
        mask = mask[north:south, west:east]

        fig, ax = plt.subplots(figsize=(6, 4))
        if not FRANCE_ONLY:
            ax.set_xlim([-10.53904, 34.55792])
            ax.set_ylim([34.56858, 71.18392])
        else:
            ax.set_xlim([-4.807615, 8.238722])
            ax.set_ylim([42.325170, 51.235825])
        cmap = plt.cm.plasma
        cmap.set_bad(color="none")
        mask_inds = np.where(mask.reshape(-1) == 1)[0]

        im = preds[:, to_map[sid]]
        print(
            "SpeciesId:",
            to_map[sid],
            "; Num samples:",
            num_samples[sid],
            im.min().item(),
            im.max().item(),
        )
        im = torch.rot90(im.view(RES_LON, RES_LAT))
        im = torch.reshape(im, (RES_LAT * RES_LON, 1))
        im = im[mask_inds]

        op_im = np.ones(mask.shape[0] * mask.shape[1]) * np.nan
        op_im[mask_inds] = im.detach().view(len(mask_inds)).numpy()
        op_im = np.ma.masked_invalid(op_im)
        op_im = op_im.reshape(RES_LAT, RES_LON)

        TRESHHOLD = 0
        if TRESHHOLD > 0:
            # op_im[op_im > TRESHHOLD] = 1
            op_im[op_im <= TRESHHOLD] = 0

        if FRANCE_ONLY:
            op_im = op_im[408 - 186 : 408 - 86, 64:209]

        if FRANCE_ONLY:
            im = ax.imshow(
                op_im,
                extent=(-4.807615, 8.238722, 42.325170, 51.235825),
                vmin=vmin,
                vmax=vmax,
                cmap=cmap,
            )
        else:
            im = ax.imshow(
                op_im,
                extent=(-10.53904, 34.55792, 34.56858, 71.18392),
                vmin=vmin,
                vmax=vmax,
                cmap=cmap,
            )
        if not NOCCS:
            ax.scatter(lon_occs, lat_occs, c="lime", alpha=0.5, s=3)

        if not name == "only_dist":
            fig.colorbar(im, ax=ax)

        fig.savefig(
            "./visuals/"
            + name
            + "/"
            + str(to_map[sid])
            + ("_noccs" if NOCCS else "")
            + ("_france" if FRANCE_ONLY else "")
            + ("_std" if STD else "")
        )

        plt.show()

In [None]:
checkpoint_names = ["", "", ""]
preds = []
for name in checkpoint_names:
    model, sinr, dataset = get_model(name)
    preds.append(get_preds(model, sinr, dataset))

In [None]:
preds = torch.stack(preds)
preds_average = preds.mean(axis=0)
preds_std = preds.std(axis=0)
print(name)

In [None]:
to_map = [
    265,
    268,
    271,
    439,
    751,
    905,
    966,
    1122,
    1224,
    1303,
    1559,
    1957,
    2071,
    2854,
    3207,
    3384,
    3947,
    4062,
    4269,
    4501,
    5022,
    5113,
    5400,
    5793,
    6510,
    6612,
    6895,
    6922,
    7519,
    7580,
    7760,
    8023,
    8196,
    8267,
    8586,
    8791,
    8994,
    9170,
    9240,
    9315,
    9509,
    9753,
    9761,
    9807,
    9983,
]  # classes present in the first PA sample
# to_map = random.sample(c.keys(), NUM_SAMPLES)
to_map = [265]
num_samples = [c[sid] for sid in to_map]
"""name = "only_dist"
preds_average = torch.zeros(502*408, 10000)"""
print_and_save_res(preds_average, name, NOCCS=False, FRANCE_ONLY=False, STD=False)
print_and_save_res(preds_average, name, NOCCS=False, FRANCE_ONLY=True, STD=False)
print_and_save_res(preds_average, name, NOCCS=True, FRANCE_ONLY=False, STD=False)
print_and_save_res(preds_average, name, NOCCS=True, FRANCE_ONLY=True, STD=False)
"""print_and_save_res(preds_std, name, NOCCS=False, FRANCE_ONLY=False, STD=True)
print_and_save_res(preds_std, name, NOCCS=False, FRANCE_ONLY=True, STD=True)
print_and_save_res(preds_std, name, NOCCS=True, FRANCE_ONLY=False, STD=True)
print_and_save_res(preds_std, name, NOCCS=True, FRANCE_ONLY=True, STD=True)"""