In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import segmentation_models_pytorch as smp
import cv2
import os
import numpy as np
import random
from tqdm import tqdm
import copy
from skimage.util import random_noise
import matplotlib.pyplot as plt
from datetime import datetime

In [2]:
import torch
import cv2
import os

from sklearn.model_selection import train_test_split


def default_aug(image, mask):
    return image, mask


def default_preprocessing(image):
    return torch.from_numpy(image).permute(2, 0, 1).float()


class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self,
                 tiles,
                 masks,
                 augmentations,  # Augmentations
                 preprocessing   # Processing data for the model
                 ):
        self.has_mask = masks is not None
        self.tiles = tiles
        self.masks = masks
        self.augmentations = augmentations
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.tiles)

    def get_item(self, idx):
        image = self.tiles[idx]
        mask = None
        if self.has_mask:
            mask = self.masks[idx]
        image, mask = self.augmentations(image, mask)
        if self.has_mask:
            mask = torch.from_numpy(mask).float()
        before_preprocessing = image
        image = self.preprocessing(image)

        return image, mask, before_preprocessing

    def __getitem__(self, idx):
        image, mask, before_preprocessing = self.get_item(idx)
        if self.has_mask:
            return image, mask
        return image


def get_tiles(tile_size,
              path_to_pics,
              path_to_masks=None,  # None if there are no labels
              ):
    # Images loading
    images = []
    masks = None
    for root, dirs, files in os.walk(path_to_pics):
        for file in sorted(files):
            img = cv2.imread(os.path.join(root, file))
            h, w, _ = img.shape

            for h_coord in range(0, h // tile_size):
                for w_coord in range(0, w // tile_size):
                    y = h_coord * tile_size
                    x = w_coord * tile_size
                    images.append(img[y: y + tile_size, x: x + tile_size])

    # Masks loading
    if path_to_masks is not None:
        masks = []
        for root, dirs, files in os.walk(path_to_masks):
            for file in sorted(files):
                mask = cv2.imread(os.path.join(root, file))
                mask = (mask[:, :, 0] > 0).astype('uint8')
                h, w = mask.shape

                for h_coord in range(0, h // tile_size):
                    for w_coord in range(0, w // tile_size):
                        y = h_coord * tile_size
                        x = w_coord * tile_size
                        masks.append(mask[y: y + tile_size, x: x + tile_size])

    return images, masks


def get_datasets(tile_size,
                 path_to_pics,
                 path_to_masks=None,
                 augmentations=default_aug,            # Augmentations
                 preprocessing=default_preprocessing   # Processing data for the model
                 ):
    tiles, masks = get_tiles(tile_size, path_to_pics, path_to_masks)
    train_tiles, val_tiles, train_masks, val_masks = train_test_split(tiles, masks, test_size=0.2, random_state=42)
    train_dataset = SegmentationDataset(train_tiles, train_masks, augmentations, preprocessing)
    val_dataset = SegmentationDataset(val_tiles, val_masks, default_aug, preprocessing)
    return train_dataset, val_dataset


In [3]:
BACKBONE = 'timm-efficientnet-b5'
PRETRAIN = "imagenet"
preprocessing_fn = smp.encoders.get_preprocessing_fn(BACKBONE, pretrained=PRETRAIN)

def preprocessing(image):
    image = preprocessing_fn(image)
    return torch.from_numpy(image).permute(2, 0, 1).float()

def augmentations(image, mask):
    # Random turn
    if random.randint(1, 2) == 1:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
    if random.randint(1, 2) == 1:
        image = cv2.flip(image, 0)
        mask = cv2.flip(mask, 0)

    # Color augmentations
    ch_col = random.randint(80, 120) / 100
    image = image.astype('float64') * ch_col
    image[image > 255] = 255
    image = image.astype('uint8')
    return image, mask

dataset_train, dataset_val = get_datasets(256,
                                          'data/train/images',
                                          'data/train/masks',
                                          augmentations=augmentations,
                                          preprocessing=preprocessing)

In [4]:
dataloader_train = DataLoader(dataset_train, batch_size=14, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=14, shuffle=False)

In [5]:
model = smp.PAN(
    encoder_name=BACKBONE,
    in_channels=3,
    classes=1,
    activation='sigmoid'
)

In [6]:
import torch
import numpy
import wandb
from tqdm import tqdm

TEAM_NAME = "knife_team"
PROJECT_NAME = "building-segmentation"


class SegmentationModel:
    def __init__(self, model: torch.nn.Module, model_name: str = "model", device="cuda"):
        self.device = device
        self.model = model.to(device)
        self.model_name = model_name

    def __call__(self, x):
        return self.model(x)

    def train(self, loss, train_loader, val_loader,
              metrics: dict, optimizer, target_metric="f1",
              epochs=100, wandb_logging=False, path_to_save_model=None,
              verbose=True):

        if wandb_logging:
            run = wandb.init(
                entity=TEAM_NAME,
                project=PROJECT_NAME,
                config={
                    "architecture": self.model_name,
                    "epochs": epochs,
                    "batch_size": train_loader.batch_size,
                    "optimizer": optimizer.__class__.__name__,
                    "loss": loss.__class__.__name__,
                    "lr": optimizer.param_groups[0]['lr'],
                    "target_metric": target_metric,
                }
            )

        best_metric = 0
        train_logs_list, valid_logs_list = [], []

        for epoch in range(1, epochs + 1):
            self.model.train()
            results = {}
            for metric_name, metric in metrics.items():
                results[metric_name] = []
            results['train_loss'] = []

            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            pbar.set_description(f"Epoch {epoch}")

            # Train cycle
            for batch_idx, (data, target) in pbar:
                data, target = data.to(self.device), target.to(self.device).unsqueeze(1)
                optimizer.zero_grad()
                output = self.model(data)
                loss_value = loss(output, target)
                loss_value.backward()
                optimizer.step()

                # Train metrics
                results['train_loss'].append(loss_value.item())
                for metric_name, metric in metrics.items():
                    results[metric_name].append(metric(output, target).detach().cpu().item())

                pbar.set_postfix({
                    "loss": numpy.mean(results['train_loss']),
                })

            results['train_loss'] = numpy.mean(results['train_loss'])
            for metric_name, metric in metrics.items():
                results[metric_name] = numpy.mean(results[metric_name])

            train_logs_list.append(results)
            if verbose:
                print(f'Train logs: {results}')

            results = {}
            for metric_name, metric in metrics.items():
                results[metric_name] = []
            results['val_loss'] = []

            # Validation
            self.model.eval()
            with torch.no_grad():
                for batch_idx, (data, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
                    data, target = data.to(self.device), target.to(self.device).unsqueeze(1)
                    output = self.model(data)
                    loss_value = loss(output, target)
                    
                    results['val_loss'].append(loss_value.item())
                    for metric_name, metric in metrics.items():
                        results[metric_name].append(metric(output, target).detach().cpu().item())

            results['val_loss'] = numpy.mean(results['val_loss'])
            for metric_name, metric in metrics.items():
                results[metric_name] = numpy.mean(results[metric_name])
            valid_logs_list.append(results)

            if verbose:
                print(f'Val logs: {results}')

            # Saving the model
            if results[target_metric] > best_metric:
                print(f'New best model! Val {target_metric}: {results[target_metric]}')
                best_metric = results[target_metric]
                if path_to_save_model is not None:
                    model_path = os.path.join(path_to_save_model, f'{self.model_name}_{round(best_metric, 4)}.pth')
                    torch.save(self.model.state_dict(), model_path)

            if wandb_logging:
                wandb.log({"train_loss": train_logs_list[-1]['train_loss'],
                           "val_loss": valid_logs_list[-1]['val_loss']})

                for metric_name, metric in metrics.items():
                    wandb.log({f"train_{metric_name}": train_logs_list[-1][metric_name],
                               f"val_{metric_name}": valid_logs_list[-1][metric_name]})

        if wandb_logging:
            run.finish()

        return train_logs_list, valid_logs_list

    def predict(self, test_loader):
        self.model.eval()
        with torch.no_grad():
            predictions = []
            for batch_idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
                try:
                    data, target = data
                    data, target = data.to(self.device), target.to(self.device).unsqueeze(1)
                except TypeError:
                    data = data.to(self.device)
                output = self.model(data).detach().cpu().numpy()
                for part in output:
                    predictions.append(part[0])

        return predictions

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [8]:
model_train = SegmentationModel(model, f'PAN-{BACKBONE}')

In [9]:
from segmentation_models_pytorch import utils

f1_score = smp.utils.metrics.Fscore(threshold=0.5)
def my_f1_score(y_pred, y_true):
    score = (f1_score(y_pred, y_true) + f1_score(1 - y_pred, 1 - y_true)) / 2
    return score


def my_dice_loss(p, y):
    loss = 1 - (2 * (p * y).sum() + 1) / (p.sum() + y.sum() + 1)
    return loss


criterion = my_dice_loss
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-3)
metrics = {
    'f1': my_f1_score,
    'iou': smp.utils.metrics.IoU(threshold=0.5)
}
model_train.train(criterion, dataloader_train, dataloader_val, metrics, optimizer, wandb_logging=True,
                 path_to_save_model='baseline_checkpoints/')

[34m[1mwandb[0m: Currently logged in as: [33mpaspasuy[0m ([33mknife_team[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1: 100%|██████████| 692/692 [04:38<00:00,  2.48it/s, loss=0.384]


Train logs: {'f1': 0.794218451807523, 'train_loss': 0.3836157320207254}


100%|██████████| 173/173 [00:20<00:00,  8.50it/s]


Val logs: {'f1': 0.8477157536958684, 'val_loss': 0.29044856811534464}
New best model! Val f1: 0.8477157536958684


Epoch 2: 100%|██████████| 692/692 [04:37<00:00,  2.50it/s, loss=0.283]


Train logs: {'f1': 0.8507938805343098, 'train_loss': 0.2832237013502617}


100%|██████████| 173/173 [00:19<00:00,  8.72it/s]


Val logs: {'f1': 0.860935891294755, 'val_loss': 0.2638193155299721}
New best model! Val f1: 0.860935891294755


Epoch 3: 100%|██████████| 692/692 [04:42<00:00,  2.45it/s, loss=0.258]


Train logs: {'f1': 0.8635163219333384, 'train_loss': 0.25818017822814127}


100%|██████████| 173/173 [00:19<00:00,  9.09it/s]


Val logs: {'f1': 0.866225158892615, 'val_loss': 0.252340308159073}
New best model! Val f1: 0.866225158892615


Epoch 4: 100%|██████████| 692/692 [04:34<00:00,  2.52it/s, loss=0.237]


Train logs: {'f1': 0.874676485454416, 'train_loss': 0.23657636034350865}


100%|██████████| 173/173 [00:19<00:00,  8.68it/s]


Val logs: {'f1': 0.8705851380535633, 'val_loss': 0.2441580932953454}
New best model! Val f1: 0.8705851380535633


Epoch 5: 100%|██████████| 692/692 [04:44<00:00,  2.43it/s, loss=0.223]


Train logs: {'f1': 0.881757160142667, 'train_loss': 0.22312413474727918}


100%|██████████| 173/173 [00:19<00:00,  9.01it/s]


Val logs: {'f1': 0.8711190447641936, 'val_loss': 0.24235271787367804}
New best model! Val f1: 0.8711190447641936


Epoch 6: 100%|██████████| 692/692 [04:30<00:00,  2.56it/s, loss=0.213]


Train logs: {'f1': 0.8871019030409741, 'train_loss': 0.2129389067433473}


100%|██████████| 173/173 [00:19<00:00,  9.00it/s]


Val logs: {'f1': 0.8711944387827305, 'val_loss': 0.2418928697619135}
New best model! Val f1: 0.8711944387827305


Epoch 7: 100%|██████████| 692/692 [04:27<00:00,  2.58it/s, loss=0.202]


Train logs: {'f1': 0.892901382435953, 'train_loss': 0.20181074144178732}


100%|██████████| 173/173 [00:19<00:00,  9.01it/s]


Val logs: {'f1': 0.8744183237842053, 'val_loss': 0.23590297058138543}
New best model! Val f1: 0.8744183237842053


Epoch 8: 100%|██████████| 692/692 [04:27<00:00,  2.58it/s, loss=0.194]


Train logs: {'f1': 0.8967357943685068, 'train_loss': 0.19444798045075698}


100%|██████████| 173/173 [00:19<00:00,  9.09it/s]


Val logs: {'f1': 0.8727697109900459, 'val_loss': 0.2386054220916219}


Epoch 9: 100%|██████████| 692/692 [04:36<00:00,  2.50it/s, loss=0.188]


Train logs: {'f1': 0.9000685911819425, 'train_loss': 0.1881779125660141}


100%|██████████| 173/173 [00:20<00:00,  8.63it/s]


Val logs: {'f1': 0.8740775233748331, 'val_loss': 0.2367624388953854}


Epoch 10: 100%|██████████| 692/692 [04:36<00:00,  2.50it/s, loss=0.175]


Train logs: {'f1': 0.9070152295807193, 'train_loss': 0.17499532511813104}


100%|██████████| 173/173 [00:19<00:00,  8.93it/s]


Val logs: {'f1': 0.8737580659072524, 'val_loss': 0.23682915785409123}


Epoch 11: 100%|██████████| 692/692 [04:31<00:00,  2.55it/s, loss=0.169]


Train logs: {'f1': 0.9100560795709577, 'train_loss': 0.1693261246805246}


100%|██████████| 173/173 [00:19<00:00,  9.06it/s]


Val logs: {'f1': 0.871358471798759, 'val_loss': 0.24209706872873912}


Epoch 12: 100%|██████████| 692/692 [04:32<00:00,  2.54it/s, loss=0.166]


Train logs: {'f1': 0.9120282500637749, 'train_loss': 0.16560156298855136}


100%|██████████| 173/173 [00:18<00:00,  9.27it/s]


Val logs: {'f1': 0.876795468647356, 'val_loss': 0.23069607729167607}
New best model! Val f1: 0.876795468647356


Epoch 13: 100%|██████████| 692/692 [04:32<00:00,  2.54it/s, loss=0.16] 


Train logs: {'f1': 0.9149692215326893, 'train_loss': 0.16009883651499116}


100%|██████████| 173/173 [00:18<00:00,  9.32it/s]


Val logs: {'f1': 0.875690049863275, 'val_loss': 0.23290559873415556}


Epoch 14: 100%|██████████| 692/692 [04:39<00:00,  2.48it/s, loss=0.157]


Train logs: {'f1': 0.916760786506482, 'train_loss': 0.15674485292048812}


100%|██████████| 173/173 [00:20<00:00,  8.62it/s]


Val logs: {'f1': 0.8767502114951955, 'val_loss': 0.23106222655731817}


Epoch 15: 100%|██████████| 692/692 [04:33<00:00,  2.53it/s, loss=0.15] 


Train logs: {'f1': 0.9202768893083396, 'train_loss': 0.15010997997543027}


100%|██████████| 173/173 [00:18<00:00,  9.19it/s]


Val logs: {'f1': 0.8740996804540557, 'val_loss': 0.2358743138395982}


Epoch 16: 100%|██████████| 692/692 [04:31<00:00,  2.55it/s, loss=0.151]


Train logs: {'f1': 0.9200576086437082, 'train_loss': 0.15053081159302265}


100%|██████████| 173/173 [00:18<00:00,  9.34it/s]


Val logs: {'f1': 0.8744610954571321, 'val_loss': 0.23548302691795922}


Epoch 17: 100%|██████████| 692/692 [04:31<00:00,  2.55it/s, loss=0.148]


Train logs: {'f1': 0.9215223930474651, 'train_loss': 0.14769154718156494}


100%|██████████| 173/173 [00:18<00:00,  9.20it/s]


Val logs: {'f1': 0.874130931203765, 'val_loss': 0.23570475412931055}


Epoch 18: 100%|██████████| 692/692 [04:37<00:00,  2.49it/s, loss=0.144]


Train logs: {'f1': 0.9235046922816017, 'train_loss': 0.14391325517541412}


100%|██████████| 173/173 [00:19<00:00,  9.00it/s]


Val logs: {'f1': 0.8756222246010179, 'val_loss': 0.2330293896570371}


Epoch 19: 100%|██████████| 692/692 [04:37<00:00,  2.49it/s, loss=0.135]


Train logs: {'f1': 0.9284202815652582, 'train_loss': 0.13480502013870746}


100%|██████████| 173/173 [00:20<00:00,  8.30it/s]


Val logs: {'f1': 0.8772975247719385, 'val_loss': 0.22984730956182314}
New best model! Val f1: 0.8772975247719385


Epoch 20: 100%|██████████| 692/692 [04:34<00:00,  2.52it/s, loss=0.129]


Train logs: {'f1': 0.9313656214000172, 'train_loss': 0.12917389380449504}


100%|██████████| 173/173 [00:20<00:00,  8.61it/s]


Val logs: {'f1': 0.8774055297664135, 'val_loss': 0.22964163805018958}
New best model! Val f1: 0.8774055297664135


Epoch 21: 100%|██████████| 692/692 [04:39<00:00,  2.48it/s, loss=0.127]


Train logs: {'f1': 0.9325843870295265, 'train_loss': 0.12690743433602283}


100%|██████████| 173/173 [00:19<00:00,  8.87it/s]


Val logs: {'f1': 0.8742248374602698, 'val_loss': 0.23570029583969557}


Epoch 22: 100%|██████████| 692/692 [04:30<00:00,  2.56it/s, loss=0.128]


Train logs: {'f1': 0.9318946396684371, 'train_loss': 0.12821654932347337}


100%|██████████| 173/173 [00:18<00:00,  9.11it/s]


Val logs: {'f1': 0.8754431043746155, 'val_loss': 0.23353445219855778}


Epoch 23: 100%|██████████| 692/692 [04:39<00:00,  2.48it/s, loss=0.124]


Train logs: {'f1': 0.9342985629518598, 'train_loss': 0.12374379137003352}


100%|██████████| 173/173 [00:18<00:00,  9.19it/s]


Val logs: {'f1': 0.8754400728065843, 'val_loss': 0.2332454241080091}


Epoch 24: 100%|██████████| 692/692 [04:44<00:00,  2.44it/s, loss=0.123]


Train logs: {'f1': 0.9346461004949029, 'train_loss': 0.12307157595722662}


100%|██████████| 173/173 [00:19<00:00,  8.82it/s]


Val logs: {'f1': 0.8737506559818466, 'val_loss': 0.2362012687446065}


Epoch 25: 100%|██████████| 692/692 [04:47<00:00,  2.41it/s, loss=0.123]


Train logs: {'f1': 0.9346089558622052, 'train_loss': 0.12307887080776898}


100%|██████████| 173/173 [00:19<00:00,  8.87it/s]


Val logs: {'f1': 0.8760840779095027, 'val_loss': 0.23203207647180282}


Epoch 26:  10%|█         | 71/692 [00:29<04:17,  2.41it/s, loss=0.114]


KeyboardInterrupt: 

In [None]:
predictions = model_train.predict(dataloader_val)

In [None]:
def visualize(figsize=(20, 10), **images):
    """
    :param figsize: A tuple (width, height) specifying the size of the figure. Default is (20, 10).
    :param images: Keyword arguments where the key is the name of the image and the value is the corresponding image data.
    :return: None

    This method takes in multiple images and displays them in a grid layout using matplotlib. It supports numpy arrays and torch tensors as image data. The figsize parameter allows you to specify the size of the figure. Each image will be displayed with its corresponding name as the title.

    Examples:
        visualize(figsize=(10, 5), image1=image_data1, image2=image_data2)
    """
    n_images = len(images)
    plt.figure(figsize=figsize)
    for idx, (name, image) in enumerate(images.items()):

        if isinstance(image, torch.Tensor):
            image = image.detach().cpu().numpy()
            if image.shape[0] == 1:
                image = image[0]

            if len(image.shape) == 3:
                image = image.transpose(1, 2, 0)

        if not isinstance(image, np.ndarray):
            raise ValueError(f'Image must be numpy array or torch tensor. Got {type(image)}')
        if image.shape[-1] == 1:
            image = image[:, :, 0]

        plt.subplot(1, n_images, idx + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(name.replace('_', ' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

In [None]:
for i in range(30):
    params_vis = {}
    params_vis[f'image{i}'] = dataset_val.get_item(i)[2]
    params_vis[f'true_mask{i}'] = dataset_val.get_item(i)[1]
    params_vis[f'pred_mask{i}'] = predictions[i]
    visualize(figsize=(20, 10), **params_vis)