In [1]:
import os

import numpy as np
import numpy.testing as npt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg13, VGG13_Weights
from torchvision import transforms

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import matplotlib_inline
import matplotlib.pyplot as plt

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

In [2]:
import cv2 as cv

class PhotosDataset(Dataset):
    def __init__(self, images_dir, transforms=None):
        """
        Arguments
        ---------
        images_dir : str
            Path to directory with images
            
        target_dir : str
            Path to directory with masks.
            Each mask corresponds to one image.
            Corresponding mask and image have the same name, but different format.
            
        transforms : some collection
            Sequence of transformations for images and masks. 
        """
        self.images_dir = images_dir
        self.transforms = transforms

        self.names = [os.path.splitext(e) for e in os.listdir(images_dir)]
        self.names = [e[0] for e in self.names if e[1] == '.bmp']
        
        
        
    def __len__(self):
        return len(self.names)
                   
    def __getitem__(self, idx):
        """
        Arguments
        ---------
        idx : int
            Index of image and mask
        
        Returns
        -------
        (image, mask)
        """
        name = self.names[idx]
        image = cv.imread(os.path.join(self.images_dir, name + '.bmp'))
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = cv.imread(os.path.join(self.images_dir, name + '.png'))
        mask = (cv.cvtColor(mask, cv.COLOR_BGR2GRAY) / 255)[..., None]

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        return image, mask 

In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    # A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
    A.RandomCrop(width=512, height=512),
    A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Normalize(),
    ToTensorV2(transpose_mask=True)
])

transform_test = A.Compose([
    A.Normalize(),
    ToTensorV2(transpose_mask=True)
])

In [None]:
train_dataset = PhotosDataset(
    images_dir='img',
    transforms=transform_train
)
test_dataset = PhotosDataset(
    images_dir='img',
    transforms=transform_test
)

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=2, num_workers=0, shuffle=True, drop_last=True)
test_data_loader = DataLoader(test_dataset, batch_size=2, num_workers=0, shuffle=False, drop_last=False)

In [None]:
class VGG13Encoder(torch.nn.Module):
    def __init__(self, num_blocks, weights=VGG13_Weights.DEFAULT):
        super().__init__()
        self.num_blocks = num_blocks
        
        # Будем использовать предобученную VGG13 в качестве backbone
        feature_extractor = vgg13(weights=weights).features
        
        # Каждый блок энкодера U-Net — это блок VGG13 без MaxPool2d
        self.blocks = torch.nn.ModuleList()
        for idx in range(self.num_blocks):
            # Возьмите нужные слои из `feature_extractor` для очередного U-Net блока
            # Объедините их с помощью `torch.nn.Sequential`
            self.blocks.append(
               feature_extractor[5 * idx:5 * idx + 4] 
            )

    def forward(self, x):
        activations = []
        for idx, block in enumerate(self.blocks):
            # Примените очередной блок U-Net
            # your code here
            x = block(x)

            # Сохраните активации для передачи их в декодер
            # your code here
            activations.append(x)

            # При необходимости примените max-pool
            # Можно использовать `torch.functional.F.max_pool2d`
            # your code here
            x = torch.functional.F.max_pool2d(x, kernel_size=2, stride=2)
            
        return activations

class DecoderBlock(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.upconv = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv1 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.relu = torch.nn.ReLU()
        
    def forward(self, down, left):
        # Upsample x2 и свёртка
        # your code here
        x = self.upconv(torch.nn.functional.interpolate(down, scale_factor=2))
        
        # Конкатенация выхода энкодера и предыдущего блока декодера
        # your code here
        x = x + left
        
        # Две свёртки с ReLu
        # your code here
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

class Decoder(torch.nn.Module):
    def __init__(self, num_filters, num_blocks):
        super().__init__()

        self.blocks = torch.nn.ModuleList()
        for idx in range(num_blocks):
            self.blocks.insert(0, DecoderBlock(num_filters * 2 ** idx))   

    def forward(self, acts):
        up = acts[-1]
        for block, left in zip(self.blocks, acts[-2::-1]):
            up = block(up, left)
        return up

class LinkNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_blocks=4):
        super().__init__()
        # your code here
        self.encoder = VGG13Encoder(num_blocks)
        
        # your code here
        self.decoder = Decoder(64, num_blocks - 1)
        
        # Свёртка 1x1 для попиксельной агрегации каналов
        # your code here
        self.final = torch.nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        # your code here
        x = self.encoder(x)
        x = self.decoder(x)
        x = self.final(x)

        return x

