<a href="https://colab.research.google.com/github/mmubeen-6/Entity-Relation-Model/blob/master/Entity_Relation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This is a simple repo to create an entity relation model and train it to predict the relation between two entities in a sentence. The model is trained on the [kbp37_formatted](https://huggingface.co/datasets/DFKI-SLT/kbp37) dataset.

An entity relation model is a model that takes two entities and a sentence as input and predicts the relation between the two entities in the sentence.

For example, given the sentence "The company Apple was founded by Steve Jobs", the model should predict that the relation between "Apple" and "Steve Jobs" is "founders".

# Setup

Repo setup:

In [1]:
!rm -r /content/Entity-Relation-Model
!git clone https://github.com/mmubeen-6/Entity-Relation-Model.git

rm: cannot remove '/content/Entity-Relation-Model': No such file or directory
Cloning into 'Entity-Relation-Model'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 15 (delta 4), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (15/15), 11.58 KiB | 11.58 MiB/s, done.
Resolving deltas: 100% (4/4), done.


In [2]:
import sys
sys.path.insert(0,'/content/Entity-Relation-Model')

Setting up imports

In [3]:
!pip3 install -q -r /content/Entity-Relation-Model/requirements.txt

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[?25h

#### Load relevant packages for the project

In [23]:
import os
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import (
    classification_report,
    precision_recall_fscore_support,
)
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from dataset_loader import EntityRelationDataset, get_dataset, get_tokenizer
from models import EntityRelationModel, get_base_model

Setting up a config to load all parameters

In [5]:
# contains some of the hyperparameters
configs = {
    # batch size is another hyperparameter that you can tune
    "train_batch_size": 128,
    "eval_batch_size": 32,
    # Here we select a base model for the project.
    # The base model is a pretrained model that we will use to train our model.
    # We will use the [bert-base-uncased](https://huggingface.co/bert-base-uncased) model.
    # Other option that we can use is the
    # [bert-large-uncased](https://huggingface.co/bert-large-uncased) model.
    # This model is a larger version of the bert-base-uncased model and takes more time to train.
    "base_model": "bert-base-uncased",
    # select a device to run the model on
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    # number of epochs to train the model for
    "num_epochs": 10,
    # learning rate is another hyperparameter that you can tune
    "learning_rate": 1e-4,
    # the path to save the model
    "model_path": "./model_ckpt/",
    # dataset name, this is used to load the dataset
    "dataset_name": "kbp37_formatted", # only support kbp37_formatted for now
    # number of classes in the dataset, this is used to construct the model
    # this is dataset specific, so you need to change this if you use a different dataset
    # for kbp37_formatted, there are 37 classes
    "num_classes": 37,
    # lets define a max length for the input tokens
    # this is a hyperparameter that you can tune
    # it appears that the max length of the input tokens is around 180ish in the dataset
    # so we will set the max length to 200 to be safe
    # input tokens will be truncated to this length
    # and if they are shorter than this length, they will be padded
    "seq_max_len": 200,
    # logging frequency for the training loop
    "log_interval": 50,
}

#### Select a base Model for the project
Here we select a base model for the project. The base model is a pretrained model that we will use to train our model. We will use the [bert-base-uncased](https://huggingface.co/bert-base-uncased) model.

Other option that we can use is the [bert-large-uncased](https://huggingface.co/bert-large-uncased) model. This model is a larger version of the bert-base-uncased model and takes more time to train.

In [6]:
model_name = configs["base_model"]
num_classes = configs["num_classes"]

In order to train our model, we need to add a classification head to the base model. The classification head is a layer that takes the output of the base model and outputs the relation between the two entities in the sentence.

But in order for the classification head to work, we add special tokens around the two entities in the sentence. In this case, we add the tokens [E1] and [E2] around the two entities.

For example, given the sentence "The company Apple was founded by Steve Jobs", and the two entities "Apple" and "Steve Jobs", we add the tokens [E1] and [E2] around the two entities to get the sentence. The updated sentence would be: "The company [E1] Apple [/E1] was founded by [E2] Steve Jobs [/E2]".

In [7]:
e1_start_token = '[E1]'
e1_end_token = '[/E1]'
e2_start_token = '[E2]'
e2_end_token = '[/E2]'
new_tokens = [e1_start_token, e1_end_token, e2_start_token, e2_end_token]

Now we load the tokenizer for the model. The tokenizer is used to tokenize the sentence and add the special tokens around the two entities. But we also need to add the special tokens to the tokenizer so that it knows what the special tokens are.

In [8]:
tokenizer = get_tokenizer(model_name)
num_added_toks = tokenizer.add_tokens(new_tokens)
print(f"Added {num_added_toks} new tokens")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Added 4 new tokens


Now based on the selected model, we load the base model first. Then we update the tokenizer with the special tokens. Then we add the classification head to the base model.

In [9]:
model = get_base_model(model_name)
model.resize_token_embeddings(len(tokenizer))
model = EntityRelationModel(base_model=model, num_classes=num_classes)
model = model.to(configs["device"])

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Now we load the dataset. The dataset is a dataset of sentences with two entities and the relation between the two entities. The loaded dataset from hugging face contains train, validation and test subsets. And we seperate them out

In [10]:
# Load the KBP37 dataset
dataset = get_dataset(configs["dataset_name"])
train_data_raw = dataset['train']
test_data_raw = dataset['test']
val_data_raw = dataset['validation']

Downloading data:   0%|          | 0.00/2.17M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/243k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/478k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15807 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1714 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3379 [00:00<?, ? examples/s]

In [11]:
# Now we will create the pytorch dataset objects
train_data = EntityRelationDataset(train_data_raw, tokenizer, max_length=configs["seq_max_len"])
val_data = EntityRelationDataset(val_data_raw, tokenizer, max_length=configs["seq_max_len"])
test_data = EntityRelationDataset(test_data_raw, tokenizer, max_length=configs["seq_max_len"])

In [12]:
# Now we will create the pytorch dataloader objects
train_data_loader = DataLoader(
    train_data,
    batch_size=configs["train_batch_size"],
    shuffle=True
)
val_data_loader = DataLoader(
    val_data,
    batch_size=configs["eval_batch_size"],
    shuffle=False
)

test_data_loader = DataLoader(
    test_data,
    batch_size=configs["eval_batch_size"],
    shuffle=False
)

In [13]:
# lets define loss function and optimizer and scheduler

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=configs["learning_rate"])

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Define the learning rate scheduler
# This will be used to adjust the learning rate during training
scheduler = StepLR(optimizer, step_size=np.floor(configs["num_epochs"] / 2), gamma=0.1)

In [14]:
def train_model(
    model: nn.Module,
    data_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
    scheduler,
    device: torch.device,
    epoch: int,
    log_interval: int = 10,
):
    """Train the model for one epoch.

    Args:
        model (nn.Module): Model to be trained.
        data_loader (DataLoader): DataLoader object.
        optimizer (torch.optim.Optimizer): Optimizer object.
        loss_fn (nn.Module): Loss function.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        device (torch.device): Device to be used.
        epoch (int): Epoch number.
        log_interval (int, optional): Number of batches after which the loss
                is logged. Defaults to 10.

    Returns:
        None
    """
    model.train()  # Set the model to training mode
    total_loss = 0

    for batch_idx, batch in enumerate(data_loader):
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % log_interval == 0:
            print(
                f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(data_loader)}, "
                f"Loss: {loss.item():.3f}"
            )

    avg_train_loss = total_loss / len(data_loader)

    scheduler.step()  # Update the learning rate after each epoch

    print(
        f"End of Epoch {epoch+1}, Average Training Loss: {avg_train_loss:.3f}"
    )

    return avg_train_loss


In [15]:
def validate_model(
    model: nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    loss_fn: Optional[nn.Module] = None,
    phase: str = "Validation",
    verbose: bool=False
):
    """Validate the model.

    Args:
        model (nn.Module): Model to be validated.
        data_loader (DataLoader): DataLoader object.
        loss_fn (nn.Module): Loss function.
        device (torch.device): Device to be used.
        verbose (bool, optional): Whether to print the classification report.
                Defaults to False.
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    all_predictions = []
    all_labels = []
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask)

            if loss_fn is not None:
                loss = loss_fn(outputs, labels)
                total_loss += loss.item()

            _, predicted_labels = torch.max(outputs, dim=1)
            all_predictions.extend(predicted_labels.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            correct_predictions += (predicted_labels == labels).sum().item()
            total_predictions += labels.size(0)

    avg_loss = total_loss / len(data_loader) if loss_fn is not None else None

    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average="weighted"
    )
    accuracy = correct_predictions / total_predictions

    if verbose:
        print(classification_report(all_labels, all_predictions))

    # Conditionally format the output string based on whether loss is computed
    if avg_loss is not None:
        print(
            f"{phase} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}"
        )
    else:
        print(
            f"{phase} - Acc: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}"
        )

    return avg_loss, precision, recall, f1, accuracy

Start training the model

In [17]:
# lets train our model now
for epoch in range(configs["num_epochs"]):
    avg_train_loss = train_model(
        model=model,
        data_loader=train_data_loader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        scheduler=scheduler,
        device=configs["device"],
        epoch=epoch,
        log_interval=configs["log_interval"],
    )
    val_loss, val_precision, val_recall, val_f1, val_acc = validate_model(
        model=model,
        data_loader=val_data_loader,
        loss_fn=loss_fn,
        device=configs["device"],
        phase="Validation",
        verbose=True
    )

    print(f"End of Epoch {epoch+1}/{configs['num_epochs']}")
    print(f"Average Training Loss: {avg_train_loss}")
    print(
        f"Validation Loss: {val_loss:.2f}, Validation Acc: {(val_acc * 100):.2f}, "
        f"Precision: {val_precision:.2f}, Recall: {val_recall:.2f}, F1-Score: {val_f1:.2f}"
    )




Epoch 1, Batch 50/124, Loss: 2.813
Epoch 1, Batch 100/124, Loss: 2.233
End of Epoch 1, Average Training Loss: 2.740


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.35      0.16      0.22       208
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        27
           3       0.41      0.94      0.57       117
           4       0.00      0.00      0.00        38
           5       0.56      0.86      0.67        69
           6       0.78      0.38      0.51        48
           7       0.00      0.00      0.00        34
           8       0.00      0.00      0.00        19
           9       0.00      0.00      0.00        20
          10       0.00      0.00      0.00        14
          11       0.00      0.00      0.00        28
          12       0.31      0.63      0.41        54
          13       0.36      0.08      0.13        49
          14       0.00      0.00      0.00        16
          15       0.21      0.46      0.29        48
          16       0.00      0.00      0.00        54
          17       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.45      0.21      0.29       208
           1       0.52      0.44      0.48        36
           2       0.00      0.00      0.00        27
           3       0.74      0.86      0.80       117
           4       0.67      0.74      0.70        38
           5       0.76      0.91      0.83        69
           6       0.75      0.94      0.83        48
           7       0.83      0.74      0.78        34
           8       0.92      0.63      0.75        19
           9       0.52      0.65      0.58        20
          10       0.00      0.00      0.00        14
          11       0.00      0.00      0.00        28
          12       0.62      0.39      0.48        54
          13       0.60      0.73      0.66        49
          14       0.43      0.81      0.57        16
          15       0.40      0.38      0.39        48
          16       0.26      0.46      0.33        54
          17       0.62    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.49      0.18      0.27       208
           1       0.69      0.25      0.37        36
           2       0.00      0.00      0.00        27
           3       0.67      0.90      0.77       117
           4       0.51      0.89      0.65        38
           5       0.75      0.96      0.84        69
           6       0.78      0.88      0.82        48
           7       0.70      0.91      0.79        34
           8       0.71      0.89      0.79        19
           9       0.68      0.65      0.67        20
          10       0.43      0.43      0.43        14
          11       1.00      0.07      0.13        28
          12       0.61      0.57      0.59        54
          13       0.79      0.63      0.70        49
          14       0.48      0.69      0.56        16
          15       0.31      0.60      0.41        48
          16       0.32      0.39      0.35        54
          17       0.59    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.44      0.18      0.26       208
           1       0.46      0.53      0.49        36
           2       0.00      0.00      0.00        27
           3       0.70      0.88      0.78       117
           4       0.67      0.76      0.72        38
           5       0.77      0.91      0.83        69
           6       0.77      0.83      0.80        48
           7       0.70      0.88      0.78        34
           8       0.75      0.79      0.77        19
           9       0.61      0.70      0.65        20
          10       0.33      0.43      0.38        14
          11       0.56      0.36      0.43        28
          12       0.57      0.61      0.59        54
          13       0.78      0.71      0.74        49
          14       0.77      0.62      0.69        16
          15       0.38      0.40      0.39        48
          16       0.34      0.57      0.43        54
          17       0.58    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.45      0.14      0.22       208
           1       0.60      0.33      0.43        36
           2       0.38      0.37      0.38        27
           3       0.64      0.90      0.75       117
           4       0.52      0.84      0.64        38
           5       0.79      0.78      0.79        69
           6       0.76      0.85      0.80        48
           7       0.74      0.91      0.82        34
           8       0.80      0.84      0.82        19
           9       0.54      0.70      0.61        20
          10       0.50      0.43      0.46        14
          11       0.55      0.39      0.46        28
          12       0.61      0.52      0.56        54
          13       0.66      0.55      0.60        49
          14       0.59      0.62      0.61        16
          15       0.42      0.56      0.48        48
          16       0.36      0.50      0.42        54
          17       0.61    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.39      0.26      0.31       208
           1       0.50      0.36      0.42        36
           2       0.35      0.41      0.38        27
           3       0.78      0.86      0.82       117
           4       0.63      0.68      0.66        38
           5       0.79      0.91      0.85        69
           6       0.79      0.88      0.83        48
           7       0.78      0.91      0.84        34
           8       0.78      0.74      0.76        19
           9       0.64      0.70      0.67        20
          10       0.43      0.43      0.43        14
          11       0.43      0.46      0.45        28
          12       0.62      0.57      0.60        54
          13       0.79      0.67      0.73        49
          14       0.69      0.69      0.69        16
          15       0.48      0.52      0.50        48
          16       0.50      0.50      0.50        54
          17       0.68    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.40      0.25      0.31       208
           1       0.53      0.44      0.48        36
           2       0.41      0.33      0.37        27
           3       0.77      0.88      0.82       117
           4       0.60      0.68      0.64        38
           5       0.79      0.90      0.84        69
           6       0.79      0.88      0.83        48
           7       0.78      0.91      0.84        34
           8       0.78      0.74      0.76        19
           9       0.67      0.70      0.68        20
          10       0.43      0.43      0.43        14
          11       0.46      0.43      0.44        28
          12       0.65      0.59      0.62        54
          13       0.77      0.67      0.72        49
          14       0.73      0.69      0.71        16
          15       0.46      0.58      0.51        48
          16       0.53      0.52      0.52        54
          17       0.65    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.38      0.25      0.30       208
           1       0.48      0.44      0.46        36
           2       0.35      0.26      0.30        27
           3       0.75      0.86      0.80       117
           4       0.60      0.68      0.64        38
           5       0.79      0.90      0.84        69
           6       0.79      0.85      0.82        48
           7       0.78      0.91      0.84        34
           8       0.76      0.68      0.72        19
           9       0.61      0.70      0.65        20
          10       0.40      0.43      0.41        14
          11       0.41      0.43      0.42        28
          12       0.62      0.56      0.59        54
          13       0.78      0.63      0.70        49
          14       0.65      0.69      0.67        16
          15       0.49      0.56      0.52        48
          16       0.54      0.52      0.53        54
          17       0.68    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.38      0.25      0.30       208
           1       0.48      0.44      0.46        36
           2       0.43      0.37      0.40        27
           3       0.76      0.88      0.82       117
           4       0.61      0.71      0.66        38
           5       0.79      0.90      0.84        69
           6       0.80      0.83      0.82        48
           7       0.78      0.91      0.84        34
           8       0.81      0.68      0.74        19
           9       0.59      0.65      0.62        20
          10       0.43      0.43      0.43        14
          11       0.45      0.46      0.46        28
          12       0.62      0.56      0.59        54
          13       0.79      0.67      0.73        49
          14       0.79      0.69      0.73        16
          15       0.48      0.60      0.53        48
          16       0.52      0.52      0.52        54
          17       0.64    

Evaluating the model now

In [18]:
# Now we will evaluate the model on the test set
test_loss, test_precision, test_recall, test_f1, test_acc = validate_model(
    model=model,
    data_loader=test_data_loader,
    loss_fn=None,
    device=configs["device"],
    phase="Test",
    verbose=True
)

print(
  f"Test Loss: {test_loss}, Test Acc: {(test_acc * 100):.3f}%, Precision: {test_precision:.3f}, "
  f"Recall: {test_recall:.3f}, F1-Score: {test_f1:.3f}"
)




              precision    recall  f1-score   support

           0       0.48      0.31      0.38       412
           1       0.46      0.40      0.43        68
           2       0.37      0.39      0.38        57
           3       0.80      0.89      0.84       227
           4       0.63      0.74      0.68        74
           5       0.79      0.86      0.83       133
           6       0.81      0.85      0.83        93
           7       0.88      0.94      0.91        68
           8       0.90      0.90      0.90        39
           9       0.62      0.55      0.58        42
          10       0.56      0.58      0.57        38
          11       0.39      0.33      0.36        55
          12       0.62      0.55      0.58       105
          13       0.73      0.78      0.76        93
          14       0.62      0.64      0.63        33
          15       0.48      0.48      0.48        87
          16       0.51      0.61      0.56       102
          17       0.76    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
# lets save the model now
def save_model_and_tokenizer(
  model: nn.Module,
  tokenizer,
  directory: str,
  model_filename: str = "model.pt"
):
    """
    Save a custom PyTorch model and a Hugging Face tokenizer to disk.

    Args:
        model (torch.nn.Module): The custom PyTorch model to be saved.
        tokenizer (PreTrainedTokenizer): The Hugging Face tokenizer to be saved.
        directory (str): The directory where the model and tokenizer should be saved.
        model_filename (str, optional): Filename for the saved model. Defaults to "model.pt".
    """
    # Ensure the directory exists
    os.makedirs(directory, exist_ok=True)
    os.makedirs(directory + "model", exist_ok=True)
    os.makedirs(directory + "tokenizer", exist_ok=True)

    # Path for the model file
    model_path = os.path.join(directory, "model", model_filename)
    tokenizer_path = os.path.join(directory, "token", model_filename)

    # Save the model
    torch.save(model.state_dict(), model_path)

    # Save the tokenizer in the same directory
    tokenizer.save_pretrained(tokenizer_path)

    print(f"Model saved to {model_path}")
    print(f"Tokenizer saved to {directory}")
