En este notebook se muestra un ejemplo de los resultados que se obtienen si tratamos de clasificar sentencias en las categorías Obligación, Derecho o Ninguna con un modelo entrenado con pocos datos, aplicando few shot.

In [None]:
!pip install setfit
!pip install huggingface-hub==0.11.0

In [36]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset
from huggingface_hub import notebook_login

Se realiza la conexión con Hugging Face para subir el modelo entrenado.

In [60]:
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.huggingface/token
Login successful


Cargamos los datos de entrenamiento y validación. Para el entenamiento se dispone de 8 ejemplo por categoría.

In [38]:
data_files = {"train": "train.csv", "validation": "validation.csv"}
dataset = load_dataset("csv", data_files=data_files)

dataset

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-215fdcedc601983f/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-215fdcedc601983f/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'type'],
        num_rows: 300
    })
    validation: Dataset({
        features: ['text', 'label', 'type'],
        num_rows: 301
    })
})

In [40]:
eval_dataset = dataset["validation"]
train_dataset = dataset["train"]

Descargamos el modelo a entrenar con el Framework SetFit.

In [41]:
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Fine-tuning con SetFitModel

In [45]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=eval_dataset,
    eval_dataset=train_dataset,
    loss_class=CosineSimilarityLoss,
    num_epochs=1,
    num_iterations=5,
    learning_rate=2e-5,
    column_mapping={"text": "text", "label": "label"},
)

In [46]:
trainer.train()

Applying column mapping to training dataset


Generating Training Pairs:   0%|          | 0/5 [00:00<?, ?it/s]

***** Running training *****
  Num examples = 3010
  Num epochs = 1
  Total optimization steps = 189
  Total train batch size = 16


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/189 [00:00<?, ?it/s]

Evaluación del modelo.

In [47]:
metrics = trainer.evaluate()
metrics

Applying column mapping to evaluation dataset
***** Running evaluation *****


{'accuracy': 0.96}

Subir modelo entrenado al repositorio de Hugging Face

In [None]:
trainer.push_to_hub('marmolpen3/sla-obligations-rights')

Se pueden inferir datos de test para su clasificación de la siguiente manera:

In [48]:
data_file = {"test": "test.csv"}
test_data = load_dataset("csv", data_files=data_file)
test_data

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-0127b7a797680cae/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-0127b7a797680cae/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['type', 'text', 'label'],
        num_rows: 580
    })
})

In [49]:
preds = model(test_data["test"][:]["text"])
preds

tensor([2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2,
        2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1,
        2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2,


# Resultados

In [51]:
def visualize_results(preds):
    label_names = {
        2: "Neither",
        0: "Obligation",
        1: "Right"
    }

    preds = preds.tolist()  # Convert tensor to a regular list
    results = [label_names[ps] for ps in preds]

    print(results)


In [52]:
visualize_results(preds)

['Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Neither', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Obligation', 'Obligation', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Obligation', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Obligation', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Right', 'Righ

In [57]:
import torch

def visualize_results(preds, true_labels):
    label_names = {
        2: "Neither",
        0: "Obligation",
        1: "Right"
    }

    preds = preds.tolist()  # Convert tensor to a regular list
    results = [label_names[ps] for ps in preds]

    total_errors = 0
    for i, result in enumerate(results):
        print("Predicted label:", result)
        print("True label:", label_names[true_labels[i]])
        print()
        if result != label_names[true_labels[i]]:
          print(test_data["test"][:]["text"][i])
          total_errors = total_errors + 1
    print(total_errors)
    print(len(preds))

In [58]:
true_labels = test_data["test"][:]["label"]

In [59]:
visualize_results(preds, true_labels)

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Neither
True label: Right

. Additionally, You are encouraged to develop a business continuity plan to ensure continuity of Your own operations in the event of a disaster.
Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted label: Right
True label: Right

Predicted lab