## Setfit

In [1]:
import torch

if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

DEVICE

'cuda'

### Tracking using Weights&Biases

In [2]:
%env WANDB_LOG_MODEL='end'
%env WANDB_WATCH='all'

env: WANDB_LOG_MODEL='end'
env: WANDB_WATCH='all'


In [3]:
config = {
    "model": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "num_samples": 16,
    "batch_size": 4,
    "num_iterations": 20,
    "num_epochs": 5,
    "seed": 42,
}

In [4]:
import wandb

wandb.login()
run = wandb.init(
    project="significance_classification", group="transformer_setfit", config=config
)

[34m[1mwandb[0m: Currently logged in as: [33mpaul_ww[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
from transformers import set_seed

set_seed(wandb.config["seed"])

  from .autonotebook import tqdm as notebook_tqdm


### Loading the dataset

In [6]:
from datasets import load_dataset

In [7]:
ds = load_dataset("paul-ww/ei-abstract-significance")

Found cached dataset parquet (/dhc/home/paul.wullenweber/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 549.90it/s]


In [8]:
class_labels = ds["train"].features["label"]
label2id = {name: class_labels.str2int(name) for name in class_labels.names}
id2label = {v: k for k, v in label2id.items()}

In [9]:
from setfit import sample_dataset

ds["train_setfit"] = sample_dataset(
    ds["train"], num_samples=wandb.config["num_samples"], seed=wandb.config["seed"]
)

Loading cached shuffled indices for dataset at /dhc/home/paul.wullenweber/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-a5382d79f656cf13.arrow


### Model Setup

In [10]:
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
from transformers import AutoTokenizer

In [11]:
tokenizer = AutoTokenizer.from_pretrained(wandb.config["model"], model_max_length=512)

In [12]:
def tokenize_text(ds):
    return tokenizer(ds["text"], truncation=True)

In [13]:
ds_tokenized = ds.map(tokenize_text, batched=True)

Loading cached processed dataset at /dhc/home/paul.wullenweber/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-f446dd680fa7c22a.arrow
Loading cached processed dataset at /dhc/home/paul.wullenweber/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-61eaa06ad9f7d334.arrow
Loading cached processed dataset at /dhc/home/paul.wullenweber/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-932faea0c00ec146.arrow
                                                                                

In [14]:
model = SetFitModel.from_pretrained(
    wandb.config["model"],
).to(DEVICE)

No sentence-transformers model found with name /dhc/home/paul.wullenweber/.cache/torch/sentence_transformers/microsoft_BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at /dhc/home/paul.wullenweber/.cache/torch/sentence_transformers/microsoft_BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you ar

In [15]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=ds_tokenized["train_setfit"],
    eval_dataset=ds_tokenized["validation"],
    loss_class=CosineSimilarityLoss,
    metric="f1",
    batch_size=wandb.config["batch_size"],
    num_iterations=wandb.config[
        "num_iterations"
    ],  # The number of text pairs to generate for contrastive learning
    num_epochs=wandb.config[
        "num_epochs"
    ],  # The number of epochs to use for contrastive learning
    seed=42,
)

In [16]:
trainer.train()

Generating Training Pairs: 100%|███████████████| 20/20 [00:00<00:00, 169.47it/s]
***** Running training *****
  Num examples = 5120
  Num epochs = 5
  Total optimization steps = 6400
  Total train batch size = 4
Iteration: 100%|████████████████████████████| 1280/1280 [06:08<00:00,  3.47it/s]
Iteration: 100%|████████████████████████████| 1280/1280 [06:09<00:00,  3.46it/s]
Iteration: 100%|████████████████████████████| 1280/1280 [06:09<00:00,  3.47it/s]
Iteration: 100%|████████████████████████████| 1280/1280 [06:10<00:00,  3.46it/s]
Iteration: 100%|████████████████████████████| 1280/1280 [06:08<00:00,  3.47it/s]
Epoch: 100%|█████████████████████████████████████| 5/5 [30:46<00:00, 369.28s/it]


In [17]:
trainer.evaluate()

***** Running evaluation *****


{'f1': 0.5573770491803279}

In [18]:
from pathlib import Path

model._save_pretrained(Path(run.dir) / "model_finetuned")

#### Evaluation

In [19]:
predictions = model.predict_proba(ds_tokenized["test"]["text"]).numpy()

In [22]:
from classification.utils import log_metrics_to_wandb

log_metrics_to_wandb(
    y_pred_proba=predictions,
    y_true_num=ds["test"]["label"],
    id2label=id2label,
    labels=class_labels.names,
    run=run,
)