Skip to content

hellebuyckf/aria_back

Repository files navigation

aria_back — Couche LLM ARIA

Inférence et fine-tuning LLM pour le système ARIA de diagnostic biomécanique course à pied. Sert un endpoint OpenAI-compatible via vLLM sur alpha-server (port 8001).

[aria_middle — alpha-server :8000]
        │  POST /v1/chat/completions
        ▼
[aria_back — alpha-server :8001]
        ├── vLLM · xgrammar · bfloat16
        └── google/medgemma-4b-it → aria-ft-sft → aria-ft-dpo (PROD)

Prérequis

  • GPU NVIDIA avec ≥ 16 GB VRAM (alpha-server)
  • CUDA 12.4+
  • Python ≥ 3.11
  • uv
  • Token HuggingFace avec accès à MedGemma 4B-it (modèle gated)

Installation

make install
cp .env.example .env
# Éditer .env — renseigner HF_TOKEN au minimum

Variables d'environnement :

VLLM_PORT=8001
VLLM_MODEL=google/medgemma-4b-it
VLLM_SERVED_NAME=aria-ft
VLLM_API_KEY=aria-local
VLLM_MAX_MODEL_LEN=4096
VLLM_GPU_MEMORY_UTILIZATION=0.90
HF_TOKEN=hf_xxx
HF_MODEL_REPO=username/aria-ft-dpo      # pour make hf-push-model
HF_DATASET_REPO=username/aria-datasets  # pour make hf-push-datasets

Mise en route complète

# 1. Datasets — copier depuis aria_middle
make copy-datasets

# 2. Vérifier le GPU
make check-gpu

# 3. Valider le serving de base
make serve          # terminal 1
make health         # terminal 2
make test

# 4. Fine-tuning (~40 min sur alpha-server)
make train-sft      # Phase 1 — SFT QLoRA
make train-dpo      # Phase 2 — DPO + fusion LoRA

# 5. Serving du modèle fine-tuné
make serve-ft

# 6. Benchmark qualité
make benchmark

Règle GPU : ne jamais lancer train-* et serve simultanément — VRAM insuffisante.

Commandes

Serving

Commande Description
make serve vLLM · MedGemma 4B-it base
make serve-ft vLLM · aria-ft-dpo (PROD)
make serve-sft vLLM · aria-ft-sft (validation)
make docker-serve Docker · MedGemma 4B-it base
make docker-serve-ft Docker · aria-ft-dpo (PROD)
make docker-stop Arrêter le container
make docker-logs Logs du container

Training

Commande Description Durée
make train-sft SFT QLoRA — MedGemma → aria-ft-sft ~20-30 min
make train-dpo DPO + fusion LoRA — SFT → aria-ft-dpo ~10-15 min
make train-all Pipeline complet SFT → DPO ~40 min

Tests & Qualité

Commande Description
make test Tous les tests (skip si vLLM absent)
make test-fast Tests sans GPU
make benchmark Latence P50/P99 + exact match sur 6 pathologies

Documentation

Commande Description
make doc Générer la documentation HTML dans site/
make doc-serve Servir avec hot-reload sur http://localhost:8080

HuggingFace Hub

Commande Description
make hf-push-model Pousser aria-ft-dpo + Model Card
make hf-push-datasets Pousser data/training/ + Dataset Card

API

Deux appels dans le pipeline ARIA, endpoint OpenAI-compatible.

Diagnostic structuré — JSON garanti

POST /v1/chat/completions
{
  "model": "aria-ft",
  "messages": [{"role": "user", "content": "{prompt}"}],
  "temperature": 0.1,
  "max_tokens": 256,
  "response_format": {
    "type": "json_schema",
    "json_schema": {
      "name": "diagnostic",
      "schema": {
        "type": "object",
        "properties": {
          "pathologie":    {"type": "string"},
          "confiance":     {"enum": ["élevée", "modérée", "faible"]},
          "justification": {"type": "string"}
        },
        "required": ["pathologie", "confiance", "justification"]
      }
    }
  }
}

xgrammar contraint la génération token par token — JSON valide à 100%.

Rapport libre

POST /v1/chat/completions
{
  "model": "aria-ft",
  "messages": [{"role": "user", "content": "{prompt}"}],
  "temperature": 0.3,
  "max_tokens": 2048
}

Pathologies couvertes

Pathologie Signal discriminant
Lombalgie mécanique inclinaison_tronc > 7°
Syndrome fémoro-patellaire (SFP) valgus_genou > 10°
Syndrome bandelette ilio-tibiale (SBIT) adduction_hanche > 10°
Périostite tibiale longueur_foulée > 1.35 m
Fasciite plantaire pronation_pied > 6°
Tendinite Achille heel_strike_index élevé + dorsiflexion_max > 15°

Pipeline d'entraînement

MedGemma 4B-it (base gated)
    │
    ├─ Phase 1 — SFT QLoRA NF4
    │   r=32, α=64, lr=2e-4, 560 steps
    │   ~1500 paires prompt/completion
    │   → data/models/aria-ft-sft (bfloat16 fusionné)
    │
    └─ Phase 2 — DPO TRL QLoRA NF4
        r=16, α=32, lr=2e-5, β=0.2, 135 steps
        ~420 triplets prompt/chosen/rejected
        → data/models/aria-ft-dpo (bfloat16 fusionné) ✓ PROD

Performances

Métrique Cible Mesure
JSON valide (guided decoding) 100% make test
Exact match pathologie ≥ 80% make benchmark
Latence P50 diagnostic < 1.5s make benchmark
Latence P99 diagnostic < 3s make benchmark
VRAM serving < 10 GB nvidia-smi

Structure

aria_back/
├── serve/
│   └── vllm_server.py          # launcher vLLM
├── training/
│   ├── train_sft.py            # SFT QLoRA TRL
│   └── train_dpo.py            # DPO TRL + fusion
├── scripts/
│   ├── benchmark.py            # latence + exact match
│   ├── generate_diagnostic_sft.py
│   ├── generate_dpo.py
│   ├── generate_model_card.py
│   ├── generate_dataset_card.py
│   └── merge_sft.py
├── tests/
│   └── test_inference.py       # tests live vLLM
├── data/
│   ├── training/               # .gitignore — datasets JSONL
│   └── models/                 # .gitignore — checkpoints
└── .github/workflows/ci.yml    # ruff + pip-audit

CI

GitHub Actions déclenché sur chaque push/PR sur main :

  • Code quality : ruff check + ruff format --check
  • Security : pip-audit sur les dépendances prod

Licence

Projet privé — usage recherche. Modèle de base soumis à la licence Google Health AI Developer Foundations.

About

Aria LLM

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors