In [92]:
# !pip install -U accelerate
# !pip install -U transformers

In [93]:
import pandas as pd
df = pd.read_csv("./dataset/problems.csv", usecols=["description", "labels"])
df.head(10)

Unnamed: 0,description,labels
0,John gave Jack a very hard problem. He wrote a...,['math']
1,Due to the recent popularity of the Deep learn...,"['dynamic programming', 'matrices']"
2,Bill is a famous mathematician in BubbleLand. ...,"['greedy', 'sorting']"
3,The competitors of Bubble Cup X gathered after...,"['shortest path', 'graphs', 'binary search']"
4,John has just bought a new car and is planning...,['dynamic programming']
5,"Consider an array A with N elements, all being...","['combinatorics', 'number theory', 'math']"
6,The citizens of BubbleLand are celebrating the...,"['dynamic programming', 'geometry']"
7,This story is happening in a town named Bubble...,"['trees', 'graphs']"
8,You are given an integer $$$x$$$ of $$$n$$$ di...,"['greedy', 'strings']"
9,You are given a Young diagram. Given diagram ...,"['greedy', 'dynamic programming', 'math']"


In [94]:
# df.shape
df.info()
# df.duplicated().sum()
# df['description'].str.len().plot.hist(bins=50)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10912 entries, 0 to 10911
Data columns (total 2 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   description  10912 non-null  object
 1   labels       10912 non-null  object
dtypes: object(2)
memory usage: 170.6+ KB


In [95]:
import ast

df['labels'] = df['labels'].apply(ast.literal_eval)
labels_cnt = [l for lab in df['labels'] for l in lab]
label_series = pd.Series(labels_cnt).value_counts()
print(label_series)

print("總共有", label_series.index.nunique(), "種 labels")

math                   3466
greedy                 3293
data structures        2710
dynamic programming    2630
graphs                 1950
sorting                1488
strings                1415
binary search          1366
trees                  1055
number theory           819
bit manipulation        785
combinatorics           736
two pointers            732
union find              424
geometry                411
divide and conquer      334
matrices                333
shortest path           289
game theory             257
hashing                 239
probabilities           229
interactive             210
Name: count, dtype: int64
總共有 22 種 labels


In [96]:
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
import torch
multilabel = MultiLabelBinarizer()
labels = multilabel.fit_transform(df["labels"]).astype('float32') # NumPy ndarray # To align label format with model prediction (both are float)
texts = df["description"].tolist()

co_matrix = np.dot(labels.T, labels)  # labels: shape (N_samples, N_labels)
total = np.sum(co_matrix)
# 機率矩陣
P_ij = co_matrix / total
# 邊際機率 P(i)
P_i = np.diag(co_matrix) / total  # shape: (n_labels,)
# 外積計算 P(i) * P(j)
P_i_P_j = np.outer(P_i, P_i)
# PMI 計算，加上小常數避免 log(0)
PMI = np.log(P_ij / (P_i_P_j + 1e-10) + 1e-10)
np.fill_diagonal(PMI, PMI.max())
# Normalize PMI to [0, 1] for soft label weight
PMI_norm = (PMI - PMI.min()) / (PMI.max() - PMI.min())
PMI_tensor = torch.tensor(PMI_norm, dtype=torch.float32)

k = 3  # 每個 label 最多擴展 3 個其他 label
PMI_topk = torch.zeros_like(PMI_tensor)
for i in range(PMI_tensor.shape[0]):
    topk_indices = torch.topk(PMI_tensor[i], k=k+1).indices  # +1 保留自己
    PMI_topk[i, topk_indices] = PMI_tensor[i, topk_indices]

# soft_labels = torch.matmul(torch.tensor(labels), PMI_tensor)
# soft_labels = torch.clamp(soft_labels, 0.0, 1.0)
# soft_labels = torch.round(soft_labels * 10) / 10

# import seaborn as sns
# import matplotlib.pyplot as plt
# label_names = multilabel.classes_
# plt.figure(figsize=(12, 10))
# sns.heatmap(PMI, xticklabels=label_names, yticklabels=label_names, cmap='YlGnBu', annot=False)
# plt.title("Label Co-occurrence Matrix")
# plt.show()


In [97]:
# calculate loss_weights to deal with dataset imbalance
label_counts = labels.sum(axis=0)
coefficient = 100
weights = 1.0 / np.log(label_counts + coefficient)
weights = weights / np.max(weights)  # normalize to [0, 1]
loss_weights = torch.tensor(weights, dtype=torch.float32)
# print(multilabel.classes_)
# print(label_counts)
print(loss_weights.max())
print(loss_weights.min())

tensor(1.)
tensor(0.7014)


In [98]:
from transformers import DistilBertTokenizer
from transformers import DistilBertForSequenceClassification
# from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import numpy as np

In [99]:
from skmultilearn.model_selection import iterative_train_test_split
texts_np = np.array(texts)  
labels_np = np.array(labels)

X_train, y_train, X_val, y_val = iterative_train_test_split( # need  multi-hot vector should be integer
        texts_np.reshape(-1, 1), labels_np, test_size=0.2
    )

# 還原回原本格式
train_texts = X_train.ravel().tolist()
val_texts = X_val.ravel().tolist()

# train_labels = torch.matmul(torch.tensor(y_train), PMI_topk)
# train_labels = torch.clamp(train_labels, 0.0, 1.0)
hard = torch.tensor(y_train, dtype=torch.float32)
soft = torch.matmul(hard, PMI_tensor)
soft = torch.clamp(soft, 0.0, 1.0)
alpha = 0.7 
train_labels = alpha * hard + (1 - alpha) * soft

val_labels = torch.tensor(y_val, dtype=torch.float32)  # hard label
val_labels = torch.clamp(val_labels, 0.0, 1.0)

print(y_train[0]) # before 
print(train_labels.numpy()[0]) # after


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0.10938767 0.1214942  0.1554979  0.08815955 0.10727646 0.11198483
 0.12614143 0.1325789  0.08376686 0.1150557  0.09712635 0.11303254
 1.         0.10799404 0.16193527 0.1496429  0.06120315 0.09811347
 0.07484833 0.07347731 0.09515471 0.08525355]


