In [6]:
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
import cv2
import scipy
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from collections import defaultdict
import time
import datetime

from sklearn.metrics import f1_score

In [2]:
class ImageDataset(Dataset):
    def __init__(self, image_names, transform, img_size=False):
        self.image_names = image_names
        self.transform = transform
        self.img_size = img_size

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = cv2.imread(f'prima/{self.image_names[idx]}.tif')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = scipy.sparse.load_npz(f'mask/mask_{self.image_names[idx]}.npz').todense()
        
        transformed = self.transform(image=img, mask=mask)
        
        if self.img_size:
            h, w = img.shape[:2]
            return transformed["image"], h, w
        return transformed["image"], transformed["mask"]

In [3]:
test_files = ['00008228',
 '00322597',
 '00008338',
 '00008064',
 '00322598',
 '00325451',
 '00008142',
 '00325452',
 '00008154',
 '00008342']

In [4]:
class Encoder(nn.Module):
    def __init__(self, channels=[3, 40, 60, 120, 160, 240], kernel=5, padding=2, pool_kernel=2):
        super().__init__()
        self.conv1 = nn.Conv2d(channels[0], channels[1], kernel, padding=padding)
        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel, padding=padding)
        self.conv3 = nn.Conv2d(channels[2], channels[3], kernel, padding=padding)
        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel, padding=padding)
        self.conv5 = nn.Conv2d(channels[4], channels[5], kernel, padding=padding)
        
        self.relu = nn.ReLU()
        # setting stride to equal kernel
        self.pool = nn.MaxPool2d(pool_kernel, stride=pool_kernel)
    
    def forward(self, x):
        x = self.conv2(self.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.conv4(self.relu(self.conv3(x)))
        x = self.pool(x)
        
        return self.relu(self.conv5(x))
    
class Decoder(nn.Module):
    def __init__(self, channels=[240, 120, 60, 2, 1], kernel=5, padding=2, mid_kernel=2):
        super().__init__()
        self.deconv1 = nn.ConvTranspose2d(channels[0], channels[1], kernel, padding=padding)
        # kernel and stride to match the pool layer in the encoder
        self.deconv2 = nn.ConvTranspose2d(channels[1], channels[2], mid_kernel, stride=mid_kernel)
        self.deconv3 = nn.ConvTranspose2d(channels[2], channels[3], mid_kernel, stride=mid_kernel)
        # for generating output (out channel is 1 mask is one layer)
        self.deconv4 = nn.ConvTranspose2d(channels[3], channels[4], kernel, padding=padding)
    
        self.relu = nn.ReLU()
      
    def forward(self, x):
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        return self.deconv4(self.deconv3(x)).squeeze()

class WickUnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))

In [5]:
model = torch.load('rotate_noise_epoch_90.pt', map_location=torch.device('cpu'))

In [6]:
val_transform = A.Compose(
    [
         A.PadIfNeeded(min_height=10296, min_width=7020),
         A.Resize(396, 264),
         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
         ToTensorV2()
    ]
)

In [7]:
ds = ImageDataset(list(test_files), val_transform, img_size=True)
val_dl = DataLoader(ds, batch_size = 1, shuffle=False)

In [31]:
model.eval()
    
with torch.no_grad():
    for i,(img, h,w) in enumerate(val_dl):
        pred = model(img)
        
        probabilities = torch.sigmoid(pred)
        predicted_masks = (probabilities >= 0.5).float().detach().numpy()

        aug = A.Compose(
            [
                A.Resize(10296, 7020),
                A.CenterCrop(h, w)
            ]
        )
        del probabilities
        transformed = aug(image=predicted_masks)
        
        sparse_mask = scipy.sparse.csr_matrix(transformed['image'])
        scipy.sparse.save_npz(f'mask/pred_mask_{test_files[i]}.npz', sparse_mask)

In [8]:
f1 = 0
for file in test_files:
    y = scipy.sparse.load_npz(f'prima/mask_{file}.npz').todense()
    pred = scipy.sparse.load_npz(f'Archive/pred_mask_{file}.npz').todense()
    f1 += f1_score(np.array(y).flatten(),np.array(pred).flatten().astype(int))
f1 = f1/len(test_files)    
f1

0.18110487719112883