In [3]:
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
import cv2
import scipy
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from collections import defaultdict
import time
import datetime

import matplotlib.pyplot as plt

In [4]:
class ImageDataset(Dataset):
    def __init__(self, image_names, transform, img_size=False):
        self.image_names = image_names
        self.transform = transform
        self.img_size = img_size

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = cv2.imread(f'prima/{self.image_names[idx]}.tif')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = scipy.sparse.load_npz(f'prima/mask_{self.image_names[idx]}.npz').todense()
        
        transformed = self.transform(image=img, mask=mask)
        
        if self.img_size:
            return transformed["image"], transformed["mask"], img.shape[:2]
        return transformed["image"], transformed["mask"]

In [6]:
test_files = {'00008228',
 '00322597',
 '00008338',
 '00008064',
 '00322598',
 '00325451',
 '00008142',
 '00325452',
 '00008154',
 '00008342'}

In [None]:
model = torch.load('')

In [16]:
val_transform = A.Compose(
    [
         A.PadIfNeeded(min_height=10296, min_width=7020),
         A.Resize(396, 264),
         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
         ToTensorV2()
    ]
)

In [None]:
ds = ImageDataset(list(test_files), val_transform, img_size=True)

In [None]:
model.eval()
    
with torch.no_grad():
    for i in range(len(ds)):
        img, mask, (h, w) = ds[i]
#        img = img.cuda()
#        mask = mask.cuda()

        pred = model(img)
        
        probabilities = torch.sigmoid(pred.squeeze(1))
        predicted_masks = (probabilities >= 0.5).float().detach().numpy()

        aug = A.Compose(
            [
                A.Resize(10296, 7020),
                A.CenterCrop(h, w)
            ]
        )
        
        original = aug(image=img, mask=pred)
        
        sparse_mask = sparse.csr_matrix(original['mask'])
        scipy.sparse.save_npz(f'prima/pred_mask_{file_name}.npz', sparse_mask)