## Setup


In [1]:
import numpy as np
import pandas as pd
from datasets import load_from_disk, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import precision_score, recall_score
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification

In [2]:
seed = 42
data_path = "../data/20-news-groups/"
model_ckpt = "bert-base-cased"
model_name = "bert-news-groups-classifier"
model_path = f"../models/{model_name}"

## Data Preprocessing


In [3]:
data = load_from_disk(dataset_path=data_path)
data

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 11314
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 7532
    })
})

In [4]:
def clean_text(text):
    return " ".join(text.split("\n\n")[1:]).replace("\n", " ")

In [5]:
clean_data = data.map(
    lambda x: {"text": [clean_text(t) for t in x["text"]]},
    batched=True,
    remove_columns=["text"],
)
clean_data

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 11314
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 7532
    })
})

In [6]:
label_names = data["train"].features["labels"].names
label_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [7]:
label2id = {label_names[i]: i for i in range(len(label_names))}
id2label = {i: label_names[i] for i in range(len(label_names))}

In [8]:
def split(ds, split="stratified", seed=42, train_size=0.75):
    splits = None
    if split == "stratified":
        splits = train_test_split(
            ds, stratify=ds.labels, random_state=seed, train_size=train_size
        )
    elif split == "balanced":
        class_ratios = ds.labels.value_counts(normalize=True)
        classes = ds.labels.unique()
        num_classes = len(classes)
        min_ratio = min(class_ratios.to_list())
        train_size = min(train_size, num_classes * min_ratio)
        print(f"Train size used: {train_size}")
        class_ratio = train_size / num_classes
        examples_per_class = int(class_ratio * len(ds))

        inds = []
        for c in classes:
            sample = ds[ds.labels == c].sample(examples_per_class, random_state=seed)
            inds.extend(sample.index.to_list())
        splits = (ds.iloc[inds, :], ds.drop(index=inds))
    else:
        raise Exception("Unknown split method")
    return splits

In [9]:
splits = split(clean_data["train"].to_pandas(), split="balanced", train_size=0.6)

Train size used: 0.6


In [10]:
clean_data["train"] = Dataset.from_pandas(splits[0].reset_index(drop=True))
clean_data["valid"] = Dataset.from_pandas(splits[1].reset_index(drop=True))
clean_data

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 6780
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 7532
    })
    valid: Dataset({
        features: ['text', 'labels'],
        num_rows: 4534
    })
})

## Tokenization


In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

In [12]:
def tokenize_ds(batch):
    return tokenizer(batch["text"], truncation=True)

In [13]:
encoded_data = clean_data.map(tokenize_ds, batched=True, remove_columns=["text"])

Map:   0%|          | 0/6780 [00:00<?, ? examples/s]

Map:   0%|          | 0/4534 [00:00<?, ? examples/s]

In [14]:
encoded_data

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 6780
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 7532
    })
    valid: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 4534
    })
})

## Modeling


In [15]:
def compute_scores(preds):
    logits, labels = preds
    pred = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, pred)
    f1 = f1_score(labels, pred, average="weighted")
    prec = precision_score(labels, pred, average="weighted")
    rec = recall_score(labels, pred, average="weighted")
    return {"Accuracy": acc, "Precision": prec, "Recall": rec, "F1": f1}

In [16]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_ckpt,
    num_labels=len(label_names),
    id2label=id2label,
    label2id=label2id,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### Training


In [18]:
import json
import wandb

with open("../data/access_tokens.json") as f:
    login_key = json.load(f)["wandb"]["login"]

wandb.login(key=login_key)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

[34m[1mwandb[0m: Currently logged in as: [33me_hossam96[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [19]:
wandb.init(project="train-test-split", name="balanced-split")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [20]:
training_args = TrainingArguments(
    output_dir=model_path,
    overwrite_output_dir=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    log_level="error",
    learning_rate=2e-5,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
)

In [21]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_data["train"],
    eval_dataset=encoded_data["valid"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_scores,
)

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.842741,0.784076,0.796659,0.784076,0.786029
2,1.682800,0.597751,0.83745,0.847563,0.83745,0.840787
3,0.598100,0.565576,0.846493,0.858365,0.846493,0.850558


TrainOutput(global_step=1272, training_loss=0.9811556954053963, metrics={'train_runtime': 1217.0335, 'train_samples_per_second': 16.713, 'train_steps_per_second': 1.045, 'total_flos': 5347856353086720.0, 'train_loss': 0.9811556954053963, 'epoch': 3.0})

In [23]:
trainer.evaluate(encoded_data["test"], metric_key_prefix="test")

{'test_loss': 0.7723978161811829,
 'test_Accuracy': 0.7842538502389803,
 'test_Precision': 0.7893519134358521,
 'test_Recall': 0.7842538502389803,
 'test_F1': 0.7852225916255242,
 'test_runtime': 128.4815,
 'test_samples_per_second': 58.623,
 'test_steps_per_second': 3.666,
 'epoch': 3.0}

In [24]:
wandb.finish()

VBox(children=(Label(value='0.034 MB of 0.050 MB uploaded (0.005 MB deduped)\r'), FloatProgress(value=0.665385…

0,1
eval/Accuracy,▁▇█
eval/F1,▁▇█
eval/Precision,▁▇█
eval/Recall,▁▇█
eval/loss,█▂▁
eval/runtime,▁█▅
eval/samples_per_second,█▁▄
eval/steps_per_second,█▁▅
test/Accuracy,▁
test/F1,▁

0,1
eval/Accuracy,0.84649
eval/F1,0.85056
eval/Precision,0.85837
eval/Recall,0.84649
eval/loss,0.56558
eval/runtime,77.5999
eval/samples_per_second,58.428
eval/steps_per_second,3.66
test/Accuracy,0.78425
test/F1,0.78522


In [25]:
outs = trainer.predict(encoded_data["test"])
preds = np.argmax(outs[0], axis=-1)
labels = outs[1]

print(classification_report(labels, preds, target_names=label_names))

                          precision    recall  f1-score   support

             alt.atheism       0.69      0.59      0.64       319
           comp.graphics       0.74      0.70      0.72       389
 comp.os.ms-windows.misc       0.72      0.75      0.74       394
comp.sys.ibm.pc.hardware       0.62      0.74      0.67       392
   comp.sys.mac.hardware       0.79      0.75      0.77       385
          comp.windows.x       0.85      0.87      0.86       395
            misc.forsale       0.84      0.84      0.84       390
               rec.autos       0.85      0.88      0.87       396
         rec.motorcycles       0.89      0.80      0.84       398
      rec.sport.baseball       0.95      0.90      0.92       397
        rec.sport.hockey       0.94      0.94      0.94       399
               sci.crypt       0.83      0.79      0.81       396
         sci.electronics       0.72      0.62      0.67       393
                 sci.med       0.88      0.87      0.87       396
         