# Imports

In [None]:
import os
os.chdir('../../vlm_toolbox/')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import gc

import networkx as nx
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from tqdm.notebook import tqdm

from config.enums import ImageDatasets, ModelType, Setups, Trainers
from config.metric import MetricIOConfig
from config.path import ANNOTATIONS_TEMPLATE_PATH
from config.setup import Setup
from metric.visualization.accuracy import plot_model_accuracy
from pipeline.pipeline import Pipeline
from util.logging import LoggerFactory
from util.numpy_helper import softmax

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# Config

In [None]:
DEVICE_TYPE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_TYPE)
DEVICE

In [None]:
logger = LoggerFactory.create_logger(name='pipeline', notebook=True, silent=True)

In [None]:
SETUP_DEFAULTS = {
    "dataset_name": ImageDatasets.IMAGENET_1K,
    "trainer_name": Trainers.COOP,
    "n_shots": 16,
    # "annotations_key_value_criteria": {'kingdom': ['Animalia']}
}

In [None]:
ANNOTATIONS_PATH = ANNOTATIONS_TEMPLATE_PATH + SETUP_DEFAULTS['dataset_name'] + '/labels.csv'
TOP_K = 5

HIERARCHY_LEVELS = {
    'coarse': None,
    'default': 'coarse'
}

REVERSED_HIERARCHY_LEVELS = {
    'coarse': 'default',
    'default': None,
}


# HIERARCHY_LEVELS = {
#     'phylum': None,
#     'class': 'phylum',
#     'order': 'class',
#     'family': 'order',
#     # 'genus': 'family',
#     # 'specific_epithet': 'genus',
#     'default': 'family'
# }

# REVERSED_HIERARCHY_LEVELS = {
#     'phylum': 'class',
#     'class': 'order',
#     'order': 'family',
#     'family': 'default',
#     # 'genus': 'specific_epithet',
#     # 'specific_epithet': 'default',
#     'default': None,
# }


# LEVELS_TOP_N_CHILDREN_TO_PROPAGATE = {
#     'phylum': 2,
#     'class': 4,
#     'order': 2,
#     'family': 3,
#     # 'genus': 4,
#     # 'specific_epithet': 4,
#     'default': TOP_K,
# }

LEVEL_NAMES = list(HIERARCHY_LEVELS.keys())

In [None]:
def create_complemented_labels_df(path=ANNOTATIONS_PATH, levels=LEVEL_NAMES, logger=logger, setup_defaults=SETUP_DEFAULTS):
    labels_df = pd.read_csv(path)
    labels_df = labels_df.rename(columns={'simplified': 'default', 'class_id': 'label_id'})
    for col, values in setup_defaults.get('annotations_key_value_criteria', {}).items():
        labels_df = labels_df[labels_df[col].isin(values)]

    for level in levels:
        labels_df[level] = labels_df[level].apply(lambda value: value.lower())
    labels_df = labels_df[['label_id', *levels]]
    
    for i, level in enumerate(levels):
        is_root = i == 0
        is_leaf = level == 'default'
        setup = Setup(
            **setup_defaults,
            setup_type=Setups.EVAL_ONLY  if not is_leaf else Setups.FULL,
            model_type=ModelType.ZERO_SHOT if not is_leaf else ModelType.FEW_SHOT,
            label_column_name=level if not is_leaf else None,
            top_k=len(labels_df[level].value_counts()),
        )
        pipeline = Pipeline(setup, device_type=DEVICE_TYPE, logger=logger)
        pipeline.setup_labels()
        level_labels_df = pipeline.label_handler.get_labels_df()
        pipeline.tear_down()
        labels_df = labels_df.merge(
            level_labels_df.rename(columns={
                'label_id': f'{level}_label_id',
                'label': level,
            }),
            on=level,
            how='left'
        )
    return labels_df
 

In [None]:
labels_df

In [None]:
labels_df = create_complemented_labels_df()
per_level_cnt = {col: labels_df[col].nunique() for col in LEVEL_NAMES}
per_level_cnt

# Create & Process Labels

In [None]:
class Node:
    def __init__(self, label, label_id, level):
        self.label = label
        self.label_id = label_id
        self.level = level

    def __hash__(self):
        return hash((self.label_id, self.level))

    def __eq__(self, other):
        return (self.level == other.level) and (self.label_id == other.label_id or self.label == other.label)

    def __repr__(self):
        return f"{self.__class__.__name__}(level={self.level}, label={self.label}, label_id={self.label_id})"

