In [None]:
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR

In [None]:
class TwoConvLayers(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = TwoConvLayers(in_channels=in_channels, out_channels=out_channels)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block(x)
        y = self.max_pool(x)
        return y, x

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.transpose = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
        self.block = TwoConvLayers(in_channels=in_channels, out_channels=out_channels)

    def forward(self, x, y):
        x = self.transpose(x)
        u = torch.cat([x, y], dim=1)
        u = self.block(u)
        return u

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1):
        super().__init__()
        self.enc_block1 = Encoder(in_channels=in_channels, out_channels=64)
        self.enc_block2 = Encoder(in_channels=64, out_channels=128)
        self.enc_block3 = Encoder(in_channels=128, out_channels=256)
        self.enc_block4 = Encoder(in_channels=256, out_channels=512)

        self.bottleneck = TwoConvLayers(in_channels=512, out_channels=1024)

        self.dec_block1 = Decoder(in_channels=1024, out_channels=512)
        self.dec_block2 = Decoder(in_channels=512, out_channels=256)
        self.dec_block3 = Decoder(in_channels=256, out_channels=128)
        self.dec_block4 = Decoder(in_channels=128, out_channels=64)

        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        x, y1 = self.enc_block1(x)
        x, y2 = self.enc_block2(x)
        x, y3 = self.enc_block3(x)
        x, y4 = self.enc_block4(x)

        x = self.bottleneck(x)

        x = self.dec_block1(x, y4)
        x = self.dec_block2(x, y3)
        x = self.dec_block3(x, y2)
        x = self.dec_block4(x, y1)

        return self.out(x)

In [None]:
class NailsImageDataset(Dataset):
    def __init__(self, folder_path, valid_images, transform_img=None):
        self.folder_path = folder_path
        self.valid_images = valid_images
        self.transform_img = transform_img

        images_path = os.path.join(self.folder_path, 'images')
        masks_path = os.path.join(self.folder_path, 'labels')
        self.images = [os.path.join(images_path, image) for image in self.valid_images]
        self.masks = [os.path.join(masks_path, image) for image in self.valid_images]

    def __getitem__(self, index):
        image_path, mask_path = self.images[index], self.masks[index]
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform_img is not None:
            seed = np.random.randint(2147483647)
            
            torch.manual_seed(seed)
            image = self.transform_img(image)
            
            torch.manual_seed(seed)
            mask = self.transform_img(mask)

        return image, mask

    def __len__(self):
        return len(self.valid_images)



In [None]:


train_transformations = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomAffine(10, shear=(-5,5)),
    transforms.ToTensor()])

test_transformations = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()])



In [None]:
data_folder = '/home/matv864/it/AI_work/data/nails'
images = os.listdir(os.path.join(data_folder, 'images'))
random.shuffle(images)
split_index = int(len(images) * 0.15)
# increase the sample due to augmentations, 
# i.e. random augmentations will be applied to the data, 
# so repetitions are practically excluded
train_images = images[split_index:]*2
val_images = images[:split_index]*2

In [None]:
batch_size = 4

val_dataset = NailsImageDataset(data_folder, val_images, transform_img=train_transformations)
train_dataset = NailsImageDataset(data_folder, train_images, transform_img=train_transformations)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def visualize_train_data(image, mask):
    if mask.dim() == 4:
        mask = mask[0] 
        image = image[0]

    if mask.dim() == 3 and mask.size(0) == 1:
        mask = mask.squeeze(0) 

    image_np = image.permute(1, 2, 0).cpu().numpy()
    mask_np = mask.cpu().numpy()
    image_np = np.clip(image_np, 0, 1)

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.imshow(image_np)
    plt.title('Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(mask_np, cmap='gray')
    plt.title('Mask')
    plt.axis('off')

    plt.show()

In [None]:
for image, mask in val_loader:
    visualize_train_data(image, mask)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth_coef=1):
        super().__init__()
        self.smooth_coef = smooth_coef

    def forward(self, logits, targets):
        num = targets.size(0)
        probs = nn.functional.sigmoid(logits)
        x = probs.view(num, -1)
        y = targets.view(num, -1)
        intersection = x*y
        sum_count_of_pixels = x.sum(1) + y.sum(1)
        score = 2 * (intersection.sum(1) + self.smooth_coef) / (sum_count_of_pixels + self.smooth_coef)
        return 1 - score.sum() / num

In [None]:
criterion_1 = nn.BCEWithLogitsLoss()
criterion_2 = DiceLoss()
EPOCH = 40
lr_rate = 0.0001
THRESHOLD = 0.5
device = 'cuda:0'
best_loss = float('inf')
model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)
scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

In [None]:
for epoch in range(EPOCH):
    print('\nEpoch {}/{}'.format(epoch + 1, EPOCH))
    print('=' * 100)
    model.train()
    total_train_loss = 0.0
    correct_train = 0  
    total_train_pixels = 0  
    for image, mask in train_loader:
        image, mask = image.to(device), mask.to(device)
        optimizer.zero_grad()
        outputs = model(image)

        train_loss = criterion_1(outputs, mask) + criterion_2(outputs, mask)
        total_train_loss += train_loss.item()

        pred = (outputs > 0.5).float()

        correct_train += (pred == mask).sum().item()
        total_train_pixels += mask.numel() 
        train_loss.backward()
        optimizer.step()

    average_train_loss = total_train_loss / len(train_loader)
    average_train_accuracy = correct_train / total_train_pixels  
    scheduler.step()

    model.eval()
    total_val_loss = 0.0
    correct_val = 0 
    total_val_pixels = 0  

    with torch.no_grad():
        for image, mask in val_loader:
            image, mask = image.to(device), mask.to(device)
            outputs = model(image)

            val_loss = criterion_1(outputs, mask) + criterion_2(outputs, mask)
            total_val_loss += val_loss.item()

            pred = (outputs > 0.5).float()

            correct_val += (pred == mask).sum().item()
            total_val_pixels += mask.numel() 

        average_val_loss = total_val_loss / len(val_loader)
        average_val_accuracy = correct_val / total_val_pixels 

        if average_val_loss < best_loss:
            best_loss = average_val_loss
            torch.save(model.state_dict(), 'best_loss_unet.pt')

    # In ra loss và accuracy
    print(f"Train Loss: {average_train_loss:.4f}, Train Accuracy: {average_train_accuracy:.4f}")
