In [1]:
# Install required packages
!pip install pytorch-crf datasets spacy fasttext seqeval ipdb

# Import the os module to modify environment variables
import os

# Set the CUDA_LAUNCH_BLOCKING environment variable to "1"
# This variable is specific to the CUDA library used for GPU acceleration
# Setting it to "1" enables synchronization mode, causing the program to wait for GPU kernel completion before proceeding
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Collecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/486.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collecting fasttext
  Downloading fasttext-0.9.2.tar.gz (68 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.8/68.8 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ipdb
  Downloading ipdb-0.13.13-py3-none-any.whl (12 kB)
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloa

In [2]:
# Import the load_dataset function from the datasets module
from datasets import load_dataset

# Load the "conll2002" dataset with the language set to Spanish ('es')
dataset = load_dataset("conll2002", 'es')

# Print the number of examples in the training split of the dataset
print("Number of examples in the training split:", len(dataset['train']))

# Print the number of examples in the validation split of the dataset
print("Number of examples in the validation split:", len(dataset['validation']))

# Print the number of examples in the test split of the dataset
print("Number of examples in the test split:", len(dataset['test']))


Downloading builder script:   0%|          | 0.00/9.23k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/7.46k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.9k [00:00<?, ?B/s]

Downloading and preparing dataset conll2002/es to /root/.cache/huggingface/datasets/conll2002/es/1.0.0/a3a8a8612caf57271f5b35c5ae1dd25f99ddb9efb9c1667abaa70ede33e863e5...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/713k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/141k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/138k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/8324 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1916 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1518 [00:00<?, ? examples/s]

Dataset conll2002 downloaded and prepared to /root/.cache/huggingface/datasets/conll2002/es/1.0.0/a3a8a8612caf57271f5b35c5ae1dd25f99ddb9efb9c1667abaa70ede33e863e5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Number of examples in the training split: 8324
Number of examples in the validation split: 1916
Number of examples in the test split: 1518


In [3]:
# Get the sentences from the dataset
sentences = dataset["train"]["tokens"] + dataset["validation"]["tokens"] + dataset["test"]["tokens"]

# Create a unique set of tokens
unique_tokens = set(token for sentence in sentences for token in sentence)

# Calculate the number of unique tokens
num_unique_tokens = len(unique_tokens)

print("Number of tokens in the vocabulary:", num_unique_tokens)


Number of tokens in the vocabulary: 31405


In [4]:
# Example of entities in a train sample
ejemplo = dataset['train'][2]
# Detokenized sentence
' '.join(ejemplo['tokens']).replace(' ,', ',').replace(' .', '.')

'El Abogado General del Estado, Daryl Williams, subrayó hoy la necesidad de tomar medidas para proteger al sistema judicial australiano frente a una página de internet que imposibilita el cumplimiento de los principios básicos de la Ley.'

In [5]:
ner_lista = dataset["train"].features["ner_tags"].feature.names
for indice, elem in enumerate(ejemplo['ner_tags']):
  print("TOKEN: {:<15} Entity: {}".format(ejemplo['tokens'][indice], ner_lista[elem]))

TOKEN: El              Entity: O
TOKEN: Abogado         Entity: B-PER
TOKEN: General         Entity: I-PER
TOKEN: del             Entity: I-PER
TOKEN: Estado          Entity: I-PER
TOKEN: ,               Entity: O
TOKEN: Daryl           Entity: B-PER
TOKEN: Williams        Entity: I-PER
TOKEN: ,               Entity: O
TOKEN: subrayó         Entity: O
TOKEN: hoy             Entity: O
TOKEN: la              Entity: O
TOKEN: necesidad       Entity: O
TOKEN: de              Entity: O
TOKEN: tomar           Entity: O
TOKEN: medidas         Entity: O
TOKEN: para            Entity: O
TOKEN: proteger        Entity: O
TOKEN: al              Entity: O
TOKEN: sistema         Entity: O
TOKEN: judicial        Entity: O
TOKEN: australiano     Entity: O
TOKEN: frente          Entity: O
TOKEN: a               Entity: O
TOKEN: una             Entity: O
TOKEN: página          Entity: O
TOKEN: de              Entity: O
TOKEN: internet        Entity: O
TOKEN: que             Entity: O
TOKEN: imposibilita

In [6]:
# Obtain the mapping from numerical labels to named labels
id2label = dataset["train"].features["ner_tags"].feature.names

# Create a mapping from named labels to numerical labels
label2id = {v: i for i, v in enumerate(id2label)}

In [7]:
# Import required modules
from collections import Counter
from torchtext.vocab import vocab as Vocab
from collections import OrderedDict

# Initialize a counter to keep track of token frequencies
counter = Counter()

# Iterate over each dataset split (train, validation, test)
for dataset_part in ['train', 'validation', 'test']:
    # Get the tokens from the current dataset split
    textos = dataset[dataset_part]['tokens']
    # Update the counter with the tokens from the current dataset split
    for texto in textos:
        counter.update(texto)

# Define special tokens for the vocabulary
specials = ["<unk>", "<pad>", "<bos>", "<eos>"]

# Create a vocabulary object based on the token frequencies
vocab = Vocab(counter, min_freq=1, specials=["<unk>", "<pad>", "<bos>", "<eos>"])

# Get the index-to-token (itos) and token-to-index (stoi) mappings from the vocabulary
itos = vocab.get_itos()
stoi = vocab.get_stoi()

# Get the index of the special tokens in the stoi mapping
UNK_IDX = stoi["<unk>"]
PAD_IDX = stoi["<pad>"]
BOS_IDX = stoi["<bos>"]
EOS_IDX = stoi["<eos>"]

# Print the size of the vocabulary
print("Vocabulary Size:", len(vocab))

Vocabulary Size: 31409


In [8]:
def tokenize_and_format(example):
    """
    Tokenizes and formats an example.
    Arguments:
    - example: An input data example in the form of a dictionary with 'tokens' and 'ner_tags' keys.
    Returns:
    - A new dictionary with 'input_ids' and 'labels' keys containing the tokens converted to IDs and NER tags respectively.
    """
    tokens = example['tokens']
    ner_tags = example['ner_tags']

    # Add BOS token at the beginning and EOS token at the end
    tokens = ['<bos>'] + tokens + ['<eos>']
    token_ids = [stoi.get(token, UNK_IDX) for token in tokens]

    # Add 0 to the left and right of the NER tags
    ner_tags = [0] + ner_tags + [0]

    return {'input_ids': token_ids, 'labels': ner_tags}

# Apply the tokenize_and_format function to the dataset
dataset = dataset.map(tokenize_and_format, batched=False)

Map:   0%|          | 0/8324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1916 [00:00<?, ? examples/s]

Map:   0%|          | 0/1518 [00:00<?, ? examples/s]

In [9]:
# Download the FastText word vectors for Spanish language
!wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.bin.gz

# Decompress the downloaded file
!gunzip cc.es.300.bin.gz

--2023-07-04 20:46:28--  https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.bin.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.224.2.42, 13.224.2.21, 13.224.2.88, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.224.2.42|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4500107671 (4.2G) [application/octet-stream]
Saving to: ‘cc.es.300.bin.gz’


2023-07-04 20:48:26 (36.7 MB/s) - ‘cc.es.300.bin.gz’ saved [4500107671/4500107671]



In [10]:
# Import the fasttext module
import fasttext

# Load the FastText model for Spanish language
ft = fasttext.load_model('cc.es.300.bin')



In [11]:
import torch

# Calculate the dimension of the embeddings
DIM = ft["random"].shape[0]

# Create an embedding matrix of random values with dimensions (vocab_size, DIM)
emb_matrix = torch.randn(len(vocab), DIM)

# Set the embedding for the PAD_IDX to zero
emb_matrix[PAD_IDX] = 0

# Build the embedding matrix

# Load all the embeddings for our vocabulary

for i, word in enumerate(itos):
    """
    Complete the embedding matrix
    """
    if i == UNK_IDX or i == PAD_IDX or i == BOS_IDX or i == EOS_IDX:
        # Skip UNK, PAD, BOS, and EOS embeddings
        pass
    else:
        # Fill in the embedding matrix
        word_vector = ft.get_word_vector(word)
        emb_matrix[i] = torch.tensor(word_vector)

In [12]:
del ft

In [13]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    """
    Function used to collate a list of examples into a batch.
    Args:
        batch (list): List of examples, where each example is a dictionary with 'labels' and 'input_ids' keys.
    Returns:
        tuple: Tuple containing the tensors of input_ids and labels after applying padding.
    """
    # Extract the 'labels' tensors from each example in the batch
    labels = [torch.tensor(example["labels"]) for example in batch]

    # Extract the 'input_ids' tensors from each example in the batch
    input_ids = [torch.tensor(example["input_ids"]) for example in batch]

    # Pad the input_ids tensors with the PAD_IDX as the padding value
    input_ids_padded = pad_sequence(input_ids, padding_value=PAD_IDX, batch_first=True)

    # Pad the labels tensors with -100 as the padding value
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

    return input_ids_padded, labels_padded

# Create DataLoaders for the training, validation, and test data.
# The batch_size will be set to 32 or 16 depending on the case.
# collate_fn: Function to collate examples into batches, using the collate_batch function.
train_dataloader = DataLoader(dataset["train"], batch_size=32, collate_fn=collate_batch)
dev_dataloader = DataLoader(dataset["validation"], batch_size=16, collate_fn=collate_batch)
test_dataloader = DataLoader(dataset["test"], batch_size=16, collate_fn=collate_batch)

In [14]:
from tqdm.auto import tqdm
import numpy as np
import seqeval
import torch.nn.functional as F
from torch import nn
from datasets import load_metric

# Load the seqeval metric for sequence labeling evaluation
metric = load_metric("seqeval")

def validate_step(model, dataloader):
    """
    Validate step

    Calculates F1 and other metrics.
    """
    device = next(model.parameters()).device
    with torch.no_grad():
        # Disable gradient calculation
        all_labels = []
        all_preds = []
        all_losses = []
        for text, labels in tqdm(dataloader):
            text = text.to(device)
            labels = labels.to(device)
            logits = model(text)

            # Compute the cross-entropy loss
            loss = F.cross_entropy(
                logits.view(-1, 9),
                labels.view(-1),
            )
            all_losses.append(loss.detach().item())

            # Instead of softmax, directly get the maximum value
            preds = logits.argmax(-1)

            # Convert label indices to corresponding labels
            true_labels = [[id2label[l] for l in label if l != -100] for label in labels]
            # Ignore the -100 values
            true_predictions = [
                [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(preds, labels)
            ]

            all_labels += true_labels
            all_preds += true_predictions

        # Compute evaluation metrics
        metrics = metric.compute(predictions=all_preds, references=all_labels)

        metrics["loss"] = np.array(all_losses).mean()
        metrics["micro_f1"] = seqeval.metrics.sequence_labeling.f1_score(
            all_labels, all_preds, average="micro",
        )
        metrics["macro_f1"] = seqeval.metrics.sequence_labeling.f1_score(
            all_labels, all_preds, average="macro",
        )
        return metrics

def log_metrics(writer, metrics):
    """
    Log metrics to tensorboard
    """
    for k, v in metrics.items():
        if type(v) is dict:
            # Handle metrics with sub-categories (e.g., LOC, PER)
            for sub_k, sub_v in v.items():
                if sub_k == "number":
                    continue
                writer.add_scalar(f"dev/{k} {sub_k}", sub_v, global_step=step)
        else:
            writer.add_scalar(f"dev/{k}", sub_v, global_step=step)

  metric = load_metric("seqeval")


Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

In [23]:
int(512/2)

256

In [51]:
from google.protobuf.reflection import ParseMessage
import torch.nn.functional as F
from torch import nn
from datasets import load_metric

# Load the seqeval metric for sequence labeling evaluation
metric = load_metric("seqeval")

class MyNERModel(nn.Module):
    """
    Custom NER model class.
    Args:
        vocab_size (int): Vocabulary size.
        embedding_dim (int): Dimension of embeddings.
        pad_idx (int): Padding index.
        rnn_units (int): Number of units in the LSTM layer.
        num_labels (int): Number of entity labels.
        num_layers (int, optional): Number of LSTM layers. Default: 1.
        dropout (float, optional): Dropout rate. Default: 0.25.
        embedding_matrix (torch.Tensor, optional): Pre-trained embedding matrix. Default: None.
        freeze_embeddings (bool, optional): Indicator of whether to freeze embeddings during training. Default: True.
    """
    def __init__(self, vocab_size, embedding_dim, pad_idx, rnn_units, num_labels, num_layers=1,
                 dropout=0.25, embedding_matrix=None, freeze_embeddings=True):
        """
        Constructor of the MyNERModel class.
        Creates the necessary layers for the model.
        """
        super().__init__()
        # Embedding layer
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, padding_idx=pad_idx, freeze=freeze_embeddings)
        self.lstm = nn.LSTM(embedding_dim, rnn_units, num_layers, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(2*rnn_units, num_labels)

    def forward(self, text):
        """
        Performs a forward pass in the model.
        Args:
            text (torch.Tensor): Encoded text sequence.
        Returns:
            torch.Tensor: Logits produced by the model.
        """
        # Embedding layer
        embedded = self.embedding(text)
        # LSTM layer
        output, _ = self.lstm(embedded)
        # Dropout layer
        output = self.dropout(output)
        logits = self.fc(output)
        return logits


In [52]:
# CHECK!
# Obtain the number of classes for the NER tags
num_classes = dataset["train"].features["ner_tags"].feature.num_classes

# Create an instance of the MyNERModel
model = MyNERModel(
    vocab_size=len(vocab), embedding_dim=DIM, pad_idx=PAD_IDX, rnn_units=512, #512,
    embedding_matrix=emb_matrix, num_layers=16,
    freeze_embeddings=True, num_labels=num_classes, dropout=0.25
)

# Get a batch of text and labels from the training dataloader
text, labels = next(iter(train_dataloader))

# Forward pass through the model to get predictions
preds = model(text)

# Calculate the loss using cross-entropy
loss = F.cross_entropy(
    preds.view(-1, num_classes),
    labels.view(-1),
)


In [53]:
loss

tensor(2.2096, grad_fn=<NllLossBackward0>)

In [55]:
import torch
from tqdm.auto import tqdm
from pprint import pprint as pp
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

# Initialize a SummaryWriter for logging training progress
writer = SummaryWriter()

# Determine the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create an instance of the MyNERModel
model = MyNERModel(
    vocab_size=len(vocab), embedding_dim=DIM, pad_idx=PAD_IDX, rnn_units=512, embedding_matrix=emb_matrix,
    num_layers=8, freeze_embeddings=True, num_labels=num_classes,
)

# Set the number of epochs and initialize the step counter
num_epochs = 15
step = 0

# Set the learning rate for the optimizer
lr = 1e-3

# Define the optimizer (Adam) for updating model parameters
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Move the model to the appropriate device
model = model.to(device)

# Training loop
for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader):
        step += 1

        # Get the text and labels from the batch and move them to the device
        text, labels = batch
        text = text.to(device)
        labels = labels.to(device)

        # Reset the gradients
        optimizer.zero_grad()

        # Forward pass through the model to get logits
        logits = model(text)

        # Calculate the loss using cross-entropy
        loss = F.cross_entropy(logits.view(-1, num_classes), labels.view(-1))

        # Backpropagation to calculate gradients
        loss.backward()

        # Update the model parameters
        optimizer.step()

        # Calculate the gradient norm
        total_norm = sum(param.grad.detach().norm(2) ** 2 for param in model.parameters() if param.requires_grad) ** (0.5)

        # Log the loss and gradient norm to TensorBoard
        writer.add_scalar("train/loss", loss, global_step=step)
        writer.add_scalar("train/gradient_norm", total_norm, global_step=step)

    # Perform validation on the development set
    metrics = validate_step(model, dev_dataloader)
    pp(metrics)

    # Log the metrics to TensorBoard
    log_metrics(writer, metrics)


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


{'LOC': {'f1': 0.0264783759929391,
         'number': 985,
         'precision': 0.10135135135135136,
         'recall': 0.015228426395939087},
 'MISC': {'f1': 0.0, 'number': 445, 'precision': 0.0, 'recall': 0.0},
 'ORG': {'f1': 0.35629224332153203,
         'number': 1700,
         'precision': 0.24523704031900753,
         'recall': 0.6511764705882352},
 'PER': {'f1': 0.12365376944555247,
         'number': 1222,
         'precision': 0.12062256809338522,
         'recall': 0.12684124386252046},
 'loss': 0.33180166855454446,
 'macro_f1': 0.1266060971900059,
 'micro_f1': 0.24798524128556174,
 'overall_accuracy': 0.8899656417936745,
 'overall_f1': 0.24798524128556174,
 'overall_precision': 0.2147301160248865,
 'overall_recall': 0.29342830882352944}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.2751637879690291,
         'number': 985,
         'precision': 0.33285302593659943,
         'recall': 0.23451776649746192},
 'MISC': {'f1': 0.0, 'number': 445, 'precision': 0.0, 'recall': 0.0},
 'ORG': {'f1': 0.45919717688575207,
         'number': 1700,
         'precision': 0.36732533521524346,
         'recall': 0.6123529411764705},
 'PER': {'f1': 0.48210023866348456,
         'number': 1222,
         'precision': 0.5784650630011455,
         'recall': 0.4132569558101473},
 'loss': 0.2548438721646865,
 'macro_f1': 0.30411530087956645,
 'micro_f1': 0.4058004110527517,
 'overall_accuracy': 0.9159369218571051,
 'overall_f1': 0.4058004110527517,
 'overall_precision': 0.40331366318656375,
 'overall_recall': 0.40831801470588236}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.2487046632124352,
         'number': 985,
         'precision': 0.2872340425531915,
         'recall': 0.21928934010152284},
 'MISC': {'f1': 0.01675041876046901,
          'number': 445,
          'precision': 0.03289473684210526,
          'recall': 0.011235955056179775},
 'ORG': {'f1': 0.5798974816135503,
         'number': 1700,
         'precision': 0.4668101901686401,
         'recall': 0.7652941176470588},
 'PER': {'f1': 0.6319663512092534,
         'number': 1222,
         'precision': 0.8838235294117647,
         'recall': 0.4918166939443535},
 'loss': 0.19867536413172882,
 'macro_f1': 0.369329728698927,
 'micro_f1': 0.4867591424968474,
 'overall_accuracy': 0.9344198749008897,
 'overall_f1': 0.4867591424968474,
 'overall_precision': 0.48570121253717685,
 'overall_recall': 0.48782169117647056}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.41713062098501075,
         'number': 985,
         'precision': 0.36074074074074075,
         'recall': 0.49441624365482234},
 'MISC': {'f1': 0.018691588785046728,
          'number': 445,
          'precision': 0.030456852791878174,
          'recall': 0.01348314606741573},
 'ORG': {'f1': 0.6116529132283071,
         'number': 1700,
         'precision': 0.531970421922575,
         'recall': 0.7194117647058823},
 'PER': {'f1': 0.7112608277189605,
         'number': 1222,
         'precision': 0.8633177570093458,
         'recall': 0.6047463175122749},
 'loss': 0.17717733945076664,
 'macro_f1': 0.43968398767933126,
 'micro_f1': 0.5423017450850453,
 'overall_accuracy': 0.9413267553519513,
 'overall_f1': 0.5423017450850453,
 'overall_precision': 0.5221182475542322,
 'overall_recall': 0.5641084558823529}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.6182152713891444,
         'number': 985,
         'precision': 0.5651808242220353,
         'recall': 0.682233502538071},
 'MISC': {'f1': 0.22880215343203228,
          'number': 445,
          'precision': 0.28523489932885904,
          'recall': 0.19101123595505617},
 'ORG': {'f1': 0.6874139355549436,
         'number': 1700,
         'precision': 0.6462972553081305,
         'recall': 0.7341176470588235},
 'PER': {'f1': 0.8530140379851362,
         'number': 1222,
         'precision': 0.8608333333333333,
         'recall': 0.8453355155482815},
 'loss': 0.15257231445672612,
 'macro_f1': 0.5968613495903141,
 'micro_f1': 0.6773690078037904,
 'overall_accuracy': 0.957219628226588,
 'overall_f1': 0.6773690078037904,
 'overall_precision': 0.6578605456907752,
 'overall_recall': 0.6980698529411765}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.5928118393234673,
         'number': 985,
         'precision': 0.5079710144927536,
         'recall': 0.7116751269035533},
 'MISC': {'f1': 0.3088803088803089,
          'number': 445,
          'precision': 0.3614457831325301,
          'recall': 0.2696629213483146},
 'ORG': {'f1': 0.6821211243117937,
         'number': 1700,
         'precision': 0.6721873215305539,
         'recall': 0.6923529411764706},
 'PER': {'f1': 0.7651737128505651,
         'number': 1222,
         'precision': 0.7832047986289632,
         'recall': 0.7479541734860884},
 'loss': 0.17078891817169886,
 'macro_f1': 0.5872467463415337,
 'micro_f1': 0.6484079269650411,
 'overall_accuracy': 0.9524447185270021,
 'overall_f1': 0.6484079269650411,
 'overall_precision': 0.6289416846652268,
 'overall_recall': 0.6691176470588235}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7157991708889914,
         'number': 985,
         'precision': 0.6551433389544689,
         'recall': 0.7888324873096447},
 'MISC': {'f1': 0.4014598540145985,
          'number': 445,
          'precision': 0.4376657824933687,
          'recall': 0.3707865168539326},
 'ORG': {'f1': 0.7231467473524963,
         'number': 1700,
         'precision': 0.7445482866043613,
         'recall': 0.7029411764705882},
 'PER': {'f1': 0.8889757623143081,
         'number': 1222,
         'precision': 0.8510479041916168,
         'recall': 0.9304418985270049},
 'loss': 0.13592693464597688,
 'macro_f1': 0.6823453836425986,
 'micro_f1': 0.7393857271906052,
 'overall_accuracy': 0.9644260417584354,
 'overall_f1': 0.7393857271906052,
 'overall_precision': 0.7269094138543517,
 'overall_recall': 0.7522977941176471}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7330049261083743,
         'number': 985,
         'precision': 0.7119617224880382,
         'recall': 0.7553299492385787},
 'MISC': {'f1': 0.44540540540540546,
          'number': 445,
          'precision': 0.42916666666666664,
          'recall': 0.46292134831460674},
 'ORG': {'f1': 0.7480564353584797,
         'number': 1700,
         'precision': 0.7326565143824028,
         'recall': 0.7641176470588236},
 'PER': {'f1': 0.9090909090909091,
         'number': 1222,
         'precision': 0.9098360655737705,
         'recall': 0.9083469721767594},
 'loss': 0.12727322949504014,
 'macro_f1': 0.7088894189907922,
 'micro_f1': 0.7573844419391207,
 'overall_accuracy': 0.9661351422782134,
 'overall_f1': 0.7573844419391207,
 'overall_precision': 0.7434705621956618,
 'overall_recall': 0.7718290441176471}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.6782682512733448,
         'number': 985,
         'precision': 0.5827862873814734,
         'recall': 0.8111675126903554},
 'MISC': {'f1': 0.40360766629086814,
          'number': 445,
          'precision': 0.40497737556561086,
          'recall': 0.40224719101123596},
 'ORG': {'f1': 0.6737201365187714,
         'number': 1700,
         'precision': 0.802439024390244,
         'recall': 0.5805882352941176},
 'PER': {'f1': 0.8610596517228603,
         'number': 1222,
         'precision': 0.7867298578199052,
         'recall': 0.9509001636661211},
 'loss': 0.1513740816677455,
 'macro_f1': 0.6541639264514612,
 'micro_f1': 0.7049143372407575,
 'overall_accuracy': 0.9596158928728746,
 'overall_f1': 0.7049143372407575,
 'overall_precision': 0.6918141592920354,
 'overall_recall': 0.7185202205882353}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7225225225225225,
         'number': 985,
         'precision': 0.6493927125506073,
         'recall': 0.8142131979695432},
 'MISC': {'f1': 0.5141129032258064,
          'number': 445,
          'precision': 0.46617915904936014,
          'recall': 0.5730337078651685},
 'ORG': {'f1': 0.749531542785759,
         'number': 1700,
         'precision': 0.7989347536617842,
         'recall': 0.7058823529411765},
 'PER': {'f1': 0.8731686898283801,
         'number': 1222,
         'precision': 0.8937446443873179,
         'recall': 0.853518821603928},
 'loss': 0.12867556020307044,
 'macro_f1': 0.714833914590617,
 'micro_f1': 0.7497444053163693,
 'overall_accuracy': 0.9658356091974275,
 'overall_f1': 0.7497444053163693,
 'overall_precision': 0.7414064255223545,
 'overall_recall': 0.7582720588235294}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7002893757751136,
         'number': 985,
         'precision': 0.5906555090655509,
         'recall': 0.8598984771573605},
 'MISC': {'f1': 0.3687374749498998,
          'number': 445,
          'precision': 0.33273056057866185,
          'recall': 0.4134831460674157},
 'ORG': {'f1': 0.7155482815057282,
         'number': 1700,
         'precision': 0.8066420664206642,
         'recall': 0.6429411764705882},
 'PER': {'f1': 0.8949698189134809,
         'number': 1222,
         'precision': 0.880443388756928,
         'recall': 0.9099836333878887},
 'loss': 0.14027466926296864,
 'macro_f1': 0.6698862377860557,
 'micro_f1': 0.7225633582672769,
 'overall_accuracy': 0.9615716676944762,
 'overall_f1': 0.7225633582672769,
 'overall_precision': 0.70271444082519,
 'overall_recall': 0.7435661764705882}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7108333333333334,
         'number': 985,
         'precision': 0.602826855123675,
         'recall': 0.8659898477157361},
 'MISC': {'f1': 0.44147157190635455,
          'number': 445,
          'precision': 0.43805309734513276,
          'recall': 0.4449438202247191},
 'ORG': {'f1': 0.7245053272450532,
         'number': 1700,
         'precision': 0.750788643533123,
         'recall': 0.7},
 'PER': {'f1': 0.9055214723926381,
         'number': 1222,
         'precision': 0.9051512673753066,
         'recall': 0.9058919803600655},
 'loss': 0.1411089775278621,
 'macro_f1': 0.6955829262193449,
 'micro_f1': 0.7417746759720838,
 'overall_accuracy': 0.9640560303057,
 'overall_f1': 0.7417746759720838,
 'overall_precision': 0.7161497326203209,
 'overall_recall': 0.7693014705882353}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7713603818615752,
         'number': 985,
         'precision': 0.7279279279279279,
         'recall': 0.8203045685279188},
 'MISC': {'f1': 0.499473129610116,
          'number': 445,
          'precision': 0.47023809523809523,
          'recall': 0.5325842696629214},
 'ORG': {'f1': 0.7818349186867137,
         'number': 1700,
         'precision': 0.8171905067350866,
         'recall': 0.7494117647058823},
 'PER': {'f1': 0.8647834274952918,
         'number': 1222,
         'precision': 0.8011165387299372,
         'recall': 0.939443535188216},
 'loss': 0.12392753655246148,
 'macro_f1': 0.7293629644134241,
 'micro_f1': 0.7740567090868496,
 'overall_accuracy': 0.9675623293101929,
 'overall_f1': 0.7740567090868496,
 'overall_precision': 0.752713851498046,
 'overall_recall': 0.7966452205882353}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7686496694995278,
         'number': 985,
         'precision': 0.7184466019417476,
         'recall': 0.8263959390862944},
 'MISC': {'f1': 0.4888888888888889,
          'number': 445,
          'precision': 0.44403669724770645,
          'recall': 0.5438202247191011},
 'ORG': {'f1': 0.7725030826140568,
         'number': 1700,
         'precision': 0.8115284974093264,
         'recall': 0.7370588235294118},
 'PER': {'f1': 0.9007453903491566,
         'number': 1222,
         'precision': 0.8651092690278824,
         'recall': 0.939443535188216},
 'loss': 0.12522380525479093,
 'macro_f1': 0.7326967578379076,
 'micro_f1': 0.7767666554319738,
 'overall_accuracy': 0.9687076028543741,
 'overall_f1': 0.7767666554319738,
 'overall_precision': 0.7599472411519015,
 'overall_recall': 0.7943474264705882}


  0%|          | 0/261 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?it/s]

{'LOC': {'f1': 0.7285164600256522,
         'number': 985,
         'precision': 0.6292466765140325,
         'recall': 0.86497461928934},
 'MISC': {'f1': 0.5705128205128205,
          'number': 445,
          'precision': 0.5437881873727087,
          'recall': 0.6},
 'ORG': {'f1': 0.773208722741433,
         'number': 1700,
         'precision': 0.8218543046357616,
         'recall': 0.73},
 'PER': {'f1': 0.887149917627677,
         'number': 1222,
         'precision': 0.8930348258706468,
         'recall': 0.881342062193126},
 'loss': 0.12144565283476065,
 'macro_f1': 0.7398469802268957,
 'micro_f1': 0.7712330304050264,
 'overall_accuracy': 0.9684080697735882,
 'overall_f1': 0.7712330304050264,
 'overall_precision': 0.7535628151721114,
 'overall_recall': 0.7897518382352942}


In [56]:
validate_step(model, test_dataloader)

  0%|          | 0/95 [00:00<?, ?it/s]

{'LOC': {'precision': 0.7485620377978636,
  'recall': 0.8404059040590406,
  'f1': 0.7918296392872662,
  'number': 1084},
 'MISC': {'precision': 0.5306666666666666,
  'recall': 0.5852941176470589,
  'f1': 0.5566433566433567,
  'number': 340},
 'ORG': {'precision': 0.7982708933717579,
  'recall': 0.7914285714285715,
  'f1': 0.7948350071736012,
  'number': 1400},
 'PER': {'precision': 0.9144827586206896,
  'recall': 0.9020408163265307,
  'f1': 0.9082191780821918,
  'number': 735},
 'overall_precision': 0.7775978407557355,
 'overall_recall': 0.8094970497330711,
 'overall_f1': 0.793226872246696,
 'overall_accuracy': 0.9774963807289854,
 'loss': 0.0849027257314638,
 'micro_f1': 0.793226872246696,
 'macro_f1': 0.762881795296604}

In [30]:
from spacy.lang.es import Spanish
import spacy
nlp = Spanish()
tokenizer = nlp.tokenizer

def simplify_entities(entities):
    """
    Simplifies the identified entities by combining consecutive elements of a phrase.

    Parameters:
        entities (list): A list of dictionaries representing the identified entities.
                         Each dictionary should contain two keys: 'entidad' (entity type) and 'texto' (entity text).
    Returns:
        list: A list of dictionaries representing the simplified entities.
              Each dictionary contains two keys: 'entidad' (entity type) and 'texto' (simplified entity text).
    Example:
        entities = [{'entidad': 'B-PER', 'texto': 'Juan'},
                    {'entidad': 'I-PER', 'texto': 'Manuel'},
                    {'entidad': 'I-PER', 'texto': 'Pérez'},
                    {'entidad': 'B-ORG', 'texto': 'Universidad'},
                    {'entidad': 'I-ORG', 'texto': 'de'},
                    {'entidad': 'I-ORG', 'texto': 'San'},
                    {'entidad': 'I-ORG', 'texto': 'Andrés'}]

        simplified_entities = simplify_entities(entities)
        print(simplified_entities)
        Output: [{'entidad': 'PER', 'texto': 'Juan Manuel Pérez'},
                 {'entidad': 'ORG', 'texto': 'Universidad de San Andrés'}]
    """
    simplified_entities = []
    current_entity = None
    current_text = ""

    for entity in entities:
        if entity["entidad"].startswith("B-"):
            if current_entity is not None:
                simplified_entities.append({"entidad": current_entity[2:], "texto": current_text})
            current_entity = entity["entidad"]
            current_text = str(entity["texto"])
        elif entity["entidad"].startswith("I-"):
            if current_entity is not None:
                current_text += " " + str(entity["texto"])

    if current_entity is not None:
        simplified_entities.append({"entidad": current_entity[2:], "texto": current_text})

    return simplified_entities

