In [65]:
#import torch
import pandas as pd
import numpy as np
import json
import os
import sys
from sklearn.metrics import f1_score
import h3pandas
import torch
import h3
from sklearn.metrics import precision_recall_curve
from tqdm import tqdm
import pickle

sys.path.append('../')
import datasets
import models
import utils
import setup

In [150]:
train_params = {}

train_params['experiment_name'] = 'demo' # This will be the name of the directory where results for this run are saved.
train_params['species_set'] = 'all'
train_params['hard_cap_num_per_class'] = -1
train_params['num_aux_species'] = 0
train_params['input_enc'] = 'sin_cos'
train_params['loss'] = 'an_full'

params = setup.get_default_params_train(train_params)

train_dataset = datasets.get_train_data(params)

train_df = pd.DataFrame(train_dataset.locs, columns=['lng','lat'])
train_df['lng'] = train_df['lng']*180
train_df['lat'] = train_df['lat']*90
train_df['label'] = train_dataset.labels

h3_resolution = 5
train_df_h3 = train_df.h3.geo_to_h3(h3_resolution)
all_spatial_grid_counts = train_df_h3.index.value_counts()
presence_absence = pd.DataFrame({
    "background": all_spatial_grid_counts,
})
presence_absence = presence_absence.fillna(0)

In [151]:
# load model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_params = torch.load('../pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [152]:
def find_mapping_between_models(vision_taxa, geo_taxa):
    # this will output an array of size N_overlap X 2
    # the first column will be the indices of the vision model, and the second is their
    # corresponding index in the geo model
    taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1
    taxon_map[:, 0] = np.arange(vision_taxa.shape[0])
    geo_taxa_arr = np.array(geo_taxa)
    for tt_id, tt in enumerate(vision_taxa):
        ind = np.where(geo_taxa_arr==tt)[0]
        if len(ind) > 0:
            taxon_map[tt_id, 1] = ind[0]
    inds = np.where(taxon_map[:, 1]>-1)[0]
    taxon_map = taxon_map[inds, :]
    return taxon_map

def convert_to_inat_vision_order(geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map):
        # this is slow as we turn the sparse input back into the same size as the dense one
        vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
        geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
        vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob

        geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]]

        return geo_pred, vision_pred



In [153]:
with open('paths.json', 'r') as f:
            paths = json.load(f)
# load vision model predictions:
data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz'))
print(data['probs'].shape[0], 'total test observations')
# load locations:
meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv'))
obs_locs  = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32)
# taxonomic mapping:
taxon_map = find_mapping_between_models(data['model_to_taxa'], train_params['params']['class_to_taxa'])
print(taxon_map.shape[0], 'out of', len(data['model_to_taxa']), 'taxa in both vision and geo models')

282974 total test observations
44877 out of 55378 taxa in both vision and geo models


In [154]:
data['model_to_taxa']

array([461052, 891919, 401834, ..., 321494, 140477,  68764])

In [155]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [156]:
freq_df = pd.read_csv('../data/train/geo_prior_train.csv')
counts = freq_df.taxon_id.value_counts()

In [None]:
def norm_cal_activation(preds,N,gamma):
    ac = N**gamma
    return (np.exp(preds)/ ac) / ((np.exp(preds) / ac) + 1)

In [159]:
train_params['params']['class_to_taxa'][0]

7

In [163]:
for i in range(raw_preds.shape[1]):
    id = train_params['params']['class_to_taxa'][i]
    print(counts[id])

5115
432
313
60
173
1091
161
5444
69
528
107
240
348
140
151
191
70
285
202
304
161
3398
305
2365
1619
134
65
564
75
154
84
54
147
3746
677
202
305
270
448
37491
352
459
22370
940
1241
1123
119
697
3750
73
95
54
81
83
51
197
1930
59
1258
83
1406
53
103
1579
5067
35864
478
674
636
527
10156
540
54
153
564
65
422
227
59
66
51
165
4997
60
51
1053
320
2071
138
152
55
54
342
350
3755
9771
241
1096
66
1959
127
144
64
1016
54
54
86
75
62
98
183
348
69
91
204
113
426
72
105
168
62
124
750
359
288
87
308
1013
61
449
157
128
231
299
340
90
336
3540
502
61
389
157
72
118
722
193
72
56
192
103
3065
67
144
141
403
988
480
3838
287
656
2716
5172
1687
65
10214
416
299
491
2655
86
244
728
597
127
865
63
234
51
177
173
211
124
1171
97
63
507
285
865
3076
135
72
170
78
107
1664
388
57
459
75
1778
99
55
937
4466
99
222
398
1209
207
218
151
1713
1616
145
249
62
396
753
130
51
71
249
1320
888
369
1188
462
327
559
105
616
65
215
527
6690
244
101
1815
582
19475
3238
87
230
54
7812
86
1331
3955
81
505
1700
61