In [100]:
# model
import torch
import torch.nn as nn
from transformers import DistilBertModel


class DistilBertWithSoftLabel(nn.Module):
    def __init__(self, num_labels, loss_weights):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')  # No reduction 
        self.loss_weights = loss_weights


    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state[:, 0])  # [CLS] token
        
        if labels is not None: # label is soft_labels
            loss_matrix = self.loss_fn(logits, labels)
            # loss = loss_matrix.mean()  # or weighted sum
            loss_weights = self.loss_weights.to(logits.device)
            loss = (loss_matrix * loss_weights).mean()

            return {"logits": logits, "loss": loss}

        return {"logits": logits}

def data_collator(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        # 'labels': torch.stack([torch.tensor(x['labels'], dtype=torch.float32) for x in batch])
        'labels': torch.stack([x['labels'] for x in batch])
    }


In [101]:
# model settings
from transformers import DistilBertTokenizerFast
checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
label_counts = labels.sum(axis=0)
model = DistilBertWithSoftLabel(num_labels=len(labels[0]),loss_weights=loss_weights)

In [102]:
# tokenize
class CustomDataset(Dataset):
  def __init__(self, texts, labels, tokenizer, max_len=128):
    self.texts = texts
    self.labels = labels
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = str(self.texts[idx])
    label = self.labels[idx]
    # label = torch.tensor(self.labels[idx])
    if not isinstance(label, torch.Tensor):
        label = torch.tensor(label, dtype=torch.float32)
    else:
        label = label.detach().clone().float()

    encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors='pt')

    return {
        'input_ids': encoding['input_ids'].squeeze(0),
        'attention_mask': encoding['attention_mask'].squeeze(0),
        'labels': label
    }

train_dataset = CustomDataset(train_texts, train_labels, tokenizer)
val_dataset = CustomDataset(val_texts, val_labels, tokenizer)

