### Creating maps from the trained models

In [None]:
import random
from collections import Counter, defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from PIL import Image
import numpy as np
import os
import json
import sys

sys.path.append('/home/jdolli/')
from sentinel2_foundation_model.models import Sent2AE as Sent2VAE, View

from dataset import SINR_DS
from models import *
from utils import DefaultParams
from embedders import get_embedder

In [None]:
#dataset_file = pd.read_csv('/data/jdolli/glc23_data/Pot_10_to_1000.csv', sep=";", header='infer', low_memory=False)
dataset_file = pd.read_csv('/data/jdolli/glc23_data/Pot_10_to_1000_nofrance.csv', sep=";", header='infer', low_memory=False)
cps = os.listdir("/scratch/jdolli/sent-sinr/checkpoints")
with open('/data/jdolli/glc23_data/Presence_Absence_surveys/loc_to_spec.csv', "r") as f:
    val_data = json.load(f)

In [None]:
cps = os.listdir("/scratch/jdolli/sent-sinr/checkpoints")
for cp in cps:
    if "plus" in cp:
        try:
            #if float(cp[-12:-7]) > 0:
                print(cp)
        except:
            pass
        

In [None]:
#print(dataset_file[dataset_file["speciesId"] == 265].iloc[0])

In [None]:
#import pygbif
#pygbif.occurrences.get(key=3951621754)

In [None]:
name = "sat_sinr_mf_zc ae_default val_loss=-0.0052"
#name = "sinr loc_env val_loss=-0.0227"
#name = "sinr loc val_loss=-0.0025"
def get_model(name):
    if "sat" in name:
        sinr = False
        PREDICTORS = "loc_env_sent2"
    elif "env" in name:
        sinr = True
        PREDICTORS = "loc_env"
    else:
        sinr = True
        PREDICTORS = "loc"
    if PREDICTORS.endswith("LR"):
        bioclim_path = "/data/jdolli/glc23_data/sinr_data/data/env/bioclim_elevation_scaled_europe.npy"
    else:
        bioclim_path = "/data/jdolli/bioclim+elev/bioclim_elevation_scaled_europe.npy"
    dataset = SINR_DS(dataset_file, PREDICTORS, sent_data_path = "/data/jdolli/glc23_data/SatelliteImages/", bioclim_path = bioclim_path, use_subm_val=False)

    default_params = DefaultParams(sinr)
    default_params.dataset.predictors = PREDICTORS

    if "sat" in name:
        default_params.model = name.split(" ")[0]
        model = SAT_SINR(default_params, dataset, get_embedder(default_params))
    else:
        model = SINR(default_params, dataset)

    path = "/scratch/jdolli/sent-sinr/checkpoints/" + name + ".ckpt"

    state_dict = torch.load(path)["state_dict"]
    model.load_state_dict(state_dict, strict=True)
    return model.eval(), sinr, dataset

In [None]:
RES_LON = 502
RES_LAT = 408

c = Counter(dataset_file["speciesId"].to_numpy())
# to_map = [6372, 6805, 5782, 1192, 6325]

max_lon = 34.55792
min_lon = -10.53904
max_lat = 71.18392
min_lat = 34.56858

In [None]:
def get_preds(model, sinr, dataset):
    if sinr:
        locs = []
        model.net.to("cpu")
        for i in tqdm(range(RES_LON)):
            # i is lon
            # j is lat
            for j in range(RES_LAT):
                lon = i/RES_LON
                lat = j/RES_LAT
                lon = lon * (max_lon - min_lon) + min_lon
                lat = lat * (max_lat - min_lat) + min_lat
                locs.append(dataset.encode(lon, lat))
        locs = torch.stack(locs)
        preds = model(locs).sigmoid()
    else:
        preds = []
        model.net.to("cuda")
        for i in tqdm(range(RES_LON)):
            # i is lon
            # j is lat
            for j in range(RES_LAT):
                lon = i/RES_LON
                lat = j/RES_LAT
                lon = lon * (max_lon - min_lon) + min_lon
                lat = lat * (max_lat - min_lat) + min_lat
                loc = dataset.encode(lon, lat)
                pos = str(lat) + "," + str(lon)
                rgb_path = "/data/jdolli/sentinel_2 2021 Europe/rgb/" + pos  + ".jpeg"
                nir_path = "/data/jdolli/sentinel_2 2021 Europe/nir/" + pos  + ".jpeg"
                try:
                    rgb = Image.open(rgb_path)
                    nir = Image.open(nir_path)
                    to_tensor = torchvision.transforms.PILToTensor()
                    sent2 = torch.concat([to_tensor(rgb), to_tensor(nir)], dim=0)/255
                except:
                    sent2 = torch.zeros(4, 128, 128)
                if sent2.shape != torch.Size([4, 128, 128]):
                    sent2 = torch.zeros(4, 128, 128)
                with torch.no_grad():
                    preds.append(model.net((loc.to("cuda"), sent2.to("cuda")), no_sent2=False).detach().to("cpu"))
        preds = torch.stack(preds).sigmoid() #Put it again for everything except the sigmoid fucked one
        #preds = torch.stack(preds)
    return preds

