generate upper bound for st dataset

In [50]:
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score
import h3pandas
import torch
import h3
from sklearn.metrics import precision_recall_curve
from tqdm import tqdm

sys.path.append('../')
import datasets
import models
import utils
import setup

In [51]:
MODEL_PATH = '../pretrained_models/' + "1000_cap_models/final_loss_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000/model.pt"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_params = torch.load(MODEL_PATH, map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [52]:
data2 = np.load(os.path.join('../data/eval/snt/', 'snt_res_5.npy'), allow_pickle=True)
data2 = data2.item()

In [53]:
data2.keys()

dict_keys(['loc_indices_per_species', 'labels_per_species', 'taxa', 'obs_locs', 'obs_locs_idx'])

In [54]:
loc_indices_per_species = data2['loc_indices_per_species']


In [55]:
species_ids2 = data2['taxa']
labels_per_species = data2['labels_per_species']


In [56]:
train_params['params']['class_to_taxa']

[7,
 14,
 19,
 25,
 26,
 34,
 39,
 41,
 50,
 53,
 62,
 68,
 77,
 85,
 91,
 95,
 101,
 108,
 115,
 132,
 156,
 162,
 178,
 199,
 243,
 248,
 288,
 290,
 297,
 300,
 323,
 324,
 329,
 357,
 401,
 436,
 443,
 458,
 460,
 473,
 478,
 479,
 482,
 486,
 487,
 488,
 489,
 491,
 519,
 522,
 535,
 542,
 585,
 649,
 728,
 804,
 831,
 840,
 846,
 863,
 867,
 871,
 880,
 882,
 890,
 906,
 913,
 931,
 949,
 974,
 981,
 1026,
 1050,
 1066,
 1070,
 1078,
 1081,
 1094,
 1096,
 1134,
 1138,
 1158,
 1204,
 1224,
 1230,
 1241,
 1264,
 1280,
 1300,
 1321,
 1339,
 1360,
 1392,
 1399,
 1406,
 1409,
 1415,
 1419,
 1425,
 1428,
 1439,
 1449,
 1459,
 1468,
 1478,
 1485,
 1486,
 1489,
 1495,
 1501,
 1507,
 1514,
 1525,
 1530,
 1538,
 1550,
 1559,
 1590,
 1593,
 1607,
 1612,
 1616,
 1626,
 1630,
 1644,
 1656,
 1670,
 1677,
 1691,
 1692,
 1705,
 1709,
 1719,
 1723,
 1730,
 1736,
 1738,
 1758,
 1787,
 1788,
 1789,
 1793,
 1827,
 1832,
 1850,
 1856,
 1865,
 1877,
 1903,
 1904,
 1907,
 1917,
 1927,
 1930,
 1940,
 19

In [57]:
obs_locs = np.array(data2['obs_locs'], dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

# get classes to eval
classes_of_interest = torch.zeros(len(species_ids2), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids2):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(tt)])
    if len(class_of_interest) != 0:
        classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)

In [58]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [59]:
def f1_at_thresh(y_true, y_pred, thresh, type = 'binary'):
    y_thresh = y_pred > thresh
    return f1_score(y_true, y_thresh, average=type)

In [60]:
threshs1 = pd.read_csv("../snt_core_results/rdm_background_results/an_full_1000/"+f"/thresholds.csv")
threshs2 = pd.read_csv("../snt_core_results/masking_results/an_full_1000/"+f"/thresholds.csv")
threshs3 = pd.read_csv("../snt_core_results/tgt_background_results/an_full_1000/"+f"/thresholds.csv")

In [61]:
threshs1

Unnamed: 0.1,Unnamed: 0,taxon_id,thres,area,pseudo_fscore
0,0,7,0.174615,1.370584e+07,0.809074
1,1,162,0.041433,2.169532e+07,0.798486
2,2,243,0.032748,1.443876e+07,0.820774
3,3,473,0.056576,1.384570e+07,0.813376
4,4,519,0.165063,1.002357e+07,0.789298
...,...,...,...,...,...
530,530,979756,0.032159,1.747132e+07,0.830421
531,531,979757,0.101276,1.132299e+07,0.802454
532,532,1286843,0.021704,7.985930e+06,0.886598
533,533,1289467,0.099973,3.218448e+06,0.922118


In [62]:
threshs2

Unnamed: 0.1,Unnamed: 0,taxon_id,thres
0,0,7,0.075577
1,1,162,0.056274
2,2,243,0.057868
3,3,473,0.064698
4,4,519,0.029784
...,...,...,...
530,530,979756,0.038190
531,531,979757,0.048353
532,532,1286843,0.046857
533,533,1289467,0.120552


In [63]:
threshs3

Unnamed: 0.1,Unnamed: 0,taxon_id,thres,area,pseudo_fscore
0,0,7,0.092541,1.889643e+07,0.924000
1,1,162,0.010249,2.899613e+07,0.968539
2,2,243,0.007171,2.124565e+07,0.970484
3,3,473,0.029878,1.777582e+07,0.957160
4,4,519,0.043879,2.241053e+07,0.955326
...,...,...,...,...,...
530,530,979756,0.017539,2.005423e+07,0.952819
531,531,979757,0.024039,2.039059e+07,0.970190
532,532,1286843,0.037028,7.005170e+06,0.951977
533,533,1289467,0.050488,3.834521e+06,0.961538


In [67]:
per_species_f1 = np.zeros(len(classes_of_interest))
for tt_id, tt in tqdm(enumerate(classes_of_interest), total=len(classes_of_interest)):

    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()

    # generate ground truth labels for current taxa
    cur_loc_indices = np.array(loc_indices_per_species[tt_id])
    cur_labels = np.array(labels_per_species[tt_id])

    pred = preds[cur_loc_indices]

    thresh = threshs1['thres'][tt_id]
    per_species_f1[tt_id] = f1_at_thresh(cur_labels, pred, thresh, type='binary')


100%|██████████| 535/535 [00:30<00:00, 17.60it/s]


In [68]:
per_species_f1.mean()

0.6678959835445306

In [None]:
per_species_f1.mean()

0.7019053670673474