In [2]:
import sys

import torch
import pandas as pd
import numpy as np
import json
import os

import matplotlib.pyplot as plt

In [3]:
sys.path.append('../')
import datasets
import models
import utils


In [4]:
HIGH_RES = True
THRESHOLD = 0.5
DISABLE_OCEAN_MASK = False
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SET_MAX_CMAP_TO_1 = False

In [5]:
# load model
train_params = torch.load('../pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt', map_location='cpu')
model = models.get_model(train_params['params'])
model.load_state_dict(train_params['state_dict'], strict=True)
model = model.to(DEVICE)
model.eval()

ResidualFCNet(
  (class_emb): Linear(in_features=256, out_features=47375, bias=False)
  (feats): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (3): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=256, bias=True)
    )
    (4): ResLayer(
      (nonlin1): ReLU(inplace=True)
      (nonlin2): ReLU(inplace=True)
      (dropout1): Dropout(p=0.5, inplace=False)
      (w1): Linear(in_features=256, out_features=256, bias=True)
      (w2): Linear(in_features=256, out_features=

In [6]:
#load reference from iucn
with open(os.path.join('../data/eval/iucn/', 'iucn_res_5.json'), 'r') as f:
            data = json.load(f)

In [7]:
species_ids = list((data['taxa_presence'].keys()))

In [8]:
len(species_ids)

2418

In [9]:
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
    raster = datasets.load_env()
else:
    raster = None
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)

In [10]:
# Define the number of points
num_long_points = 8
num_lat_points = 16

In [11]:
# Generate list of tuples with evenly spaced points
lowres_locs = [[x, y] for x in range(-180, 180, 360 // num_lat_points) for y in range(-90, 91, 180 // num_long_points)]

In [12]:
len(lowres_locs)

153

In [13]:
obs_locs = np.array(lowres_locs, dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [14]:
obs_locs

tensor([[-1.0000, -1.0000],
        [-1.0000, -0.7556],
        [-1.0000, -0.5111],
        [-1.0000, -0.2667],
        [-1.0000, -0.0222],
        [-1.0000,  0.2222],
        [-1.0000,  0.4667],
        [-1.0000,  0.7111],
        [-1.0000,  0.9556],
        [-0.8778, -1.0000],
        [-0.8778, -0.7556],
        [-0.8778, -0.5111],
        [-0.8778, -0.2667],
        [-0.8778, -0.0222],
        [-0.8778,  0.2222],
        [-0.8778,  0.4667],
        [-0.8778,  0.7111],
        [-0.8778,  0.9556],
        [-0.7556, -1.0000],
        [-0.7556, -0.7556],
        [-0.7556, -0.5111],
        [-0.7556, -0.2667],
        [-0.7556, -0.0222],
        [-0.7556,  0.2222],
        [-0.7556,  0.4667],
        [-0.7556,  0.7111],
        [-0.7556,  0.9556],
        [-0.6333, -1.0000],
        [-0.6333, -0.7556],
        [-0.6333, -0.5111],
        [-0.6333, -0.2667],
        [-0.6333, -0.0222],
        [-0.6333,  0.2222],
        [-0.6333,  0.4667],
        [-0.6333,  0.7111],
        [-0.6333,  0

In [15]:
loc_feat

tensor([[ 8.7423e-08,  8.7423e-08, -1.0000e+00, -1.0000e+00],
        [ 8.7423e-08, -6.9466e-01, -1.0000e+00, -7.1934e-01],
        [ 8.7423e-08, -9.9939e-01, -1.0000e+00, -3.4900e-02],
        [ 8.7423e-08, -7.4314e-01, -1.0000e+00,  6.6913e-01],
        [ 8.7423e-08, -6.9756e-02, -1.0000e+00,  9.9756e-01],
        [ 8.7423e-08,  6.4279e-01, -1.0000e+00,  7.6604e-01],
        [ 8.7423e-08,  9.9452e-01, -1.0000e+00,  1.0453e-01],
        [ 8.7423e-08,  7.8801e-01, -1.0000e+00, -6.1566e-01],
        [ 8.7423e-08,  1.3917e-01, -1.0000e+00, -9.9027e-01],
        [-3.7461e-01,  8.7423e-08, -9.2718e-01, -1.0000e+00],
        [-3.7461e-01, -6.9466e-01, -9.2718e-01, -7.1934e-01],
        [-3.7461e-01, -9.9939e-01, -9.2718e-01, -3.4900e-02],
        [-3.7461e-01, -7.4314e-01, -9.2718e-01,  6.6913e-01],
        [-3.7461e-01, -6.9756e-02, -9.2718e-01,  9.9756e-01],
        [-3.7461e-01,  6.4279e-01, -9.2718e-01,  7.6604e-01],
        [-3.7461e-01,  9.9452e-01, -9.2718e-01,  1.0453e-01],
        

In [16]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

In [17]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [18]:
out_df = pd.DataFrame(columns=list(range(0,len(loc_emb))))

for tt_id, taxa in enumerate(taxa_ids):
    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    out_df.loc[taxa.item()] = preds

In [19]:
out_df.to_csv('../output/lowres_pred.csv')

same but ocean mask

In [20]:
mask = np.load('../data/masks/ocean_mask.npy')
reduced_mask = mask[1:-1,2:-2]
reduced_mask.shape

(1000, 2000)

In [21]:
pooling_factor_rows = reduced_mask.shape[0] // num_long_points
pooling_factor_cols = reduced_mask.shape[1] // num_lat_points

# Reshape and apply average pooling
pooled_array = reduced_mask.reshape((num_long_points, pooling_factor_rows, num_lat_points, pooling_factor_cols)).mean(axis=(1, 3))

# Apply thresholding to convert average values to 0 or 1
threshold = 0.25
binary_pooled_array = (pooled_array > threshold).astype(int)
mask_inds = np.where(binary_pooled_array.reshape(-1) == 1)[0]

In [24]:
locs = utils.coord_grid(binary_pooled_array.shape)
locs = locs[mask_inds, :]

In [25]:
obs_locs = np.array(locs, dtype=np.float32)
obs_locs = torch.from_numpy(obs_locs).to('cpu')
loc_feat = enc.encode(obs_locs)

In [26]:
classes_of_interest = torch.zeros(len(species_ids), dtype=torch.int64)
taxa_ids = torch.zeros(len(species_ids), dtype=torch.int64)
for tt_id, tt in enumerate(species_ids):
    class_of_interest = np.array([train_params['params']['class_to_taxa'].index(int(tt))])
    classes_of_interest[tt_id] = torch.from_numpy(class_of_interest)
    taxa_ids[tt_id] = int(tt)

In [27]:
with torch.no_grad():
    loc_emb = model(loc_feat, return_feats=True)
    wt = model.class_emb.weight[classes_of_interest, :]

In [28]:
out_df = pd.DataFrame(columns=list(range(0,len(loc_emb))))

for tt_id, taxa in enumerate(taxa_ids):
    wt_1 = wt[tt_id,:]
    preds = torch.sigmoid(torch.matmul(loc_emb, wt_1)).cpu().numpy()
    out_df.loc[taxa.item()] = preds

In [29]:
out_df.to_csv('../output/lowres_oceanmask_pred.csv')

In [31]:
preds.shape

(48,)