In [1]:
#imports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torchvision.io import read_image, ImageReadMode
import segmentation_models_pytorch as smp
import time
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


ValueError: mutable default <class 'timm.models.maxxvit.MaxxVitConvCfg'> for field conv_cfg is not allowed: use default_factory

In [2]:
#constants
labels = "labels"
images = "images"
channels = 3
rate_learning = 1e-3
epochs = 300
bs = 10
k_prop = 0.8
wd = 1e-3
h  = 256
w = 192
loss_list = []

total_acc_train = np.zeros(epochs)
total_acc_test = np.zeros(epochs)
total_dice_train = np.zeros(epochs)
total_dice_test = np.zeros(epochs)
#cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
#create model
def model_create():
    accuracy = dice_score
    model = smp.Unet(encoder_name='efficientnet-b1', in_channels=channels, classes=1, activation=None).to(device)
    loss = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=rate_learning, weight_decay=wd)
    return model, optimizer, loss, accuracy

In [4]:
class SegmentationDataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.images  = os.listdir(images)
        self.labels = os.listdir(labels)
        
        self.transform  = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path    = os.path.join(images, self.images[index])
        mask_path   = os.path.join(labels, self.images[index])
        image        = (read_image(img_path))
        label        = (read_image(mask_path, mode=ImageReadMode.GRAY))
        
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
            label = torch.cat([label], dim=0)
        
        return image, label

In [5]:
#splitting
def splitting(data):
    length = data.__len__()
    test_length = length - int(k_prop * length)
    train_length = int(k_prop * length)

    (train, test) = torch.utils.data.random_split(data, [train_length, test_length])
    return train, test


# dataload function
def dataload(train, test):
    train_loader = DataLoader(train, batch_size=bs, shuffle=True)
    test_loader = DataLoader(test, batch_size=bs, shuffle=True )
    return train_loader, test_loader

In [6]:
#pixel accuracy
def dice_score(preds, labels):
    dice_score  = 0
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    dice_score += (2 * (preds * labels).sum())/((preds + labels).sum() + 1e-7)
    return dice_score
    

In [7]:
def run_model(model, optim, trainloader, testloader, loss_func, accuracy, cnt):
    train_acc = []
    train_dice = []
    
    test_dice = []
    test_acc = []
    
    for data, labels in trainloader:
        model.train()
        data, labels = data.to(device), labels.to(device)
        optim.zero_grad()
        target = model(data)

        dice = accuracy(target, labels)
        
        loss = loss_func(target, labels)
        
        loss.backward()
        optim.step()
        
        loss_list.append(loss.item())
        
        train_dice.append(dice.cpu())
        
    
    train_dice = (np.array(train_dice)).mean()
    
    transformer = transforms.ToPILImage()
    
    sigma = nn.Sigmoid()
    
    for data, labels in testloader:
        model.eval()
        data, labels = data.to(device), labels.to(device)

        target = model(data)
    
        target = sigma(target)
        if (cnt%100==0): 
            for i in range((np.array(target.shape))[0]):
                temp_target = transformer(target[i])
                temp_labels = transformer(labels[i])
                temp_target.show()
                temp_labels.show()
    
    cnt += 1 

    return model, train_dice, cnt

In [9]:
def run():
    
    train_transform = transforms.Compose([
    transforms.ConvertImageDtype(dtype=torch.float32),
    transforms.Resize([h, w]),
    transforms.Normalize(mean=[0.5], std=[0.25]),
    transforms.accimage
    ])
    
    test_transform = transforms.Compose([
    transforms.ConvertImageDtype(dtype=torch.float32),
    transforms.Resize([h, w]),
    transforms.Normalize(mean=[0.5], std=[0.25])
    ])
    
    nails = SegmentationDataset(train_transform, test_transform)
    
    (train, test) = splitting(nails)
    
    train_loader, test_loader = dataload(train, test)

    model, optimizer, loss, dice = model_create()

    total_acc = np.zeros(epochs)
    class_total_acc = np.zeros(epochs)
    cnt = 0
    for i in range(epochs):
        model, train_dice, cnt = run_model(model, optimizer, train_loader, test_loader, loss, dice, cnt)
    print(0)
        

In [None]:
run()