def identificar_entidades(model, tokenizer, text):
    """
    Identifies named entities in the given text using a trained model.

    Parameters:
        model (torch.nn.Module): The trained model for named entity recognition.
        tokenizer: The tokenizer object used to tokenize the input text.
        text (str): The input text from which entities are to be identified.

    Returns:
        list: A list of dictionaries representing the identified entities.
              Each dictionary contains two keys: 'entidad' (entity type) and 'texto' (entity text).
    """
    # Tokenize text and obtain tokens and token IDs
    tokens = tokenizer(text)
    token_ids = [stoi.get(token.text, stoi["<unk>"]) for token in tokens]
    input_ids = torch.tensor(token_ids).unsqueeze(0).to(next(model.parameters()).device)
    output = model(input_ids)
    predicted_labels = output.argmax(dim=-1)[0]
    labels = [id2label[label_id.item()] for label_id in predicted_labels]
    entities = [{"entidad": label, "texto": token} for token, label in zip(tokens,labels) if label != 'O']
    return simplify_entities(entities)

In [57]:
text = 'Las denominaciones adoptadas sucesivamente desde 1810 hasta el presente, a saber: Provincias Unidas del Río de la \
Plata, República Argentina, Confederación Argentina, serán en adelante nombres oficiales indistintamente para la \
designación del Gobierno y territorio de las provincias, empleándose las palabras Nación Argentina en la formación \
y sanción de las leyes.'
identificar_entidades(model, tokenizer, text)

[{'entidad': 'LOC', 'texto': 'Las'},
 {'entidad': 'LOC', 'texto': 'Provincias Unidas del Río de la Plata'},
 {'entidad': 'LOC', 'texto': 'República Argentina'},
 {'entidad': 'ORG', 'texto': 'Confederación Argentina'},
 {'entidad': 'ORG', 'texto': 'Gobierno'},
 {'entidad': 'LOC', 'texto': 'Nación Argentina'}]

In [45]:
identificar_entidades(model, tokenizer, "Juan Manuel Pérez es el profesor de NLP de la Universidad de San Andrés")

[{'entidad': 'PER', 'texto': 'Juan Manuel Pérez'},
 {'entidad': 'LOC', 'texto': 'Universidad de San Andrés'}]

In [60]:
text = "El sitio Aires de los Lagos es un lugar mágico, fui con Juan Pérez"
identificar_entidades(model, tokenizer, text)

[{'entidad': 'LOC', 'texto': 'El sitio Aires de'},
 {'entidad': 'LOC', 'texto': 'Lagos'},
 {'entidad': 'PER', 'texto': 'Juan Pérez'}]

In [59]:
text = "Laura Romano es una nutricionista muy famosa que vive en Buenos Aires."
identificar_entidades(model, tokenizer, text)

[{'entidad': 'PER', 'texto': 'Laura Romano'},
 {'entidad': 'LOC', 'texto': 'Buenos Aires'}]