In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from math import isnan

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, middle_channels=None):
        super(DoubleConv, self).__init__()
        if middle_channels is None:
            middle_channels = out_channels
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, middle_channels=None):
        super(Down, self).__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1),
            DoubleConv(in_channels, out_channels, middle_channels),
        )

    def forward(self, x):
        return self.down(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, middle_channels=None):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels, middle_channels)

    def forward(self, x1, x2):
        #x1是下采的，x2是本层的
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        out = torch.cat([x2, x1], dim=1)
        return self.conv(out)


class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Out, self).__init__()
        self.out = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.out(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, middle_channels=None):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.inc = DoubleConv(self.in_channels, 64, middle_channels)
        self.down1 = Down(64, 128, middle_channels)
        self.down2 = Down(128, 256, middle_channels)
        self.down3 = Down(256, 512, middle_channels)
        self.down4 = Down(512, 1024, middle_channels)
        self.up1 = Up(1024, 512, middle_channels)
        self.up2 = Up(512, 256, middle_channels)
        self.up3 = Up(256, 128, middle_channels)
        self.up4 = Up(128, 64, middle_channels)
        self.out = Out(64, self.out_channels)

    def forward(self, x):
        x0 = self.inc(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x, x0)
        return self.out(x)

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_channels=3, num_filters=64):
        super(PatchGANDiscriminator, self).__init__()

        # 输入层: 图像经过若干卷积层后输出一个局部判断
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, num_filters, kernel_size=4, stride=2, padding=1),  # 输出 [batch_size, 64, H/2, W/2]
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(num_filters, num_filters * 2, kernel_size=4, stride=2, padding=1),  # 输出 [batch_size, 128, H/4, W/4]
            nn.BatchNorm2d(num_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=4, stride=2, padding=1),  # 输出 [batch_size, 256, H/8, W/8]
            nn.BatchNorm2d(num_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(num_filters * 4, num_filters * 8, kernel_size=4, stride=2, padding=1),  # 输出 [batch_size, 512, H/16, W/16]
            nn.BatchNorm2d(num_filters * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1),  # 输出 [batch_size, 1, H/16, W/16]
            nn.Sigmoid()  # 输出每个patch的真实性（0或1）
        )

    def forward(self, x):
        return self.model(x)

In [None]:
def get_d_loss(real, fake):
    real_labels = torch.ones_like(real)
    fake_labels = torch.zeros_like(fake)
    real_loss = nn.BCELoss()(real, real_labels)
    fake_loss = nn.BCELoss()(fake, fake_labels)
    return real_loss + fake_loss

def get_g_loss(fake):
    fake_labels = torch.ones_like(fake)
    fake_loss = nn.BCELoss()(fake, fake_labels)
    return fake_loss

In [None]:
g_losses = []
d_losses = []

def train(epochs,train_loader,g_optimizer,d_optimizer,generator,discriminator):
    generator.train()
    discriminator.train()
    length = len(train_loader)
    for epoch in range(epochs):
        g_running_loss = 0.0
        d_running_loss = 0.0
        for i,(x,y) in enumerate(train_loader):
            x,y = x.to(device), y.to(device)
            #训练鉴别器
            if (i+1) % 2 ==0:
                d_optimizer.zero_grad()
                d_real = discriminator(y)
                d_fake = discriminator(generator(x.detach()))
                d_loss = get_d_loss(d_real,d_fake)
                d_loss.backward()
                d_running_loss += d_loss.item()
                d_optimizer.step()

            #训练生成器
            g_optimizer.zero_grad()
            g_loss = get_g_loss(discriminator(generator(x)))
            g_loss.backward()
            g_running_loss += g_loss.item()
            g_optimizer.step()

            if(i+1) % 10 == 0:
                print(
                    f"Epoch : [{epoch + 1}/{epochs}]\tIter : [{i + 1}/{length}]\t"
                    f"Generator Loss: {g_running_loss / (i + 1):.3f}\t"
                    f"Discriminator Loss: {d_running_loss / (i + 1):.3f}",
                    '\n'
                )

        d_losses.append(d_running_loss / length)
        g_losses.append(g_running_loss / length)

        print(
            f"Epoch : [{epoch + 1}/{epochs}]\t"
            f"Generator Loss: {g_running_loss / length:.3f}\t"
            f"Discriminator Loss: {d_running_loss / length:.3f}",
        )