def create_labels_graph(df, levels=LEVEL_NAMES):
    G = nx.DiGraph()
    label_ids = [f"{l}_label_id" for l in levels]
    for index, row in df.iterrows():
        previous_node = None
        for level, label_id_field in zip(levels, label_ids):
            label = row[level]
            label_id = row[label_id_field]
            node = Node(label, label_id, level)
            G.add_node(node)
            if previous_node:
                G.add_edge(previous_node, node)
            previous_node = node

    return G
    
def get_to_children_mapping(level, labels_df=labels_df, reversed_hierarchy=REVERSED_HIERARCHY_LEVELS):
    source_col = f"{level}_label_id"
    dest_col = f"{reversed_hierarchy[level]}_label_id"
    mapping_df = labels_df[[source_col, dest_col]].drop_duplicates(keep='first').sort_values(by=source_col).reset_index(drop=True)
    return mapping_df.groupby(source_col)[dest_col].apply(list).to_dict()

  
def get_to_children_cnt_mapping(level, labels_df=labels_df, reversed_hierarchy=REVERSED_HIERARCHY_LEVELS):
    source_col = f"{level}_label_id"
    dest_col = f"{reversed_hierarchy[level]}_label_id"
    mapping_df = labels_df[[source_col, dest_col]].drop_duplicates(keep='first').sort_values(by=source_col).reset_index(drop=True)
    return {k: len(v) for k, v in mapping_df.groupby(source_col)[dest_col].apply(list).to_dict().items()}

def get_to_parent_mapping(level, labels_df=labels_df, hierarchy=REVERSED_HIERARCHY_LEVELS):
    source_col = f"{level}_label_id"
    dest_col = f"{hierarchy[level]}_label_id"
    mapping_df = labels_df[[source_col, dest_col]].drop_duplicates(keep='first').sort_values(by=source_col).reset_index(drop=True)
    return {k: len(v) for k, v in mapping_df.groupby(source_col)[dest_col].apply(list).to_dict().items()}

def get_to_parent_cnt_mapping(level, labels_df=labels_df, hierarchy=REVERSED_HIERARCHY_LEVELS):
    source_col = f"{level}_label_id"
    dest_col = f"{hierarchy[level]}_label_id"
    mapping_df = labels_df[[source_col, dest_col]].drop_duplicates(keep='first').sort_values(by=source_col).reset_index(drop=True)
    return mapping_df.groupby(source_col)[dest_col].apply(list).to_dict()


def get_node(G, level, label=None, label_id=None):
    if label and label_id:
        return Node(label, label_id, level)
    search_node = Node(label, label_id, level)
    for node in G.nodes:
        if node == search_node:
            return node

def get_parent(G, level, label=None, label_id=None):
    search_node = get_node(G, level, label, label_id)
    parent = [node for node in G.predecessors(search_node)]
    return parent[0] if len(parent) else None

def get_parents_to_root(G, level, label=None, label_id=None):
    search_node = get_node(G, level, label, label_id)
    parents = []
    while True:
        parent = get_parent(G, search_node.level, search_node.label, search_node.label_id)
        if not parent:
            break
        parents.append(parent)
        search_node = parent
    return parents

def get_children(G, level, label=None, label_id=None):
    search_node = get_node(G, level, label, label_id)
    children = [node for node in G.successors(search_node)]
    return children if len(children) else []

In [None]:
labels_graph = create_labels_graph(labels_df)

# Load All Predictions

In [None]:
def load_and_process_prediction_df(setup, top_k=None):
    path = MetricIOConfig.get_config(setup) + 'per_sample.parquet'
    print(path)
    top_k = top_k or setup.get_top_k()
    prediction_df = pd.read_parquet(path).reset_index(drop=True).reset_index(drop=False).rename(columns={'index': 'dataset_idx'})
    prediction_df['label_id'] = prediction_df['actual_label_id']
    is_correct_series = prediction_df['label_id'] == prediction_df[f'pred@1_label_id']
    preds_ids_columns = [f"pred@{k+1}_label_id" for k in range(top_k)]
    preds_prob_columns = [f"pred@{k+1}_prob" for k in range(top_k)]
    static_columns = ['dataset_idx', 'class_id', 'label_id']
    cleaned_prediction_df = prediction_df[static_columns].copy()
    label_ids = prediction_df[preds_ids_columns].values
    probs = prediction_df[preds_prob_columns].values
 
    sorted_idx = np.argsort(label_ids, axis=1)
    sorted_probs = np.take_along_axis(probs, sorted_idx, axis=1).astype(np.float16)
    sorted_label_ids = np.take_along_axis(label_ids, sorted_idx, axis=1)

    cleaned_prediction_df['probs'] = list(sorted_probs)
    cleaned_prediction_df['is_correct'] = is_correct_series
    return cleaned_prediction_df

