In [52]:
#imports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torchvision.io import read_image, ImageReadMode
import segmentation_models_pytorch as smp

In [53]:
#constants
labels = "labels"
images = "images"
channels = 3
rate_learning = 1e-3
epochs = 5000
bs = 10
k_prop = 0.8
wd = 1e-3
h  = 256
w = 192
loss_list = []

total_acc_train = np.zeros(epochs)
total_acc_test = np.zeros(epochs)
total_dice_train = np.zeros(epochs)
total_dice_test = np.zeros(epochs)
#cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [54]:
#create model
def model_create():
    accuracy = pixel_accuracy
    model = smp.Unet(encoder_name='resnet34', in_channels=channels, classes=1, activation=None).to(device)
    loss = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=rate_learning, weight_decay=wd)
    return model, optimizer, loss, accuracy

In [55]:
class SegmentationDataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.images  = os.listdir(images)
        self.labels = os.listdir(labels)
        
        self.transform  = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path    = os.path.join(images, self.images[index])
        mask_path   = os.path.join(labels, self.images[index])
        image        = (read_image(img_path))
        label        = (read_image(mask_path, mode=ImageReadMode.GRAY))
        
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
            label = torch.cat([label], dim=0)
        
        return image, label

In [56]:
#splitting
def splitting(data):
    length = data.__len__()
    test_length = length - int(k_prop * length)
    train_length = int(k_prop * length)

    (train, test) = torch.utils.data.random_split(data, [train_length, test_length])
    return train, test


# dataload function
def dataload(train, test):
    train_loader = DataLoader(train, batch_size=bs, shuffle=True)
    test_loader = DataLoader(test, batch_size=bs, shuffle=True )
    return train_loader, test_loader

In [57]:
#pixel accuracy
def pixel_accuracy(preds, labels):
    num_correct = 0
    num_pixels  = 0
    dice_score  = 0
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    num_correct += (preds == labels).sum()
    num_pixels += torch.numel(preds)
    dice_score += (2 * (preds * labels).sum())/((preds + labels).sum() + 1e-7)
    return num_correct/num_pixels, dice_score
    

In [58]:
def run_model(model, optim, trainloader, testloader, loss_func, accuracy):
    train_acc = []
    train_dice = []
    
    test_dice = []
    test_acc = []
    
    for data, labels in trainloader:
        model.train()
        data, labels = data.to(device), labels.to(device)
        optim.zero_grad()
        target = model(data)

        acc, dice = accuracy(target, labels)
        
        loss = loss_func(target, labels)
        
        loss.backward()
        optim.step()
        
        loss_list.append(loss.item())
        
        train_acc.append(acc.cpu())
        
        train_dice.append(dice.cpu())
        

    train_acc = (np.array(train_acc)).mean()
    
    train_dice = (np.array(train_dice)).mean()

    for data, labels in testloader:
        model.eval()
        data, labels = data.to(device), labels.to(device)

        target = model(data)
        
        acc, dice = accuracy(target, labels)

        test_acc.append(acc.cpu())
        
        test_dice.append(dice.cpu())

    test_acc = (np.array(test_acc)).mean()
    
    test_dice = (np.array(test_dice)).mean()

    return model, test_acc, train_acc, train_dice, test_dice

In [59]:
def run():
    
    train_transform = transforms.Compose([
    transforms.ConvertImageDtype(dtype=torch.float32),
    transforms.Resize([h, w]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Normalize(mean=[0.5], std=[0.25])
    ])
    
    test_transform = transforms.Compose([
    transforms.ConvertImageDtype(dtype=torch.float32),
    transforms.Resize([h, w]),
    ])
    
    nails = SegmentationDataset(train_transform, test_transform)
    
    (train, test) = splitting(nails)
    
    train_loader, test_loader = dataload(train, test)

    model, optimizer, loss, acc = model_create()

    total_acc = np.zeros(epochs)
    class_total_acc = np.zeros(epochs)
    for i in range(epochs):
        model, test_acc, train_acc, train_dice, test_dice = run_model(model, optimizer, train_loader, test_loader, loss, acc)
        print(i, test_acc, train_acc, train_dice, test_dice)
        total_acc_train[i] = train_acc
        total_acc_test[i] = test_acc
        total_dice_train[i] = train_dice
        total_dice_test[i] = test_dice
        

In [None]:
if __name__=="__main__":
    run()

0 0.046752933 0.14076702 0.075209334 0.0817077
1 0.12455954 0.40323204 0.08262907 0.2237002
2 0.09913026 0.75579184 0.08667944 0.12822677
3 0.35000917 0.89999515 0.15103476 0.091440275
4 0.63190514 0.9266337 0.13243793 0.0996043
5 0.91302395 0.9304908 0.14334193 0.068811245
6 0.92129517 0.93237513 0.040005736 0.015035783
7 0.89417326 0.9459311 0.032233186 5.9371033e-08
8 0.9519328 0.94825 0.013098863 0.0038018161
9 0.92818 0.956211 0.015192983 0.0006762561
10 0.92552596 0.9322758 0.0069163926 0.00042549934
11 0.9020946 0.9407194 0.0072894716 0.0
12 0.95336 0.95800257 0.0018264897 5.266209e-05
13 0.95336 0.9403756 0.0025400205 5.2662082e-05
14 0.9255992 0.940647 0.05762433 0.00019806773
15 0.8953156 0.9582903 0.048884086 6.214954e-05
16 0.9255992 0.95954436 0.0011509255 0.0
17 0.92803955 0.9573165 0.0019332909 0.00425291
18 0.93112695 0.9553031 0.0035524103 0.002812219
19 0.9287486 0.9352161 0.023338798 0.12434309
20 0.89912415 0.95482063 0.37852898 0.20996657
21 0.9470856 0.9604977 0.3

In [None]:
plt.plot(total_acc_train)
plt.plot(total_acc_test)
plt.show()

In [None]:
plt.plot(total_dice_train)
plt.plot(total_dice_test)
plt.show()