In [None]:
def print_and_save_res(preds, name, NOCCS=False, FRANCE_ONLY=False, VAL_DATA=False, STD=False):
    FRANCE_SE = 42.325170, 8.238722
    FRANCE_NW = 51.235825, -4.807615

    if VAL_DATA:
        id_to_val = defaultdict(list)
        for idx in to_map:
            for key in val_data:
                if idx in val_data[key]:
                    id_to_val[str(idx)+"_lon"].append(float(key.split("/")[0]))
                    id_to_val[str(idx)+"_lat"].append(float(key.split("/")[1]))


    try:
        os.mkdir("./visuals/"+name)
    except:
        pass
    NUM_SAMPLES = len(to_map)
    #NUM_SAMPLES = 1

    for sid in range(NUM_SAMPLES):

        vmin = 0
        if STD:
            vmax = 0.5
        else:
            vmax = 1
        #vmin = 0.3
        #vmax = 0.8
        occs = dataset_file.query("speciesId == " + str(to_map[sid]))
        assert len(occs) == num_samples[sid]
        lon_occs = occs["lon"].to_numpy()
        lat_occs = occs["lat"].to_numpy()
        # lon, lat = dataset._normalize_loc_to_uniform(lon, lat)

        mask = np.load(os.path.join("/data/jdolli/glc23_data/sinr_data/data/masks", 'ocean_mask_hr.npy'))
        lon_res = mask.shape[1] / 360
        lat_res = mask.shape[0] / 180
        north = int((90-max_lat) * lat_res)
        south = int((90-min_lat) * lat_res)
        west = int((180 + min_lon) * lon_res)
        east = int((180 + max_lon) * lon_res)
        mask = mask[north:south, west:east]

        fig, ax = plt.subplots(figsize=(6, 4))
        if not FRANCE_ONLY:
            ax.set_xlim([-10.53904, 34.55792])
            ax.set_ylim([34.56858, 71.18392])
        else:
            ax.set_xlim([-4.807615, 8.238722])
            ax.set_ylim([42.325170, 51.235825])
        cmap = plt.cm.plasma
        cmap.set_bad(color='none')
        mask_inds = np.where(mask.reshape(-1) == 1)[0]

        im = preds[:, to_map[sid]]
        print("SpeciesId:", to_map[sid], "; Num samples:", num_samples[sid], im.min().item(), im.max().item())
        im = torch.rot90(im.view(RES_LON, RES_LAT))
        im = torch.reshape(im, (RES_LAT * RES_LON, 1))
        im = im[mask_inds]

        op_im = np.ones(mask.shape[0] * mask.shape[1]) * np.nan
        op_im[mask_inds] = im.detach().view(len(mask_inds)).numpy()
        op_im = np.ma.masked_invalid(op_im)
        op_im = op_im.reshape(RES_LAT, RES_LON)

        TRESHHOLD = 0
        if TRESHHOLD > 0:
            #op_im[op_im > TRESHHOLD] = 1
            op_im[op_im <= TRESHHOLD] = 0

        if FRANCE_ONLY:
            op_im = op_im[408-186:408-86, 64:209]

        if FRANCE_ONLY:
            im = ax.imshow(op_im, extent=(-4.807615, 8.238722, 42.325170, 51.235825), vmin=vmin, vmax=vmax, cmap=cmap)
        else:
            im = ax.imshow(op_im, extent=(-10.53904, 34.55792, 34.56858, 71.18392), vmin=vmin, vmax=vmax, cmap=cmap)
        if not NOCCS:
            ax.scatter(lon_occs, lat_occs, c="lime", alpha=0.5, s=3)
        if VAL_DATA:
            ax.scatter(id_to_val[str(to_map[sid])+"_lon"], id_to_val[str(to_map[sid])+"_lat"], c="red", alpha=1, s=5)

        if not name == "only_dist":
            fig.colorbar(im, ax=ax)

        fig.savefig("./visuals/"+name+"/"+str(to_map[sid])+("_noccs" if NOCCS else "")+("_france" if FRANCE_ONLY else "")
                   +("_std" if STD else ""))

        plt.show()

