# 01 - Adjusting Predictions Using Country (SnakeCLEF-Reduced)

In [1]:
import os

os.chdir('..')

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.core import training, metrics
from src.utils import io

PREDICTIONS_DIR = 'predictions/'
DATA_DIR = 'data/snake_clef2021_dataset/'

PRED_FILES = {
    'EfficientNet-B0': 'snake_mini_efficientnet_b0_pred.npy',
    'ViT-Base-224': 'snake_mini_vit_base_224_pred.npy',

    'EfficientNet-B4': 'snake_mini_efficientnet_b4_pred.npy',
    'NoisyStudent-B4': 'snake_mini_efficientnet_b4_ns_pred.npy',
    'EfficientNetV2-S': 'snake_mini_efficientnetv2_s_pred.npy',

    'ViT-Base-384': 'snake_mini_vit_base_384_pred.npy',
    'DeiT-Base-384': 'snake_mini_deit_base_384_pred.npy',
    'BEiT-Base-384': 'snake_mini_beit_base_384_pred.npy',

    'ViT-Large-384': 'snake_mini_vit_large_384_pred.npy'}
TARG_FILE = 'snake_mini_targ.npy'

## Load the Data

In [3]:
# load metadata
valid_df = pd.read_csv(DATA_DIR + 'SnakeCLEF2021_test_metadata_cleaned.csv')

classes = np.unique(valid_df['binomial'])
no_classes = len(classes)
print(f'No classes: {no_classes}')
print(f'Test set length: {len(valid_df):,d}')

No classes: 770
Test set length: 26,227


In [4]:
species = np.unique(valid_df['binomial'])
countries = np.unique(valid_df['country'].fillna('unknown'))

# load country-species map and create country f1 score metric
country_map_df = pd.read_csv(DATA_DIR + 'species_to_country_mapping.csv', index_col=0)
country_weights = metrics.clean_country_map(country_map_df, species, missing_val=0)
country_f1_score = metrics.CountryF1Score(country_weights)

# create country-species weight for adjusting predictions
country_lut = io.read_json(DATA_DIR + 'country_lut.json')
country_weights_adj = metrics.clean_country_map(
    country_map_df.rename(columns=country_lut), species, countries, missing_val=1)

## Compute Predictions

In [5]:
if not os.path.isdir(PREDICTIONS_DIR):
    os.mkdir(PREDICTIONS_DIR)

# compute predictions
if not all([os.path.isfile(PREDICTIONS_DIR + x) for x in PRED_FILES.values()]):
    !sh test_snake.sh

## Load Predictions

In [6]:
from tqdm import tqdm

def get_scores(pred, targ):
    scores = training.classification_scores(pred, targ)
    scores['country_f1_score'] = country_f1_score(pred.argmax(1), targ)
    return scores


# load target file
targ = np.load(PREDICTIONS_DIR + TARG_FILE)

# load prediction file of each model and compute scores
scores_dict = {}
adj_scores_dict = {}
for model_name, pred_file in tqdm(PRED_FILES.items()):
    # load prediction file
    pred = np.load(PREDICTIONS_DIR + pred_file)

    # adjust predictions using country
    _df = valid_df.copy()
    _df['country'] = _df['country'].fillna('unknown')
    bin_map = country_weights_adj.loc[:, _df['country']].values.T
    pred_adj = pred * bin_map

    # compute scores
    scores_dict[model_name] = get_scores(pred, targ)
    adj_scores_dict[model_name] = get_scores(pred_adj, targ)

100%|██████████| 9/9 [03:35<00:00, 23.91s/it]


## Evaluate Scores

In [7]:
scores_df = pd.DataFrame.from_dict(scores_dict, orient='index')
scores_df.columns = pd.MultiIndex.from_product([scores_df.columns, ['Original']])
adj_scores_df = pd.DataFrame.from_dict(adj_scores_dict, orient='index')
adj_scores_df.columns = pd.MultiIndex.from_product([adj_scores_df.columns, ['Adjusted']])

eval_df = pd.concat([scores_df, adj_scores_df], axis=1)
for met in ['accuracy', 'top_3', 'f1_score']:
    eval_df[met, 'Diff'] = eval_df[met, 'Adjusted'] - eval_df[met, 'Original']

In [8]:
_df = eval_df[['accuracy', 'top_3', 'f1_score']].round(3) * 100
for met in ['accuracy', 'top_3', 'f1_score']:
    _df[met, 'Diff'] = '+' + _df[met, 'Diff'].round(1).fillna('').astype(str).replace('+', np.nan)
_df

Unnamed: 0_level_0,accuracy,accuracy,accuracy,top_3,top_3,top_3,f1_score,f1_score,f1_score
Unnamed: 0_level_1,Original,Adjusted,Diff,Original,Adjusted,Diff,Original,Adjusted,Diff
EfficientNet-B0,69.1,76.0,6.9,80.7,89.4,8.6,57.5,64.8,7.3
ViT-Base-224,71.5,75.9,4.4,82.8,87.2,4.4,64.7,70.3,5.6
EfficientNet-B4,73.5,79.5,6.0,89.2,94.4,5.2,68.7,74.1,5.4
NoisyStudent-B4,80.2,83.8,3.6,91.6,96.2,4.6,71.0,75.3,4.3
EfficientNetV2-S,79.4,84.1,4.7,92.1,95.6,3.5,70.3,75.0,4.7
ViT-Base-384,78.5,82.4,3.8,92.8,94.6,1.8,73.0,77.8,4.8
DeiT-Base-384,77.9,80.2,2.3,94.6,95.7,1.1,72.2,76.2,4.0
BEiT-Base-384,83.5,84.9,1.3,92.7,94.3,1.6,74.8,78.1,3.3
ViT-Large-384,84.7,85.9,1.2,94.6,96.4,1.7,76.4,79.4,3.1
