In [1]:

import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score, precision_recall_curve
from sklearn.model_selection import train_test_split
import argparse
import logging
from tqdm import tqdm

sys.path.append('../')
import datasets
import models
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_PATH = '../pretrained_models/' + "model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt"


In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load model
train_params = torch.load(MODEL_PATH, map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()


ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [15]:
#load reference from iucn
with open(os.path.join('../data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)
species_ids = list((data['taxa_presence'].keys()))

if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

obs_locs = np.array(data['locs'], dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

def f1_at_thresh(y_true, y_pred, thresh, type = 'binary'):
    y_thresh = y_pred > thresh
    return f1_score(y_true, y_thresh, average=type)

In [20]:
lengths = [len(x) for x in data['taxa_presence'].values()]

In [26]:
min_value = min(lengths)
min_index = lengths.index(min_value)
min_index

452

In [31]:
output = list()
for tt_id, taxa in tqdm(enumerate(taxa_ids), total=len(taxa_ids)):
    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    taxa = taxa.item()
    species_locs = data['taxa_presence'].get(str(taxa))
    y_test = np.zeros(preds.shape, int)
    y_test[species_locs] = 1

    #split into 10% used for setting thresholds 90% for final eval
    eval_preds, thresh_preds, eval_y_test, thresh_y_test = train_test_split(
        preds, y_test, test_size=0.01, random_state=42,
    )

    #calculate thresholds 
    precision, recall, thresholds = precision_recall_curve(thresh_y_test, thresh_preds)
    p1 = (2 * precision * recall)
    p2 = (precision + recall)
    out = np.zeros( (len(p1)) )
    fscore = np.divide(p1,p2, out=out, where=p2!=0)
    index = np.argmax(fscore)
    thres = thresholds[index]

    #evaluate performance
    f1 = f1_at_thresh(eval_y_test, eval_preds, thres)
    
    row = {
        "taxon_id": taxa,
        "thres": thres,
        "fscore": f1
    }
    row_dict = dict(row)
    output.append(row_dict)
output_pd = pd.DataFrame(output)

100%|██████████| 2418/2418 [17:39<00:00,  2.28it/s]


In [33]:
output_pd.fscore.mean()

0.6365878347289665

In [36]:
output_pd.max()

taxon_id    1.369303e+06
thres       9.817630e-01
fscore      9.277161e-01
dtype: float64