# Testing Simple Classifiers

The goal of this notebook is to train simple classifiers on the gpt-labelled openalex/patents and then test them on further data. We will use the following classifiers:

* Logistic Regression
* K-Nearest Neighbors
* Random Forest
* SGD Classifier
* Support Vector Machine

The embeddings are generated using the "all-MiniLM-L6-v2" sentence-transformer model. 

The equivalent refactored file resides in pipeline/models/binary_classifier. This notebook is for testing purposes only.

## 1. Import packages

In [None]:
# Import packages
from dotenv import load_dotenv
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
import wandb

In [None]:
## Nesta DS utils
from nesta_ds_utils.loading_saving import S3

In [None]:
## Import from project
from discovery_child_development import PROJECT_DIR, logging, config, S3_BUCKET
from discovery_child_development.utils import classification_utils
from discovery_child_development.utils.general_utils import replace_binary_labels
from discovery_child_development.getters.binary_classifier.gpt_labelled_datasets import (
    get_labelled_data_for_classifier,
)
from discovery_child_development.getters.openalex import get_sentence_embeddings
from discovery_child_development.getters.binary_classifier.prompts_edge_cases import get_examples
from discovery_child_development.utils.testing_examples_utils import testing_examples_simple
from discovery_child_development.utils.general_utils import replace_binary_labels
from discovery_child_development.utils import wandb as wb

In [None]:
load_dotenv()

## 2. Setting parameters

In [None]:
MODEL_PATH = PROJECT_DIR / "outputs/models/"
S3_PATH = "models/binary_classifier/"

PATH_FROM = "data/labels/binary_classifier/processed/"
VECTORS_PATH = "data/labels/binary_classifier/vectors/"
VECTORS_FILE = "sentence_vectors_384_labelled.parquet"

In [None]:
# Setting the seed
SEED = config["seed"]
np.random.seed(SEED)

In [None]:
#PARAMS
wandb_run = False
save_model = False

## 3. Load data

In [None]:
labelled_text_training = get_labelled_data_for_classifier(
        set_type="train", path_from=PATH_FROM
)
labelled_text_validation = get_labelled_data_for_classifier(
        set_type="validation", path_from=PATH_FROM
)

In [None]:
examples = get_examples()

## 4. Setting up training and validation sets

In [None]:
# Embeddings from all-MiniLM-L6-v2
embeddings_all = get_sentence_embeddings(
        s3_bucket=S3_BUCKET, filepath=VECTORS_PATH, filename=VECTORS_FILE
)

In [None]:
# Create training and validation sets
training_set = labelled_text_training.merge(embeddings_all, on="id", how="left")
validation_set = labelled_text_validation.merge(embeddings_all, on="id", how="left")
training_set = replace_binary_labels(training_set, replace_cat=["Relevant","Not-relevant"])
validation_set = replace_binary_labels(validation_set, replace_cat=["Relevant","Not-relevant"])

# Setting up the training and validation sets
X_train = training_set["miniLM_384_vector"].apply(pd.Series).values
X_val = validation_set["miniLM_384_vector"].apply(pd.Series).values

Y_train = training_set["labels"]
Y_val = validation_set["labels"]

## 5. Training and evaluating the models

In [None]:
models_simple = ["log_regression", "knn", "random_forest", "sgd", "svm"]
if not save_model:
    models_all = {}
for model in models_simple:
    # Initialise wandb run
    if wandb_run:
        # Initialize a wandb run
        run = wandb.init(
            project="ISS supervised ML",
            job_type="Binary classifier - base models",
            save_code=True,
            tags=["gpt-labelled", "all-MiniLM-L6-v2", model, "openealex/patents"],
        )
        # Add reference to this data in wandb
        wb.add_ref_to_data(
            run=run,
            name="binary_train_data_raw",
            description=f"Binary classifier training data",
            bucket=S3_BUCKET,
            filepath=f"{PATH_FROM}gpt_labelled_train.csv",
        )
        
    # Creating the classifier
    if model == "log_regression":
        classifier = LogisticRegression(penalty="l2", random_state=SEED)
    elif model == "knn":
        classifier = KNeighborsClassifier()
    elif model == "random_forest":
        classifier = RandomForestClassifier(random_state=SEED)
    elif model == "sgd":
        classifier = SGDClassifier(random_state=SEED)
    elif model == "svm":
        classifier = LinearSVC(random_state=SEED)

    # Fitting the model
    classifier.fit(X_train, Y_train)
    # Predicting on the validation set
    predictions = classifier.predict(X_val)

    # Creating metrics
    metrics = classification_utils.create_average_metrics(
    Y_val, predictions, average="binary"
    )
    logging.info(metrics)

    if save_model:
        # Save model to S3
        S3.upload_obj(
        obj=classifier,
        bucket=S3_BUCKET,
        path_to=f"{S3_PATH}gpt_labelled_binary_classifier_{model}.pkl",
        )
    else:
        models_all[model] = classifier

    if wandb_run:
        # Log metrics
        wandb.run.summary["f1"] = metrics["f1"]
        wandb.run.summary["accuracy"] = metrics["accuracy"]
        wandb.run.summary["precision"] = metrics["precision"]
        wandb.run.summary["recall"] = metrics["recall"]

        # Adding reference to this model in wandb
        wb.add_ref_to_data(
            run=run,
            name=f"binary_classifier_{model}",
            description=f"{model} model trained on binary classifier training data",
            bucket=S3_BUCKET,
            filepath=f"{S3_PATH}gpt_labelled_binary_classifier_{model}.pkl",
        )

        # End the weights and biases run
        wandb.finish()

## 6. Trialing some examples

In [None]:
# Removing Not-specified
examples = examples.query("labels!='Not specified'")
examples = replace_binary_labels(examples, replace_cat=["Relevant","Not relevant"])
examples.tail()

In [None]:
testing_examples_simple(list(examples.text),list(examples.labels),models_all["log_regression"])