In [70]:
import json
import os

import datasets
import numpy as np
import setup

In [3]:
train_params = {}

train_params[
    "experiment_name"
] = "test_all_training"  # This will be the name of the directory where results for this run are saved.

"""
species_set
- Which set of species to train on.
- Valid values: 'all', 'snt_birds'
"""
train_params["species_set"] = "all"

"""
hard_cap_num_per_class
- Maximum number of examples per class to use for training.
- Valid values: positive integers or -1 (indicating no cap).
"""
train_params["hard_cap_num_per_class"] = 1000

"""
num_aux_species
- Number of random additional species to add.
- Valid values: Nonnegative integers. Should be zero if params['species_set'] == 'all'.
"""
train_params["num_aux_species"] = 0

"""
input_enc
- Type of inputs to use for training.
- Valid values: 'sin_cos', 'env', 'sin_cos_env'
"""
train_params["input_enc"] = "sin_cos"

"""
loss
- Which loss to use for training.
- Valid values: 'an_full', 'an_slds', 'an_ssdl', 'an_full_me', 'an_slds_me', 'an_ssdl_me'
"""
train_params["loss"] = "an_full"

params = setup.get_default_params_train(overrides=train_params)

In [73]:
with open("paths.json", "r") as f:
    paths = json.load(f)

Read in the S&T taxa ids

In [4]:
data_dir = paths["train"]
obs_file = os.path.join(data_dir, params["obs_file"])
taxa_file = os.path.join(data_dir, params["taxa_file"])
taxa_file_snt = os.path.join(data_dir, "taxa_subsets.json")

with open(taxa_file_snt, "r") as f:  #
    taxa_subsets = json.load(f)
taxa_of_interest = list(taxa_subsets["snt_birds"])

[7,
 162,
 243,
 316,
 473,
 519,
 542,
 804,
 831,
 890,
 906,
 981,
 1280,
 1392,
 1399,
 1406,
 1409,
 1907,
 1960,
 1965,
 1986,
 2090,
 2191,
 2548,
 2552,
 2599,
 2669,
 2938,
 2969,
 3017,
 3108,
 3117,
 3280,
 3454,
 3460,
 3544,
 3545,
 3759,
 3849,
 3857,
 3862,
 3864,
 3869,
 3873,
 3875,
 3892,
 3893,
 3896,
 3901,
 3906,
 3931,
 3936,
 3938,
 3962,
 4246,
 4345,
 4364,
 4368,
 4381,
 4385,
 4399,
 4449,
 4450,
 4457,
 4496,
 4512,
 4626,
 4627,
 4647,
 4665,
 4672,
 4793,
 4798,
 4817,
 4835,
 4836,
 4838,
 4840,
 4857,
 4885,
 4892,
 4893,
 4937,
 4938,
 4940,
 4956,
 4981,
 4999,
 5009,
 5020,
 5034,
 5063,
 5097,
 5108,
 5112,
 5180,
 5196,
 5206,
 5212,
 5227,
 5268,
 5277,
 5305,
 5349,
 5355,
 5416,
 6317,
 6359,
 6363,
 6369,
 6432,
 6433,
 6553,
 6557,
 6571,
 6638,
 6643,
 6893,
 6899,
 6915,
 6917,
 6921,
 6930,
 6933,
 6937,
 7019,
 7024,
 7087,
 7089,
 7107,
 7170,
 7266,
 7294,
 7347,
 7363,
 7428,
 7429,
 7458,
 7462,
 7464,
 7470,
 7486,
 7493,
 7497,
 7498,

In [11]:
len(taxa_of_interest)

540

Read in the S&T evaluation data to see what that looks like

In [22]:
D = np.load(os.path.join(paths["snt"], "snt_res_5.npy"), allow_pickle=True)
D = D.item()

In [24]:
D.keys()

dict_keys(['loc_indices_per_species', 'labels_per_species', 'taxa', 'obs_locs', 'obs_locs_idx'])

In [41]:
print(len(D["loc_indices_per_species"]))
print(D["obs_locs"].shape)

535
(554882, 2)


In [40]:
len(D["loc_indices_per_species"][0])

174746

## Rarify the data for a set of S&T species
First the inat data for S&T species

In [5]:
locs, labels, _, _, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)


