# Classifier prototype

Using setfit & only data available from text blocks or sentences task. This means that a bunch of training data is sentences rather than text blocks.


In [31]:
import sys

!{sys.executable} -m pip install argilla
!{sys.executable} -m pip install setfit sklearn scikit-multilearn


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting scikit-multilearn
  Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.4/89.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-multilearn
Successfully installed scikit-multilearn-0.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [91]:
import os
import random

from datasets import Dataset
from dotenv import load_dotenv, find_dotenv
import argilla as rg
from tqdm.auto import tqdm
import numpy as np
import evaluate
from setfit import SetFitModel, SetFitTrainer
from sklearn.preprocessing import MultiLabelBinarizer
from skmultilearn.model_selection import iterative_train_test_split
import evaluate
import pandas as pd

load_dotenv(find_dotenv(), override=True)

True

In [3]:
# Config

DATASET_NAME = "sectors-sentence-or-text-block"

In [5]:
# User management is done at a workspace level
rg.init(
    workspace="gst",
    api_key=os.environ["ARGILLA_API_KEY"],
)

dataset = rg.load(DATASET_NAME).to_datasets()
dataset_df = dataset.to_pandas()

## 1. Data analysis

In [14]:
dataset_df = dataset_df.dropna(subset=["annotation"])
dataset_df["annotation"].explode().value_counts()

agriculture, forestry and other land use    44
energy                                      42
industry                                    23
transport                                   19
insurance & financial services              17
water services                              15
buildings                                    6
fisheries & aquaculture                      4
health services                              4
tourism                                      4
Name: annotation, dtype: int64

In [16]:
dataset_df["has_labels"] = dataset_df.annotation.apply(lambda i: len(i) > 0)
dataset_df.groupby("has_labels").count()["text"]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset_df['has_labels'] = dataset_df.annotation.apply(lambda i: len(i) > 0)


has_labels
False    283
True     121
Name: text, dtype: int64

## 2. Train-test split & format data

Basic non-stratified train-test split as a stratified one is difficult in the multilabel case.

In [98]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(dataset_df["annotation"].values)
X = dataset_df["text"].values.reshape(-1)
X = np.reshape(X, (X.size, 1))

X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size=0.3)
X_train_1d = np.array([i[0] for i in X_train])
X_test_1d = np.array([i[0] for i in X_test])

## 3. Train setfit classifier

In [99]:
train_dataset = Dataset.from_dict({"text": X_train_1d, "label": y_train})
test_dataset = Dataset.from_dict({"text": X_test_1d, "label": y_test})

In [100]:
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    multi_target_strategy="multi-output",  # one-vs-rest; multi-output; classifier-chain
)

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_precision_metric = evaluate.load("precision", "multilabel")
multilabel_recall_metric = evaluate.load("recall", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")


# micro metrics
def compute_metrics(y_pred, y_test):
    return {
        "f1": multilabel_f1_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["f1"],
        "precision": multilabel_precision_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["precision"],
        "recall": multilabel_recall_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["recall"],
        "accuracy": multilabel_accuracy_metric.compute(
            predictions=y_pred, references=y_test
        )["accuracy"],
    }

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [101]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    metric=compute_metrics,
    num_iterations=10,
)

# TODO: optimise num_iterations parameter using optuna: https://github.com/huggingface/setfit#running-hyperparameter-search

trainer.train()
metrics = trainer.evaluate()
print(metrics)

Generating Training Pairs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 52.98it/s]
***** Running training *****
  Num examples = 4060
  Num epochs = 1
  Total optimization steps = 254
  Total train batch size = 16