In [103]:
# Multi-Label Classification Evaluation Metrics
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, hamming_loss, roc_curve
from transformers import EvalPrediction
import torch


def find_optimal_thresholds(y_true, y_probs):
    thresholds = []
    for i in range(y_true.shape[1]):
        fpr, tpr, th = roc_curve(y_true[:, i], y_probs[:, i])
        youdens_j = tpr - fpr
        best_th = th[np.argmax(youdens_j)]
        thresholds.append(best_th)
    print("Optimal thresholds:", thresholds)
    return np.array(thresholds)


def find_f1_optimal_thresholds(y_true, y_probs):
    thresholds = []
    y_true = (y_true >= 0.5).astype(int)
    for i in range(y_true.shape[1]):
        best_f1 = 0
        best_th = 0.5
        for th in np.linspace(0.05, 0.95, 50):
            y_pred_i = (y_probs[:, i] >= th).astype(int)
            f1 = f1_score(y_true[:, i], y_pred_i)
            if f1 > best_f1:
                best_f1 = f1
                best_th = th
        thresholds.append(best_th)
    return np.array(thresholds)


def multi_labels_metrics(predictions, labels):
  sigmoid = torch.nn.Sigmoid()
  # probs = sigmoid(torch.Tensor(predictions))
  probs = sigmoid(torch.tensor(predictions)).detach().cpu().numpy()

  y_pred = np.zeros(probs.shape)
  thresholds = np.maximum(find_f1_optimal_thresholds(labels,probs), 0.05)
#   thresholds = np.full(probs.shape[1], 0.3)

  y_pred = (probs >= thresholds).astype(int)
  y_true = labels
  
  f1 = f1_score(y_true, y_pred, average = 'macro')
  roc_auc = roc_auc_score(y_true, probs, average = 'macro')
  hamming = hamming_loss(y_true, y_pred)

  metrics = {
      "roc_auc": roc_auc,
      "hamming_loss": hamming,
      "f1": f1
  }
  return metrics

def compute_metrics(p:EvalPrediction):
  preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
  labels = (p.label_ids > 0.5).astype(int) # p.label_ids
  result = multi_labels_metrics(predictions=preds,
                                labels=labels)
  
  return result

In [104]:
# Training Arguments
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    output_dir = './results',
    num_train_epochs=5,
    save_steps=1000,
    save_total_limit=2
)

trainer = Trainer(model=model,
                  args=args,
                  train_dataset=train_dataset,
                  eval_dataset = val_dataset,
                  compute_metrics=compute_metrics,
                  data_collator=data_collator)

In [105]:
trainer.train()

Step,Training Loss
500,0.491
1000,0.484
1500,0.4777
2000,0.4739
2500,0.4689
3000,0.4653
3500,0.4594
4000,0.4522
4500,0.452
5000,0.4415


TrainOutput(global_step=5460, training_loss=0.46454372196407107, metrics={'train_runtime': 605.3743, 'train_samples_per_second': 72.129, 'train_steps_per_second': 9.019, 'total_flos': 0.0, 'train_loss': 0.46454372196407107, 'epoch': 5.0})

In [106]:
trainer.evaluate()

{'eval_loss': 0.3183152973651886,
 'eval_roc_auc': 0.7542329924741993,
 'eval_hamming_loss': 0.14028536860110977,
 'eval_f1': 0.45215181529330173,
 'eval_runtime': 18.2242,
 'eval_samples_per_second': 119.566,
 'eval_steps_per_second': 14.98,
 'epoch': 5.0}

## k=100, Threshold - roc optimal , p.label_ids > 0.8 as P
{'eval_loss': 0.5422307252883911,
 'eval_roc_auc': 0.6355803896134508,
 'eval_hamming_loss': 0.39180860367301046,
 'eval_f1': 0.5670835103814952,
 'eval_runtime': 22.0219,
 'eval_samples_per_second': 99.129,
 'eval_steps_per_second': 12.397,
 'epoch': 5.0}

