In [1]:
import sys
from pathlib import Path
project_root = Path().resolve().parent
sys.path.append(str(project_root))

from src import preprocessing
from src import constants
from src import training
from src import evaluation

## Load and validate the training data

In [2]:
texts, labels = preprocessing.load_data(
    constants.DATA_FILE_PATH,
    sample_frac=0.1
)

2024-07-15 20:29:03,307 - JaneAustenLogger - INFO - Data validation successful.


## Preprocess the data

Tokenise and split the data.

In [3]:
tokenizer, train_loader, val_loader, test_loader = preprocessing.preprocess_data(
    texts,
    labels,
    tokenizer=constants.TOKENIZER,
    train_size=constants.TRAIN_SIZE,
    val_size=constants.VAL_SIZE,
    directory=constants.MODEL_FILE_PATH
)

2024-07-15 20:29:18,775 - JaneAustenLogger - INFO - Tokenizer saved to ../artefacts/austen_classifier_model_v2
2024-07-15 20:29:18,776 - JaneAustenLogger - INFO - Datasets and data loaders prepared successfully.


## Model Training

Train a model with a Huggingface distilbert model.

In [4]:
model = training.train_and_validate_model(
    constants.MODEL,
    train_loader,
    val_loader,
    num_epochs=constants.NUM_EPOCHS,
    directory=constants.MODEL_FILE_PATH
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024-07-15 20:29:19,591 - JaneAustenLogger - INFO - Beginning training and validation process.
2024-07-15 20:29:19,591 - JaneAustenLogger - INFO - Total Epochs: 3
2024-07-15 20:29:19,591 - JaneAustenLogger - INFO - Training rows: 3783
2024-07-15 20:29:19,592 - JaneAustenLogger - INFO - Validation rows: 3783


## Evaluation

Evaluate the model with F1 Score. Log the results.

In [None]:
evaluation.evaluate_model(
    model, test_loader
)

2024-07-15 20:26:58,391 - JaneAustenLogger - INFO - Evaluation rows: 1237
2024-07-15 20:28:43,416 - JaneAustenLogger - INFO - Test Set Evaluation
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
2024-07-15 20:28:43,438 - JaneAustenLogger - INFO - 
              precision    recall  f1-score   support

           0       0.61      1.00      0.76       759
           1       0.00      0.00      0.00       478

    accuracy                           0.61      1237
   macro avg       0.31      0.50      0.38      1237
weighted avg       0.38      0.61      0.47      1237

