In [17]:
%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import numpy as np
import json
import cv2 as cv
from tqdm import tqdm

In [10]:
num_classes = 2
batch_size = 32
img_size = [512, 512]
learning_rate = 0.0001
num_epochs = 10
colors = [(0, 255, 0), (255, 0, 0)]
classes = ['pool', 'lack_of_fusion']

In [12]:
pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))

def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
                                bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));

W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);

X = torch.randn(size=(1, 3, 512, 512))
net(X).shape



torch.Size([1, 2, 512, 512])

In [13]:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """ Display a list of images.
    
    Args:
        imgs (list): List of images
        num_rows (int): Number of rows
        num_cols (int): Number of columns
        titles (list, optional): List of titles. Defaults to None.
        scale (float, optional): Scale. Defaults to 1.5.
        
    Returns:
        list: List of axes
    """
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        ax.imshow(img.permute(1, 2, 0))
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

# Convert and Read Dataset

Dataset folder structure is like below:

```
Dataset Root
    ├── bmp: .bmp images. Original images
    ├── jsons: .json files. Labels
    ├── images (Optional): Convert .bmp images to .jpg images and save them here
    ├── segm (Optional): .png images. Segmentation masks

In [25]:
def label2mask(json_file, root='mydata', show=False):
    os.makedirs(os.path.join(root, 'segm'), exist_ok=True)
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)
    except:
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    img = np.zeros(
        (data['imageHeight'], data['imageWidth'], 3), dtype=np.uint8)
    save_path = os.path.join(root, 'segm', os.path.basename(
        json_file).split('.')[0] + '.png')
    labels = data['shapes']
    for label in labels:
        points = np.array(label['points'], dtype=np.int32)
        cv.fillPoly(img, [points], colors[classes.index(label['label'])])
    cv.imwrite(save_path, img)
    if show:
        plt.imshow(img)
        plt.axis('off')


def convert_img(img_path, root='mydata', show=False):
    os.makedirs(os.path.join(root, 'images'), exist_ok=True)
    img = cv.imread(img_path)
    save_path = os.path.join(root, 'images', os.path.basename(img_path).split('.')[0] + '.jpg')
    cv.imwrite(save_path, img)
    if show:
        plt.imshow(img)
        plt.axis('off')

In [65]:
def nlx_rand_crop(feature, label, height=360, width=360):
    """随机裁剪特征和标签图像"""
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label


def nlx_color2label():
    color2label = torch.zeros(256 ** 3, dtype=torch.long)
    for i, color in enumerate(colors):
        color2label[(color[0] * 256 + color[1]) * 256 + color[2]] = i
    return color2label


def nlx_label_indices(colormap, colormap2label):
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 +
           colormap[:, :, 2])
    return colormap2label[idx]


def imgread(path):
    img = cv.imread(path)
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    return img


def imgwrite(img, path):
    img = cv.cvtColor(img, cv.COLOR_RGB2BGR)
    cv.imwrite(path, img)


def list_imgs(root='mydata', ratio=[0.6, 0.2, 0.2]):
    all_imgs = [x.split('.')[0]
                for x in os.listdir(os.path.join(root, 'images'))]
    train_imgs = all_imgs[:int(ratio[0] * len(all_imgs))]
    val_imgs = all_imgs[int(ratio[0] * len(all_imgs)):int(
        (ratio[0] + ratio[1]) * len(all_imgs))]
    test_imgs = all_imgs[int((ratio[0] + ratio[1]) * len(all_imgs)):]
    return train_imgs, val_imgs, test_imgs

def read_imgs(root='mydata')


In [None]:
class NLX2024Dataset(torch.utils.data.Dataset):

    def __init__(self, root, crop_size):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        train_imgs, val_imgs, test_imgs = list_imgs(root)
        self.train_imgs = train_imgs
        self.val_imgs = val_imgs
        self.test_imgs = test_imgs
        self.color2label = nlx_color2label()
        print('read ' + str(len(self.train_imgs) +
              len(self.val_imgs) + len(self.test_imgs)) + ' images')
        
        