In [75]:
#import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score
import h3pandas
import torch
import h3
from sklearn.metrics import precision_recall_curve
from shapely.geometry import Polygon
import geopandas as gpd
import matplotlib.pyplot as plt
import math

#sys.path.append('../')
import datasets
import models
import utils
import setup

get training data

In [2]:
train_params = {}

train_params['experiment_name'] = 'demo' # This will be the name of the directory where results for this run are saved.
train_params['species_set'] = 'all'
train_params['hard_cap_num_per_class'] = 1000
train_params['num_aux_species'] = 0
train_params['input_enc'] = 'sin_cos'
train_params['loss'] = 'an_full'

In [3]:
params = setup.get_default_params_train(train_params)

In [4]:
train_dataset = datasets.get_train_data(params)


Loading  data/train/geo_prior_train.csv
Number of unique classes 47375
subsampling (up to) 1000 per class for the training set
final training set size: 15132683


In [5]:
train_df = pd.DataFrame(train_dataset.locs, columns=['lng','lat'])
train_df['lng'] = train_df['lng']*180
train_df['lat'] = train_df['lat']*90
train_df['label'] = train_dataset.labels

In [6]:
h3_resolution = 4
train_df_h3 = train_df.h3.geo_to_h3(h3_resolution)
all_spatial_grid_counts = train_df_h3.index.value_counts()
presence_absence = pd.DataFrame({
    "background": all_spatial_grid_counts,
})
presence_absence = presence_absence.fillna(0)

In [7]:
resolution = h3_resolution
area = h3.hex_area(resolution)

In [8]:
# load model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_params = torch.load('./pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [9]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [10]:
#load reference from iucn
with open(os.path.join('./data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)
species_ids = list((data['taxa_presence'].keys()))

In [11]:
def generate_h3_cells_atRes(resolution=4):
    h3_cells = list(h3.get_res0_indexes())
    h3_atRes_cells = set()
    for cell in h3_cells:
        h3_atRes_cells = h3_atRes_cells.union(h3.h3_to_children(cell, resolution))
    return list(h3_atRes_cells)

In [12]:
#generate gdfk table
h3_atRes_cells = generate_h3_cells_atRes(h3_resolution)
gdfk = pd.DataFrame(index=h3_atRes_cells).h3.h3_to_geo()
gdfk["lng"] = gdfk["geometry"].x
gdfk["lat"] = gdfk["geometry"].y
_ = gdfk.pop("geometry")
gdfk = gdfk.rename_axis('h3index')


In [13]:
obs_locs = np.array(gdfk[['lng', 'lat']].values, dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [14]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    

In [15]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [80]:
def subsample_thresh(presences, sample_size, n_times):
    thresholds = []
    for n in range(n_times):
        thresholds.append(np.random.choice(presences.values, size=math.ceil(sample_size*len(presences)), replace=False).min())
    return np.mean(thresholds)

In [83]:
def f1_at_thresh(y_true, y_pred, thresh, type = 'binary'):
        y_thresh = y_pred > thresh
        return f1_score(y_true, y_thresh, average=type)

In [88]:
obs_locs_iucn = np.array(data['locs'], dtype=np.float32)
obs_locs_iucn = torch.from_numpy(obs_locs_iucn).to('cpu')
loc_feat_iucn = enc.encode(obs_locs_iucn)

with torch.no_grad():
        loc_emb_iucn = model(loc_feat_iucn, return_feats=True)
        wt_iucn = model.class_emb.weight[classes_of_interest, :]

In [101]:
output = []
for class_index, class_id in enumerate(classes_of_interest):
    wt_1 = wt[class_index,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    gdfk["pred"] = preds

    target_spatial_grid_counts = train_df_h3[train_df_h3.label==class_id.item()].index.value_counts()
         
    presence_absence["forground"] = target_spatial_grid_counts
    presence_absence["predictions"] = gdfk["pred"]
    presence_absence.forground = presence_absence.forground.fillna(0)

    presences = presence_absence[(presence_absence["forground"]>0)]["predictions"]

    ####
    # Applying subsample_thresh function
    sample_sizes = np.arange(0.1, 1, 0.1)
    n_times_values = [1, 5, 10]

    ###load iucn data
    taxa = train_dataset.class_to_taxa[class_id.item()]
    species_locs = data['taxa_presence'].get(str(taxa))
    wt_1_iucn = wt_iucn[class_index,:]
    preds_iucn = torch.sigmoid(torch.matmul(loc_emb_iucn, wt_1_iucn)).cpu().numpy()

    for sample_size in sample_sizes:
        for n_times in n_times_values:

            thres = subsample_thresh(presences, sample_size, n_times)
            y_test = np.zeros(preds_iucn.shape, int)
            y_test[species_locs] = 1

            f1 = f1_at_thresh(y_test, preds_iucn, thres, type='binary')

            output.append({"taxon_id": taxa,
                           'sample_size': sample_size,
                           'n_times': n_times,
                           'thres': thres,
                           "iucn_f1": f1})
    
    if(class_index%100==0):
            print(class_index)
output_pd = pd.DataFrame(output)
    

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400


In [102]:
output_pd.to_csv("./results/thresholds/thres_masking_sumsample_all.csv")