# Testing Simple Classifiers

The goal of this notebook is to train simple classifiers on the gpt-labelled openalex/patents and then test them on further data. We will use the following classifiers:

* Logistic Regression
* K-Nearest Neighbors
* Random Forest
* SGD Classifier
* Support Vector Machine

The embeddings are generated using the "all-MiniLM-L6-v2" sentence-transformer model. 

The equivalent refactored file resides in pipeline/models/binary_classifier. This notebook is for testing purposes only.

## 1. Import packages

In [1]:
# Import packages
from dotenv import load_dotenv
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
import wandb

In [18]:
## Nesta DS utils
from nesta_ds_utils.loading_saving import S3
from discovery_child_development.utils.jsonl_utils import load_jsonl
import json

In [3]:
## Import from project
from discovery_child_development import PROJECT_DIR, logging, config, S3_BUCKET
from discovery_child_development.utils import classification_utils
from discovery_child_development.utils.general_utils import replace_binary_labels
from discovery_child_development.getters.binary_classifier.gpt_labelled_datasets import (
    get_labelled_data_for_classifier,
)
from discovery_child_development.getters.openalex import get_sentence_embeddings
from discovery_child_development.getters.binary_classifier.prompts_edge_cases import get_examples
from discovery_child_development.utils.testing_examples_utils import testing_examples_simple
from discovery_child_development.utils.general_utils import replace_binary_labels
from discovery_child_development.utils import wandb as wb

  from .autonotebook import tqdm as notebook_tqdm


2024-02-29 11:57:53,766 - botocore.credentials - INFO - Found credentials in environment variables.
2024-02-29 11:57:54,230 - datasets - INFO - PyTorch version 2.1.2 available.


In [4]:
load_dotenv()

True

## 2. Setting parameters

In [8]:
MODEL_PATH = PROJECT_DIR / "outputs/models/taxonomy_cat/binary"
MODEL_PATH.mkdir(parents=True, exist_ok=True)
S3_PATH = "models/taxonomy_cat/binary/"

PATH_FROM = "data/labels/binary_classifier/processed/"
VECTORS_PATH = "data/outputs/vectors/"
VECTORS_FILE = "sentence_vectors_384_labelled.parquet"

In [20]:
LABELS_DIR = PROJECT_DIR / 'outputs/labels/taxonomy_cat'
PATH_TO_TOPICS = PROJECT_DIR / "discovery_child_development/pipeline/labelling/taxonomy_cat/prompts/topics.json"

EVALS_DIR = PROJECT_DIR / 'outputs/labels/evals_data'
LABELS_TAXONOMY = EVALS_DIR / "taxonomy_labels_eval_annotated.jsonl"

In [79]:
# Setting the seed
SEED = config["seed"]
np.random.seed(SEED)

In [80]:
#PARAMS
wandb_run = False
save_model = False

In [81]:
def get_taxonomy_data(path: str):
    return (
        pd.DataFrame(load_jsonl(path))
        .assign(correct = lambda df: df['answer'] == 'accept')
        )

In [229]:
# Embeddings from all-MiniLM-L6-v2
embeddings_all = (
    get_sentence_embeddings(
        s3_bucket=S3_BUCKET, filepath=VECTORS_PATH, filename=VECTORS_FILE, id="id",
    )
    .reset_index()
    .assign(id=lambda df: df['id'].apply(lambda x: x.split('/')[-1]))
    # restore index
    .set_index("id")
)

In [230]:
len(embeddings_all)

51234

## 3. Load data

In [370]:
# topic = "ai2"
# topic = "ar_vr"
# topic = "income"
# topic = "parenting"
topic = "mobile"

In [374]:
topics = json.load(open(PATH_TO_TOPICS, 'r'))
topic_info = topics[topic]
topic_name = topic_info['name']

eval_df = (
    get_taxonomy_data(LABELS_TAXONOMY)
    .query("prediction == @topic_name")
    .drop(columns=["prediction"])
    .rename(columns={"answer": "prediction"})
    .assign(prediction=lambda df: df.prediction.map({"reject": "Not-relevant", "accept": "Relevant"}))    
)[['id', 'prediction', 'text']]

labels_df = (
    pd.DataFrame(load_jsonl(LABELS_DIR / f"taxonomy_cat_{topic}.jsonl"))
    .assign(id = lambda df: df['id'].apply(lambda x: x.split('/')[-1]))
    .query("id not in @eval_df.id")
    .rename(columns={"prediction": "labels"})
)[['id', 'labels', 'text']]