Loading  data/train/geo_prior_train.csv
30402074 observation(s) out of 35500262 from different taxa removed
Number of unique classes 536


There are 536 S&T species with inat data here
select 50 species to rarify at random 

In [10]:
len(np.unique(labels))

536

In [16]:
rare_spps = np.random.choice(labels, size=50, replace=False)
rare_spps

array([  3938,   3460,   1965,  10094,   4940,    162,  72458, 117016,
         3892,   9325,   7493,   9607,  12024,   3460,   9602, 144491,
         4381,   4940,   9943,   4938,   6921,   4937,  14995,   3017,
        12937,  10247,  13695,   8229,   5212,  14898,  12942,   5212,
       979756, 145308,   7019,   3857,   9744,   7513, 199840,   9185,
         9083,   9607,  13858,   5112,   9100,   9092,  72486, 144455,
        13851,  13000])

In [17]:
# check how many data points each species has
counts = np.zeros(len(rare_spps))
for i, s in enumerate(rare_spps):
    counts[i] = len(np.where(labels == s)[0])

counts

array([12250., 22351.,  3838., 48527., 40093.,  3398., 19072.,  7983.,
       17281.,  6735., 14361., 34352., 11521., 22351., 19000., 17450.,
        6864., 40093.,  4979., 11920., 29506., 16309., 18596., 24748.,
       16362., 10554., 23613., 37265., 80913.,  9734., 31594., 80913.,
        4238., 27707.,  7481.,  9320., 58504., 19359., 73409., 12190.,
       74243., 34352., 91261., 30554., 57649., 10299., 13283., 80010.,
       21693., 12366.])

In [18]:
locs_2 = locs.copy()
labels_2 = labels.copy()

In [51]:
# Now cut out data from the inat dataset for these species
labels_rare = labels.copy()
locs_rare = locs.copy()
# del_idx = []

for s in rare_spps:
    idx = np.where(labels_rare == s)[0]
    del_idx = np.random.choice(idx, size=len(idx) - 50, replace=False)
    locs_rare = np.delete(locs_rare, del_idx, axis=0)
    labels_rare = np.delete(labels_rare, del_idx)

# del_idx = np.concatenate( del_idx, axis=0 )
# len(del_idx)

In [49]:
counts = np.zeros(len(rare_spps))
for i, s in enumerate(rare_spps):
    counts[i] = len(np.where(labels_rare == s)[0])

counts

array([50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.,
       50., 50., 50., 50., 50., 50., 50., 50., 50., 50., 50.])

In [52]:
# Check that the arrays are of equal size)
print(len(locs), len(labels))
print(len(locs_rare), len(labels_rare))

5098188 5098188
3895753 3895753


### Write out the rarified species data

In [68]:
# map(np.array, zip(*locs_rare))


array([[-1.48946518e+02,  6.35737267e+01,  9.79757000e+05],
       [-1.35445007e+02,  5.92358322e+01,  5.30500000e+03],
       [-7.15302658e+01,  4.13766479e+01,  1.44530000e+05],
       ...,
       [-8.07600937e+01,  2.87953091e+01,  4.89200000e+03],
       [-1.00342346e+02,  4.43662376e+01,  6.93000000e+03],
       [-8.28303986e+01,  2.80640793e+01,  1.67980000e+04]])

In [69]:
rare_concat = np.hstack([locs_rare, np.reshape(labels_rare, (len(labels_rare), 1))])

idx = []
for s in rare_spps:
    idx.append(np.where(labels_rare == s)[0])

idx = np.concatenate(idx, axis=0)
rare_concat_subset = rare_concat[idx, :]
rare_concat_subset.shape

(2500, 3)

In [75]:
np.savetxt(os.path.join(paths["rarefied"], "SNT_test.csv"), rare_concat_subset, delimiter=",")