# Classifier prototype

Using setfit & only data available from text blocks or sentences task. This means that a bunch of training data is sentences rather than text blocks.


In [31]:
import sys

!{sys.executable} -m pip install argilla
!{sys.executable} -m pip install setfit sklearn scikit-multilearn


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting scikit-multilearn
  Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.4/89.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-multilearn
Successfully installed scikit-multilearn-0.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [20]:
import os
import random
from typing import Sequence

from datasets import Dataset
from dotenv import load_dotenv, find_dotenv
import argilla as rg
from tqdm.auto import tqdm
import numpy as np
import evaluate
from setfit import SetFitModel, SetFitTrainer
from sklearn.preprocessing import MultiLabelBinarizer
from skmultilearn.model_selection import iterative_train_test_split
import evaluate
import pandas as pd
from cpr_data_access.models import Dataset as CPRDataset
from cpr_data_access.models import BaseDocument, TextBlock

load_dotenv(find_dotenv(), override=True)

True

In [2]:
# Config

DATASET_NAME = "sector-text-classifier"

In [3]:
# User management is done at a workspace level
rg.init(
    workspace="gst",
    api_key=os.environ["ARGILLA_API_KEY"],
)

dataset = rg.load(DATASET_NAME).to_datasets()
dataset_df = dataset.to_pandas()

## 1. Data analysis

In [4]:
dataset_df = dataset_df.dropna(subset=["annotation"])
dataset_df["annotation"].explode().value_counts()

energy                                      44
tourism                                     40
agriculture, forestry and other land use    38
water services                              36
industry                                    33
health services                             32
transport                                   28
insurance & financial services              24
buildings                                   23
fisheries & aquaculture                     23
Name: annotation, dtype: int64

In [5]:
dataset_df["has_labels"] = dataset_df.annotation.apply(lambda i: len(i) > 0)
dataset_df.groupby("has_labels").count()["text"]

has_labels
False    108
True     208
Name: text, dtype: int64

## 2. Train-test split & format data

Basic non-stratified train-test split as a stratified one is difficult in the multilabel case.

In [6]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(dataset_df["annotation"].values)
X = dataset_df["text"].values.reshape(-1)
X = np.reshape(X, (X.size, 1))

X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size=0.3)
X_train_1d = np.array([i[0] for i in X_train])
X_test_1d = np.array([i[0] for i in X_test])

## 3. Train setfit classifier

In [7]:
train_dataset = Dataset.from_dict({"text": X_train_1d, "label": y_train})
test_dataset = Dataset.from_dict({"text": X_test_1d, "label": y_test})

In [8]:
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    multi_target_strategy="multi-output",  # one-vs-rest; multi-output; classifier-chain
)

multilabel_f1_metric = evaluate.load("f1", "multilabel")
multilabel_precision_metric = evaluate.load("precision", "multilabel")
multilabel_recall_metric = evaluate.load("recall", "multilabel")
multilabel_accuracy_metric = evaluate.load("accuracy", "multilabel")


# micro metrics
def compute_metrics(y_pred, y_test):
    return {
        "f1": multilabel_f1_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["f1"],
        "precision": multilabel_precision_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["precision"],
        "recall": multilabel_recall_metric.compute(
            predictions=y_pred, references=y_test, average="micro"
        )["recall"],
        "accuracy": multilabel_accuracy_metric.compute(
            predictions=y_pred, references=y_test
        )["accuracy"],
    }

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [9]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    metric=compute_metrics,
    num_iterations=1,  # chosen for demo purposes to make training quick
)

# TODO: optimise num_iterations parameter using optuna: https://github.com/huggingface/setfit#running-hyperparameter-search

trainer.train()
metrics = trainer.evaluate()
print(metrics)

Generating Training Pairs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 32.93it/s]
***** Running training *****
  Num examples = 446
  Num epochs = 1
  Total optimization steps = 28
  Total train batch size = 16
Epoch:   0%|                                                                                                                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]
Iteration:   0%|                                                                                                                                                                                                                                                         

{'f1': 0.7225806451612903, 'precision': 0.9491525423728814, 'recall': 0.5833333333333334, 'accuracy': 0.6842105263157895}


