In [1]:
import os
import re
from functools import partial

import torch
import numpy as np
import pandas as pd

from datasets import Dataset, Value, ClassLabel, Features, Sequence
from transformers import pipeline, AutoTokenizer
from tqdm.auto import tqdm

from dataset.textdataset import ArticleDataset
from dataset.transformers_dataset import load_data



In [2]:
df = load_data("multi_label_dataset.csv", "./articles", use_original_text=True)

In [11]:
df.head()

Unnamed: 0,File,Text,adulting-101,big-read,commentary,gen-y-speaks,gen-z-speaks,singapore,voices,world,fp
0,16_bear_cubs_rescued_from_home_in_Laos.txt,bangkok sixteen undernourished asiatic black b...,1,1,1,1,1,1,0,0,./articles\singapore\16_bear_cubs_rescued_from...
1,2_separate_fire_incidents_caused_by_active_mob...,singapore span le four hour tuesday march 19 a...,0,0,0,1,0,0,0,0,./articles\gen-y-speaks\2_separate_fire_incide...
2,2_years_jail_caning_and_fine_for_man_who_splas...,singapore believing girlfriend time cheating s...,0,0,0,1,0,0,0,0,./articles\gen-y-speaks\2_years_jail_caning_an...
3,5_smart_ways_to_stretch_your_dollar_with_GrabF...,app new feature serve convenience value whethe...,1,1,1,1,1,1,1,1,./articles\world\5_smart_ways_to_stretch_your_...
4,A_Barefaced_Charmaine_Sheh_48_Looks_Amazing_On...,used seeing hong kong actress charmaine sheh 4...,0,0,1,1,0,0,1,0,./articles\voices\A_Barefaced_Charmaine_Sheh_4...


In [3]:
def get_dict(df):
    dataset = {}
    for _, row in df.iterrows():
        binary_targets = row[2:-1].to_numpy()
        labels = df.columns[2:-1][binary_targets == 1]
        targets = [df.columns[2:-1].get_loc(label) for label in labels]
        labels = list(map(lambda x: x.replace("-", " "), labels))
        if dataset.get("text") is None:
            dataset["text"] = [row["Text"]]
            dataset["binary_targets"] = [binary_targets]
            dataset["targets"] = [targets]
            dataset["labels"] = [labels]
        else:
            dataset["text"].append(row["Text"])
            dataset["binary_targets"].append(binary_targets)
            dataset["targets"].append(targets)
            dataset["labels"].append(labels)
    return dataset

In [8]:
# find examples that are not every label
df[df.iloc[:, 2:-1].sum(axis=1) < len(df.columns[2:])]

Unnamed: 0,File,Text,adulting-101,big-read,commentary,gen-y-speaks,gen-z-speaks,singapore,voices,world,fp
0,16_bear_cubs_rescued_from_home_in_Laos.txt,bangkok sixteen undernourished asiatic black b...,1,1,1,1,1,1,0,0,./articles\singapore\16_bear_cubs_rescued_from...
1,2_separate_fire_incidents_caused_by_active_mob...,singapore span le four hour tuesday march 19 a...,0,0,0,1,0,0,0,0,./articles\gen-y-speaks\2_separate_fire_incide...
2,2_years_jail_caning_and_fine_for_man_who_splas...,singapore believing girlfriend time cheating s...,0,0,0,1,0,0,0,0,./articles\gen-y-speaks\2_years_jail_caning_an...
3,5_smart_ways_to_stretch_your_dollar_with_GrabF...,app new feature serve convenience value whethe...,1,1,1,1,1,1,1,1,./articles\world\5_smart_ways_to_stretch_your_...
4,A_Barefaced_Charmaine_Sheh_48_Looks_Amazing_On...,used seeing hong kong actress charmaine sheh 4...,0,0,1,1,0,0,1,0,./articles\voices\A_Barefaced_Charmaine_Sheh_4...
...,...,...,...,...,...,...,...,...,...,...,...
161,Woman_arrested_for_Taylor_Swift_concert_ticket...,singapore 29 year old woman arrested monday ma...,1,1,1,1,0,0,0,0,./articles\gen-y-speaks\Woman_arrested_for_Tay...
162,Woman_charged_with_cheating_Taylor_Swift_fan_o...,singapore 29 year old woman tuesday mar 12 cha...,0,0,0,0,0,1,1,0,./articles\voices\Woman_charged_with_cheating_...
163,Worm_Moon_to_light_up_Singapore_sky.txt,singapore may appealing name worm moon still s...,0,1,1,1,0,0,0,0,./articles\gen-y-speaks\Worm_Moon_to_light_up_...
164,Your_Say_3_reasons_why_S_pore_is_justified_in_...,refer article baby step toughest period life s...,1,1,1,1,1,1,1,1,./articles\world\Your_Say_3_reasons_why_S_pore...


