In [None]:
!mkdir /content/train_folder
!mkdir /content/val_folder

!unzip -q /content/drive/MyDrive/igor_files/patches_train.zip -d /content/train_folder

In [None]:
!unzip -q /content/drive/MyDrive/igor_files/patches_val.zip -d /content/val_folder

In [None]:
!ls /content/drive/MyDrive/igor_files/

patches_train.zip  patches_val.zip


In [None]:
!pip install wandb

In [None]:
from functools import lru_cache, partial
from itertools import repeat
from multiprocessing import Pool

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import argparse
import logging
import os
import random
import sys
from pathlib import Path
from tqdm import tqdm
from glob import glob
import argparse
import cv2

from PIL import Image

import wandb

In [None]:
""" Parts of the U-Net model """

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [None]:
def plot_img_and_mask(img, mask):
    classes = mask.max() + 1
    fig, ax = plt.subplots(1, classes + 1)
    ax[0].set_title('Input image')
    ax[0].imshow(img)
    for i in range(classes):
        ax[i + 1].set_title(f'Mask (class {i + 1})')
        ax[i + 1].imshow(mask == i)
    plt.xticks([]), plt.yticks([])
    plt.show()


def load_image(filename):
    ext = os.path.splitext(filename)[1]
    if ext == '.npy':
        return Image.fromarray(np.load(filename))
    elif ext in ['.pt', '.pth']:
        return Image.fromarray(torch.load(filename).numpy())
    else:
        return Image.open(filename)


def unique_mask_values(idx, mask_dir, mask_suffix):
    mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0]
    mask = np.asarray(load_image(mask_file))
    if mask.ndim == 2:
        return np.unique(mask)
    elif mask.ndim == 3:
        mask = mask.reshape(-1, mask.shape[-1])
        return np.unique(mask, axis=0)
    else:
        raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}')


def get_idx_from_filename(filename):
    point_symb_ind = filename.rfind('.')
    first_idx_symb_ind = filename.rfind('_') + 1
    return int(filename[first_idx_symb_ind:point_symb_ind])


def get_dict_with_images_from_folder(folder_path):
    if not folder_path:
        return {}
    filenames_list = list(os.listdir(folder_path))
    idx_list = list(map(get_idx_from_filename, filenames_list))
    idx2file_path = {
       idx_list[i] : os.path.join(folder_path, filename)
       for i, filename in enumerate(filenames_list)
    }

    return idx2file_path

In [None]:
%cd /content

/content


In [None]:
PATCH_SIZE = 256
BASE_PATH = '.'
SEED = 123

BLACK_PICTURES_LIST = []