In [375]:
# labels_df

In [376]:
# eval_df.merge(labels_df, on="id", how="left").query("labels != prediction")[['id','prediction', 'labels', 'text_x']]

In [377]:
# # Train-test split
labelled_text_training = labels_df.sample(frac=0.8, random_state=0)
labelled_text_validation = labels_df.drop(labelled_text_training.index)

In [378]:
# labelled_text_training = get_labelled_data_for_classifier(
#         set_type="train", path_from=PATH_FROM
# )
# labelled_text_validation = get_labelled_data_for_classifier(
#         set_type="validation", path_from=PATH_FROM
# )

In [379]:
# examples = get_examples()

## 4. Setting up training and validation sets

In [380]:
# Create training and validation sets
training_set = labelled_text_training.merge(embeddings_all, on="id", how="left")
validation_set = labelled_text_validation.merge(embeddings_all, on="id", how="left")
training_set = replace_binary_labels(training_set, replace_cat=["Relevant","Not-relevant"]).dropna(subset=["miniLM_384_vector"])
validation_set = replace_binary_labels(validation_set, replace_cat=["Relevant","Not-relevant"]).dropna(subset=["miniLM_384_vector"])

# Setting up the training and validation sets
X_train = training_set["miniLM_384_vector"].apply(pd.Series).values
X_val = validation_set["miniLM_384_vector"].apply(pd.Series).values

Y_train = training_set["labels"]
Y_val = validation_set["labels"]

In [381]:
len(X_train), len(X_val)

(799, 200)

## 5. Training and evaluating the models

In [382]:
models_simple = ["log_regression", "knn", "random_forest", "sgd", "svm"]
if not save_model:
    models_all = {}
for model in models_simple:
    # Initialise wandb run
    if wandb_run:
        # Initialize a wandb run
        run = wandb.init(
            project="ISS supervised ML",
            job_type="Binary classifier - base models",
            save_code=True,
            tags=["gpt-labelled", "all-MiniLM-L6-v2", model, "openealex/patents"],
        )
        # Add reference to this data in wandb
        wb.add_ref_to_data(
            run=run,
            name="binary_train_data_raw",
            description=f"Binary classifier training data",
            bucket=S3_BUCKET,
            filepath=f"{PATH_FROM}gpt_labelled_train.csv",
        )
        
    # Creating the classifier
    if model == "log_regression":
        classifier = LogisticRegression(penalty="l2", random_state=SEED)
    elif model == "knn":
        classifier = KNeighborsClassifier()
    elif model == "random_forest":
        classifier = RandomForestClassifier(random_state=SEED)
    elif model == "sgd":
        classifier = SGDClassifier(random_state=SEED)
    elif model == "svm":
        classifier = LinearSVC(random_state=SEED)

    # Fitting the model
    classifier.fit(X_train, Y_train)
    # Predicting on the validation set
    predictions = classifier.predict(X_val)

    # Creating metrics
    metrics = classification_utils.create_average_metrics(
    Y_val, predictions, average="binary"
    )
    logging.info(f"Metrics for {model} model:")
    logging.info(metrics)
    logging.info(f"------")

    if save_model:
        # Save model to S3
        S3.upload_obj(
        obj=classifier,
        bucket=S3_BUCKET,
        path_to=f"{S3_PATH}gpt_labelled_binary_classifier_{model}.pkl",
        )
    else:
        models_all[model] = classifier

    if wandb_run:
        # Log metrics
        wandb.run.summary["f1"] = metrics["f1"]
        wandb.run.summary["accuracy"] = metrics["accuracy"]
        wandb.run.summary["precision"] = metrics["precision"]
        wandb.run.summary["recall"] = metrics["recall"]

        # Adding reference to this model in wandb
        wb.add_ref_to_data(
            run=run,
            name=f"binary_classifier_{model}",
            description=f"{model} model trained on binary classifier training data",
            bucket=S3_BUCKET,
            filepath=f"{S3_PATH}gpt_labelled_binary_classifier_{model}.pkl",
        )

        # End the weights and biases run
        wandb.finish()