In [4]:
classes = [x.replace("-", " ") for x in df.columns[2:-1].to_list()]

In [16]:
ds = ArticleDataset("./articles", "multi_label_dataset.csv", None, 512)
dict_ds = {
    "text": ds.articles,
    "labels": ds.labels,
    "targets": ds.targets
}

dataset = Dataset.from_dict(dict_ds, features=Features({
    "text": Value("string"),
    "labels": Sequence(ClassLabel(names=classes)),
    "targets": Sequence(Value("int64"))
}))

In [5]:
dataset = Dataset.from_dict(
    get_dict(df),
    features=Features(
        {
            "text": Value("string"),
            "targets": Sequence(ClassLabel(num_classes=8, names=list(range(8)))),
            "binary_targets": Sequence(Value("int32")),
            "labels": Sequence(ClassLabel(names=classes)),
        }
    ),
)

In [6]:
train, test = dataset.train_test_split(test_size=0.2)

In [7]:
dataset.features

{'text': Value(dtype='string', id=None),
 'binary_targets': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'targets': Sequence(feature=ClassLabel(names=[0, 1, 2, 3, 4, 5, 6, 7], id=None), length=-1, id=None),
 'labels': Sequence(feature=ClassLabel(names=['adulting 101', 'big read', 'commentary', 'gen y speaks', 'gen z speaks', 'singapore', 'voices', 'world'], id=None), length=-1, id=None)}

In [19]:
dataset[0]

{'text': 'BANGKOK\xa0—\xa0Sixteen\xa0undernourished Asiatic black\xa0bear\xa0cubs\xa0have been found in a home in Laos capital Vientiane by a conservation charity, the largest rescue of the year.\nThe clutch of\xa0cubs, also known as moon bears after the white crescent of fur across their chests, are classified as vulnerable on the International Union for Conservation of Nature\xa0(IUCN) Red List of endangered species.\xa0\nAcross Asia, thousands of the animals are kept as pets or farmed to extract their bile for use in costly traditional medicine.\nWildlife conservation charity Free the Bears said they found 17\xa0cubs\xa0in a private home in Laos early last week, but that one of them had already died.\n"When we arrived at the house there were\xa0bear\xa0cubs\xa0everywhere," said Fatong Yang, an animal manager with the charity.\nThe group found ten males and six females, weighing between 1.3 to 4kg and believed to be around two to four months old.\n"Cubs\xa0this small are extremely vu

In [8]:
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    dtype=torch.bfloat16,
    device="cuda",
    fp16=True
)
candidate_labels = list(map(lambda x: x.replace("-", " "), df.columns[2:-1]))

In [9]:
def tokenize_text(instance, tokenizer):
    return tokenizer(instance["text"], truncation=True)

In [12]:
dataset[0]