## k=100, Threshold = 0.3 , p.label_ids > 0.8 as P
{'eval_loss': 0.5422307252883911,
 'eval_model_preparation_time': 0.0041,
 'eval_roc_auc': 0.6355803896134508,
 'eval_hamming_loss': 0.5140132428268022,
 'eval_f1': 0.6445478260151899,
 'eval_runtime': 21.2506,
 'eval_samples_per_second': 102.727,
 'eval_steps_per_second': 12.847}

## k=100, Threshold - f1 optimal , p.label_ids > 0.5 as P
{'eval_loss': 0.5422307252883911,
 'eval_model_preparation_time': 0.0029,
 'eval_roc_auc': 0.6209105006225866,
 'eval_hamming_loss': 0.2817015783117478,
 'eval_f1': 0.8335942488331942,
 'eval_runtime': 21.624,
 'eval_samples_per_second': 100.953,
 'eval_steps_per_second': 12.625}

## k=100, Threshold - roc optimal , p.label_ids > 0.5 as P
{'eval_loss': 0.5422307252883911,
 'eval_model_preparation_time': 0.0062,
 'eval_roc_auc': 0.6209105006225866,
 'eval_hamming_loss': 0.3794194811143964,
 'eval_f1': 0.7055115285615337,
 'eval_runtime': 19.9769,
 'eval_samples_per_second': 109.276,
 'eval_steps_per_second': 13.666}

{'eval_loss': 0.3183152973651886,
 'eval_roc_auc': 0.7542329924741993,
 'eval_hamming_loss': 0.14028536860110977,
 'eval_f1': 0.45215181529330173,
 'eval_runtime': 18.2242,
 'eval_samples_per_second': 119.566,
 'eval_steps_per_second': 14.98,
 'epoch': 5.0}
 

In [107]:
print(multilabel.classes_)


['binary search' 'bit manipulation' 'combinatorics' 'data structures'
 'divide and conquer' 'dynamic programming' 'game theory' 'geometry'
 'graphs' 'greedy' 'hashing' 'interactive' 'math' 'matrices'
 'number theory' 'probabilities' 'shortest path' 'sorting' 'strings'
 'trees' 'two pointers' 'union find']


In [108]:
# trainer.save_model("distilbert-finetuned-imdb-multi-label")
# trainer.save_model("best_model")
torch.save(model.state_dict(), "./best_model/model.pt")
torch.save(loss_weights, "./best_model/weights.pt")
tokenizer.save_pretrained("./best_model")  # tokenizer 可用 Hugging Face 的


('./best_model/tokenizer_config.json',
 './best_model/special_tokens_map.json',
 './best_model/vocab.txt',
 './best_model/added_tokens.json',
 './best_model/tokenizer.json')

In [109]:
# import pickle
# with open("multi-label-binarizer.pkl", "wb") as f:
#   pickle.dump(multilabel, f)

In [None]:
from sklearn.metrics import classification_report
print("Evaluating...")
preds = trainer.predict(val_dataset).predictions
pred_binary = (preds > 0.5).astype(int)
val_labels_binary = (val_labels > 0.5).int().numpy()
print("\nClassification Report:")
label_names = multilabel.classes_
print(classification_report(val_labels_binary, pred_binary, target_names=label_names))


Evaluating...

Classification Report:
                     precision    recall  f1-score   support

      binary search       0.28      0.07      0.12       273
   bit manipulation       0.62      0.25      0.35       157
      combinatorics       0.43      0.06      0.10       151
    data structures       0.55      0.29      0.38       542
 divide and conquer       0.00      0.00      0.00        71
dynamic programming       0.37      0.24      0.29       526
        game theory       0.66      0.53      0.59        51
           geometry       0.64      0.52      0.58        82
             graphs       0.63      0.53      0.58       390
             greedy       0.50      0.47      0.48       659
            hashing       0.00      0.00      0.00        50
        interactive       1.00      0.88      0.94        42
               math       0.58      0.48      0.53       693
           matrices       0.93      0.40      0.56        67
      number theory       0.69      0.37      

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