In [183]:
K = len(presence_absence)
IF = len(train_df_h3[train_df_h3.label==0].index.value_counts())
IIF = np.log(K/IF)

In [184]:
IIF

5.201911636243623

In [187]:
K = len(presence_absence)

def calibrate_predictions(raw_preds):
    calibrated_preds = raw_preds.detach().clone()

    # Iterate over each column of raw_preds
    for col_idx in range(raw_preds.shape[1]):
        col_value = raw_preds[:, col_idx]

        IF = len(train_df_h3[train_df_h3.label==col_idx].index.value_counts())
        IIF = np.log(K/IF)

        scaled_raw_preds = col_value*IIF
        col_preds = torch.sigmoid(scaled_raw_preds).cpu().numpy()

        calibrated_preds[:, col_idx] = torch.tensor(col_preds)
    return calibrated_preds

In [147]:
results = {}

# loop over in batches
batch_start = np.hstack((np.arange(0, data['probs'].shape[0], 2048), data['probs'].shape[0]))
correct_pred = np.zeros(data['probs'].shape[0])

for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1)), total=len(batch_start)-1):
    batch_inds = np.arange(batch_start[bb], batch_start[bb+1])

    vision_probs = data['probs'][batch_inds, :]
    vision_inds = data['inds'][batch_inds, :]
    gt = data['labels'][batch_inds]

    obs_locs_batch = torch.from_numpy(obs_locs[batch_inds, :]).to('cpu')
    loc_feat = enc.encode(obs_locs_batch)

    with torch.no_grad():
        geo_pred = model(loc_feat).cpu().numpy()
        lolo = geo_pred

    geo_pred, vision_pred = convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds,
                                                            data['model_to_taxa'], taxon_map)

    comb_pred = np.argmax(vision_pred*geo_pred, 1)
    comb_pred = (comb_pred==gt)
    correct_pred[batch_inds] = comb_pred
    if bb_id== 20: break
    break

results['vision_only_top_1'] = float((data['inds'][:, -1] == data['labels']).mean())
results['vision_geo_top_1'] = float(correct_pred.mean())


  0%|          | 0/139 [00:07<?, ?it/s]


In [149]:
lolo.shape

(2048, 47375)

In [108]:
results

{'vision_only_top_1': 0.7540940157046231,
 'vision_geo_top_1': 0.005898068373772856}

In [188]:
results = {}

# loop over in batches
batch_start = np.hstack((np.arange(0, data['probs'].shape[0], 2048), data['probs'].shape[0]))
correct_pred = np.zeros(data['probs'].shape[0])

for bb_id, bb in tqdm(enumerate(range(len(batch_start)-1)),total=len(batch_start)-1):
    batch_inds = np.arange(batch_start[bb], batch_start[bb+1])

    vision_probs = data['probs'][batch_inds, :]
    vision_inds = data['inds'][batch_inds, :]
    gt = data['labels'][batch_inds]

    obs_locs_batch = torch.from_numpy(obs_locs[batch_inds, :]).to('cpu')
    loc_feat = enc.encode(obs_locs_batch)

    with torch.no_grad():
        loc_emb = model(loc_feat, return_feats=True)
        wt = model.class_emb.weight.detach().clone()
        wt.requires_grad = False
    raw_preds = torch.matmul(loc_emb, wt.T)
    geo_pred = calibrate_predictions(raw_preds)

    lala = geo_pred

    geo_pred, vision_pred = convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds,
                                                            data['model_to_taxa'], taxon_map)

    comb_pred = np.argmax(vision_pred*geo_pred, 1)
    comb_pred = (comb_pred==gt)
    correct_pred[batch_inds] = comb_pred
    if bb_id== 20: break
    break
    

results['vision_only_top_1'] = float((data['inds'][:, -1] == data['labels']).mean())
results['vision_geo_top_1'] = float(correct_pred.mean())

  0%|          | 0/139 [05:11<?, ?it/s]


KeyboardInterrupt: 

In [97]:
def report(results):
        print('Overall accuracy vision only model', round(results['vision_only_top_1'], 3))
        print('Overall accuracy of geo model     ', round(results['vision_geo_top_1'], 3))
        print('Gain                              ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3))


In [98]:
report(results)

Overall accuracy vision only model 0.754
Overall accuracy of geo model      0.12
Gain                               -0.634
