In [1]:
#import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score
import h3pandas
import torch
import h3
from sklearn.metrics import precision_recall_curve
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss


#sys.path.append('../')
import datasets
import models
import utils
import setup

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


get training data

In [2]:
train_params = {}

train_params['experiment_name'] = 'demo' # This will be the name of the directory where results for this run are saved.
train_params['species_set'] = 'all'
train_params['hard_cap_num_per_class'] = 1000
train_params['num_aux_species'] = 0
train_params['input_enc'] = 'sin_cos'
train_params['loss'] = 'an_full'

In [3]:
params = setup.get_default_params_train(train_params)

In [4]:
# load model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_params = torch.load('./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [5]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [6]:
#load reference from iucn
with open(os.path.join('./data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)
species_ids = list((data['taxa_presence'].keys()))

In [7]:
def generate_h3_cells_atRes(resolution=4):
    h3_cells = list(h3.get_res0_indexes())
    h3_atRes_cells = set()
    for cell in h3_cells:
        h3_atRes_cells = h3_atRes_cells.union(h3.h3_to_children(cell, resolution))
    return list(h3_atRes_cells)

In [8]:
obs_locs = np.array(data['locs'], dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [9]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

In [10]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [11]:
output = []
species_dict = dict()
for class_index, class_id in enumerate(classes_of_interest):
    wt_1 = wt[class_index,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    
    taxa = taxa_ids[class_index].item()
    species_locs = data['taxa_presence'].get(str(taxa))

    truth_array = np.zeros(preds.shape, int)
    truth_array[species_locs] = 1

    #generate calibration curve data for clustering
    n_bins = 20
    y_true = truth_array
    y_prob = preds
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')
    calib_dict = dict()
    calib_dict['pred_probs'] = prob_pred
    calib_dict['emp_probs'] = prob_true


    # Compute Expected Calibration Error (ECE)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    binids = np.searchsorted(bins[1:-1], y_prob)
    bin_total = np.bincount(binids, minlength=len(bins))
    nonzero = bin_total != 0
    ece = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))

    # Compute Brier score
    brier_score = brier_score_loss(y_true, y_prob)

    row = {
        "taxon_id": taxa,
        "ece": ece,
        "brier": brier_score
    }
    row_dict = dict(row)
    output.append(row_dict)

    if(class_index%100==0):
            print(class_index)

output_pd = pd.DataFrame(output)


0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400


In [21]:
from torchmetrics.classification import BinaryCalibrationError
y_prob_t, y_true_t = torch.tensor(y_prob), torch.tensor(y_true)


metric = BinaryCalibrationError(n_bins=10, norm='l1')
metric(y_prob_t, y_true_t)

tensor(0.0003)

In [23]:
taxa

1369303

In [15]:
output_pd

Unnamed: 0,taxon_id,ece,brier
0,17090,0.001350,0.001237
1,18938,0.003294,0.003910
2,17556,0.003930,0.004511
3,18295,0.009091,0.007797
4,14152,0.000371,0.000749
...,...,...,...
2413,1368519,0.000960,0.001386
2414,1367368,0.000707,0.000511
2415,1369291,0.000311,0.000444
2416,1369292,0.000237,0.000172


In [12]:
output_pd.describe()

Unnamed: 0,taxon_id,ece,brier
count,2418.0,2418.0,2418.0
mean,107035.6,0.006202,0.006176
std,229829.4,0.016927,0.016409
min,14.0,1.8e-05,8e-06
25%,11957.25,0.000658,0.00068
50%,28961.5,0.001761,0.001991
75%,67792.5,0.006133,0.006421
max,1369303.0,0.511264,0.506477


In [13]:
output_pd[output_pd['ece']>0.2]

Unnamed: 0,taxon_id,ece,brier
787,4146,0.511264,0.506477
830,4535,0.292255,0.275345


In [14]:
output_pd[output_pd['taxon_id'] == 4537]

Unnamed: 0,taxon_id,ece,brier
823,4537,0.152197,0.143562
