In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
import pathlib
import sys
import torchmetrics
from torchmetrics.classification import (
    MulticlassAUROC,
    MulticlassJaccardIndex,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
    MulticlassAccuracy,
    BinaryAccuracy,
    BinaryAUROC,
    BinaryF1Score,
    BinaryPrecision,
    BinaryRecall,
    BinaryJaccardIndex,
)
import torch
import torch.nn as nn

root = pathlib.Path().absolute().parent
DATASET_PATH = root / 'datasets'
MODEL_REGISTRY = root / 'model_registry'

sys.path.append(str(root))

from src.data.classification import TumorClassificationDataset, TumorBinaryClassificationDataset, CLASSIFICATION_NORMALIZER
from src.utils.config import get_device
from src.enums import DataSplit
from src.models.classification.cnn import ClassificationMulticlassCNN, ClassificationCNN
from src.trainer import eval_classification, train_classification
from src.utils.visualize import create_classification_results

In [None]:
DIM = 256
N_EPOCHS = 20
BATCH_SIZE = 32

transform = transforms.Compose(
    [
        transforms.Resize((DIM, DIM)),  # TODO: make this larger
        transforms.ToTensor(),
        CLASSIFICATION_NORMALIZER
    ]
)

device = get_device()

CNN_MULTI_MODEL = MODEL_REGISTRY / 'cnn_multi.pth'
CNN_BINARY_MODEL = MODEL_REGISTRY / 'cnn_binary.pth'

def build_model_for_job(is_multiclass: bool):
    """
    Builds a model for the job based on the type of classification task

    Args:
        is_multiclass (bool): Whether the task is multiclass or binary
    """
    if is_multiclass:
        model = ClassificationMulticlassCNN()
        criterion = nn.CrossEntropyLoss()
    else:
        model = ClassificationCNN()
        criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)

    return model, criterion, optimizer

### Binary Classification

In [None]:
train_dataset = TumorBinaryClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TRAIN,
    transform=transform,
)

test_dataset = TumorBinaryClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TEST,
    transform=transform,
)

print("Train dataset length: ", len(train_dataset))
print("Test dataset length: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model, criterion, optimizer = build_model_for_job(is_multiclass=False)
model.to(device)

# train_classification(
#     model,
#     train_loader,
#     optimizer,
#     criterion,
#     device,
#     N_EPOCHS,
#     is_multiclass=False,
#     model_path=CNN_BINARY_MODEL
# )

In [None]:
model.load_state_dict(torch.load(CNN_BINARY_MODEL))
model.to(device)
model.eval()

metrics = torchmetrics.MetricCollection(
    [
        BinaryAUROC().to(device),
        BinaryJaccardIndex().to(device),
        BinaryAccuracy().to(device),
        BinaryF1Score().to(device),
        BinaryPrecision().to(device),
        BinaryRecall().to(device),
    ]
)

y_true, y_pred, total_metrics = eval_classification(
    model,
    test_loader,
    metrics,
    device,
    is_multiclass=False,
)

bin_accuracy = total_metrics["BinaryAccuracy"]
print(f'Accuracy on test set: {bin_accuracy:.2%}')


In [None]:
class_names = ['No Tumor', 'Tumor']
create_classification_results(y_true, y_pred, class_names=class_names)

### Multi Classification

In [None]:
train_dataset = TumorClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TRAIN,
    transform=transform,
)

test_dataset = TumorClassificationDataset(
    root_dir=DATASET_PATH,
    split=DataSplit.TEST,
    transform=transform,
)

print("Train dataset length: ", len(train_dataset))
print("Test dataset length: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model, criterion, optimizer = build_model_for_job(is_multiclass=True)
model.to(device)
train_classification(
    model,
    train_loader,
    optimizer,
    criterion,
    device,
    N_EPOCHS,
    is_multiclass=True,
    model_path=CNN_MULTI_MODEL
)

In [None]:
# Load the model
model.load_state_dict(torch.load(CNN_MULTI_MODEL))
model.eval()  # Set the model to evaluation mode
model.to(device)

metrics = torchmetrics.MetricCollection(
    [
        MulticlassAUROC(4).to(device),
        MulticlassJaccardIndex(4).to(device),
        MulticlassAccuracy(4).to(device),
        MulticlassF1Score(4).to(device),
        MulticlassPrecision(4).to(device),
        MulticlassRecall(4).to(device),
    ]
)

with torch.no_grad():
    y_true, y_pred, total_metrics = eval_classification(
        model, test_loader, metrics, device, is_multiclass=True
    )

    print(f"Validation Metrics: ", total_metrics)

multi_class_accuracy = total_metrics["MulticlassAccuracy"]
print(f"Multi-class accuracy: {multi_class_accuracy:.2%}")

In [None]:
class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary']
create_classification_results(y_true, y_pred, class_names)
