In [3]:
%cd ..
!pwd

/Users/andrew/courses/MIPT_MLOps/mb_opc
/Users/andrew/courses/MIPT_MLOps/mb_opc


In [18]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp

import cv2
import os
import numpy as np
import random

from src.dataset import OPCDataset, apply_transform
from src.losses import BoundaryLoss, TVLoss, ContourLoss, IouLoss
from src.metrics import PixelAccuracy, IoU
from src.utils import set_random_seed, load_model

In [7]:
DATASET_PATH = 'data/processed/gds_dataset'
BATCH_SIZE = 1
MODEL_TYPE = 'manet'

TEST_DATASET = OPCDataset(
    os.path.join(DATASET_PATH, "origin/test_origin/"),
    os.path.join(DATASET_PATH, "correction/test_correction/"),
    transform=apply_transform(binarize_flag=True),
)

TEST_LOADER = DataLoader(
    TEST_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=8
)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available() and MODEL_TYPE not in ["cfno", "pspnet"]:
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print(f"Using: {DEVICE}")

Using: mps


In [11]:
model = load_model(model_type = MODEL_TYPE, 
                   weights_path = 'checkpoints/manet_checkpoint', 
                   device = DEVICE)

Loading weights from local directory
Model 'manet' loaded.
Total parameters: 31777361


In [13]:
tv_loss = TVLoss(weight=1.0)
contour_loss = ContourLoss(weight=1.0, device=DEVICE)
mae_loss = torch.nn.L1Loss()
iou_loss = IouLoss(weight=1.0)
pixel_acc = PixelAccuracy()
bce = torch.nn.BCELoss()
boundary_loss = BoundaryLoss(device=DEVICE)

In [12]:
image, target = next(iter(TEST_LOADER))
image, target = image.to(DEVICE), target.to(DEVICE)
print(f'Image shape: {image.shape}')
print(f'Target shape: {target.shape}')

Image shape: torch.Size([1, 1, 1024, 1024])
Target shape: torch.Size([1, 1, 1024, 1024])


In [15]:
with torch.no_grad():
  pred = model(image).sigmoid()

In [23]:
contour_loss_iter = contour_loss(pred, target)
mae_loss_iter = mae_loss(pred, target)
iou_loss_iter = iou_loss(pred, target)
bce_loss = bce(pred, target)
bd_loss = boundary_loss(pred, target)
total_loss = mae_loss_iter + contour_loss_iter + iou_loss_iter

In [24]:
print(f'contour loss:{contour_loss_iter}')
print(f'mse loss:{mae_loss_iter}')
print(f'iou_loss:{iou_loss_iter}')
print(f'bce loss:{bce_loss}')
print(f'Boundary loss:{bd_loss}')

contour loss:0.106756791472435
mse loss:0.0005883198464289308
iou_loss:0.004180252552032471
bce loss:0.002033837605267763
Boundary loss:2.135075092315674


In [19]:
pixel_acc = PixelAccuracy()
iou = IoU()
accuracy = pixel_acc(pred, target)

print(f'Pixel Accuracy: {accuracy}')
print(f'IoU: {iou(pred, target)}')

Pixel Accuracy: 0.9994430541992188
IoU: 0.9958197474479675
