In [29]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from PIL import Image
import glob
from tqdm import tqdm

In [2]:
# class CFG:
#     model_name    = 'Unet'
#     backbone      = 'efficientnet-b7'
#     ckpt_path     = '/mnt/SSD/workspace/roads_buildings/src/exps/1700856032.515359/best_epoch.bin'
#     img_size      = [1024, 1024]
#     num_classes   = 1
#     device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
def build_model(backbone, num_classes, device):
    model = smp.Unet(
        encoder_name=backbone,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=num_classes,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to(device)
    return model

def load_model(backbone, num_classes, device, path):
    model = build_model(backbone, num_classes, device)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [4]:
def preprocess_array_img(img):
    img = img.astype('float32') # original is uint16
    img = img / 255
    # img = A.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    mx = np.max(img)
    if mx:
        img/=mx # scale image to [0, 1]
    return img

In [27]:
def load_img(path, size=None):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)[:, :, :3]
    init_shape = img.shape[:2]
    if size:
        img = cv2.resize(img, size)
    img = img.astype('float32') # original is uint16
    img = img / 255
    # img = A.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    mx = np.max(img)
    if mx:
        img/=mx # scale image to [0, 1]
    return img, init_shape

In [6]:
def predict_full(img_path):
    threshold = 0.5
#     img_path = '/mnt/SSD/workspace/roads_buildings/train/train/images/train_image_014.png'
    image, shape = load_img(img_path, (1024, 1024))
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
    image = image.cuda()
    with torch.no_grad():
        pred = model(image)
        pred = (nn.Sigmoid()(pred)>=threshold).double()
        pred = pred.cpu().numpy().astype(np.uint8)[0][0]
        pred = cv2.resize(pred, (shape[1], shape[0]), cv2.INTER_NEAREST)
    return pred

In [7]:
def prediction_patched(model, image, patch_size, step):
#     image = torch.nn.functional.pad(image, (step, step, step, step))
    _, input_h, input_w = image.shape
    
    segm_img = torch.zeros((input_h, input_w), dtype=torch.float32)
    patch_num=1
    for i in range(0, input_h, step):   #Steps of 256
        for j in range(0,input_w, step):  #Steps of 256
            
            input_image = torch.zeros((3, patch_size, patch_size))
            
            single_patch = image[:, i:i+patch_size, j:j+patch_size]
#             single_patch_norm = np.expand_dims(normalize(np.array(single_patch), axis=1),2)
            single_patch_shape = single_patch.shape[-2:]
#             single_patch_input = np.expand_dims(single_patch, 0)
#             print(single_patch_input.shape)
            
    
            input_image[:, :single_patch_shape[0], :single_patch_shape[1]] = single_patch
            
            with torch.no_grad():
                single_patch_prediction = model(input_image.cuda().unsqueeze(0))
#                 single_patch_prediction = nn.Sigmoid()(single_patch_prediction)
                single_patch_prediction = single_patch_prediction.cpu().numpy()
            
            result_image = single_patch_prediction[:, :, :single_patch_shape[0], :single_patch_shape[1]]
            
#             segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += cv2.resize(single_patch_prediction, single_patch_shape[::-1])
#             single_patch_prediction = np.expand_dims(np.expand_dims(single_patch_prediction, 0), 0)
            segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += result_image
          
            patch_num+=1
#     return segm_img.numpy()[step:-step, step:-step]
    return segm_img

In [8]:
import ttach as tta

transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.VerticalFlip(),
        tta.Rotate90(angles=[0, 90, 180, 270]),
    ]
)

def do_tta(image):
    masks = []
    for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 

        # augment image
        augmented_image = transformer.augment_image(image.cuda())

        # pass to model
        with torch.no_grad():
            model_output = model(augmented_image).detach().cpu()

        # reverse augmentation for mask and label
        deaug_mask = transformer.deaugment_mask(model_output)

        # save results
        masks.append(deaug_mask)

    masks = torch.stack(masks)[:, 0, 0].sum(0)
    return masks

def prediction_patched_tta(model, image, patch_size, step):
#     image = torch.nn.functional.pad(image, (step, step, step, step))
    _, input_h, input_w = image.shape
    
    segm_img = torch.zeros((input_h, input_w), dtype=torch.float32)
    patch_num=1
    for i in range(0, input_h, step):   #Steps of 256
        for j in range(0,input_w, step):  #Steps of 256
            
            input_image = torch.zeros((3, patch_size, patch_size))
            
            single_patch = image[:, i:i+patch_size, j:j+patch_size]
#             single_patch_norm = np.expand_dims(normalize(np.array(single_patch), axis=1),2)
            single_patch_shape = single_patch.shape[-2:]
#             single_patch_input = np.expand_dims(single_patch, 0)
#             print(single_patch_input.shape)
            
    
            input_image[:, :single_patch_shape[0], :single_patch_shape[1]] = single_patch
            single_patch_prediction = do_tta(input_image.unsqueeze(0))
#             with torch.no_grad():
#                 single_patch_prediction = model(input_image.cuda().unsqueeze(0))
            single_patch_prediction = single_patch_prediction.cpu().unsqueeze(0).unsqueeze(0).numpy()
            
            result_image = single_patch_prediction[:, :, :single_patch_shape[0], :single_patch_shape[1]]
            