In [None]:
sat_sinr_lf = [
    "sat_sinr_lf ae_default val_loss=0.0342",
    "sat_sinr_lf ae_default val_loss=0.0358",
    "sat_sinr_lf ae_default val_loss=0.0353",
    "sat_sinr_lf ae_default val_loss=0.0345",
    "sat_sinr_lf ae_default val_loss=0.0372"
]
sat_sinr_mf_zc = [
    "sat_sinr_mf_zc ae_default val_loss=-0.0047",
    "sat_sinr_mf_zc ae_default val_loss=-0.0046-v1",
    "sat_sinr_mf_zc ae_default val_loss=-0.0071-v1",
    "sat_sinr_mf_zc ae_default val_loss=-0.0062",
    "sat_sinr_mf_zc ae_default val_loss=-0.0052",
]
sinr_loc = [
    "sinr loc val_loss=0.0011",
    "sinr loc val_loss=0.0006",
    "sinr loc val_loss=-0.0040",
    "sinr loc val_loss=0.0000",
    "sinr loc val_loss=-0.0028"
]
sinr_loc_env = [
    "sinr loc_env val_loss=-0.0267",
    "sinr loc_env val_loss=-0.0240",
    "sinr loc_env val_loss=-0.0238",
    "sinr loc_env val_loss=-0.0222",
    "sinr loc_env val_loss=-0.0233-v1"
]
SIGM_FCK = [
"sat_sinr_lf ae_default SIGMOID FCKEDval_loss=0.0263",
"sat_sinr_lf ae_default SIGMOID FCKEDval_loss=0.0195",
"sat_sinr_lf ae_default SIGMOID FCKEDval_loss=0.0259"
]
lf_plus=[
"sat_sinr_lf ae_default plusval_loss=0.0077",
"sat_sinr_lf ae_default plusval_loss=0.0068",
"sat_sinr_lf ae_default plusval_loss=0.0083",
"sat_sinr_lf ae_default plusval_loss=0.0064"
]
preds = []
for name in lf_plus[:3]:
    model, sinr, dataset = get_model(name)
    preds.append(get_preds(model, sinr, dataset))

In [None]:
preds_sinr_loc = torch.stack(preds)
preds_average = preds_sinr_loc.mean(axis=0)
preds_std = preds_sinr_loc.std(axis=0)
print(name)

In [None]:
preds_sinr_lf_plus = torch.stack(preds)
preds_average = preds_sinr_lf_plus.mean(axis=0)
preds_std = preds_sinr_lf_plus.std(axis=0)
print(name)

In [None]:
preds_sinr_fuck = torch.stack(preds)
preds_average = preds_sinr_fuck.mean(axis=0)
preds_std = preds_sinr_fuck.std(axis=0)
print(name)

In [None]:
preds_sinr_loc_env = torch.stack(preds)
preds_average = preds_sinr_loc_env.mean(axis=0)
preds_std = preds_sinr_loc_env.std(axis=0)
print(name)

In [None]:
preds_lf = torch.stack(preds)
preds_average = preds_lf.mean(axis=0)
preds_std = preds_lf.std(axis=0)
print(name)

In [None]:
preds_mf_zc = torch.stack(preds)
preds_average = preds_mf_zc.mean(axis=0)
preds_std = preds_mf_zc.std(axis=0)
print(name)

In [None]:
to_map = [265, 268,271,439,751,905,966,1122,1224,1303,1559,1957,2071,2854,3207,3384,3947,4062,4269,4501,5022,5113,5400,5793,6510,6612,6895,6922,7519,7580,
          7760,8023,8196,8267,8586,8791,8994,9170,9240,9315,9509,9753,9761,9807,9983] # classes present in the first PA sample
# to_map = random.sample(c.keys(), NUM_SAMPLES)
to_map = [265]
num_samples = [c[sid] for sid in to_map]
"""name = "only_dist"
preds_average = torch.zeros(502*408, 10000)"""
print_and_save_res(preds_average, name, NOCCS=False, FRANCE_ONLY=False, VAL_DATA=False, STD=False)
print_and_save_res(preds_average, name, NOCCS=False, FRANCE_ONLY=True, VAL_DATA=False, STD=False)
print_and_save_res(preds_average, name, NOCCS=True, FRANCE_ONLY=False, VAL_DATA=False, STD=False)
print_and_save_res(preds_average, name, NOCCS=True, FRANCE_ONLY=True, VAL_DATA=False, STD=False)
"""print_and_save_res(preds_std, name, NOCCS=False, FRANCE_ONLY=False, VAL_DATA=False, STD=True)
print_and_save_res(preds_std, name, NOCCS=False, FRANCE_ONLY=True, VAL_DATA=False, STD=True)
print_and_save_res(preds_std, name, NOCCS=True, FRANCE_ONLY=False, VAL_DATA=False, STD=True)
print_and_save_res(preds_std, name, NOCCS=True, FRANCE_ONLY=True, VAL_DATA=False, STD=True)"""