In [1]:
import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('../../')
import datasets
import models
import utils

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# load model
train_params = torch.load('../../pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [5]:
#load reference from iucn
with open(os.path.join('../../data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)
species_ids = list((data['taxa_presence'].keys()))

In [6]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [7]:
obs_locs = np.array(data['locs'], dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [10]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

In [13]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [14]:
def f1_at_thresh(y_true, y_pred, thresh, type = 'binary'):
    y_thresh = y_pred > thresh
    return f1_score(y_true, y_thresh, average=type)

### include inat preds

In [16]:
inat_threshs = pd.read_csv('../../inat_thresholds_v1.csv')
inat_threshs

Unnamed: 0.1,Unnamed: 0,taxon_id,thres,area
0,0,17090,0.291031,1.033869e+06
1,1,18938,0.813246,6.585604e+05
2,2,17556,0.012559,8.400185e+06
3,3,18295,0.089120,7.084835e+06
4,4,14152,0.206174,8.975540e+05
...,...,...,...,...
2413,2413,1368519,0.436164,1.649942e+06
2414,2414,1367368,0.455293,6.939668e+05
2415,2415,1369291,0.435408,4.921499e+05
2416,2416,1369292,0.316032,2.867924e+05


In [20]:
per_species_f1 = np.zeros((len(taxa_ids)))
for tt_id, taxa in enumerate(taxa_ids):
    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    taxa = taxa.item()
    species_locs = data['taxa_presence'].get(str(taxa))
    y_test = np.zeros(preds.shape, int)
    y_test[species_locs] = 1

    thresh = inat_threshs['thres'][tt_id]
    per_species_f1[tt_id] = f1_at_thresh(y_test, preds, thresh, type='binary')
    if tt_id % 100 == 0:
        print(tt_id)  

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400


In [22]:
per_species_f1.mean()

0.5952574603319684

In [None]:
#np.save('./results/f1_scores_linspace.npy', per_species_f1)