In [1]:
import os
import time
import torch
import dagshub

from PIL import Image
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

from xrkit.base import CONFIG
from xrkit.models import *
from xrkit.data.dataset import NIHDataset

from torch.utils.data import DataLoader

import pytorch_lightning as L
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

os.chdir("..")
torch.set_float32_matmul_precision("high")
dagshub.init(CONFIG.dagshub.repository_name, CONFIG.dagshub.repository_owner, mlflow=True)

In [2]:
train_dataset = NIHDataset("train")
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)

validation_dataset = NIHDataset("validation")
validation_loader = DataLoader(
    validation_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)

test_dataset = NIHDataset("test")
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG.base.batch_size,
    shuffle=False,
    num_workers=CONFIG.base.n_workers,
    pin_memory=True,
    drop_last=False,
)

In [3]:
epochs = 100
model = DenseNet201Model(n_epochs=epochs)

experiment_name = model.__class__.__name__.lower()[:-5]
experiment_name = 'debug'
metric, mode = "validation_f1_score", "max"

logger = MLFlowLogger(experiment_name=experiment_name, tracking_uri=CONFIG.dagshub.tracking_uri)

checkpoint_callback = ModelCheckpoint(
    monitor=metric,
    dirpath=f"models/{experiment_name}",
    filename="model-{epoch:03d}-{validation_f1_score:.2f}",
    save_top_k=1,
    mode=mode,
    enable_version_counter=False,
)

early_stop_callback = EarlyStopping(monitor=metric, min_delta=0.00, patience=5, mode=mode)
trainer = L.Trainer(max_epochs=epochs, logger=logger, callbacks=[checkpoint_callback, early_stop_callback])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
start_training_time = time.time()
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader)
end_training_time = time.time()

run_id = trainer.logger.run_id
checkpoint_path = checkpoint_callback.best_model_path

/home/yullhan/miniconda3/envs/nih/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/yullhan/Projects/NIH-ChestXRay/models/debug exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | network   | DenseNet201       | 18.1 M
1 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
25.0 K    Trainable params
18.1 M    Non-trainable params
18.1 M    Total params
72.472    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
start_testing_time = time.time()
trainer.test(model=model, dataloaders=test_loader, ckpt_path=checkpoint_path)
end_testing_time = time.time()

In [None]:
dataloaders = {
    "train": train_loader,
    "validation": validation_loader,
    "test": test_loader,
}

results = trainer.predict(model=model, dataloaders=list(dataloaders.values()), ckpt_path=checkpoint_path)

In [None]:
print(f"Tempo de treinamento: {(end_training_time - start_training_time):.1f} segundos")
print(f"Tempo de teste: {(end_testing_time - start_testing_time) / 4:.1f} segundos")