2024-02-29 17:49:13,262 - root - INFO - Metrics for log_regression model:
2024-02-29 17:49:13,272 - root - INFO - {'accuracy': 0.83, 'precision': 0.65, 'recall': 0.325, 'f1': 0.43333333333333335, 'hamming': 0.17, 'jaccard': 0.2765957446808511}
2024-02-29 17:49:13,277 - root - INFO - ------


2024-02-29 17:49:13,331 - root - INFO - Metrics for knn model:
2024-02-29 17:49:13,334 - root - INFO - {'accuracy': 0.76, 'precision': 0.4, 'recall': 0.4, 'f1': 0.4, 'hamming': 0.24, 'jaccard': 0.25}
2024-02-29 17:49:13,342 - root - INFO - ------
2024-02-29 17:49:14,042 - root - INFO - Metrics for random_forest model:
2024-02-29 17:49:14,042 - root - INFO - {'accuracy': 0.825, 'precision': 0.7777777777777778, 'recall': 0.175, 'f1': 0.2857142857142857, 'hamming': 0.175, 'jaccard': 0.16666666666666666}
2024-02-29 17:49:14,042 - root - INFO - ------
2024-02-29 17:49:14,059 - root - INFO - Metrics for sgd model:
2024-02-29 17:49:14,064 - root - INFO - {'accuracy': 0.735, 'precision': 0.3968253968253968, 'recall': 0.625, 'f1': 0.4854368932038835, 'hamming': 0.265, 'jaccard': 0.32051282051282054}
2024-02-29 17:49:14,068 - root - INFO - ------
2024-02-29 17:49:14,110 - root - INFO - Metrics for svm model:
2024-02-29 17:49:14,111 - root - INFO - {'accuracy': 0.785, 'precision': 0.4545454545454



## 6. Trialing some examples

In [341]:
# # Removing Not-specified
examples = eval_df.rename(columns={"prediction": "labels"})
examples = replace_binary_labels(examples, replace_cat=["Relevant","Not-relevant"])
examples.tail()

Unnamed: 0,id,labels,text
449,W4248651340,1,Class Inequality in Parental Childcare Time: Evidence from Synthetic Couples in the ATUS. The time that parents spend teaching and playing with their young children has important consequences for ...
455,W4235825001,0,"He Tatau Pounamu. Considerations for an early childhood peace curriculum focusing on criticality, indigeneity, and an ethic of care, in Aotearoa New Zealand. This article discusses some of the phi..."
459,W2946468075,0,The prevalence of premature thelarche in girls and gynecomastia in boys and the associated factors in children in Southern China. To investigate the prevalence and risk factors of premature thelar...
460,W4311597343,1,A genealogical study of the emergence of kindergartens in Iran: an intersectional approach. There are histories describing in detail the development of early childhood education (ECE) around the w...
461,W3109647903,1,Relationship between mother’s knowledge and behaviour with oral health status of early childhood. Introduction: Early childhood period has a high caries risk that needs special attention from pare...


In [361]:
testing_examples_simple(list(examples.text),list(examples.labels),models_all["log_regression"])

2024-02-29 17:46:57,470 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: all-MiniLM-L6-v2
2024-02-29 17:46:57,617 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu


Batches: 100%|██████████| 2/2 [00:01<00:00,  1.86it/s]


(array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0]),
 {'accuracy': 0.56,
  'precision': 0.6,
  'recall': 0.13043478260869565,
  'f1': 0.21428571428571427,
  'hamming': 0.44,
  'jaccard': 0.12})

In [77]:
# testing_examples_simple(['this text is about AI', 'this text is about carrots'],[1, 0],models_all["log_regression"])

## Inference

In [362]:
def inference_simple(examples: pd.DataFrame, classifier):
    """Testing the simple classifier on some examples

    Args:
        examples (list): List of strings
        labels (list): List of labels
        classifier (sklearn classifier): Trained classifier

    Returns:
        predictions (list): List of predictions
        metrics (dict): Dictionary of metrics
    """
    test_df = examples
    X_test = test_df["miniLM_384_vector"].apply(pd.Series).values
    predictions = classifier.predict(X_test)
    try:
        probs = classifier.predict_proba(X_test)[:, 1]
    except:
        probs = None
    return (
        test_df
        .assign(labels=predictions)
        .assign(prob_relevant=probs)
    )