In [None]:
class IoUScore(torch.nn.Module):
    def __init__(self, threshold, reduction=None):
        """
        Arguments
        ---------
        threshold : float
            threshold for logits binarization
        reduction : Optional[str] (None, 'mean' or 'sum')
            specifies the reduction to apply to the output:
            
            None: no reduction will be applied
            'mean': the sum of the output will be divided by the number of elements in the batch
            'sum':  the output will be summed. 
        with_logits : bool
            If True, use additional sigmoid for inputs
        """
        super().__init__()
        
        self.threshold = threshold
        self.reduction = reduction
        
    @torch.no_grad()
    def forward(self, logits, true_labels):
        """
        Arguments
        ---------
        logits: torch.Tensor
            Unnormalized probability of true class. Shape: [B, ...]
        true_labels: torch.Tensor
            Mask of correct predictions. Shape: [B, ...]
        Returns
        -------
        torch.Tensor
            If reduction is 'mean' or 'sum' returns a tensor with a single element
            Otherwise, returns a tensor of shape [B]
        """
        # your code here
        dims = list(range(1, logits.ndim))
        logits = logits > self.threshold

        score = torch.sum(logits * true_labels, dim=dims) / torch.sum(logits + true_labels - logits * true_labels, dim=dims)
        
        if self.reduction == 'sum':
            # your code here
            score = torch.sum(score)
        elif self.reduction == 'mean':
            # your code here
            score = torch.mean(score)
            
        return score

In [None]:
class DiceLoss(torch.nn.Module):
    def __init__(self, eps=1e-7, reduction=None, with_logits=True):
        """
        Arguments
        ---------
        eps : float
            eps in denominator
        reduction : Optional[str] (None, 'mean' or 'sum')
            specifies the reduction to apply to the output:
            
            None: no reduction will be applied
            'mean': the sum of the output will be divided by the number of elements in the batch
            'sum':  the output will be summed. 
        with_logits : bool
            If True, use additional sigmoid for inputs
        """
        super().__init__()
        self.eps = eps
        self.reduction = reduction
        self.with_logits = with_logits
        
    def forward(self, logits, true_labels):
        """
        Arguments
        ---------
        logits: torch.Tensor
            Unnormalized probability of true class. Shape: [B, ...]
        true_labels: torch.Tensor
            Mask of correct predictions. Shape: [B, ...]
        Returns
        -------
        torch.Tensor
            If reduction is 'mean' or 'sum' returns a tensor with a single element
            Otherwise, returns a tensor of shape [B]
        """
        true_labels = true_labels.to(torch.long)
        
        if self.with_logits:
            logits = torch.sigmoid(logits)

        dims = list(range(1, logits.ndim))

        loss_value = 1 - 2 * torch.sum(logits * true_labels, dim=dims) / torch.sum(logits + true_labels + self.eps, dim=dims)
        
        if self.reduction == 'sum':
            loss_value = torch.sum(loss_value)
        elif self.reduction == 'mean':
            # your code here
            loss_value = torch.mean(loss_value)
        elif self.reduction is None:
            pass
        return loss_value

In [None]:
def generate_plot(image, mask, logits):
    fig, axs = plt.subplots(1, 4, figsize=(12, 6))
    
    image = image.permute(1, 2, 0).numpy()
    image = (image * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)

    mask = mask.squeeze(0).numpy()
    logits = logits.squeeze(0).numpy()
    logits_binarized = logits > 0.0

    axs[0].imshow(image)
    axs[1].imshow(mask, cmap='gray')
    axs[2].imshow(logits, cmap='gray')
    axs[3].imshow(logits_binarized, cmap='gray')

    axs[0].axis('off')
    axs[1].axis('off')
    axs[2].axis('off')
    axs[3].axis('off')

    axs[0].set_title('Original')
    axs[1].set_title('True mask')
    axs[2].set_title('Pred mask')
    axs[3].set_title('Binarized mask')

    fig.tight_layout()

    return fig

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