def load_all_predictions(path=ANNOTATIONS_PATH, levels=LEVEL_NAMES, setup_defaults=SETUP_DEFAULTS, top_k=None):
    predictions_dict = {}
    for i, level in enumerate(levels):
        is_root, is_leaf = i == 0, i == len(levels) - 1
        setup = Setup(
            **setup_defaults,
            model_type=ModelType.FEW_SHOT if not is_leaf else ModelType.PRETRAINED,
            label_column_name=level if not is_leaf else None,
            top_k=len(labels_df[level].value_counts()),
        )
        predictions_dict[level] = load_and_process_prediction_df(setup, top_k=top_k)
    return predictions_dict

In [None]:
prediction_dfs = load_all_predictions()

# Top-Down Prediction

In [None]:
def perform_hierarchical_prediction(
    prediction_dfs=prediction_dfs,
    labels_df=labels_df,
    labels_graph=labels_graph,
    levels=LEVEL_NAMES,
    top_n_dict={},
):
    def initialize_label_ids(df):
        shape = df.iloc[0]['probs'].shape[0]
        label_ids = np.arange(shape)
        df['label_ids'] = [label_ids for _ in range(df.shape[0])]

    def add_prediction_stats(df, top_n):
        final_df = df[['dataset_idx', 'class_id', 'label_id']].copy()
        for k in range(top_n):
            final_df[f"pred@{k+1}_label_id"] = df['top_label_ids'].apply(lambda x: x[k] if k < len(x) else None)
            final_df[f"pred@{k+1}_prob"] = df['probs'].apply(lambda x: x[k] if k < len(x) else None)

        def find_correct_prediction_rank(row):
            label_id = row['label_id']
            for k in range(top_n):
                if row.get(f"pred@{k+1}_label_id") == label_id:
                    return k + 1
            return -1

        final_df['correct_pred_rank'] = final_df.apply(find_correct_prediction_rank, axis=1)
        return final_df

    def get_accuracy_df(df, top_n):
        final_df = add_prediction_stats(df, top_n)
        accuracies_dict = {'accuracy': [], 'top_k': []}
        for k in range(1, top_n + 1):
            accuracy_k = final_df['correct_pred_rank'].apply(lambda x: x <= k and x != -1).mean()
            accuracies_dict['accuracy'].append(accuracy_k)
            accuracies_dict['top_k'].append(k)

        accuracies_df = pd.DataFrame(accuracies_dict)
        return accuracies_df

    def sorted_top_n_indices(arr, n):
        return np.argsort(-np.array(arr))[:n]

    def get_next_label_ids(label_ids, mapping_dict):
        next_ids = set()
        for label_id in label_ids:
            next_ids.update(mapping_dict.get(label_id, []))
        return list(next_ids)

    def fetch_probabilities_by_indices(df, next_label_ids):
        return [df.iloc[idx]['probs'][indices] for idx, indices in enumerate(next_label_ids)]

    def get_next_label_ids(label_ids, mapping_dict):
        next_ids = set()
        for label_id in label_ids:
            next_ids.update(mapping_dict.get(label_id, []))
        return list(next_ids)

    def get_propagation_coefficient(row, cnt_mapping_dict):
        ids = row['top_label_ids']
        probs = row['probs']
        next_props_coefficients = []
        for label_id, prob in zip(ids, probs):
            next_cnt = cnt_mapping_dict.get(label_id, 0)
            array = np.full(next_cnt, prob)
            # array = np.ones(next_cnt)
            next_props_coefficients.extend(array)
        
        return softmax(np.array(next_props_coefficients))

    df_dict = {level: prediction_dfs[level].copy() for level in levels}
    level_wise_accuracies_df = {}
    current_predictions = df_dict[levels[0]].copy()
    initialize_label_ids(current_predictions)
    trace_df = current_predictions[['dataset_idx', 'class_id']].copy()
    trace_columns = []

    for i in tqdm(range(len(levels)), desc="Processing Levels"):
        current_level = levels[i]
        top_n = top_n_dict[current_level]

        current_predictions['probs'] = current_predictions['probs'].apply(softmax)

        current_predictions['top_indices'] = current_predictions['probs'].apply(
            lambda x: sorted_top_n_indices(x, top_n)
        )
        current_predictions['top_label_ids'] = current_predictions.apply(
            lambda x: [x['label_ids'][idx] for idx in x['top_indices']], axis=1,
        )
        trace_df[f"{current_level}_top_label_ids"] = current_predictions['top_label_ids'].copy()
        trace_df[f"{current_level}_top_probs"] = current_predictions['probs'].copy()

        level_wise_accuracies_df[current_level] = get_accuracy_df(current_predictions, top_n)

        if i == len(levels) - 1:
            continue
        
        next_level = levels[i+1]
        
        mapping_children = get_to_children_mapping(current_level)
        mapping_children_cnt = {k: len(v) for k, v in mapping_children.items()}

        next_predictions = df_dict[next_level].copy()
        
        current_predictions['next_label_ids'] = current_predictions['top_label_ids'].apply(
            lambda ids: get_next_label_ids(ids, mapping_children),
        )
        
        current_predictions['propagation_coefficients'] = current_predictions.apply(
            lambda row: get_propagation_coefficient(row, mapping_children_cnt),
            axis=1,
        )
        
        current_predictions['probs'] = fetch_probabilities_by_indices(next_predictions, current_predictions['next_label_ids'])
        current_predictions['probs'] = current_predictions.apply(
            lambda row: row['probs'] * row['propagation_coefficients'],
            axis=1,
        )    
        current_predictions['label_ids'] = current_predictions['next_label_ids']
        current_predictions['label_id'] = next_predictions['label_id']

    final_leaf_df = add_prediction_stats(current_predictions, top_n_dict[levels[-1]])
    trace_df['trace'] = final_leaf_df['pred@1_label_id'].apply(lambda label_id: get_parents_to_root(labels_graph, levels[-1], label_id=label_id))
    return level_wise_accuracies_df, trace_df