def data_processing(base_path: str = BASE_PATH,
                    train_data_dir_name: str = 'drive/MyDrive/Olymps/leadersofdigital/msk/skol_data/train_updated_titiles',
                    val_part: float = 0.2):
    np.random.seed(SEED)
    source_img_dir = os.path.join(base_path,
                                  os.path.join(train_data_dir_name,
                                               'images')
    )
    source_mask_dir = os.path.join(base_path,
                                   os.path.join(train_data_dir_name,
                                                'masks')
    )

    new_save_dir = os.path.join(base_path, 'patches_train')
    os.makedirs(new_save_dir, exist_ok=True)

    save_img_dir = os.path.join(new_save_dir, 'images')
    save_mask_dir = os.path.join(new_save_dir, 'masks')
    os.makedirs(save_img_dir, exist_ok=True)
    os.makedirs(save_mask_dir, exist_ok=True)
    os.makedirs(os.path.join(
        base_path,
        os.path.relpath(save_img_dir, base_path).replace('train', 'val')
    ), exist_ok=True)
    os.makedirs(os.path.join(
        base_path,
        os.path.relpath(save_mask_dir, base_path).replace('train', 'val')
    ), exist_ok=True)

    img_train_id = 0
    img_val_id = 0
    cnt_zero_pics = 0
    for id in range(len(glob(source_img_dir + '/*'))):
        print(f'---- id: {id} ----')
        img_path = f'{source_img_dir}/train_image_{id:03}.png'
        mask_path = f'{source_mask_dir}/train_mask_{id:03}.png'
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        print(mask_path.split('/')[-1][:-4])
        if img.shape != mask.shape:
            continue
        img_name = os.path.join(save_img_dir, img_path.split('/')[-1][:-4])
        img_name = img_name[:img_name.rfind('_')]
        mask_name = os.path.join(save_mask_dir, mask_path.split('/')[-1][:-4])
        mask_name = mask_name[:mask_name.rfind('_')]
        for i in tqdm(range(0, img.shape[0] // PATCH_SIZE * 256, 128)):
            for j in range(0, img.shape[1] // PATCH_SIZE * 256, 128):
                patch_img = img[i:i + PATCH_SIZE, j:j + PATCH_SIZE]
                patch_mask = mask[i:i + PATCH_SIZE, j:j + PATCH_SIZE]
                if patch_mask.sum() <= 4:
                    cnt_zero_pics += 1
                    if cnt_zero_pics % 2 == 0:
                        continue
                if np.random.uniform(0, 1) < val_part:
                    img_name_cur = os.path.join(
                            base_path,
                            os.path.relpath(img_name, base_path).replace('train', 'val')
                        )
                    mask_name_cur = os.path.join(
                            base_path,
                            os.path.relpath(mask_name, base_path).replace('train', 'val')
                        )
                    cur_id = img_val_id
                    img_val_id += 1
                else:
                    img_name_cur = img_name
                    mask_name_cur = mask_name
                    cur_id = img_train_id
                    img_train_id += 1
                cv2.imwrite(f"{img_name_cur}_{cur_id:03}.png", patch_img)
                cv2.imwrite(f"{mask_name_cur}_{cur_id:03}.png", patch_mask)


In [None]:
class ImageSegmentationDataset(Dataset):
    def __init__(self, images_dir: str, mask_dir: str, transform=None, is_train_or_val: bool = True):
        self.images_dir = images_dir
        self.mask_dir = mask_dir
        self.is_train_or_val = is_train_or_val
        self.transform = transform
        self.idx2img_filepath = get_dict_with_images_from_folder(self.images_dir)
        self.idx2mask_filepath = get_dict_with_images_from_folder(self.mask_dir)

        assert not self.is_train_or_val or self.idx2img_filepath.keys() == self.idx2mask_filepath.keys()

        self.ids = sorted(list(self.idx2img_filepath.keys()))

    def __len__(self):
        return len(self.idx2img_filepath)

    def preprocess(self, pil_img, is_mask):
        #### TODO: correct for our preprocess
        img = np.asarray(pil_img)

        if self.transform:
            img = self.transform(img)
        if is_mask:
            img *= 255
        return img

    def __getitem__(self, idx):
        idx = self.ids[idx]
        img_file = self.idx2img_filepath[idx]
        mask_file = self.idx2mask_filepath[idx]

        img = load_image(img_file)
        if self.is_train_or_val:
            mask = load_image(mask_file)
        else:
            mask = None

        assert not self.is_train_or_val or img.size == mask.size, \
            f'Image {img_file} and mask {mask_file} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(img, is_mask=False)
        if self.is_train_or_val:
            mask = self.preprocess(mask, is_mask=True)[0]

        return {
            'image': img,
            'mask': mask
        }

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Lambda(lambda x: x / 255.),
    transforms.Resize((256, 256))
])
train_dataset = ImageSegmentationDataset('/content/train_folder/patches_train/images',
                                         '/content/train_folder/patches_train/masks', transform=transform)

In [None]:
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [None]:
@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
      with tqdm(total=num_val_batches, desc=f'Validation round ', unit='batch', leave=False) as pbar:
        for batch in dataloader:
            image, mask_true = batch['image'], batch['mask']

            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # predict the mask
            mask_pred = net(image)

            if net.n_classes == 1:
                assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
                mask_pred = (F.sigmoid(mask_pred.squeeze(1)) > 0.5).float()
                # compute the Dice score
                dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
            else:
                assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
                # convert to one-hot format
                mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
                mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
                # compute the Dice score, ignoring background
                dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)

    net.train()
    return dice_score / max(num_val_batches, 1)

In [None]:
MAIN_DATA_FOLDER = '/content/drive/MyDrive'
DIR_TRAIN_IMG = '/content/train_folder/patches_train/images'
DIR_TRAIN_MASK = '/content/train_folder/patches_train/masks'
DIR_VAL_IMG = '/content/val_folder/patches_val/images'
DIR_VAL_MASK = '/content/val_folder/patches_val/masks'
CHECKPOINTS_DIR = os.path.join(MAIN_DATA_FOLDER, 'checkpoints/')


