# Testing HuggingFace DistilBERT model

The goal of this notebook is to train a Distilbert hugging face classifier on the gpt-labelled openalex/patents data and then test it on further data. 

The embeddings are generated using the "distilbert-based-uncased" sentence-transformer model. 

The equivalent refactored file resides in pipeline/models/binary_classifier. This notebook is for testing purposes only.

## 1. Import packages

In [None]:
import wandb
import numpy as np
from nesta_ds_utils.loading_saving import S3
from discovery_child_development.getters.binary_classifier.embeddings_hugging_face import (
    get_embeddings,
)
from discovery_child_development.getters.binary_classifier.prompts_edge_cases import get_examples
from discovery_child_development.utils.huggingface_pipeline import (
    load_model,
    load_training_args,
    load_trainer,
    saving_huggingface_model
)
from discovery_child_development.utils import wandb as wb
from discovery_child_development.utils import classification_utils
from discovery_child_development.utils.testing_examples_utils import testing_examples_huggingface
from discovery_child_development import (
    logging,
    S3_BUCKET,
    config,
    binary_config,
    PROJECT_DIR
)

## 2. Setting parameters

In [None]:
# Set up
S3_PATH = "models/binary_classifier/"
VECTORS_PATH = "data/labels/binary_classifier/vectors/"
VECTORS_FILE = "distilbert_sentence_vectors_384_labelled"
SEED = config["seed"]
# Set the seed
np.random.seed(SEED)

In [None]:
#PARAMS
wandb_run = False
save_model = False
production = False

## 3. Load data

In [None]:
if not production:
    VECTORS_FILE = VECTORS_FILE + "_test"

# Loading the training and validation embeddings
embeddings_training = get_embeddings(
    identifier="",
    production=production,
    set_type="train",
    vectors_path=VECTORS_PATH,
    vectors_file=VECTORS_FILE,
)
embeddings_validation = get_embeddings(
    identifier="",
    production=production,
    set_type="validation",
    vectors_path=VECTORS_PATH,
    vectors_file=VECTORS_FILE,
)

In [None]:
examples = get_examples()

## 4. Training the model

In [None]:
if wandb_run:
    print("Logging in wandb")
    run = wandb.init(
        project="ISS supervised ML",
        job_type="Binary classifier - huggingface",
        save_code=True,
        tags=["gpt-labelled", "distilbert", "openealex/patents"],
    )

In [None]:
# Load the model
model = load_model(config=binary_config, num_labels=2)

# Train model with early stopping
training_args = load_training_args(**binary_config["training_args"])
trainer = load_trainer(
    model=model,
    args=training_args,
    train_dataset=embeddings_training,
    eval_dataset=embeddings_validation,
    config=binary_config,
)
trainer.train()

## 5. Evaluating the model

In [None]:
# Evaluate model
trainer.evaluate()

# View f1, prediction, recall and accuracy of predictions on validation set
model_predictions = trainer.predict(embeddings_validation)

In [None]:
model_predictions.metrics

In [None]:
predictions = np.argmax(model_predictions.predictions, axis=-1)
labels = model_predictions.label_ids.ravel().tolist()
# Creating confusion matrix
confusion_matrix = classification_utils.plot_confusion_matrix(
    labels, predictions, None, "Relevant works"
)

In [None]:
# Saving the model and logging to wandb
if save_model:
    # Save model to S3
    SAVE_TRAINING_RESULTS_PATH = PROJECT_DIR / "outputs/data/models/"
    saving_huggingface_model(trainer, 
                             f"gpt_labelled_binary_classifier_distilbert_production_{production}", 
                             save_path=SAVE_TRAINING_RESULTS_PATH, 
                             s3_path=S3_PATH)

In [None]:
if wandb_run:
    # Log metrics
    wandb.run.summary["f1"] = model_predictions.metrics["test_f1"]
    wandb.run.summary["accuracy"] = model_predictions.metrics["test_accuracy"]
    wandb.run.summary["precision"] = model_predictions.metrics["test_precision"]
    wandb.run.summary["recall"] = model_predictions.metrics["test_recall"]

    # Adding reference to this model in wandb
    wb.add_ref_to_data(
        run=run,
        name=f"gpt_labelled_binary_classifier_distilbert_production_{production}",
        description=f"Distilbert model trained on binary classifier training data",
        bucket=S3_BUCKET,
        filepath=f"{S3_PATH}gpt_labelled_binary_classifier_distilbert_production_{production}.tar.gz",
    )

    # Log confusion matrix
    wb_confusion_matrix = wandb.Table(
        data=confusion_matrix, columns=confusion_matrix.columns
    )
    run.log({"confusion_matrix": wb_confusion_matrix})

    # End the weights and biases run
    wandb.finish()

## 6. Trialing some test examples

In [None]:
# Removing Not-specified
examples = examples.query("labels!='Not specified'")
examples.tail()

In [None]:
testing_examples_huggingface(trainer,examples, binary_config)

# 7. Trialling the model on the openalex concepts

In [2]:
from discovery_child_development.getters.openalex import get_abstracts
from discovery_child_development.getters.openalex_broad_concepts import get_abstracts_broad
from discovery_child_development.getters.binary_classifier.gpt_labelled_datasets import get_labelled_data_for_classifier
import pandas as pd

In [None]:
# Get labelled training data
labelled_data = get_labelled_data_for_classifier(set_type="train")
labelled_data_ids = labelled_data.id.unique()

In [None]:
# Get abstracts
abstracts = get_abstracts().query("id not in @labelled_data_ids")
abstracts_broad = get_abstracts_broad().query("id not in @labelled_data_ids")

In [None]:
# Collecting sample of results
relevant = abstracts.sample(500,random_state=SEED).assign(labels=1)
not_relevant = abstracts_broad.sample(500,random_state=SEED).assign(labels=0)
test_set = pd.concat([relevant,not_relevant])

In [None]:
results = testing_examples_huggingface(trainer,test_set[['labels','text']], binary_config)

In [None]:
results[1]

In [None]:
results[0]