In [1]:
import torch
import torch.nn as nn

In [146]:
class Encoder(nn.Module):
    def __init__(self, channels=[3, 40, 60, 120, 160, 240], kernel=5, padding=2, pool_kernel=2):
        super().__init__()
        self.conv1 = nn.Conv2d(channels[0], channels[1], kernel, padding=padding)
        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel, padding=padding)
        self.conv3 = nn.Conv2d(channels[2], channels[3], kernel, padding=padding)
        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel, padding=padding)
        self.conv5 = nn.Conv2d(channels[4], channels[5], kernel, padding=padding)
        
        self.relu = nn.ReLU()
        # setting stride to equal kernel
        self.pool = nn.MaxPool2d(pool_kernel, stride=pool_kernel)
    
    def forward(self, x):
        x = self.conv2(self.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.conv4(self.relu(self.conv3(x)))
        x = self.pool(x)
        
        return self.relu(self.conv5(x))
    
class Decoder(nn.Module):
    def __init__(self, channels=[240, 120, 60, 2, 1], kernel=5, padding=2, mid_kernel=2):
        super().__init__()
        self.deconv1 = nn.ConvTranspose2d(channels[0], channels[1], kernel, padding=padding)
        # kernel and stride to match the pool layer in the encoder
        self.deconv2 = nn.ConvTranspose2d(channels[1], channels[2], mid_kernel, stride=mid_kernel)
        self.deconv3 = nn.ConvTranspose2d(channels[2], channels[3], mid_kernel, stride=mid_kernel)
        # for generating output (out channel is 1 mask is one layer)
        self.deconv4 = nn.ConvTranspose2d(channels[3], channels[4], kernel, padding=padding)
    
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        return self.deconv4(self.deconv3(x)).squeeze()

class WickUnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder(self.encoder(x))

# model training

In [128]:
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
import cv2
import scipy
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

In [147]:
class ImageDataset(Dataset):
    def __init__(self, image_names, transform, test=False):
        self.image_names = image_names
        self.transform = transform
        self.test = test

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = cv2.imread(f'prima/{self.image_names[idx]}.tif')
        mask = scipy.sparse.load_npz(f'prima/mask_{self.image_names[idx]}.npz').todense()
        
        transformed = self.transform(image=img, mask=mask)
        
        if self.test:
            img_size = img.shape[:2]
            return transformed["image"], transformed["mask"], img_size
        return transformed["image"], transformed["mask"]

In [148]:
train_transform = A.Compose(
    [
        A.Resize(392, 260),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

val_transform = A.Compose(
    [
         A.Resize(392, 260),
         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
         ToTensorV2()
    ]
)

In [65]:
training_files = ['00008152', '00008153', '00008150', '00008148', '00325453',
'00008063', '00008230', '00322470', '00325454', '00008332',
'00008062', '00008143', '00008344', '00008141', '00008146',
'00008088', '00008346', '00325450', '00008149', '00008147',
'00325448', '00008227', '00008151', '00322599', '00322596',
'00008334', '00008336', '00008089', '00008229', '00322468',
'00008144', '00008340', '00008061', '00008084', '00325449',
'00322471', '00008140', '00322469', '00008086', '00008145']

test_files = ['00008228',
 '00322597',
 '00008338',
 '00008064',
 '00322598',
 '00325451',
 '00008142',
 '00325452',
 '00008154',
 '00008342']

In [137]:
train_files = np.random.choice(training_files, 32, replace=False)
val_files = list(set(training_files) - set(train_files))

In [149]:
train_ds = ImageDataset(train_files, train_transform)
val_ds = ImageDataset(val_files, val_transform)

In [150]:
batch_size = 16

In [151]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
)

In [127]:
def train(model, dataloader, loss_fun, optimizer):
    total_loss = 0
    total_row = 0

    model.train()

    for img, mask in dataloader:
        pred = model(img)
        loss = loss_fun(pred, mask.float())

        total_loss += loss.item() * img.shape[0]
        total_row += img.shape[0]
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return total_loss / total_row

def validate(model, dataloader, loss_fun):
    total_loss = 0
    total_row = 0
    
    model.eval()
    
    with torch.no_grad():
        for img, mask in dataloader:
            pred = model(img)
            loss = loss_fun(pred, mask.float())

            total_loss += loss.item() * img.shape[0]
            total_row += img.shape[0]

    return total_loss / total_row

In [152]:
model = WickUnet()
optimizer = Adam(model.parameters(), lr=0.001)
loss_fun = nn.BCEWithLogitsLoss()

In [None]:
model

In [154]:
tl = train(model, train_dl, loss_fun, optimizer)
vl = validate(model, val_dl, loss_fun)

In [156]:
tl, vl

(0.7293292880058289, 0.6918942332267761)

In [159]:
img = cv2.imread('prima/00008062.tif')

In [161]:
transformed_img = val_transform(image=img)

In [165]:
output = model(transformed_img['image'].unsqueeze(0))

In [245]:
T, thresh = cv2.threshold(output_np, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

error: OpenCV(4.5.3) /private/var/folders/24/8k48jl6d249_n_qfxwsl6xvm0000gn/T/pip-req-build-gt63l4kp/opencv/modules/imgproc/src/thresh.cpp:1557: error: (-2:Unspecified error) in function 'double cv::threshold(cv::InputArray, cv::OutputArray, double, double, int)'
> THRESH_OTSU mode:
>     'src_type == CV_8UC1 || src_type == CV_16UC1'
> where
>     'src_type' is 5 (CV_32FC1)
