In [1]:
import sys
sys.path.append('../scripts')
from spatioformer import SpatioformerModel
from cnn import CNNModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import scipy
from scipy.stats import gaussian_kde
import os

In [2]:
class MyDataset(Dataset):

    def __init__(self, pickle_dir='../data_to_release/samples_to_release.pkl'):
        with open(pickle_dir, 'rb') as f:         
            self.imgs = pickle.load(f)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        row = self.imgs.loc[idx]
        array, lon_4326, lat_4326, richness = row['Image'], row['Longitude'], row['Latitude'], row['Richness']

        return torch.from_numpy(array.astype('float32')), richness.astype('float32'), torch.from_numpy(lon_4326.astype('float32')), torch.from_numpy(lat_4326.astype('float32'))

    
def get_dataloaders(
        batch_size=2048,
        num_workers=os.cpu_count(),
        split_file='../data_to_release/split_to_release.pkl',
        ):
    
    dataset = MyDataset()
    
    with open(split_file, 'rb') as f:         
        split = pickle.load(f)

    train_indices = split['train']
    val_indices = split['val']
    test_indices = split['test']
    
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler)
    val_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, sampler=val_sampler)
    test_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, sampler=test_sampler)

    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_dataloaders()

In [3]:
%%time
# Spatioformer - with geolocational encoding
for input_size in [1, 3, 5, 7, 9]:
    
    print(f'input image size is {input_size}')

    net = SpatioformerModel(device='cpu', if_encode=True, patchsize=input_size)
    net.load_state_dict(torch.load(f'../models/spatioformer/diff_insize/input_size_{input_size}/model.pth', map_location=torch.device('cpu')))
    net.eval()

    true_spatioformer = []
    predicted_spatioformer = []

    crop_start = int(5 - (input_size + 1) / 2)
    crop_end = int(4 + (input_size + 1) / 2)

    for test_step_counter, (array, richness, lon, lat) in enumerate(test_loader):
        predicted = net(array[:, crop_start:crop_end, crop_start:crop_end, :], lon[:, crop_start:crop_end, crop_start:crop_end], lat[:, crop_start:crop_end, crop_start:crop_end]).squeeze(-1)
        true_spatioformer.extend(richness.detach().numpy().tolist())
        predicted_spatioformer.extend(predicted.detach().numpy().tolist())

    # Coefficient of correlation (r)
    correlation_coefficient = np.corrcoef(true_spatioformer, predicted_spatioformer)[0, 1]
    # Coefficient of determination (r2)
    coefficient_of_determination = r2_score(true_spatioformer, predicted_spatioformer)
    # Mean Absolute Error (MAE)
    mae = mean_absolute_error(true_spatioformer, predicted_spatioformer)
    # Relative Absolute Error (RAE)
    rae = mae / np.mean(np.abs(true_spatioformer))
    # Mean Squared Error (MSE)
    mse = mean_squared_error(true_spatioformer, predicted_spatioformer)
    # Relative Squared Error (RSE)
    rse = mse / np.mean(np.square(true_spatioformer))
    # Root Mean Squared Error (RMSE)
    rmse = np.sqrt(mse)

    print({
        'Coefficient of correlation (r)': round(correlation_coefficient, 2),
        'Coefficient of determination (r2)': round(coefficient_of_determination, 2),
        'Mean Absolute Error (MAE)': round(mae, 2),
        'Relative Absolute Error (RAE)': round(rae, 2),
        'Mean Squared Error (MSE)': round(mse, 2),
        'Relative Squared Error (RSE)': round(rse, 2),
        'Root Mean Squared Error (RMSE)': round(rmse, 2)})

input image size is 1
{'Coefficient of correlation (r)': 0.68, 'Coefficient of determination (r2)': 0.46, 'Mean Absolute Error (MAE)': 9.02, 'Relative Absolute Error (RAE)': 0.32, 'Mean Squared Error (MSE)': 140.56, 'Relative Squared Error (RSE)': 0.13, 'Root Mean Squared Error (RMSE)': 11.86}
input image size is 3
{'Coefficient of correlation (r)': 0.69, 'Coefficient of determination (r2)': 0.48, 'Mean Absolute Error (MAE)': 8.87, 'Relative Absolute Error (RAE)': 0.32, 'Mean Squared Error (MSE)': 135.1, 'Relative Squared Error (RSE)': 0.13, 'Root Mean Squared Error (RMSE)': 11.62}
input image size is 5
{'Coefficient of correlation (r)': 0.7, 'Coefficient of determination (r2)': 0.49, 'Mean Absolute Error (MAE)': 8.74, 'Relative Absolute Error (RAE)': 0.31, 'Mean Squared Error (MSE)': 132.9, 'Relative Squared Error (RSE)': 0.13, 'Root Mean Squared Error (RMSE)': 11.53}
input image size is 7
{'Coefficient of correlation (r)': 0.72, 'Coefficient of determination (r2)': 0.5, 'Mean Absolut