In [151]:
import torch
from torchmetrics.classification import Dice, JaccardIndex, BinaryJaccardIndex

from dataset import SegmentationDataset, DataLoader

In [2]:
dice_metric = Dice(num_classes=2, average="micro")
iou_metric = JaccardIndex(task="binary", average="micro")

In [39]:
targets = torch.tensor([[[[1, 1, 1, 0], [1, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0]]]], dtype=torch.float32)
targets2 = torch.tensor([[[[1, 1, 1, 0], [1, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0]]]], dtype=torch.float32)
targets.shape

torch.Size([1, 1, 4, 4])

In [40]:
dice = dice_metric(targets.int(), targets2.int())
iou = iou_metric(targets.int(), targets2.int())
print(f'Dice: {dice:.2f}')
print(f'IoU: {iou:.2f}')

Dice: 1.00
IoU: 1.00


In [56]:
image_dir = '../datasets/Brain_tumor_segmentation/manual_test'
mask_dir = '../datasets/Brain_tumor_segmentation/manual_test_masks'
dataset = SegmentationDataset(image_dir, mask_dir)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
image, mask = next(iter(dataloader))
mask.shape

torch.Size([1, 1, 256, 256])

In [57]:
dice = dice_metric(mask.int(), mask.int())
iou = iou_metric(mask.int(), mask.int())
print(f'Dice: {dice:.2f}')
print(f'IoU: {iou:.2f}')

Dice: 1.00
IoU: 1.00


In [58]:
from unet import UNet
model = UNet(in_channels=3, out_channels=1)

In [59]:
model.load_state_dict(torch.load("./models/unet_14_03.pth", weights_only=True))

<All keys matched successfully>

In [60]:
model.eval()
output = model(image)
output = torch.sigmoid(output)

In [147]:
dice = dice_metric(output.int(), mask.int())
iou = iou_metric(output, mask.int())
print(f'Dice: {dice:.2f}')
print(f'IoU: {iou:.4f}')

Dice: 0.98
IoU: 0.8428


In [150]:
def dice_score(pred, targets, threshold=0.5, eps=1e-7):
    pred = (pred > threshold).float()
    targets = targets.float()
    intersection = (pred * targets).sum()
    total_sum = pred.sum() + targets.sum() + eps
    return (2. * intersection + eps) / total_sum
dice_score(output, mask).item()

0.9428561925888062

In [129]:
def iou_score(pred, targets, threshold=0.5):
    pred = (pred > threshold).float()
    targets = targets.float()
    intersect = (pred * targets).sum()
    union = pred.sum() + targets.sum() - intersect
    return intersect/union
print(iou_score(output, mask).item())

0.8918901681900024


In [138]:
targets_1 = torch.tensor([1, 1, 0.5, 0], dtype=torch.float32)
target_2 = torch.tensor([0, 0, 1, 1], dtype=torch.float32)
print(iou_score(targets_1, target_2).item())

0.0


In [179]:
def precision_score(pred, targets, threshold=0.5):
    pred = (pred > threshold).float()
    intersect = (pred * targets).sum()
    total_pixel_pred = pred.sum()
    precision = intersect/total_pixel_pred
    return precision.item()

precision_score(output, mask)

0.9131737947463989

In [180]:
def recall_score(pred, targets, threshold=0.5):
    pred = (pred > threshold).float()
    intersect = (pred*targets).sum()
    total_pixel_truth = targets.sum()
    recall = intersect/total_pixel_truth
    return recall.item()
recall_score(output, mask)

0.9745330214500427

In [None]:
|