#             segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += cv2.resize(single_patch_prediction, single_patch_shape[::-1])
#             single_patch_prediction = np.expand_dims(np.expand_dims(single_patch_prediction, 0), 0)
            segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += result_image
          
            patch_num+=1
#     return segm_img.numpy()[step:-step, step:-step]
    return segm_img

In [9]:
def predict_patched(img_path, step, use_tta=False):
    image = torch.from_numpy(load_img(img_path, None)[0]).permute(2, 0, 1)
    if use_tta:
        mask_patched = prediction_patched_tta(model, image, patch_size=1024, step=step)
    else:
        mask_patched = prediction_patched(model, image, patch_size=1024, step=step)
    return mask_patched


In [10]:
models = [
    {
        'path': 'exps/1700856032.515359',
        'backbone': 'efficientnet-b7',
        'comment': 'b7, imagenet',
        'score': 0.6813924908638
    },
    {
        'path': 'exps/1700845040.797431',
        'backbone': 'timm-efficientnet-b7',
        'comment': 'timm-b7, imagenet, opensource pretrained',
        'score': 0.6732556819915771
    },
    {
        'path': 'exps/1700823919.5695016',
        'backbone': 'timm-efficientnet-b7',
        'comment': 'timm-b7, noisy-student',
        'score': 0.6628191471099854
    },
    {
        'path': 'exps/1700884866.665121',
        'backbone': 'efficientnet-b7',
        'comment': 'efficientnet-b7, our dataset',
        'score': 0.6754
    },
]

In [38]:
models_path = '/mnt/SSD/workspace/roads_buildings/src/'
model_idx = -1
model = load_model(models[model_idx]['backbone'], 1, 'cuda', models_path+models[model_idx]['path']+'/best_epoch.bin')

In [39]:
# # использовать для инференса!!!
# img_path = '/mnt/SSD/workspace/roads_buildings/train/train/images/train_image_006.png'
# pred_patched = predict_patched(img_path, step=256, use_tta=False)
# pred = (nn.Sigmoid()(pred_patched)>0.5).numpy().astype(np.uint8)

In [None]:
for img_path in tqdm(glob.glob('/mnt/SSD/workspace/roads_buildings/test_dataset_test/images/*')):
    pred_patched = predict_patched(img_path, step=1024, use_tta=True)
    pred = (nn.Sigmoid()(pred_patched)>0.5).numpy().astype(np.uint8)
    mask_path = img_path.replace('image', 'mask')
    print(mask_path)
    Image.fromarray(pred).save(mask_path)

  0%|                                                                                                                                                                                                                         | 0/8 [00:00<?, ?it/s]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_006.png


 12%|██████████████████████████▏                                                                                                                                                                                      | 1/8 [00:53<06:16, 53.72s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_001.png


 38%|██████████████████████████████████████████████████████████████████████████████▍                                                                                                                                  | 3/8 [03:01<04:47, 57.51s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_003.png


 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                        | 4/8 [03:17<02:44, 41.21s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_007.png


In [20]:
# img_path = '/mnt/SSD/workspace/roads_buildings/train/train/images/train_image_006.png'
# pred_patched = predict_patched(img_path, step=1024, use_tta=True)
# pred = (nn.Sigmoid()(pred_patched)>0.5).numpy().astype(np.uint8)

In [37]:
test_files = tqdm(glob.glob('/mnt/SSD/workspace/roads_buildings/test_dataset_test/images/*'))
for img_path in test_files:
    masks = []
    for  model_dict in models:
        print(model_dict['comment'], model_dict['score'])
        model = load_model(model_dict['backbone'], 1, 'cuda', models_path+model_dict['path']+'/best_epoch.bin')
        mask_path = img_path.replace('image', 'mask')
        pred_patched = predict_patched(img_path, step=256, use_tta=False)
        masks.append(pred_patched)
    
    pred = torch.stack(masks).numpy().sum(0)
    pred = (nn.Sigmoid()(torch.from_numpy(pred))>0.5).numpy().astype(np.uint8)
    mask_path = img_path.replace('image', 'mask')
    print(mask_path)
    Image.fromarray(pred).save(mask_path)

  0%|                                                                                                                                                                                                                         | 0/8 [00:00<?, ?it/s]

b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754
/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_006.png


 12%|██████████████████████████                                                                                                                                                                                      | 1/8 [03:19<23:16, 199.44s/it]

b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754
/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_001.png


 25%|████████████████████████████████████████████████████                                                                                                                                                            | 2/8 [08:24<26:09, 261.60s/it]

b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754


 38%|██████████████████████████████████████████████████████████████████████████████                                                                                                                                  | 3/8 [10:09<15:49, 189.86s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_003.png
b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754


 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                        | 4/8 [11:00<09:00, 135.06s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_007.png
b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754
/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_000.png


 62%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                              | 5/8 [17:29<11:20, 226.73s/it]

b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754


 75%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 6/8 [19:02<06:02, 181.39s/it]

/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_002.png
b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754
/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_005.png


 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                          | 7/8 [22:10<03:03, 183.31s/it]

b7, imagenet 0.6813924908638
timm-b7, imagenet, opensource pretrained 0.6732556819915771
timm-b7, noisy-student 0.6628191471099854
efficientnet-b7, our dataset 0.6754
/mnt/SSD/workspace/roads_buildings/test_dataset_test/masks/test_mask_004.png


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [24:31<00:00, 183.97s/it]