In [67]:
def compute_perclass_metrics(y_pred, y_test):
    sample_metrics = {
        "f1": multilabel_f1_metric.compute(
            predictions=y_pred, references=y_test, average="samples"
        )["f1"],
        "precision": multilabel_precision_metric.compute(
            predictions=y_pred, references=y_test, average="samples"
        )["precision"],
        "recall": multilabel_recall_metric.compute(
            predictions=y_pred, references=y_test, average="samples"
        )["recall"],
        "accuracy": multilabel_accuracy_metric.compute(
            predictions=y_pred, references=y_test
        )["accuracy"],
    }

    per_class_metrics = (
        multilabel_f1_metric.compute(
            predictions=y_pred, references=y_test, average=None
        )
        | multilabel_precision_metric.compute(
            predictions=y_pred, references=y_test, average=None
        )
        | multilabel_recall_metric.compute(
            predictions=y_pred, references=y_test, average=None
        )
    )

    per_class_metrics_by_class = dict()

    for idx, class_name in enumerate(list(mlb.classes_)):
        f1_score = per_class_metrics["f1"][idx]
        precision_score = per_class_metrics["precision"][idx]
        recall_score = per_class_metrics["recall"][idx]

        per_class_metrics_by_class[class_name] = {
            "f1": f1_score,
            "precision": precision_score,
            "recall": recall_score,
        }

    return {"sample_level": sample_metrics, "per_class": per_class_metrics_by_class}


trainer.metric = compute_perclass_metrics

metrics = trainer.evaluate()
display(metrics)

