# Prediction based on texts containing rare animals

In [None]:
import torch
from torch import cuda


device = 'cuda' if cuda.is_available() else 'cpu'
device

### Load models

In [None]:
model_names = {
    "bert":{"slug": "bert-base-uncased-pubmed", "file_name":"bert-base-uncased.pt"},
    "roberta":{"slug": "roberta-base-pubmed", "file_name":"roberta-base.pt"},
    "deberta":{"slug": "deberta-base-pubmed", "file_name":"deberta-base.pt"},
    "bluebert":{"slug": "bluebert-large-pubmed", "file_name":"bluebert_pubmed_uncased_L-24_H-1024_A-16.pt"},
    "xlnet":{"slug": "xlnet-large-pubmed", "file_name":"xlnet-large-cased.pt"},
    "svm":{"slug":"svm-linear-pubmed", "file_name":["svm.pkl", "vectorizer.pkl"]}
}

In [None]:
import kaggle
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
import joblib

from Source_code.z_utils.BERTClassifier import BERTClassifier
from Source_code.z_utils.RoBERTaClassifier import RoBERTaClassifier
from Source_code.z_utils.DeBERTaClassifier import DeBERTaClassifier
from Source_code.z_utils.BlueBERTClassifier import BlueBERTClassifier
from Source_code.z_utils.XLNetClassifier import XLNetClassifier


kaggle.api.authenticate()
data_path = "./models/"
if not os.path.exists(data_path):
    os.makedirs(data_path)
    print(f"Directory created: {data_path}")
models = {}

for model in model_names.keys():
    file_name = model_names[model]["file_name"]
    target_path = f"{data_path}{file_name}"
    
    if not os.path.exists(target_path):
        slug = model_names[model]["slug"]
        if model == "svm":
            kaggle.api.model_instance_version_download_cli(f"marcelhiltner/{slug}/scikitlearn/{slug}/1", data_path, untar=True)
        else:
            kaggle.api.model_instance_version_download_cli(f"marcelhiltner/{slug}/pytorch/{slug}/1", data_path, untar=True)
        
    if model == "svm":
        svm = joblib.load(f"{data_path}{file_name[0]}")
        print("svm loaded.")
        vectorizer = joblib.load(f"{data_path}{file_name[1]}")
        print("vectorizer loaded.")
        models[model] = (svm, vectorizer)
    else:
        models[model] = torch.load(target_path)
        models[model].eval()
        print(f"{model} loaded.")

### Preprocess text and predict

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
import spacy
import numpy as np

from Source_code.z_utils.data_preprocessing import preprocess_text
from Source_code.z_utils.lemmatize import lemmatize
from Source_code.z_utils.predict import predict


text = ''
text_pp = preprocess_text(text)

for model_key in models.keys():
    print("=" * 30)
    print(f'Model {model_key}')
    print("=" * 30)
    
    if model_key == "svm":
        lemmatizer = spacy.load('en_core_web_sm')
        text_prep = preprocess_text(text_pp, numbers=True)
        text_lemm = lemmatize(lemmatizer, text_prep)
        pred = models[model_key][0].predict(models[model_key][1].transform([text_lemm]))[0]
    else:
        model = models[model_key]
        model.to(device)
        model.eval()
        pred, _ = predict(model, texts=[text_pp], device=device)
        pred = np.argmax(pred[0])
    
    print(f"Prediction: {pred}")