In [1]:
#import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score
import h3pandas
import torch
import h3
from sklearn.metrics import precision_recall_curve
from shapely.geometry import Polygon
import geopandas as gpd
import matplotlib.pyplot as plt

sys.path.append('../')
import datasets
import models
import utils
import setup

from sklearn.linear_model import LogisticRegression


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


get training data

In [6]:
train_params = {}

train_params['experiment_name'] = 'demo' # This will be the name of the directory where results for this run are saved.
train_params['species_set'] = 'all'
train_params['hard_cap_num_per_class'] = 1000
train_params['num_aux_species'] = 0
train_params['input_enc'] = 'sin_cos'
train_params['loss'] = 'an_full'

In [7]:
params = setup.get_default_params_train(train_params)

In [8]:
train_dataset = datasets.get_train_data(params)


Loading  ../data/train/geo_prior_train.csv
Number of unique classes 47375
subsampling (up to) 1000 per class for the training set
final training set size: 15132683


In [9]:
train_df = pd.DataFrame(train_dataset.locs, columns=['lng','lat'])
train_df['lng'] = train_df['lng']*180
train_df['lat'] = train_df['lat']*90
train_df['label'] = train_dataset.labels

In [10]:
h3_resolution = 5
train_df_h3 = train_df.h3.geo_to_h3(h3_resolution)
all_spatial_grid_counts = train_df_h3.index.value_counts()
presence_absence = pd.DataFrame({
    "background": all_spatial_grid_counts,
})
presence_absence = presence_absence.fillna(0)

In [11]:
resolution = h3_resolution
area = h3.hex_area(resolution)

In [12]:
# load model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_params = torch.load('../pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [13]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [14]:
#load reference from iucn
with open(os.path.join('../data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)
species_ids = list((data['taxa_presence'].keys()))

In [15]:
def generate_h3_cells_atRes(resolution=4):
    h3_cells = list(h3.get_res0_indexes())
    h3_atRes_cells = set()
    for cell in h3_cells:
        h3_atRes_cells = h3_atRes_cells.union(h3.h3_to_children(cell, resolution))
    return list(h3_atRes_cells)

In [18]:
h3_atRes_cells = [h3.geo_to_h3(coord[1], coord[0], resolution=5) for coord in data['locs']]
gdfk = pd.DataFrame(index=h3_atRes_cells)
gdfk["lng"] = np.array(data['locs'])[:,0]
gdfk["lat"] = np.array(data['locs'])[:,1]
gdfk = gdfk.rename_axis('h3index')

In [19]:
obs_locs = np.array(gdfk[['lng', 'lat']].values, dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [20]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    

In [21]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [22]:
def fscore_and_thres(y_test, preds):
    precision, recall, thresholds = precision_recall_curve(y_test, preds)
    p1 = (2 * precision * recall)
    p2 = (precision + recall)
    out = np.zeros( (len(p1)) )
    fscore = np.divide(p1,p2, out=out, where=p2!=0)
    index = np.argmax(fscore)
    thres = thresholds[index]
    max_fscore = fscore[index]
    return max_fscore, thres

In [30]:
freq_df = pd.read_csv('../data/train/geo_prior_train.csv')

In [37]:
counts = freq_df.taxon_id.value_counts()
cap = 1000
counts_capped = counts.apply(lambda x: cap if x > cap else x)

In [None]:
def fscore_and_thres(y_test, preds):
    precision, recall, thresholds = precision_recall_curve(y_test, preds)
    p1 = (2 * precision * recall)
    p2 = (precision + recall)
    out = np.zeros( (len(p1)) )
    fscore = np.divide(p1,p2, out=out, where=p2!=0)
    index = np.argmax(fscore)
    thres = thresholds[index]
    max_fscore = fscore[index]
    return max_fscore, thres

In [53]:
def norm_cal_activation(preds,N,gamma):
    ac = N**gamma
    return (np.exp(preds / ac)) / ((np.exp(preds / ac)) + 1)

In [56]:
from sklearn.model_selection import train_test_split
output = []
for class_index, class_id in enumerate(classes_of_interest):
    wt_1 = wt[class_index,:]
    raw_preds = torch.matmul(loc_emb, wt_1)
    taxa = train_dataset.class_to_taxa[class_id.item()]
    N = counts_capped[taxa]

    species_locs = data['taxa_presence'].get(str(taxa))
    y_test = np.zeros(preds.shape, int)
    y_test[species_locs] = 1

    gammas = [0.5, 1, 2, 5]

    row_dict = dict()
    row_dict['taxon_id'] = taxa

    for gamma in gammas:
        preds = norm_cal_activation(raw_preds.numpy(), N, gamma=gamma)
        max_fscore, thres = fscore_and_thres(y_test, preds)
        row_dict[f'max_fscore_gamma_{gamma}'] = max_fscore
        row_dict[f'thres_gamma_{gamma}'] = thres
    
    output.append(row_dict)

    if(class_index%100==0):
            print(class_index)
    if class_index == 3: break

output_pd = pd.DataFrame(output)
    

0


In [57]:
output_pd

Unnamed: 0,taxon_id,max_fscore_gamma_0.5,thres_gamma_0.5,max_fscore_gamma_1,thres_gamma_1,max_fscore_gamma_2,thres_gamma_2,max_fscore_gamma_5,thres_gamma_5
0,17090,0.784927,0.442041,0.784927,0.493091,0.784894,0.499903,0.784927,0.5
1,18938,0.603481,0.460475,0.603481,0.497495,0.603465,0.49999,0.603457,0.5
2,17556,0.573293,0.440308,0.573293,0.491602,0.573269,0.499835,0.573293,0.5
3,18295,0.904918,0.421569,0.904918,0.489958,0.904902,0.499838,0.904918,0.5


In [45]:
preds

array([6.6966983e-12, 4.6307237e-18, 7.6709468e-17, ..., 3.3998110e-16,
       8.8218506e-13, 3.9533735e-14], dtype=float32)