# 01 - Adjusting Predictions Using Country (SnakeCLEF)

In [1]:
import os

os.chdir('..')

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.core import training, metrics
from src.utils import io

PREDICTIONS_DIR = 'predictions/'
DATA_DIR = 'data/snake_clef2021_dataset/'

PRED_FILES = {
    'EfficientNet-B0': 'snake_full_efficientnet_b0_pred.npy',
    'ViT-Base-224': 'snake_full_vit_base_224_pred.npy',

    'EfficientNet-B4': 'snake_full_efficientnet_b4_pred.npy',
    'NoisyStudent-B4': 'snake_full_efficientnet_b4_ns_pred.npy',
    'EfficientNetV2-S': 'snake_full_efficientnetv2_s_pred.npy',

    'ViT-Base-384': 'snake_full_vit_base_384_pred.npy',
    'DeiT-Base-384': 'snake_full_deit_base_384_pred.npy',
    'BEiT-Base-384': 'snake_full_beit_base_384_pred.npy',

    'ViT-Large-384': 'snake_full_vit_large_384_pred.npy'}
TARG_FILE = 'snake_full_targ.npy'

## Load the Data

In [3]:
# load metadata
valid_df = pd.read_csv(DATA_DIR + 'SnakeCLEF2021_test_metadata_cleaned.csv')

classes = np.unique(valid_df['binomial'])
no_classes = len(classes)
print(f'No classes: {no_classes}')
print(f'Test set length: {len(valid_df):,d}')

No classes: 770
Test set length: 26,227


In [4]:
species = np.unique(valid_df['binomial'])
countries = np.unique(valid_df['country'].fillna('unknown'))

# load country-species map and create country f1 score metric
country_map_df = pd.read_csv(DATA_DIR + 'species_to_country_mapping.csv', index_col=0)
country_weights = metrics.clean_country_map(country_map_df, species, missing_val=0)
country_f1_score = metrics.CountryF1Score(country_weights)

# create country-species weight for adjusting predictions
country_lut = io.read_json(DATA_DIR + 'country_lut.json')
country_weights_adj = metrics.clean_country_map(
    country_map_df.rename(columns=country_lut), species, countries, missing_val=1)

## Compute Predictions

In [6]:
if not os.path.isdir(PREDICTIONS_DIR):
    os.mkdir(PREDICTIONS_DIR)

# compute predictions
if not all([os.path.isfile(PREDICTIONS_DIR + x) for x in PRED_FILES.values()]):
    !sh test_snake.sh

## Load Predictions

In [7]:
from tqdm import tqdm

def get_scores(pred, targ):
    scores = training.classification_scores(pred, targ)
    scores['country_f1_score'] = country_f1_score(pred.argmax(1), targ)
    return scores


# load target file
targ = np.load(PREDICTIONS_DIR + TARG_FILE)

# load prediction file of each model and compute scores
scores_dict = {}
adj_scores_dict = {}
for model_name, pred_file in tqdm(PRED_FILES.items()):
    # load prediction file
    pred = np.load(PREDICTIONS_DIR + pred_file)

    # adjust predictions using country
    _df = valid_df.copy()
    _df['country'] = _df['country'].fillna('unknown')
    bin_map = country_weights_adj.loc[:, _df['country']].values.T
    pred_adj = pred * bin_map

    # compute scores
    scores_dict[model_name] = get_scores(pred, targ)
    adj_scores_dict[model_name] = get_scores(pred_adj, targ)

100%|██████████| 9/9 [01:33<00:00, 10.42s/it]


## Evaluate Scores

In [8]:
scores_df = pd.DataFrame.from_dict(scores_dict, orient='index')
scores_df.columns = pd.MultiIndex.from_product([scores_df.columns, ['Original']])
adj_scores_df = pd.DataFrame.from_dict(adj_scores_dict, orient='index')
adj_scores_df.columns = pd.MultiIndex.from_product([adj_scores_df.columns, ['Adjusted']])

eval_df = pd.concat([scores_df, adj_scores_df], axis=1)
for met in ['accuracy', 'top_3', 'f1_score']:
    eval_df[met, 'Diff'] = eval_df[met, 'Adjusted'] - eval_df[met, 'Original']

In [9]:
_df = eval_df[['accuracy', 'top_3', 'f1_score']].round(3) * 100
for met in ['accuracy', 'top_3', 'f1_score']:
    _df[met, 'Diff'] = '+' + _df[met, 'Diff'].round(1).fillna('').astype(str).replace('+', np.nan)
_df

Unnamed: 0_level_0,accuracy,accuracy,accuracy,top_3,top_3,top_3,f1_score,f1_score,f1_score
Unnamed: 0_level_1,Original,Adjusted,Diff,Original,Adjusted,Diff,Original,Adjusted,Diff
EfficientNet-B0,85.1,88.0,2.9,94.2,95.4,+1.2,70.6,75.9,5.3
ViT-Base-224,87.5,88.6,1.0,94.3,96.5,+2.2,75.5,79.1,3.6
EfficientNet-B4,91.4,92.2,0.8,96.9,97.2,+0.3,80.2,82.5,2.3
NoisyStudent-B4,91.7,92.6,0.9,97.4,97.5,+0.1,81.0,83.5,2.5
EfficientNetV2-S,91.6,92.6,1.0,96.0,97.0,+1.0,81.3,84.0,2.8
ViT-Base-384,91.7,92.9,1.2,97.3,97.5,+0.2,82.0,84.6,2.6
DeiT-Base-384,90.6,91.7,1.0,96.7,97.2,+0.5,81.2,83.8,2.6
BEiT-Base-384,93.7,93.9,0.2,97.8,97.7,+-0.2,84.6,85.8,1.2
ViT-Large-384,92.0,92.6,0.6,97.5,97.6,+0.0,82.9,84.9,2.0
