In [None]:
import cv2 as cv
import numpy as np
import json
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


In [None]:
# Class color (BGR)
COLORS_BGR = [(0, 255, 0), (0, 0, 255)]
COLORS_RGB = [(0, 255, 0), (255, 0, 0)]
CLASSES = ['pool', 'lack_of_fusion']

In [None]:
def label2mask(json_file, root='mydata', show=False):
    """According to the json file, create a mask for the image.

    Save mask file with same name in {root}/seg/ folder. 

    Save into .jpg format.  

    Args:
        json_file (str): Json file path
        show (bool, optional): Show the mask. Defaults to False.
    """
    os.makedirs(os.path.join(root, 'seg'), exist_ok=True)
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)
    except:
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    img = np.zeros(
        (data['imageHeight'], data['imageWidth'], 3), dtype=np.uint8)
    save_path = os.path.join(root, 'seg', os.path.basename(
        json_file).split('.')[0] + '.jpg')
    labels = data['shapes']
    for label in labels:
        points = np.array(label['points'], dtype=np.int32)
        cv.fillPoly(img, [points], COLORS_BGR[CLASSES.index(label['label'])])
    cv.imwrite(save_path, img)
    if show:
        plt.imshow(img)
        plt.show()


def convert_img(root='mydata'):
    """Convert images into .jpg format.

    Args:
        root (str, optional): _description_. Defaults to 'mydata'.
    """
    os.makedirs(os.path.join(root, 'Images'), exist_ok=True)
    for img_file in tqdm(os.listdir(os.path.join(root, 'bmp'))):
        img = cv.imread(os.path.join(root, 'bmp', img_file))
        cv.imwrite(os.path.join(
            root, 'Images', img_file.split('.')[0] + '.jpg'), img)


# for json_file in tqdm(os.listdir('mydata/Json')):
#     label2mask(os.path.join('mydata/Json', json_file))
# convert_img()

In [None]:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """ Display a list of images.
    
    Args:
        imgs (list): List of images
        num_rows (int): Number of rows
        num_cols (int): Number of columns
        titles (list, optional): List of titles. Defaults to None.
        scale (float, optional): Scale. Defaults to 1.5.
        
    Returns:
        list: List of axes
    """
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        ax.imshow(img.permute(1, 2, 0))
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

In [None]:
import torch
import torchvision
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from torchvision import transforms

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def read_images(root='mydata'):
    """Read images from root folder.

        The root folder should look like this:   

        root
        ├── Images: images .jpg/.png
        ├── Json: json files
        ├── seg: mask files .jpg/.png
        ├── Annotations: annotations files (Needed to create seg/)
        ├── labels: labels files (optional)
        ├── classes.txt: classes files  

    Args:
        root (str, optional): Root folder of images
        ratio (list, optional): Ratio of train, val, test.
    """
    mode = torchvision.io.image.ImageReadMode.RGB
    img_paths = []
    label_paths = []
    for img in os.listdir(os.path.join(root, 'Images')):
        img_paths.append(os.path.join(root, 'Images', img))
        label_paths.append(os.path.join(root, 'seg', img))
    print('Total images:', len(img_paths))
    print('Total labels:', len(label_paths))
    images, labels = [], []
    for i, img_path in enumerate(img_paths):
        images.append(torchvision.io.read_image(img_path, mode))
        labels.append(torchvision.io.read_image(label_paths[i], mode))
    return images, labels


def voc_colormap2label():
    """构建从RGB到VOC类别索引的映射"""
    colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
    for i, colormap in enumerate(COLORS_RGB):
        colormap2label[
            (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
    return colormap2label


def voc_label_indices(colormap, colormap2label):
    """将VOC标签中的RGB值映射到它们的类别索引"""
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]


def voc_rand_crop(feature, label, height, width):
    """随机裁剪特征和标签图像"""
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label

In [None]:
class VOCSegDataset(torch.utils.data.Dataset):
    """一个用于加载VOC数据集的自定义数据集"""

    def __init__(self, crop_size, voc_dir):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        features, labels = read_images(voc_dir)
        self.features = [self.normalize_image(feature)
                         for feature in self.filter(features)]
        self.labels = self.filter(labels)
        self.colormap2label = voc_colormap2label()
        print('read ' + str(len(self.features)) + ' examples')

    def normalize_image(self, img):
        return self.transform(img.float() / 255)

    def filter(self, imgs):
        return [img for img in imgs if (
            img.shape[1] >= self.crop_size[0] and
            img.shape[2] >= self.crop_size[1])]

    def __getitem__(self, idx):
        feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
                                       *self.crop_size)
        return (feature, voc_label_indices(label, self.colormap2label))

    def __len__(self):
        return len(self.features)

In [None]:
# Model

pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                                    kernel_size=64, padding=16, stride=32))


def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight


conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
                                bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4))

W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W)

net = net.to(device)

In [None]:
trans = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()])

In [None]:
dataset = VOCSegDataset((512,512), 'mydata')