# Aplicación de Few Shot con paraphrase-mpnet-base-v2

Instalamos e importamos dependencias

In [None]:
!pip install setfit

In [15]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer

Cargamos el conjunto de datos de entrenamiento y validación

In [None]:
# Load dataset

data_files = {"train": "train.csv", "validation": "validation.csv"}
dataset = load_dataset("csv", data_files=data_files)

dataset

In [4]:
labels = [label for label in dataset['train'].features.keys() if label not in ['text', 'obligation', 'right', 'neither']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['service',
 'metric',
 'objective',
 'remedy',
 'claim',
 'exception',
 'definition']

In [5]:
def encode_labels(record):
  return {"labels": [record[label] for label in labels]}

dataset = dataset.map(encode_labels)

  0%|          | 0/117 [00:00<?, ?ex/s]

  0%|          | 0/51 [00:00<?, ?ex/s]

In [6]:
train_ds = dataset["train"]
train_ds

Dataset({
    features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither', 'labels'],
    num_rows: 117
})

In [7]:
eval_ds = dataset["validation"]
eval_ds

Dataset({
    features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither', 'labels'],
    num_rows: 51
})

Descargamos el modelo a entrenar con el Framework SetFit

In [8]:
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id, multi_target_strategy="one-vs-rest")

Downloading (…)lve/main/config.json:   0%|          | 0.00/594 [00:00<?, ?B/s]

Downloading (…)f39ef/.gitattributes:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)0182ff39ef/README.md:   0%|          | 0.00/3.70k [00:00<?, ?B/s]

Downloading (…)82ff39ef/config.json:   0%|          | 0.00/594 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)f39ef/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Downloading (…)0182ff39ef/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)2ff39ef/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Fine-tuning con multi-label SetFitModel empleando la estratégia one-vs-rest

In [9]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=12,
    num_epochs=3,
    num_iterations=50,
    learning_rate=2e-5,
    column_mapping={
        "text": "text",  
        "labels": "label"
        }
)

In [10]:
trainer.train()

Applying column mapping to training dataset
***** Running training *****
  Num examples = 15400
  Num epochs = 3
  Total optimization steps = 3852
  Total train batch size = 12


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3852 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3852 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3852 [00:00<?, ?it/s]

Evaluación del modelo

In [11]:
metrics = trainer.evaluate()
metrics

Applying column mapping to evaluation dataset
***** Running evaluation *****


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

{'accuracy': 0.43137254901960786}

Inferencia de los datos de test para su clasificación

In [18]:
test_data = load_dataset("csv", data_files={"test": "test.csv"})

test_data



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-70ede6224e3584f0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-70ede6224e3584f0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 27
    })
})

In [19]:
preds = model(test_data["test"][:]["text"])
preds

array([[1, 1, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [1, 0, 0, 0, 1, 0, 0],
       [0, 1, 1, 1, 0, 0, 0],
       [0, 1, 1, 1, 0, 0, 0],
       [0, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 1],
       [1, 0, 0, 0, 1, 1, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 1, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0]])

Resultados obtenidos

In [20]:
[[f for f, p in zip(labels, ps) if p] for ps in preds]

[['service', 'metric'],
 ['service', 'metric'],
 ['service', 'metric'],
 ['service', 'metric', 'objective'],
 ['claim'],
 ['service', 'claim'],
 ['metric', 'objective', 'remedy'],
 ['metric', 'objective', 'remedy'],
 ['metric', 'objective', 'remedy'],
 ['claim'],
 ['claim'],
 ['claim', 'definition'],
 ['service', 'claim', 'exception'],
 ['claim'],
 ['definition'],
 ['definition'],
 ['service'],
 ['service'],
 ['service'],
 ['service', 'claim'],
 ['service'],
 ['service'],
 ['service', 'definition'],
 ['service'],
 ['service', 'metric'],
 ['service'],
 ['service']]