In [None]:
#Необходимые библиотеки

import torch
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import os
import numpy as np
from tqdm.notebook import tqdm
import cv2
import pandas as pd
import albumentations as A

In [None]:
#Путь для тестовых изображений

from pathlib import Path

ROOT = Path("dataset-here") #путь до датасета

test_image_path = ROOT / "test/images/"

In [None]:
#Класс датасета

class NiiasDatasetSampleSolution(Dataset):
    def __init__(self, df, folder_path, transform=None):
        self.df = df
        self.folder_path = folder_path
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        name = self.df.iloc[index]['img_name']
        image = np.array(Image.open(os.path.join(self.folder_path, name)).convert("RGB"))
        img_h, img_w, _ = image.shape
        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]
        
        image = image.transpose(2, 0, 1)
        image = torch.from_numpy(image)
        image = image.float()/255

        return {
            'image': image,
            'name': name,
            'img_h': img_h,
            'img_w': img_w,
          }

In [None]:
GLOBAL_PARAMETERS = {
    'IMAGE_HEIGHT': 512,
    'IMAGE_WIDTH': 1024,
}

In [None]:
solution_transforms = A.Compose(
        [
            A.Resize(height=GLOBAL_PARAMETERS['IMAGE_HEIGHT'],width=GLOBAL_PARAMETERS['IMAGE_WIDTH']),
        ],
    )

In [None]:
test_path = sorted(test_image_path.glob("*.png"))
solution_names = [i.name for i in test_path]

solution_df = pd.DataFrame(solution_names, columns=['img_name'])
solution_df.head()

In [None]:
#Лоадер

solution_ds = NiiasDatasetSampleSolution(
        df=solution_df,
        folder_path=test_image_path,
        transform=solution_transforms)

solution_loader = DataLoader(
    solution_ds,
    batch_size=1,
    num_workers=1,
    shuffle=False,
)

In [None]:
#Используемые модели

models = [
    smp.Unet('resnet34', encoder_weights='imagenet', classes=4, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16]),
    smp.Unet('resnext50_32x4d', encoder_weights='imagenet', classes=4, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16]),
    smp.Unet('efficientnet-b2', encoder_weights='imagenet', classes=4, activation=None, encoder_depth=5, decoder_channels=[256, 128, 64, 32, 16]),
]

In [None]:
#Пути до чекпоинтов моделей

checkpoints = [
    '../models/unet-resnet34.pt',
    '../models/unet-resnext50.pt',
    '../models/unet-efficientnet-b2.pt',
]


In [None]:
#Загрузка моделей в память

device = 'cuda:0'

loaded_models = []

for model, check in zip(models, checkpoints):
    model = torch.load(check, map_location=device)
    model.eval
    loaded_models.append(model)

In [None]:
# Предикт ансамблем моделей

!rm -rf sample_solution
!mkdir sample_solution

with torch.no_grad():
    for n, batch in enumerate(tqdm(solution_loader)):

        template = torch.zeros(1, 4, 512, 1024)
        
        for model, thres in zip(loaded_models, [0.33, 0.33, 0.33]):
            
            predictions = model.predict((batch['image'].to(device)))
            template += predictions.cpu().detach().numpy() * thres
            
        template = torch.argmax(template, dim=1)
        template = template.cpu().squeeze(0).numpy()

        template[template == 1] = 6
        template[template == 2] = 7
        template[template == 3] = 10

        prediction_mask_gray = Image.fromarray(template.astype(np.uint8))
        prediction_mask_gray = prediction_mask_gray.resize((batch['img_w'], batch['img_h']), Image.NEAREST)
        prediction_mask_gray.save(os.path.join("sample_solution", f"{batch['name'][0]}"))

In [None]:
mask_test_path = Path('./sample_solution/')
mask_test_path = sorted(mask_test_path.glob("*.png"))
len(mask_test_path)

In [None]:
#Посмотрим случайный предикт

import random
import matplotlib.pyplot as plt

random_path = str(random.sample(mask_test_path, 1)[0])
random_mask = cv2.imread(random_path)
plt.figure(figsize=(15, 20))
plt.imshow(random_mask*30)
plt.show()