In [300]:
ENRICHED_DATA_DIR = PROJECT_DIR / 'outputs/enrichments'
PATH_TO_DATASET = ENRICHED_DATA_DIR / 'openalex_patents_relevance_labels_only_relevant.csv'
relevant_df = (
    pd.read_csv(PATH_TO_DATASET)
    .assign(id = lambda df: df['id'].apply(lambda x: x.split('/')[-1]))
    .merge(embeddings_all.reset_index(), on="id", how="left")
)

In [363]:
results_model = []
for model in models_all:
    results_df = (
        inference_simple(relevant_df, models_all[model])
        .rename(columns={"labels": "prediction"})
        .assign(model = model)
    )[['id', 'prediction', "prob_relevant", "model"]]
    results_model.append(results_df)
results_model = pd.concat(results_model, axis=0)


In [364]:
results_model.groupby(["id"]).prediction.mean().value_counts()

0.0    34863
0.2     7441
0.4     3164
0.6     2137
0.8     2049
1.0     1580
Name: prediction, dtype: int64

In [365]:
ensemble_df = (
    results_model
    .groupby(["id"])
    .agg(
        prediction = ("prediction", "mean"),
        prob_relevant = ("prob_relevant", "mean"),
    )
)
ensemble_df.head(2)

Unnamed: 0_level_0,prediction,prob_relevant
id,Unnamed: 1_level_1,Unnamed: 2_level_1
AR-098372-A1,0.0,0.201884
AR-101703-A4,0.0,0.120133


In [366]:
results_df = (
    relevant_df
    .merge(ensemble_df, on="id", how="left")
    .drop(columns=["miniLM_384_vector", "Unnamed: 0"])
)

In [367]:
len(results_df.query("prob_relevant > 0.5"))

5059

In [368]:
len(results_df.query("prediction >= 0.8"))

3629

In [384]:
results_df.query("prediction >= 0.8").sort_values("prob_relevant", ascending=False).head(10)

Unnamed: 0,id,text,predictions,source,prediction,prob_relevant
29523,W4235494632,Digital Technology in Kindergarten. This chapter examines the literature surrounding digital technologies within kindergarten. It highlights the ways in which mobile devices and smart gadgets are ...,1,openalex,1.0,0.912876
25087,W2734818634,Digital Technology in Kindergarten. This chapter examines the literature surrounding digital technologies within kindergarten. It highlights the ways in which mobile devices and smart gadgets are ...,1,openalex,1.0,0.912876
3549,KR-20200111350-A,Online to offline language education method using blockable typo convergence design technology. According to the O2O language learning education method using the BTCD technology according to an em...,1,patents,1.0,0.904673
44641,W2909408616,Exploring visual prompts for communicating directional awareness to kindergarten children. Although a myriad of educational applications using tablets and multi-touch technology for kindergarten c...,1,openalex,1.0,0.903075
19438,W2288892440,Teacher knowledge for using technology to foster early literacy: A literature review. A literature review was conducted to describe the knowledge and skills teachers need for using technology to f...,1,openalex,1.0,0.892191
11491,KR-20150091638-A,"Smart-device with app possible colloquial learning of Infant.. The present invention relates to a smart device for learning a daily life language, and more particularly, to a smart device for lear...",1,patents,1.0,0.88
2384,KR-101619828-B1,"Smart-device with app possible colloquial learning of Infant.. The present invention relates to a smart device for learning a daily life language, and more particularly, to a smart device for lear...",1,patents,1.0,0.88
38615,W4285196572,Parents' Perception of the Use of Digital Book Reading App in Improving English Skills for Early Childhood. Literature studies have shown benefits of using technology and multimedia to improve chi...,1,openalex,1.0,0.877434
11599,KR-20220170213-A,"Smart mental care system for childcare and education. The present invention relates to a smart childcare, education, and mental care system for infants and children, and more particularly, a speak...",1,patents,1.0,0.876326
29220,W3020374818,"Screens or no screens: Understanding young children's use of digital technologies. Children are accessing digital-technologies at younger ages and at an increasing rate, especially in the home env...",1,openalex,1.0,0.874945


In [383]:
len(results_df)

51234

In [263]:
# pd.set_option('display.max_colwidth', 200)
# results_df.query("predicted == 1").groupby("source").size()

In [138]:
# results_df.predicted.value_counts()

In [385]:
(
    results_df[['id', 'prediction', 'prob_relevant']]
    .assign(topic=topic)
    .to_csv(ENRICHED_DATA_DIR / f'taxonomy_cat/taxonomy_cat_predictions_{topic}.csv', index=False)
)