<a href="https://colab.research.google.com/github/founderlin/PaperReview1/blob/master/x12_U_Net_Gan02_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import cv2
import glob

from torch.optim import optimizer
from torch.utils.data import Dataset
import random
import torch
import torch.nn as nn
from torch import optim, tensor
import torch.nn.functional as F
from torchvision import transforms
import numpy as np

Set the data loader

In [None]:
class DataLoader(Dataset):
    def __init__(self, data_path):
        # read images for training
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))

    def augment(self, image, flipCode):
        # data enrichment using cv2.flip
        flip = cv2.flip(image, flipCode)
        return flip

    def __getitem__(self, index):
        # generate path of each image via index
        image_path = self.imgs_path[index]
        # print(image_path)

        # generate path of label
        label_path = image_path.replace('image', 'label')

        # read all images and labels
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)

        # convert RGB to one-channel (black and white)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)

        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])

        # process labels, switch 255 to 1
        if label.max() > 1:
            label = label / 255

        # data enrichment
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)

        return image, label

    def __len__(self):
        # get the size of data set
        return len(self.imgs_path)

Set the U-net part

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

Set the U-net model

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)


    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # x = x.view(x.size(0), 512)
        output = self.model(x)
        return output


Train the model

In [None]:
def train_net(netG, netD, device, data_path, epochs=50, batch_size=4, lr=1e-5):
    # 加载训练集
    isbi_dataset = DataLoader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    

    # 定义RMSprop算法
    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
    # optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-8, amsgrad=False)

    optimizer_D = optim.RMSprop(netD.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer_G = optim.RMSprop(netG.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)

    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()

    # best_loss统计，初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        netG.train()
        netD.train()

        # 按照batch_size开始训练
        for n, (real_image, real_label) in enumerate(train_loader):

            # Data for training the discriminator
            real_image = real_image.to(device=device, dtype=torch.float32)
            real_label = real_label.to(device=device, dtype=torch.float32)
            made_image = netG(real_image)
            real_label_D = torch.ones((batch_size, 1, 512, 1)).to(device=device)
            made_label_D = torch.zeros(batch_size, 1, 512, 1).to(device=device)

            all_image = torch.cat((real_image, made_image))
            all_label_D = torch.cat((real_label_D, made_label_D))

            # Training the discriminator
            netD.zero_grad()
            out_netD = netD(all_image)
            # print(out_netD.size(), all_label_D.size())
            loss_D = criterion(out_netD, all_label_D)
            loss_D.backward()
            optimizer_D.step()

            # Training the generator
            netG.zero_grad()
            out_netD_made = netG(real_image)
            loss_G = criterion(out_netD_made, real_label)

            # 保存loss值最小的网络参数
            if loss_G < best_loss:
                best_loss = loss_G
                torch.save(netG.state_dict(), '/content/drive/MyDrive/Unet GAN model/best_model_Gan02.pth')

            loss_G.backward()
            optimizer_G.step()

            if n == batch_size - 1:
                # print(f"Epoch: {epoch} Loss D.: {loss_D}")
                # print(f"Epoch: {epoch} Loss G.: {loss_G}")
                print(f"{loss_D}, {loss_G}")
            
    #     if epoch%2==0:
    #         save_n='/content/drive/MyDrive/Unet GAN model/' + str(epoch) + "_model_GANA128.pth"
    #         torch.save(netG.state_dict(), save_n)

    # save_final="/content/drive/MyDrive/Unet GAN model/last_model_GANA128.pth"
    # torch.save(netG.state_dict(), save_final)

In [None]:

# 选择设备，有cuda用cuda，没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络，图片单通道1，分类为1。
generator = UNet(n_channels=1, n_classes=1).to(device=device)
discriminator = Discriminator().to(device=device).to(device=device)
# 指定训练集地址，开始训练
data_path = "/content/drive/MyDrive/Colab Notebooks/dataA/train"
train_net(generator, discriminator, device, data_path)

Test

In [None]:
    # 选择设备，有cuda用cuda，没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络，图片单通道，分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('/content/drive/MyDrive/Unet GAN model/best_model_Gan02.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('/content/drive/MyDrive/Colab Notebooks/dataA/Gan02/*.png')
    # 遍历素有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1，通道为1，大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # img_tensor = transforms.functional.to_pil_image(img)
        # img_tensor = transforms.functional.resize(img_tensor, 256)
        # img_tensor = transforms.functional.to_tensor(img_tensor)

        # 将tensor拷贝到device中，只用cpu就是拷贝到cpu中，用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)