# ü´Å CheXpert Training on Google Colab

**Chest X-Ray Multi-Label Classification**
- Model: DenseNet121
- Loss: BCEWithLogitsLoss (with pos_weight)
- Metric: AUC per class

## 1. Setup Environment

In [None]:
# Install dependencies
!pip install pytorch-lightning albumentations torchmetrics -q

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo (‡∏ñ‡πâ‡∏≤‡∏¢‡∏±‡∏á‡πÑ‡∏°‡πà‡πÑ‡∏î‡πâ clone)
!git clone https://github.com/folklazy/chest-xray-detection.git
%cd chest-xray-detection

## 2. Configuration

In [None]:
# === ‡πÅ‡∏Å‡πâ path ‡∏ï‡∏£‡∏á‡∏ô‡∏µ‡πâ‡πÉ‡∏´‡πâ‡∏ï‡∏£‡∏á‡∏Å‡∏±‡∏ö Google Drive ‡∏Ç‡∏≠‡∏á‡∏Ñ‡∏∏‡∏ì ===
DATA_DIR = "/content/drive/MyDrive/data"  # ‡πÇ‡∏ü‡∏•‡πÄ‡∏î‡∏≠‡∏£‡πå‡∏ó‡∏µ‡πà‡πÄ‡∏Å‡πá‡∏ö CheXpert
CSV_PATH = "/content/drive/MyDrive/data/CheXpert-v1.0-small/train.csv"

# Training Config
IMG_SIZE = 320
BATCH_SIZE = 32  # Colab ‡∏°‡∏µ GPU ‡πÉ‡∏´‡∏ç‡πà‡∏Å‡∏ß‡πà‡∏≤ local
NUM_WORKERS = 2
EPOCHS = 15
LR = 1e-4

# Class Weights (‡∏à‡∏≤‡∏Å calc_pos_weight.py)
POS_WEIGHT = [5.68, 7.29, 14.01, 3.28, 1.59]

## 3. Import & Setup

In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from src.dataset import CheXpertDataModule
from src.model import CheXpertLightning

# Check GPU
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## 4. Training

In [None]:
pl.seed_everything(42, workers=True)

# DataModule
dm = CheXpertDataModule(
    data_dir=DATA_DIR,
    csv_path=CSV_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

# Model
model = CheXpertLightning(
    model_name="densenet121",
    num_classes=5,
    lr=LR,
    pos_weight=POS_WEIGHT,
)

In [None]:
# Callbacks
checkpoint = ModelCheckpoint(
    monitor="val_auc",
    mode="max",
    save_top_k=1,
    filename="best-{epoch:02d}-{val_auc:.4f}",
    dirpath="/content/drive/MyDrive/checkpoints",  # Save to Drive
)

early_stop = EarlyStopping(
    monitor="val_auc",
    mode="max",
    patience=5,
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")

logger = TensorBoardLogger("logs", name="chexpert")

In [None]:
# Trainer
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator="auto",
    devices="auto",
    precision="16-mixed",
    callbacks=[checkpoint, early_stop, lr_monitor],
    logger=logger,
    log_every_n_steps=20,
)

# Train!
trainer.fit(model, dm)

In [None]:
print("\n‚úÖ Best checkpoint:", checkpoint.best_model_path)

## 5. TensorBoard

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs