In [None]:
import os

import hydra
import numpy as np
import omegaconf
import torch
import transformers
from sklearn.metrics import f1_score, matthews_corrcoef, precision_score, recall_score
from tqdm.auto import tqdm

import classifier
import dataloader

In [None]:
omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
  'if_then_else',
  lambda condition, x, y: x if condition else y
)

In [None]:
# Load classifier
with hydra.initialize(version_base=None, config_path='../configs/'):
    classifier_config = hydra.compose(
        config_name='config',
        overrides=[
            'hydra.output_subdir=null',
            f"hydra.run.dir={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8",
            'hydra/job_logging=disabled',
            'hydra/hydra_logging=disabled',
            '+is_eval_classifier=True',
            'mode=train_classifier',
            'loader.global_batch_size=32',
            'loader.eval_global_batch_size=64',
            'loader.batch_size=1',
            'loader.eval_batch_size=1',
            'data=ten_species',
            'data.label_col=species_label',
            'data.num_classes=10',
            'classifier_model=hyenadna-classifier',
            'classifier_model.hyena_model_name_or_path=LongSafari/hyenadna-small-32k-seqlen-hf',
            'classifier_model.n_layer=8',
            'classifier_backbone=hyenadna',
            'model.length=32768',
            'diffusion=null',
            'T=null',
            f"eval.checkpoint_path={os.path.dirname(os.getcwd())}/outputs/ten_species/eval_classifier/hyenadna-small-32k_from-scratch_nlayer-8/checkpoints/best.ckpt",
        ]
    )
classifier_config = omegaconf.OmegaConf.create(classifier_config)
tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_config.data.tokenizer_name_or_path, trust_remote_code=True)
pretrained_classifier = classifier.Classifier.load_from_checkpoint(
    classifier_config.eval.checkpoint_path,
    tokenizer=tokenizer,
    config=classifier_config, logger=False)
pretrained_classifier.eval();

In [None]:
tokenizer = dataloader.get_tokenizer(classifier_config)
_, val_dl = dataloader.get_dataloaders(
    classifier_config, tokenizer, skip_train=True, valid_seed=classifier_config.seed)

In [None]:
labels = []
preds = []
for batch in tqdm(val_dl):
    preds.append(
        pretrained_classifier(batch['input_ids'].to(pretrained_classifier.device)).argmax(dim=-1).detach().to(
            'cpu', non_blocking=True).numpy()
    )
    labels.append(batch['species_label'].numpy())

In [None]:
labels = np.concatenate(labels)
preds = np.concatenate(preds)

In [None]:
overall_accuracy_score = (preds == labels).sum() / preds.size
overall_f1_score = f1_score(y_pred=preds, y_true=labels, average="macro", labels=list(range(classifier_config.data.num_classes)))
overall_mcc_score = matthews_corrcoef(y_pred=preds, y_true=labels)

print(f"Overall Acc: {overall_accuracy_score:0.3f}")
print(f"Overall F1:  {overall_f1_score:0.3f}")
print(f"Overall MCC: {overall_mcc_score:0.3f}")

In [None]:
f1_scores = f1_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))
precision_scores = precision_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))
recall_scores = recall_score(y_pred=preds, y_true=labels, average=None , labels=list(range(classifier_config.data.num_classes)))

species_list = ['Homo_sapiens', 'Mus_musculus', 'Drosophila_melanogaster', 'Danio_rerio',
                'Caenorhabditis_elegans', 'Gallus_gallus', 'Gorilla_gorilla', 'Felis_catus',
                'Salmo_trutta', 'Arabidopsis_thaliana']
for s in range(classifier_config.data.num_classes):
    print(f"Class {s} - {species_list[s]}:")
    print(f"   F1:        {f1_scores[s]:0.3f}")
    print(f"   Precision: {precision_scores[s]:0.3f}")
    print(f"   Recall:    {recall_scores[s]:0.3f}")