In [17]:
import sys

import torch
import pandas as pd
import numpy as np
import json
import os

import geopandas as gpd

import matplotlib.pyplot as plt

In [18]:
sys.path.append('../')
import datasets
import models
import utils


In [19]:
HIGH_RES = True
THRESHOLD = 0.5
DISABLE_OCEAN_MASK = False
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SET_MAX_CMAP_TO_1 = False

In [20]:
kings_county_map = gpd.read_file('../other_data/continent-poly/Continents.shp')
Africa = kings_county_map['geometry'].iloc[0]
Asia = kings_county_map['geometry'].iloc[1]
Australia = kings_county_map['geometry'].iloc[2]
NorthAmerica= kings_county_map['geometry'].iloc[3]
Oceania = kings_county_map['geometry'].iloc[4]
SouthAmerica= kings_county_map['geometry'].iloc[5]
Antarctica = kings_county_map['geometry'].iloc[6]
Europe = kings_county_map['geometry'].iloc[7]


In [21]:
# load model
train_params = torch.load('../pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [22]:
#load reference from iucn
with open(os.path.join('../data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)

In [23]:
species_ids = list((data['taxa_presence'].keys()))

In [24]:
len(species_ids)

2418

In [25]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [26]:
obs_locs = np.array(data['locs'], dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [27]:
loc_feat

tensor([[-0.4897,  0.8307, -0.8719,  0.5566],
        [-0.6273,  0.9677, -0.7788,  0.2521],
        [-0.6769,  0.9661, -0.7361,  0.2583],
        ...,
        [-0.2567,  0.4512,  0.9665, -0.8924],
        [-0.5892,  0.7716,  0.8080, -0.6361],
        [-0.5621,  0.6373,  0.8271, -0.7706]])

In [28]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

In [29]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [30]:
import geopandas as gpd
df = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))

def plot_map(taxa,preds,species_locs):
    test = np.array(data['locs'])
    plot_array = test[species_locs]
    thresholds = [0.02,0.1,0.5]

    mask1 = preds>thresholds[0]
    plot_threhs1 = test[mask1]

    mask2 = preds>thresholds[1]
    plot_threhs2 = test[mask2]

    mask3 = preds>thresholds[2]
    plot_threhs3 = test[mask3]

    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(10,5))
    fig.suptitle("taxa " + str(taxa),  y=0.75)
    #iucn map
    df.plot(ax=ax1,color="lightgray")
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.scatter(plot_array[:,0], plot_array[:,1], color='firebrick', s=10, alpha=0.05)

    #predictions

    df.plot(ax=ax2,color="lightgray")

    ax2.set_xticks([])
    ax2.set_yticks([])

    ax2.scatter(plot_threhs1[:,0], plot_threhs1[:,1], color='lightsteelblue', s=10, alpha=0.05)
    ax2.scatter(plot_threhs2[:,0], plot_threhs2[:,1], color='cornflowerblue', s=10, alpha=0.05)
    ax2.scatter(plot_threhs3[:,0], plot_threhs3[:,1], color='mediumblue', s=10, alpha=0.05)

    plt.tight_layout()


  df = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))


In [31]:
locations = np.array(data['locs'])

In [64]:
from shapely.geometry import Point
all_species_continents = []
for tt_id, taxa in enumerate(taxa_ids):
    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    taxa = taxa.item()
    species_locs = data['taxa_presence'].get(str(taxa))

    indices_of_largest_values = np.argsort(preds)[-10:]

    # get 100 highest probability coordinates
    mask = np.zeros_like(preds, dtype=bool)
    mask[indices_of_largest_values] = True
    top100loc = locations[mask]

    points = [Point(lon, lat) for lon, lat in top100loc]

    species_in_continent = []
    for _,x in kings_county_map.iterrows():
        continent = x.geometry
        for p in points:
            if continent.contains(p):
                species_in_continent.append(x.CONTINENT)
                break
    all_species_continents.append(species_in_continent)

In [66]:
all_species_continents

[['South America'],
 ['South America'],
 ['Africa'],
 ['South America'],
 ['Australia'],
 ['Asia'],
 ['Africa'],
 ['North America'],
 ['Australia'],
 ['Africa'],
 ['South America'],
 ['Australia'],
 ['Africa'],
 ['North America'],
 ['South America'],
 ['North America'],
 ['South America'],
 ['South America'],
 ['South America'],
 ['South America'],
 ['Asia'],
 ['North America'],
 ['Australia'],
 ['Australia'],
 ['South America'],
 ['South America'],
 ['Australia'],
 ['Africa'],
 ['Europe'],
 ['Australia'],
 ['Australia'],
 ['Australia'],
 ['Africa'],
 [],
 ['South America'],
 ['Africa'],
 ['Asia'],
 ['South America'],
 ['Australia'],
 ['Australia'],
 [],
 ['South America'],
 ['South America'],
 [],
 ['Australia'],
 ['South America'],
 ['South America'],
 ['Africa'],
 ['South America'],
 ['Africa'],
 ['North America'],
 ['Australia'],
 ['Asia'],
 ['Africa'],
 ['Oceania'],
 ['South America'],
 ['Africa'],
 ['South America'],
 ['Africa'],
 ['South America'],
 ['Africa'],
 ['Australia'],
 

In [77]:
outdf = pd.DataFrame({'speciesID': species_ids, 'continent': all_species_continents})


In [79]:
import pickle
with open('../output/species_continents.pkl', 'wb') as fp:
    pickle.dump(outdf, fp)