***** Running evaluation *****
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'sample_level': {'f1': 0.44210526315789467,
  'precision': 0.48947368421052634,
  'recall': 0.424561403508772,
  'accuracy': 0.6842105263157895},
 'per_class': {'agriculture, forestry and other land use': {'f1': 0.8421052631578948,
   'precision': 1.0,
   'recall': 0.7272727272727273},
  'buildings': {'f1': 0.8333333333333333,
   'precision': 1.0,
   'recall': 0.7142857142857143},
  'energy': {'f1': 0.6666666666666667,
   'precision': 0.875,
   'recall': 0.5384615384615384},
  'fisheries & aquaculture': {'f1': 0.8333333333333333,
   'precision': 1.0,
   'recall': 0.7142857142857143},
  'health services': {'f1': 0.7499999999999999,
   'precision': 1.0,
   'recall': 0.6},
  'industry': {'f1': 0.4285714285714285, 'precision': 0.75, 'recall': 0.3},
  'insurance & financial services': {'f1': 0.6,
   'precision': 1.0,
   'recall': 0.42857142857142855},
  'tourism': {'f1': 0.9090909090909091,
   'precision': 1.0,
   'recall': 0.8333333333333334},
  'transport': {'f1': 0.6666666666666666, 'pr

In [None]:
metrics["accuracy"]

In [14]:
y_pred = model.predict(X_test_1d, as_numpy=True)

df_test = pd.DataFrame(
    {
        "text": X_test_1d,
        "y_true": mlb.inverse_transform(y_test),
        "y_pred": mlb.inverse_transform(y_pred),
    }
)
print(
    f"incorrect: {len(df_test[df_test['y_true'] != df_test['y_pred']])}/{len(df_test)}"
)

with pd.option_context("display.max_colwidth", -1):
    display(df_test[df_test["y_true"] != df_test["y_pred"]])

incorrect: 30/95


Unnamed: 0,text,y_true,y_pred
5,"Over 7,000 businesses and 500 of the biggest investors have joined the UN Race to Zero, and financial leaders managing $130 trillion in assets are committed to 1.5°C.Over 1800 businesses have pledged to become net zero by 2030. SMEs are also rising to the climate change challenge in a number of ways. An increasing amount of initiatives and commitments registered on the NAZCA portal and the growing participation of business and industry representatives at the annual Conference of the Parties (COP) as well as intersessional technical meetings demonstrates the ever-sharpening focus of business on climate change action.","(industry, insurance & financial services)",()
7,Ocean acidification and other changes in high-latitude oceans are also a cause for concern for Iceland because of the importance of the fishing industry and the ocean in general in the Icelandic economy.,"(fisheries & aquaculture, industry)","(fisheries & aquaculture,)"
8,"The Declaration brings together industry, governments and multipliers to work at sectoral level and signatories increased twofold since the launch in November 2021. UNWTO is leading the implementation of the Glasgow Declaration, which was endorsed by its Executive Council in June 2022 in Decision CE/116.","(industry,)",()
12,"The Law sets requirement to the Cabinet of Ministers to regulate the criteria for determination and managing of highly vulnerable territories with increased requirements for the protection of water and soill. Law on Pollution also classifying polluting activities into Categories A, B, and C, considering the quantity and effect or the risk of pollution caused to human health and the environment. In agriculture sector polluting activities requiring a Category A permit are farms for the intensive rearing of pigs and poultry with more than 40 000 places for poultry or with more than 2 000 places for production pigs with weight over 30 kg (with more than 750 places for sows). These farms shall apply the best available techniques to prevent pollution.","(agriculture, forestry and other land use, water services)","(agriculture, forestry and other land use,)"
17,(1) Establishing an online platform that enables easy access to the underlying data and information used for calculating the REDD+ FRLs and results (see paras. 29-30 above).,"(agriculture, forestry and other land use,)",()
19,"Create an insurance fund f for the agricultural sector during emergency situations and in the context of climate change, improve existing and construct new storage facilities for crop and livestock products;","(agriculture, forestry and other land use, insurance & financial services)","(agriculture, forestry and other land use,)"
21,"The health sub-sector will also be affected by climate change. Indeed, the high temperatures and the increase in precipitation would lead to the proliferation of vector-borne diseases such as malaria, meningitis, diarrhea, water-borne diseases such as cholera, respiratory diseases such as rhinitis and sinusitis and infectious diseases.","(health services,)","(health services, water services)"
22,"30. Non-economic losses refer to a broad range of losses that are not financially quantifiable or commonly traded in markets and may impact individuals, society or the environment."" Parties mandated in 2012 a technical paper on non-economic losses,"" which describes the concept and brought into view eight types of such losses, which were emerging in three areas: loss of life, health or mobility (incidence of direct loss on individuals); loss of territory, cultural heritage, indigenous or local knowledge, or societal or cultural identity (society); and loss of biodiversity or ecosystem services (environment).","(health services,)",()
23,5) Improve Shipping,"(transport,)",()
24,"The Climate Change Law envisages the Ministry of Tourism and Environment, and the NEA having a more active role in the coordination and compilation work regarding the GHG inventory. As such, the knowledge of the application of these tools will need to be transferred to these organisations. There is currently no system used to collate information related to these activities and Albania envisages the implementation of an MRV tool to support this collation and tracking of information.","(tourism,)",()


## 4. Evaluate on a sample of unlabelled text blocks

In [18]:
DOC_LIMIT = 500
TEXT_BLOCKS_PER_DOCUMENT = 1

dataset = (
    CPRDataset(BaseDocument)
    .load_from_local(os.environ["DOCS_DIR_GST"], limit=DOC_LIMIT)
    .filter_by_language("en")
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:31<00:00, 15.83it/s]


In [31]:
text_blocks_doc_metadata_sample = []

for document in tqdm(dataset.documents):
    if document.text_blocks is None:
        print(f"Skipping {document.document_id} as no text blocks")
        continue

    doc_metadata = document.dict(exclude={"text_blocks", "page_metadata"})

    # Randomly sample a fixed number of text blocks per document
    if len(document.text_blocks) <= TEXT_BLOCKS_PER_DOCUMENT:
        blocks = document.text_blocks
    else:
        blocks = random.sample(document.text_blocks, TEXT_BLOCKS_PER_DOCUMENT)

    text_blocks_doc_metadata_sample += zip(blocks, [doc_metadata] * len(blocks))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401/401 [00:00<00:00, 16196.10it/s]


In [68]:
def predict_from_text_blocks(
    model: SetFitModel, text_blocks_and_doc_metadata: Sequence[tuple[TextBlock, dict]]
):
    text = [
        block.to_string().replace("\n", " ").replace("  ", " ")
        for (block, _) in text_blocks_and_doc_metadata
    ]
    y_pred = model.predict(text, as_numpy=False)
    y_pred_df = pd.DataFrame(y_pred, columns=mlb.classes_)

    metadata = [
        {"text": text[idx]}
        | block.dict(include={"language", "text_block_id", "type", "page_number"})
        | metadata_dict
        for idx, (block, metadata_dict) in enumerate(text_blocks_and_doc_metadata)
    ]
    metadata_df = pd.DataFrame.from_records(metadata)

    predictions_df = pd.concat([metadata_df, y_pred_df], axis=1)

    return predictions_df


pred_df = predict_from_text_blocks(model, text_blocks_doc_metadata_sample[:20])

In [70]:
print(pred_df.columns)
pred_df.head()

Index(['text', 'text_block_id', 'language', 'type', 'page_number',
       'document_id', 'document_name', 'document_source_url',
       'document_content_type', 'document_md5_sum', 'languages', 'translated',
       'has_valid_text', 'document_metadata',
       'agriculture, forestry and other land use', 'buildings', 'energy',
       'fisheries & aquaculture', 'health services', 'industry',
       'insurance & financial services', 'tourism', 'transport',
       'water services'],
      dtype='object')


Unnamed: 0,text,text_block_id,language,type,page_number,document_id,document_name,document_source_url,document_content_type,document_md5_sum,...,"agriculture, forestry and other land use",buildings,energy,fisheries & aquaculture,health services,industry,insurance & financial services,tourism,transport,water services
0,"16. Lastly, the CGE collaborates with other co...",p_5_b_2,en,BlockType.TEXT,5,CCLW.GST.623.623,CGE_GST20submission,https://unfccc.int/sites/default/files/resourc...,application/pdf,4c219bb0d9c9c0ea4764090fc8526e61,...,0,0,0,0,0,0,0,0,0,0
1,Constructing a BAU scenario is a collaborative...,p_49_b_13,en,BlockType.TEXT,49,CCLW.GST.722.722,BUR20Report_Final,https://unfccc.int/sites/default/files/resourc...,application/pdf,8e8c6e42e66efc05534454cb3a640e33,...,0,0,0,0,0,0,0,0,0,0
2,18. Lesotho reported in its first BUR informat...,p_4_b_2,en,BlockType.TEXT,4,CCLW.GST.817.817,tasr1_2022_LSO,https://unfccc.int/sites/default/files/resourc...,application/pdf,d6c614d4ca70544d2d53e28ac53e9856,...,0,0,0,0,0,0,0,0,0,0
3,"For Singapore to achieve net zero, we are emba...",p_28_b_6,en,BlockType.TEXT,28,CCLW.GST.520.520,Singapore20-20NC5BUR5,https://unfccc.int/sites/default/files/resourc...,application/pdf,448cd627f1022a776f74dc8faa316f8f,...,0,0,1,0,0,0,0,0,0,0
4,. Minimize GHG emissions from transport sector...,p_93_b_15,en,BlockType.GOOGLE_BLOCK,93,CCLW.GST.208.208,PakistanE28099s20First20Biennial20Update20Repo...,https://unfccc.int/sites/default/files/resourc...,application/pdf,38a54cc6f9c4578b740586f124779ec2,...,0,0,1,0,0,0,0,0,1,0


In [72]:
class_names = mlb.classes_

In [95]:
from cpr_data_access.models import Span

In [90]:
def reverse_encoding(row):
    return row[class_names].index[row[class_names] == 1].tolist()


pred_df["predictions"] = pred_df.apply(reverse_encoding, axis=1)

In [99]:
spans = []

for idx, row in pred_df.head(20).iterrows():
    for pred_class in row["predictions"]:
        spans.append(
            Span(
                document_id=row["document_id"],
                text_block_text_hash=row["text_hash"],
                type=pred_class,
                id=pred_class,
                text=row["text"],
                start_idx=0,
                end_idx=len(row["text"]),
                sentence=row["text"],
                pred_probability=1,  # FIXME
                annotator="sector-classifier",  # FIXME
            )
        )