Epoch:   0%|                                                                                                                                                       | 0/1 [00:00<?, ?it/s]
Iteration:   0%|                                                                                                                                                 | 0/254 [00:00<?, ?it/s][A
Iteration:   0%|▌                                                                                                                                        | 1/254 [00:09<40:04,  9.50s/it][A
Iteration:   1%|█                                                                                                      

{'f1': 0.6436781609195402, 'precision': 0.8484848484848485, 'recall': 0.5185185185185185, 'accuracy': 0.8442622950819673}


In [108]:
def compute_perclass_metrics(y_pred, y_test):
    return {
        "f1": multilabel_f1_metric.compute(
            predictions=y_pred, references=y_test, average=None
        ),
        "precision": multilabel_precision_metric.compute(
            predictions=y_pred, references=y_test, average=None
        ),
        "recall": multilabel_recall_metric.compute(
            predictions=y_pred, references=y_test, average=None
        ),
        "accuracy": multilabel_accuracy_metric.compute(
            predictions=y_pred, references=y_test
        ),
    }


trainer.metric = compute_perclass_metrics

print(mlb.classes_)
print(trainer.evaluate())

***** Running evaluation *****


['agriculture, forestry and other land use' 'buildings' 'energy'
 'fisheries & aquaculture' 'health services' 'industry'
 'insurance & financial services' 'tourism' 'transport' 'water services']
{'f1': {'f1': array([0.63157895, 0.        , 0.88888889, 0.        , 0.        ,
       0.33333333, 0.57142857, 0.        , 0.90909091, 0.33333333])}, 'precision': {'precision': array([1.        , 0.        , 0.85714286, 0.        , 0.        ,
       0.4       , 1.        , 0.        , 1.        , 1.        ])}, 'recall': {'recall': array([0.46153846, 0.        , 0.92307692, 0.        , 0.        ,
       0.28571429, 0.4       , 0.        , 0.83333333, 0.2       ])}, 'accuracy': {'accuracy': 0.8442622950819673}}


  _warn_prf(average, modifier, msg_start, len(result))


In [102]:
y_pred = model.predict(X_test_1d, as_numpy=True)

df_test = pd.DataFrame(
    {
        "text": X_test_1d,
        "y_true": mlb.inverse_transform(y_test),
        "y_pred": mlb.inverse_transform(y_pred),
    }
)
print(
    f"incorrect: {len(df_test[df_test['y_true'] != df_test['y_pred']])}/{len(df_test)})"
)

with pd.option_context("display.max_colwidth", -1):
    display(df_test[df_test["y_true"] != df_test["y_pred"]])

incorrect: 19/122)


Unnamed: 0,text,y_true,y_pred
1,"As such, the GMW can be used as a tool for countries that do not yet have their mangrove monitoring systems to design, implement and track the progress of their national climate commitments and identify opportunities to include mangroves in the next round of nationally determined contributions (NDCs), in support of the Paris Agreement's long-term goals for mitigation and adaptation.","(agriculture, forestry and other land use, fisheries & aquaculture)",()
2,"This has led to energy production facilities, in this case thermal power plants, reducing GHG emissions (greenhouse gases) estimated at 3.565Gg of","(energy,)","(energy, industry)"
3,"56. The Party also reported information on the results achieved from the implementation of its mitigation actions, as estimated outcomes, mitigation co-benefits and emission reductions, to the extent possible. Details on the results achieved for some of the mitigation actions, across all sectors, were not reported in the BUR. During the technical analysis, the Party clarified that it faced constraints and challenges in reporting results for some of the mitigation actions in the energy, IPPU, AFOLU and waste sectors. The TTE noted that specifying these constraints for the relevant mitigation actions could facilitate a better understanding of the information reported.","(agriculture, forestry and other land use, energy, water services)","(energy,)"
4,(1) Establishing an online platform that enables easy access to the underlying data and information used for calculating the REDD+ FRLs and results (see paras. 29-30 above).,"(agriculture, forestry and other land use,)",()
6,"Even though the government has put a limit on age for importation of used vehicles, the emission factor for fuel used in transport sector is as shown in Table 2.23 are anticipated to remain applicable in the foreseeable future.","(transport,)","(energy, transport)"
7,Finance,"(insurance & financial services,)",()
9,"The purpose of this part of the Action Table is to highlight specific and promotable actions to deliver the vision for the Industry thematic area, with a sector-based\napproach focused on ICT and Mobile.","(industry,)",()
10,"41. Mitigation actions planned in the energy sector consist of eight measures, mainly on renewable energy, with one on substitution of fuel. Six mitigation measures in the AFOLU sector focus on reforestation and rehabilitation of forests and pastures and two of these are of national priority as they contribute to the preservation of natural resources. One project relates to the IPPU sector. The measures for the period 2016-2020 in the energy and AFOLU sectors cover 49.5 per cent and 46.9 per cent of the total national GHG mitigation potential, respectively. The TTE notes a discrepancy in the information provided in the BUR (page 64, table 26) and annex 2 to the BUR concerning the qualitative assumptions of the impact of mitigation action in the energy sector (Diffusion 30000 LED lamps). The TTE notes that the transparency of future reports could be improved by strengthening the QA/QC procedure in relation to the information reported.","(agriculture, forestry and other land use, energy, industry)","(energy,)"
14,"Figure 5.6- Energy industries, actual and projected GHG emissions (Mt CO₂ eq.)","(energy,)","(energy, industry)"
15,"As for water borne and food borne diseases, diarrheal diseases are directly influenced by climate change\ndue to the occurrence and the survival of bacterial agents, toxic algal blooms in water, and viral pathogens,","(agriculture, forestry and other land use, health services, water services)",()