In [None]:
top_k = 3
top_n_dict = {
    'coarse': 5,
    'default': top_k,
}

In [None]:
level_wise_acc_df, trace_df = perform_hierarchical_prediction(top_n_dict=top_n_dict)
level_wise_acc_df

In [None]:
for i, col in enumerate(LEVEL_NAMES):
    clip_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.CLIP,
        setup_type=Setups.EVAL_ONLY,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        top_k=TOP_K,
    )
    baseline_coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        setup_type=Setups.EVAL_ONLY,
        model_type=ModelType.ZERO_SHOT,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
        top_k=TOP_K,

    )
    coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
        top_k=TOP_K,

    )
    
    clip_metrics = AccuracyMetricEvaluator.load(clip_setup)['overall']
    baseline_coop_metrics = AccuracyMetricEvaluator.load(baseline_coop_setup)['overall']
    baseline_coop_metrics['trainer_name'] = 'baseline_coop'
    coop_metrics = AccuracyMetricEvaluator.load(coop_setup)['overall']
    
    col_name = clip_setup.get_label_column_name() or 'default'
    hierarchy_metrics = level_wise_accuracies_df[col]
    hierarchy_metrics['trainer_name'] = 'hierarchy_coop'
    plot_model_accuracy([clip_metrics, baseline_coop_metrics, coop_metrics, hierarchy_metrics], title=f"Overall Perfmance on the '{col_name}' Column - Dataset: {clip_setup.get_dataset_name()}")
    plt.show()

In [None]:
for i, col in enumerate(LEVEL_NAMES):
    if col != 'default':
        continue
    
    clip_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.CLIP,
        setup_type=Setups.EVAL_ONLY,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        top_k=TOP_K,
    )
    baseline_coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        setup_type=Setups.EVAL_ONLY,
        model_type=ModelType.ZERO_SHOT,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
        top_k=TOP_K,

    )
    coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        label_column_name=col,
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
        top_k=TOP_K,

    )
    
    clip_metrics = AccuracyMetricEvaluator.load(clip_setup)['overall']
    # baseline_coop_metrics = AccuracyMetricEvaluator.load(baseline_coop_setup)['overall']
    # baseline_coop_metrics['trainer_name'] = 'baseline_coop'
    coop_metrics = AccuracyMetricEvaluator.load(coop_setup)['overall']
    
    col_name = clip_setup.get_label_column_name() or 'default'
    hierarchy_metrics = level_wise_accuracies_df[col]
    hierarchy_metrics['trainer_name'] = 'hierarchy_coop'
    plot_model_accuracy([clip_metrics, coop_metrics, hierarchy_metrics], title=f"Overall Perfmance on the '{col_name}' Column - Dataset: {clip_setup.get_dataset_name()}")
    plt.show()