In [9]:
import logging
import numpy as np
import os
from pathlib import Path
import sys
import tempfile
import torch

from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import Dataset, DataLoader
from monai.engines import SupervisedTrainer
from monai.handlers import StatsHandler
from monai.inferers import SimpleInferer
from monai.networks import eval_mode
from monai.networks.nets import densenet121
from monai.transforms import ResizeD, LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose

In [10]:
#directory = '/Users/eaidy/Datasets/MONAI'
root_dir = '/Users/eaidy/Datasets/deneme_set'
test_dir = '/Users/eaidy/Datasets/deneme_set_test'

In [11]:
transform = Compose(
    [
        LoadImageD(keys="image", image_only=True),
        EnsureChannelFirstD(keys="image"),
        ScaleIntensityD(keys="image"),
        ResizeD(spatial_size=(224, 224), keys=['image', 'label'])
    ]
)

In [None]:
dataset = Dataset(data=root_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)
dataset.__getitem__(2)


In [None]:
max_epochs = 5
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(torch.device("mps"))

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
trainer = SupervisedTrainer(
    device=torch.device("mps"),
    max_epochs=max_epochs,
    train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),
    network=model,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
    loss_function=torch.nn.CrossEntropyLoss(),
    inferer=SimpleInferer(),
    train_handlers=StatsHandler(),
)

In [None]:
trainer.run()

In [None]:
dataset_dir = test_dir
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = CacheDataset(data=dataset_dir, transform=transform)
data_loader = DataLoader(dataset=testdata, batch_size=1)

max_items_to_print = 10
with eval_mode(model):
    for item in DataLoader(data_loader, batch_size=1, num_workers=0):
        prob = np.array(model(item["image"].to("mps")).detach().to("cpu"))[0]
        pred = class_names[prob.argmax()]
        gt = item["class_name"][0]
        print(f"Class prediction is {pred}. Ground-truth: {gt}")
        max_items_to_print -= 1
        if max_items_to_print == 0:
            break