# Aplicación de Few Shot con all-MiniLM-L6-v2

Instalamos e importamos dependencias

In [None]:
!pip install setfit
!pip install huggingface-hub==0.11.0

In [2]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
from huggingface_hub import notebook_login

Nos conectamos con Huggingface para subir el modelo

In [14]:
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.huggingface/token
Login successful


Cargamos el conjunto de datos de entrenamiento y validación

In [4]:
# Load dataset

data_files = {"train": "train.csv", "validation": "validation.csv"}
dataset = load_dataset("csv", data_files=data_files)

dataset



  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither'],
        num_rows: 117
    })
    validation: Dataset({
        features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither'],
        num_rows: 51
    })
})

In [5]:
labels = [label for label in dataset['train'].features.keys() if label not in ['text', 'obligation', 'right', 'neither']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['service',
 'metric',
 'objective',
 'remedy',
 'claim',
 'exception',
 'definition']

In [6]:
def encode_labels(record):
  return {"labels": [record[label] for label in labels]}

dataset = dataset.map(encode_labels)



In [7]:
train_ds = dataset["train"]
train_ds

Dataset({
    features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither', 'labels'],
    num_rows: 117
})

In [8]:
eval_ds = dataset["validation"]
eval_ds

Dataset({
    features: ['text', 'service', 'metric', 'objective', 'remedy', 'claim', 'exception', 'definition', 'obligation', 'right', 'neither', 'labels'],
    num_rows: 51
})

Descargamos el modelo a entrenar con el Framework SetFit

In [9]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
model = SetFitModel.from_pretrained(model_id, multi_target_strategy="one-vs-rest")

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Fine-tuning con multi-label SetFitModel empleando la estrategia one-vs-rest

In [10]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=32,
    num_epochs=3,
    num_iterations=80,
    column_mapping={
        "text": "text",  
        "labels": "label"
        }
)

In [11]:
trainer.train()

Applying column mapping to training dataset
***** Running training *****
  Num examples = 24640
  Num epochs = 3
  Total optimization steps = 2310
  Total train batch size = 32


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2310 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2310 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2310 [00:00<?, ?it/s]

Evaluación del modelo

In [12]:
metrics = trainer.evaluate()
metrics

Applying column mapping to evaluation dataset
***** Running evaluation *****


{'accuracy': 0.5294117647058824}

Se sube el modelo entrenado a Hugging Face

In [15]:
trainer.push_to_hub('marmolpen3/all_MiniLM_L6_v2-sla')

Cloning https://huggingface.co/marmolpen3/all_MiniLM_L6_v2-sla into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/86.7M [00:00<?, ?B/s]

Upload file model_head.pkl: 100%|##########| 24.3k/24.3k [00:00<?, ?B/s]

remote: Scanning LFS files for validity...        
remote: LFS file scan complete.        
To https://huggingface.co/marmolpen3/all_MiniLM_L6_v2-sla
   bfee124..fb6802b  main -> main

remote: LFS file scan complete.        
To https://huggingface.co/marmolpen3/all_MiniLM_L6_v2-sla
   bfee124..fb6802b  main -> main



'https://huggingface.co/marmolpen3/all_MiniLM_L6_v2-sla/commit/fb6802b809e4fcb19032ba3cee1b0ad115711e7f'

Inferencia de los datos de test para su clasificación

In [16]:
test_data = load_dataset("csv", data_files={"test": "test.csv"})
test_data



  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 27
    })
})

In [17]:
preds = model(test_data["test"][:]["text"])
preds

array([[1, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0]])

Resultados obtenidos

In [18]:
[[f for f, p in zip(labels, ps) if p] for ps in preds]

[['service'],
 ['service', 'metric', 'objective'],
 ['service'],
 ['metric', 'objective'],
 ['claim'],
 ['service'],
 ['metric', 'objective'],
 ['metric', 'objective'],
 ['metric', 'objective'],
 ['claim'],
 ['claim'],
 ['claim'],
 ['claim'],
 ['claim'],
 ['definition'],
 ['definition'],
 ['service'],
 ['service'],
 ['service'],
 [],
 ['service'],
 ['service'],
 ['service'],
 ['service'],
 ['service'],
 ['service'],
 ['service']]