class Runner:
    def __init__(
        self,
        model,
        n_epochs,
        loss,
        optimizer,
        train_loader,
        test_loader,
        writer,
        metrics=None,
        logging_interval=1.0,
        scheduler=None,
        model_name='model',
        log_images=None,
        image_log_interval=10
    ):
        self.model = model
        self.n_epochs = n_epochs
        self.loss = loss
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.metrics = metrics
        self.writer = writer
        self.logging_interval = logging_interval
        self.scheduler = scheduler
        self.model_name = model_name
        self.log_images = log_images
        
        self.global_step = 0
        self.activations = []

        if self.metrics is None:
            self.metrics = dict()

        self.image_log_interval = image_log_interval
        self.image_log_idx = 0
        
    
    def train(self):
        logging_batch = int(self.logging_interval * len(self.train_loader))
    
        running_loss = 0
        running_metrics = {metric_name: 0 for metric_name in self.metrics}
    
        self.model.train()  # switch network submodules to train mode, e.g. it influences on batch-norm, dropout
        for i, (images, labels) in enumerate(self.train_loader):
            images, labels = images.to(device), labels.to(device)   # send data to device
            self.optimizer.zero_grad()   # zero out grads, collected from previous batch
            logits = self.model(images)  # forward pass
            loss = self.loss(logits, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss

            for metric_name in self.metrics:
                running_metrics[metric_name] += self.metrics[metric_name](logits, labels)
    
            if i % logging_batch == logging_batch - 1:
                self.global_step += 1
                self.writer.add_scalar("Loss/train", running_loss / logging_batch, global_step=self.global_step)
                running_loss = 0
                
                for metric_name in self.metrics:
                    self.writer.add_scalar(f"{metric_name}/train", running_metrics[metric_name] / logging_batch, global_step=self.global_step)
                    running_metrics[metric_name] = 0
    
        
    
    @torch.no_grad()
    def test(self):
        """calculate loss and accuracy on validation data"""
        running_loss = 0
        running_metrics = {metric_name: 0 for metric_name in self.metrics}

        log_idx = 0
        self.image_log_idx = (self.image_log_idx + 1) % self.image_log_interval
        self.model.eval()  # switch network submodules to test mode
        for images, labels in self.test_loader:
            images, labels = images.to(device), labels.to(device)
            logits = self.model(images)
            running_loss += self.loss(logits, labels)
            for metric_name in self.metrics:
                running_metrics[metric_name] += self.metrics[metric_name](logits, labels)

            if self.image_log_idx == 0:
                for image, label, logit in zip(images, labels, logits):
                    fig = generate_plot(image.cpu(), label.cpu(), logit.detach().cpu())
                    self.writer.add_figure(f"Images/{log_idx}", fig, global_step=self.global_step)
                    log_idx += 1

        l = len(self.test_loader)
    
        self.writer.add_scalar("Loss/test", running_loss / l, global_step=self.global_step)
        for metric_name in self.metrics:
            self.writer.add_scalar(f"{metric_name}/test", running_metrics[metric_name] / l, global_step=self.global_step)

        return running_loss, running_metrics
    
    def run(self):
        """full cycle of neural network training"""

        min_loss, min_metrics = self.test()
            
        for epoch in tqdm(range(self.n_epochs)):
            self.train()
            loss, metrics = self.test()
    
            if self.scheduler is not None:
                self.scheduler.step(loss)
                lr = self.scheduler.get_last_lr()[0]
                self.writer.add_scalar("Learning rate", lr, global_step=self.global_step)
    
            if loss < min_loss:
                min_loss = loss
                min_metrics = metrics
                torch.save(self.model.state_dict(), f"{self.model_name}_model.pt")
                torch.save(self.optimizer.state_dict(), f"{self.model_name}_optim.pt")
                
    
        return min_loss, min_metrics

In [11]:
model = LinkNet(num_classes=1, num_blocks=3)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10)
loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean')
writer = SummaryWriter('runs/LinkNet')

metrics = {'Dice loss': DiceLoss(reduction='mean'), 'IoU': IoUScore(threshold=0.0, reduction='mean')}

runner = Runner(
    model,
    100,
    loss_fn,
    optimizer,
    train_data_loader,
    test_data_loader,
    writer,
    logging_interval=1.0,
    scheduler=scheduler,
    metrics=metrics,
    model_name="LinkNet",
) 

loss_res, metrics_res = runner.run()

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [04:37<00:00,  2.78s/it]


In [12]:
torch.save(model.state_dict(), f"LinkNet_model.pt")