In [None]:
import kagglehub

path = kagglehub.dataset_download("umitka/chest-x-ray-balanced")

print("Path to dataset files:", path)

In [1]:
import torch
import torchvision
from torch import nn
from torchvision import transforms, datasets

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [None]:
import os

print(os.listdir(path))
print(os.listdir(f"{path}/chest_xray_balanced"))

In [None]:
path = path + "/chest_xray_balanced"

In [None]:
train_dir = f"{path}/train"
test_dir = f"{path}/test"
val_dir = f"{path}/val"

In [None]:
data_transforms = transforms.Compose(
    [
        transforms.Resize((288, 288)),
        transforms.TrivialAugmentWide(num_magnitude_bins=31),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((288, 288)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
train_data = datasets.ImageFolder(
    root=train_dir,
    target_transform=None,
    transform=data_transforms,
)

test_data = datasets.ImageFolder(
    root=test_dir,
    target_transform=None,
    transform=test_transform,
)

val_data = datasets.ImageFolder(
    root=val_dir,
    target_transform=None,
    transform=test_transform,
)

In [None]:
class_names = train_data.classes

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_data_loader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
)

test_data_loader = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
)

val_data_loader = DataLoader(
    dataset=val_data,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
)

In [None]:
img, label = test_data[0]
img.shape, label

In [None]:
import matplotlib.pyplot as plt

plt.imshow(img.permute(1, 2, 0))
plt.title(class_names[label])

In [None]:
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
model = torchvision.models.efficientnet_b2(weights=weights).to(device)

In [None]:
model.classifier

In [None]:
for layer in model.features.parameters():
    layer.requires_grad = False

for layer in list(model.features.parameters())[:-7]:
    layer.requires_grad = True

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

model.classifier = nn.Sequential(
    [
        nn.Dropout(p=0.4, inplace=True),
        nn.Linear(in_features=1408, out_features=1, bias=True),
    ]
)

In [None]:
model.to(device)

In [None]:
x = torch.rand([1, 3, 288, 288]).to(device)
model(x)

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
    [
        {"params": model.features.parameters(), "lr": 1e-4},
        {"params": model.classifier.parameters(), "lr": 5e-4},
    ],
    weight_decay=1e-4,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    mode="min",
    factor=0.1,
    patience=5,
)

In [None]:
def acc_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    return correct / len(y_true) * 100

In [None]:
import os

results_path = "./chest_xray_results"
if not os.path.exists(results_path):
    os.mkdir(results_path)

with open(f"{results_path}/class_names", "w") as f:
    f.write("\n".join(class_names))

In [None]:
patience = 3
early_stop = 0
best_loss = None
epochs = 10

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir=f"{results_path}/runs")

In [None]:
from tqdm.auto import tqdm

for epoch in tqdm(range(1, epochs + 1), desc="Epochs"):
    train_loss, train_acc = 0, 0
    model.train()

    for batch, (X, y) in tqdm(
        enumerate(train_data_loader),
        total=len(train_data_loader),
        desc=f"Training epoch {epoch}",
        leave=False,
    ):
        X, y = X.to(device), y.to(device)
        logits = model(X)
        y = y.unsqueeze(dim=1).float()
        loss = loss_fn(logits, y)
        y_pred = torch.round(torch.sigmoid(logits))

        train_loss += loss.item()
        train_acc += acc_fn(y, y_pred)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_data_loader)
    train_acc /= len(train_data_loader)

    model.eval()
    val_loss, val_acc = 0, 0
    with torch.inference_mode():
        for batch, (X, y) in tqdm(
            enumerate(val_data_loader),
            total=len(val_data_loader),
            desc=f"Validating epoch {epoch}",
            leave=False,
        ):
            X, y = X.to(device), y.to(device)
            logits = model(X)
            y = y.unsqueeze(dim=1).float()
            loss = loss_fn(logits, y)
            y_pred = torch.round(torch.sigmoid(logits))

            val_loss += loss.item()
            val_acc += acc_fn(y, y_pred)

        val_loss /= len(val_data_loader)
        val_acc /= len(val_data_loader)

    writer.add_scalars(
        main_tag="Loss",
        tag_scalar_dict={"train_loss": train_loss, "val_loss": val_loss},
        global_step=epoch,
    )

    writer.add_scalars(
        main_tag="Accuracy",
        tag_scalar_dict={"train_acc": train_acc, "val_acc": val_acc},
        global_step=epoch,
    )

    writer.add_scalar(
        tag="Learning Rate",
        scalar_value=optimizer.param_groups[0]["lr"],
        global_step=epoch,
    )

    info = (
        f"Epoch: {epoch} | "
        f"Train acc: {train_acc:.5f} | Train loss: {train_loss:.5f} | "
        f"Val acc: {val_acc:.5f} | Val loss: {val_loss:.5f}"
    )

    print(info)
    with open(f"{results_path}/training_info.txt", "a") as f:
        f.write(info + "\n")

    old_lr = optimizer.param_groups[0]["lr"]
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]["lr"]

    if new_lr != old_lr:
        print(f"Learning rate reduced: {old_lr} â†’ {new_lr} at epoch {epoch}")

    if best_loss is None or val_loss < best_loss:
        best_loss = val_loss
        torch.save(model, f"{results_path}/model.pth")
        print(f"Best model saved after epoch: {epoch}")
        early_stop = 0
    else:
        early_stop += 1
        if early_stop == patience:
            print(f"Early stopping after epoch: {epoch}")
            break

In [None]:
model = torch.load(f"{results_path}/model.pth", weights_only=False)
model.to(device)

test_preds = []
model.eval()
test_loss, test_acc = 0, 0

with torch.inference_mode():
    for batch, (X, y) in tqdm(
        enumerate(test_data_loader),
        total=len(test_data_loader),
        desc=f"Testing",
        leave=False,
    ):
        X, y = X.to(device), y.to(device)
        logits = model(X)
        y = y.unsqueeze(dim=1).float()
        loss = loss_fn(logits, y)
        y_pred = torch.round(torch.sigmoid(logits))
        test_preds.append(y_pred.cpu())
        test_loss += loss.item()
        test_acc += acc_fn(y, y_pred)

    test_loss /= len(test_data_loader)
    test_acc /= len(test_data_loader)

test_preds = torch.cat(test_preds)
test_acc, test_loss

In [None]:
from torchmetrics import ConfusionMatrix

cm = ConfusionMatrix(task='binary', num_classes=len(class_names))
conf_mat = cm(test_preds.squeeze(), torch.Tensor(test_data.targets).type(torch.int64))

In [None]:
from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt

fig, ax = plot_confusion_matrix(
    conf_mat=conf_mat.numpy(), class_names=class_names, figsize=(7, 7)
)
plt.title("Confusion matrix")
plt.savefig(f"{results_path}/confusion_matrix.png", dpi=1000)

In [None]:
model.cpu()
torch.save(model, f"{results_path}/cpu_model.pth")

In [None]:
from sklearn.metrics import classification_report

with open(f"{results_path}/classification_report.txt", "w") as f:
    f.write(
        str(
            classification_report(
                torch.Tensor(test_data.targets).type(torch.int64), test_preds
            )
        )
    )

In [None]:
from pprint import pprint

report = classification_report(
    torch.Tensor(test_data.targets).type(torch.int64), test_preds
)
pprint(report)