In [None]:
import os
from bioext.doccano_utils import DoccanoSession
from bioext.hfpipeline import GlobalConfig, DataSource, TaskType, DataHandler, HFSequenceClassificationTrainer
from transformers import AutoTokenizer
from dotenv import load_dotenv

In [None]:
# This notebook is for testing the functionality of hfpipeline.py
# Before using, load up your local Doccano instance and create a project + load data
# Sample pre-labelled data is provided in ./imports for binary classification, multiclass (3 label) classification, and multilabel (4 label) classification
# A pre-labelled NER dataset is provided, but not yet implemented in hfpipeline

In [None]:
load_dotenv()
docsesh = DoccanoSession()

In [None]:
projects = docsesh.client.list_projects()

for project in projects:
    print(f"Project ID: {project.id}, Name: {project.name}, Type: {project.project_type}")

In [None]:
config = GlobalConfig(
    doc_project_id=1,
    source=DataSource.DOCCANO,
    task=TaskType.MULTILABEL,
    num_labels=4,
    model_name="distilbert-base-uncased",
    max_length=256,
    batch_size=16,
    learning_rate=3e-5,
    num_train_epochs=3,
    output_dir="./model_output"
)

# data handler to load and preprocess data
data_handler = DataHandler(config=config)

print(f"Training samples: {len(data_handler.train_dataset)}")
print(f"Testing samples: {len(data_handler.test_dataset)}")

sample = data_handler.train_dataset[0]
print(sample)

In [None]:
# initialise trainer
trainer = HFSequenceClassificationTrainer(
    config=config,
    tokenizer=data_handler.tokenizer
)

In [None]:
trainer.setup_trainer(
    train_dataset=data_handler.train_dataset,
    eval_dataset=data_handler.test_dataset
)

In [None]:
training_metrics = trainer.train()

print(f"Model saved to: {os.path.abspath(config.output_dir)}")
print("Training metrics:")
for key, value in training_metrics.items():
    print(f"{key}: {value:.2f}")

eval_results = trainer.trainer.evaluate()

print("Evaluation metrics:")
for key, value in eval_results.items():
    if isinstance(value, (int, float)):
        print(f"{key}: {value:.2f}")