def train_model(
        model,
        device,
        epochs: int = 5,
        batch_size: int = 8,
        learning_rate: float = 1e-4,
        save_checkpoint: bool = True,
        weight_decay: float = 1e-4,
        amp: bool = False,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
        data_type: str = 'start data'
):
    # 1. Create dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Lambda(lambda x: x / 255.),
        transforms.Resize((256, 256))
    ])
    train_dataset = ImageSegmentationDataset(DIR_TRAIN_IMG, DIR_TRAIN_MASK, transform=transform, is_train_or_val=True)
    val_dataset = ImageSegmentationDataset(DIR_VAL_IMG, DIR_VAL_MASK, transform=transform, is_train_or_val=True)

    n_train = len(train_dataset)
    n_val = len(val_dataset)

    # 2. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count())
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='Updated U-Net', resume='allow', anonymous='must')
    experiment.config.update(
        dict(
            epochs=epochs,
            batch_size=batch_size,
            learning_rate=learning_rate,
            n_train=n_train,
            n_val=n_val,
            data_type=data_type,
            amp=amp)
    )

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        AMP:             {amp}
    ''') #         Validation size: {n_val}

    # 3. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.AdamW(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'max',
        patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.BCEWithLogitsLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)

                    loss = criterion(masks_pred.squeeze(1), true_masks.float())
                    loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (10 * batch_size))
                if division_step > -1:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace('/', '.')
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        try:
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu() * 255),
                                    'pred': wandb.Image(F.sigmoid(masks_pred[0][0].squeeze(1)).float().cpu() * 255),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
                        except:
                            pass

                        if save_checkpoint:
                            Path(CHECKPOINTS_DIR).mkdir(parents=True, exist_ok=True)
                            state_dict = model.state_dict()
                            torch.save(state_dict, os.path.join(CHECKPOINTS_DIR, f'checkpoint_epoch{epoch}_{global_step}.pth'))
                            logging.info(f'Checkpoint {epoch} saved!')



In [None]:
TRAIN_CONFIG = {
    'epochs': 10,
    'batch-size': 16,
    'learning-rate': 0.00001,
    'load': False
}
TRAIN_ARGS = sum([[f'--{key}', str(value)] for key, value in TRAIN_CONFIG.items()], []) + \
    ['--amp']

In [None]:
TRAIN_ARGS

['--epochs',
 '10',
 '--batch-size',
 '16',
 '--learning-rate',
 '1e-05',
 '--load',
 'False',
 '--amp']

In [None]:
def get_args(args=TRAIN_ARGS):
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=8, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')

    return parser.parse_args(args=args)


if __name__ == '__main__':
    args = get_args()
    print(args)

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    model = UNet(n_channels=3, n_classes=1, bilinear=args.bilinear)
    model = model.to(memory_format=torch.channels_last)

    logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

    # if args.load:
    #     state_dict = torch.load(args.load, map_location=device)
    #     del state_dict['mask_values']
    #     model.load_state_dict(state_dict)
    #     logging.info(f'Model loaded from {args.load}')

    model.to(device=device)
    try:
        train_model(
            model=model,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            device=device,
            # val_percent=args.val / 100,
            amp=args.amp
        )
    except torch.cuda.OutOfMemoryError:
        logging.error('Detected OutOfMemoryError! '
                      'Enabling checkpointing to reduce memory usage, but this slows down training. '
                      'Consider enabling AMP (--amp) for fast and memory efficient training')
        torch.cuda.empty_cache()
        model.use_checkpointing()
        train_model(
            model=model,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            device=device,
            img_scale=args.scale,
            # val_percent=args.val / 100,
            amp=args.amp
        )

Namespace(epochs=10, batch_size=16, lr=1e-05, load='False', val=10.0, amp=True, bilinear=False)


[34m[1mwandb[0m: Currently logged in as: [33migshchukin[0m. Use [1m`wandb login --relogin`[0m to force relogin


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Epoch 1/10:  10%|▉         | 3072/30876 [01:09<09:30, 48.76img/s, loss (batch)=1.33]
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Epoch 1/10:  10%|▉         | 3072/30876 [01:19<09:30, 48.76img/s, loss (batch)=1.33]
Epoch 1/10:  20%|█▉        | 6144/30876 [03:21<08:36, 47.89img/s, loss (batch)=1.09]
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Epoch 1/10:  20%|█▉        | 6144/30876 [03:40<08:36, 47.89img/s, loss (batch)=1.09]
Epoch 1/10:  30%|██▉       | 9216/30876 [05:28<07:48, 46.26img/s, loss (batch)=1.24]
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Epoch 1/10:  30%|██▉       | 9216/30876 [05:40<07:48, 46.26img/

In [None]:
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img).cpu()
        output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
        if net.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > out_threshold

    return mask[0].long().squeeze().numpy()



In [None]:
def get_args(args=None):
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
    parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5,
                        help='Scale factor for the input images')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args(args=args)


def get_output_filenames(args):
    def _generate_name(fn):
        return f'{os.path.splitext(fn)[0]}_OUT.png'

    return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray, mask_values):
    if isinstance(mask_values[0], list):
        out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
    elif mask_values == [0, 1]:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
    else:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)

    if mask.ndim == 3:
        mask = np.argmax(mask, axis=0)

    for i, v in enumerate(mask_values):
        out[mask == i] = v

    return Image.fromarray(out)


if __name__ == '__main__':
    args = get_args()
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    net.to(device=device)
    state_dict = torch.load(args.model, map_location=device)
    mask_values = state_dict.pop('mask_values', [0, 1])
    net.load_state_dict(state_dict)

    logging.info('Model loaded!')

    for i, filename in enumerate(in_files):
        logging.info(f'Predicting image {filename} ...')
        img = Image.open(filename)

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)

        if not args.no_save:
            out_filename = out_files[i]
            result = mask_to_image(mask, mask_values)
            result.save(out_filename)
            logging.info(f'Mask saved to {out_filename}')

        if args.viz:
            logging.info(f'Visualizing results for image {filename}, close to continue...')
            plot_img_and_mask(img, mask)