In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
import pathlib
import sys
import torchmetrics
from torchmetrics.classification import (
    MulticlassAUROC,
    MulticlassJaccardIndex,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
    MulticlassAccuracy,
    BinaryAccuracy,
    BinaryAUROC,
    BinaryF1Score,
    BinaryPrecision,
    BinaryRecall,
    BinaryJaccardIndex,
)
import torch
import torch.nn as nn

root = pathlib.Path().absolute().parent
DATASET_PATH = root / 'datasets'
MODEL_REGISTRY = root / 'model_registry'

sys.path.append(str(root))

from src.data.classification import TumorBinaryClassificationDataset, CLASSIFICATION_NORMALIZER
from src.utils.config import get_device
from src.enums import DataSplit
from src.models.classification.logreg import LogisiticRegression
from src.trainer import eval_classification, train_classification
from src.utils.visualize import create_classification_results

In [None]:
DIM = 256
N_EPOCHS = 15
BATCH_SIZE = 32

transform = transforms.Compose(
    [
        transforms.Resize((DIM, DIM)),  # TODO: make this larger
        transforms.ToTensor(),
        CLASSIFICATION_NORMALIZER
    ]
)

device = get_device()

LOG_REG_MODEL = MODEL_REGISTRY / 'log_reg.pth'

In [None]:
train_dataset = TumorBinaryClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TRAIN,
    transform=transform,
)

test_dataset = TumorBinaryClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TEST,
    transform=transform,
)

print("Train dataset length: ", len(train_dataset))
print("Test dataset length: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = LogisiticRegression(DIM * DIM * 3, 1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)

# train_classification(
#     model,
#     train_loader,
#     optimizer,
#     criterion,
#     device,
#     N_EPOCHS,
#     is_multiclass=False,
#     model_path=LOG_REG_MODEL
# )

In [None]:
model.load_state_dict(torch.load(LOG_REG_MODEL))
model.to(device)
model.eval()

metrics = torchmetrics.MetricCollection(
    [
        BinaryAUROC().to(device),
        BinaryJaccardIndex().to(device),
        BinaryAccuracy().to(device),
        BinaryF1Score().to(device),
        BinaryPrecision().to(device),
        BinaryRecall().to(device),
    ]
)

y_true, y_pred, total_metrics = eval_classification(
    model,
    test_loader,
    metrics,
    device,
    is_multiclass=False,
)

accuracy = total_metrics["BinaryAccuracy"]
print(f"Accuracy on test set: {accuracy:.2%}")

In [None]:
class_names = ["No Tumor", "Tumor"]
create_classification_results(
    y_true,
    y_pred,
    class_names,
)