### Setup the imports

In [None]:
from datasets import ClassLabel, load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset

### Load the data

In [None]:
dataset = load_dataset("csv", data_files="../data/new_punc_data_tr.csv").shuffle(seed=42)

dataset = dataset['train'].remove_columns(["Unnamed: 0", "title", "src"])

### Preprocess

In [None]:
ALGS = ['ctrl', 'fair', 'gpt', 'gpt2', 'grover', 'human', 'pplm', 'xlm', 'xlnet', 'instructgpt', 'gpt3']

In [None]:
# Mapping labels to ids
new_features = dataset.features.copy()
new_features['alg'] = ClassLabel(11, names=ALGS)
dataset = dataset.cast(new_features)

dataset = dataset.train_test_split(test_size=0.85, stratify_by_column='alg')

In [None]:
dataset['train'].features

In [None]:
# Sample a few samples for few shot training

train_dataset = sample_dataset(dataset['train'], label_column="alg", num_samples=40)
eval_dataset = dataset['test']

In [None]:
len(train_dataset)

### Load the model

In [None]:
model = SetFitModel.from_pretrained(
    'sentence-transformers/all-MiniLM-L6-v2',
    use_differentiable_head=True,
    head_params={"out_features": 11},
)

In [None]:
model.model_head

### Train the model

In [None]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20, # The number of text pairs to generate for contrastive learning
    num_epochs=1, # The number of epochs to use for contrastive learning
    column_mapping={"generation": "text", "alg": "label"} # Map dataset columns to text/label expected by trainer
)

In [None]:
# Train and evaluate!
trainer.train()
metrics = trainer.evaluate()