In [None]:
%load_ext autoreload
%autoreload 2
from ml.ingest.textocr_to_torch import TextOCRDoctrDetDataset, doctr_detection_collate
from torch.utils.data import DataLoader
from doctr.models import db_resnet50, DBNet
from pathlib import Path
import torch
import os
import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient


In [None]:
LR = 0.001
BATCH_SIZE = 4
MAX_STEPS = 6
MODEL_DIR = Path('../../model')

dataset = TextOCRDoctrDetDataset(num_samples=32)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=doctr_detection_collate)

tracking_uri = os.getenv('MLFLOW_TRACKING`_URI', 'http://localhost:65500')
mlflow.set_tracking_uri(tracking_uri)
base_experiment_name = os.getenv('MLFLOW_EXPERIMENT_NAME', 'textflow-doctr')
artifact_location_override = os.getenv('MLFLOW_EXPERIMENT_ARTIFACT_LOCATION')
client = MlflowClient()

def _create_experiment(name: str):
    target_location = artifact_location_override or f"mlflow-artifacts:/{name}"
    client.create_experiment(name, artifact_location=target_location)
    created = client.get_experiment_by_name(name)
    print(f"Created experiment '{name}' with artifact location {created.artifact_location}")
    return created

def _is_container_path(location: str) -> bool:
    if location is None:
        return False
    location = location.strip()
    remote_prefixes = ('mlflow-artifacts:', 's3://', 'gs://', 'http://', 'https://')
    if location.startswith(remote_prefixes):
        return False
    return location.startswith('/')

experiment_name = base_experiment_name
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
    experiment = _create_experiment(experiment_name)

if artifact_location_override is None:
    suffix = 1
    while _is_container_path(experiment.artifact_location):
        print(
            f"Experiment '{experiment_name}' points at {experiment.artifact_location}, which the host cannot write. Switching to a new experiment name."
        )
        experiment_name = f"{base_experiment_name}-client{suffix}"
        suffix += 1
        experiment = client.get_experiment_by_name(experiment_name)
        if experiment is None:
            experiment = _create_experiment(experiment_name)

mlflow.set_experiment(experiment_name)
print(f"Using experiment '{experiment_name}' with artifact location {experiment.artifact_location}")


In [None]:
model: DBNet = db_resnet50(pretrained=False).train()

optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], LR)

run_name = os.getenv('MLFLOW_RUN_NAME', 'doctr-dbnet-lite')
register_name = os.getenv('MLFLOW_REGISTER_MODEL_NAME')

with mlflow.start_run(run_name=run_name) as run:
    mlflow.log_params({
        'learning_rate': LR,
        'batch_size': BATCH_SIZE,
        'train_samples': len(dataset),
        'max_steps': MAX_STEPS,
    })
    mlflow.set_tag('model_architecture', 'db_resnet50')
    mlflow.set_tag('dataset', 'textocr_subset')

    for step, (images, targets) in enumerate(loader, start=1):
        optimizer.zero_grad(set_to_none=True)
        train_loss = model(images, targets)['loss']
        train_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

        mlflow.log_metric('train_loss', train_loss.item(), step=step)
        print(step, train_loss.item())

        if step >= MAX_STEPS:
            break

    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    model_path = MODEL_DIR / 'dbnet_textocr.pt'
    torch.save(model.state_dict(), model_path)
    mlflow.log_artifact(str(model_path), artifact_path='artifacts')

    if register_name:
        mlflow.pytorch.log_model(model, artifact_path='model', registered_model_name=register_name)
    else:
        mlflow.pytorch.log_model(model, artifact_path='model')

    run_id = run.info.run_id
print('MLflow run logged:', run_id)


In [None]:
print('Local checkpoint stored at', MODEL_DIR / 'dbnet_textocr.pt')