{'text': 'bangkok sixteen undernourished asiatic black bear cub found home lao capital vientiane conservation charity largest rescue year clutch cub also known moon bear white crescent fur across chest classified vulnerable international union conservation nature iucn red list endangered specie across asia thousand animal kept pet farmed extract bile use costly traditional medicine wildlife conservation charity free bear said found 17 cub private home lao early last week one already died arrived house bear cub everywhere said fatong yang animal manager charity group found ten male six female weighing 1 3 4kg believed around two four month old cub small extremely vulnerable wild mother would never leave suspect mother killed poacher mr fatong said statement weekend charity head matt hunt said organisation would bring expert cambodia cope number rescued surpassing 2019 mission five cub saved country north bear rescued single year three month 2024 said free bear said police alerted house 

In [10]:
results = classifier(dataset["text"], candidate_labels=candidate_labels, multi_label=True)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [23]:
from sklearn.metrics import (accuracy_score, average_precision_score,
                             coverage_error, f1_score,
                             label_ranking_average_precision_score,
                             label_ranking_loss, multilabel_confusion_matrix,
                             precision_score, recall_score, roc_auc_score)
from sklearn.preprocessing import MultiLabelBinarizer

from metrics.auc import godbole_accuracy

In [54]:
results[1]

{'sequence': 'SINGAPORE — In the span of less than four hours on Tuesday (March 19) afternoon, the Singapore Civil Defence Force (SCDF) put out two fires involving mobility devices in separate incidents, the first of which caused three people to be taken to hospital due to smoke inhalation.\nIn a Facebook post, SCDF said that it was alerted to a fire at about 2.10pm\xa0at Block 706 Clementi West Street 2, which is a public housing block.\nBlack smoke was seen emitting from a fourth-floor flat when emergency responders arrived at the scene.\n“The occupants had evacuated the unit before SCDF’s arrival. Three persons from the unit were assessed for smoke inhalation and conveyed to the Singapore General Hospital,” SCDF said.\nThe fire, which engulfed contents of a room, was extinguished by firefighters from Clementi Fire Station using a water jet.\nThe living room area also sustained heat and smoke damage due to the fire.\nAn initial investigation by SCDF indicated that the fire had likely

In [46]:
mlb = MultiLabelBinarizer(classes=candidate_labels)
sample_labels = df[df.columns[2:-1]].apply(lambda x: list(df.columns[2:-1][x == 1]), axis=1)

In [47]:
label_to_dashed_labels = {label: label.replace(" ", "-") for label in candidate_labels}
dashed_lables_to_labels = {label.replace(" ", "-"): label for label in candidate_labels}

In [48]:
mlb.classes

['adulting 101',
 'big read',
 'commentary',
 'gen y speaks',
 'gen z speaks',
 'singapore',
 'voices',
 'world']

In [49]:
mlb.fit(sample_labels)

In [50]:
def get_scores(results, mlb):
    for result in results:
        score = result["scores"]
        labels = result["labels"]
        scores = [score[labels.index(label)] if label in labels else 0 for label in mlb.classes]
        yield scores

In [51]:
y_scores = np.array([score for score in get_scores(results, mlb)])

In [58]:
y_true = np.array(dataset["targets"])

In [59]:
y_true.shape, y_scores.shape

((165, 8), (165, 8))

In [60]:
y_true.ndim, y_scores.ndim

(2, 2)

In [66]:
best_thresh = 0
max_f1 = 0.0
for thresh in sorted(y_scores.flatten()):
    y_pred = (y_scores > thresh).astype(int)
    f1 = f1_score(y_true, y_pred, average="samples")
    if f1 > max_f1:
        max_f1 = f1
        best_thresh = thresh

In [67]:
y_test = y_true
y_pred = y_scores > best_thresh
y_prob = y_scores
acc = accuracy_score(y_test, y_pred)
godbole_acc = godbole_accuracy(y_test, y_pred, "macro")
godbole_chance_acc = godbole_accuracy(
    y_test, np.ones_like(y_test) * np.mean(y_test), "macro"
)
cov_error = coverage_error(y_test, y_prob)
f1 = f1_score(y_test, y_pred, average="micro")
lrap = label_ranking_average_precision_score(y_test, y_prob)
lrap_chance = label_ranking_average_precision_score(
    y_test, np.ones_like(y_test) * np.mean(y_test)
)
lrl = label_ranking_loss(y_test, y_prob)
prec = precision_score(y_test, y_pred, average="micro")
rec = recall_score(y_test, y_pred, average="micro")
y_test_inv = 1 - y_test
y_pred_inv = 1 - y_pred
spec = recall_score(y_test_inv, y_pred_inv, average="micro")
mlm = multilabel_confusion_matrix(y_test, y_pred)
auroc = roc_auc_score(y_test, y_prob, average="micro")
ap = average_precision_score(y_test, y_prob, average="micro")
ap_chance_level = average_precision_score(
    y_test, np.ones_like(y_test) * np.mean(y_test), average="micro"
)
fill_rate_pred = np.sum(y_pred) / y_pred.size
fill_rate = np.sum(y_test) / y_test.size

print(
    f"acc: {acc:.4f}",
    f"jaccard_index: {godbole_acc:.4f} / {godbole_chance_acc:.4f} (chance)",
    f"lrap: {lrap:.4f} / {lrap_chance:.4f} (chance)",
    f"f1: {f1:.4f}",
    f"lrl: {lrl:.4f}",
    f"rec: {rec:.4f}",
    f"prec: {prec:.4f}",
    f"spec: {spec: 4f}",
    f"cov_err: {cov_error:.4f}",
    f"auroc: {auroc:.4f}",
    f"ap: {ap:.4f} / {ap_chance_level:.4f} (chance)",
    f"fill_rate_pred: {fill_rate_pred:.4f} / {fill_rate:.4f} (true)",
    sep="\n",
    end="\n\n",
)

acc: 0.6606
jaccard_index: 0.8053 / 0.8061 (chance)
lrap: 0.8575 / 0.8061 (chance)
f1: 0.8922
lrl: 0.1581
rec: 0.9991
prec: 0.8059
spec:  0.000000
cov_err: 7.4485
auroc: 0.5293
ap: 0.8199 / 0.8061 (chance)
fill_rate_pred: 0.9992 / 0.8061 (true)



In [86]:
((y_scores > 0.5) == y_true).mean()

0.6929824561403509