### 单图片生成

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import copy

In [5]:
EPOCHS = 50
BATCH_SIZE = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
class saveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [7]:
class unetUpSampleBlock(nn.Module):
    """
    用于创建unet右侧的上采样层，采用转置卷积进行上采样（尺寸×2）
    self.tranConv将上一层进行上采样，尺寸×2
    self.conv，将左侧特征图再做一次卷积减少通道数，所以尺寸不变
    此时两者尺寸正好一致-----建立在图片尺寸为128×128的基础上，否则上采样不能简单的×2
    """
    def __init__(self,in_channels,feature_channels,out_channels,dp=False,ps=0.25):#注意，out_channels 是最终输出通道的一半。
        super(unetUpSampleBlock,self).__init__()
        self.tranConv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2,bias=False)#输出尺寸正好为输入尺寸的两倍
        self.conv = nn.Conv2d(feature_channels,out_channels,1,bias=False) #这一层将传来的特征图再做一次卷积，将特征图通道数减半
        self.bn = nn.BatchNorm2d(out_channels*2) #将特征图与上采样再通道出相加后再一起归一化
        self.dp = dp
        if dp:
            self.dropout = nn.Dropout(ps,inplace=True)
            
    def forward(self,x,features):
        x1 = self.tranConv(x)
        x2 = self.conv(features)
        x = torch.cat([x1,x2],dim=1)
        x = self.bn(F.relu(x))
        return self.dropout(x) if self.dp else x

In [8]:
class Generator(nn.Module):
    #基于resnet50的UNet网络
    #NIR是可见光模式，3通道
    #主干网络为Unet，输入输出尺寸均为64×64
    def __init__(self,model,in_channels,out_channels):
        super(Generator,self).__init__()
        self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels,64,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace = True)
        )
        self.downsample = nn.Sequential(*list(model.children())[4:-2])
        #print(len(list(model.children())[4:-2]))
        self.features = [saveFeatures(list(self.downsample.children())[i]) for i in range(3)]
        self.up1 = unetUpSampleBlock(2048,1024,512) #feature:self.features[2]
        self.up2 = unetUpSampleBlock(1024,512,256)
        self.up3 = unetUpSampleBlock(512,256,128)
        self.up4 = unetUpSampleBlock(256,64,32) #feature:self.layer1的输出
        self.outlayer = nn.Conv2d(64,out_channels,3,1,1)
        
    def forward(self,i):
        x1 = self.layer1(i)
        x = self.downsample(x1)
        x = self.up1(x,self.features[2].features)
        x = self.up2(x,self.features[1].features)
        x = self.up3(x,self.features[0].features)
        x = self.up4(x,x1)
        return x

In [9]:
m = models.resnet50(pretrained=True)
tem_paras = copy.deepcopy(m.layer1[0].downsample[0].state_dict())
m.layer1[0].downsample[0] = nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
m.layer1[0].downsample[0].load_state_dict(tem_paras)
tem_paras = copy.deepcopy(m.layer1[0].conv2.state_dict())
m.layer1[0].conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
m.layer1[0].conv2.load_state_dict(tem_paras)
del tem_paras

In [10]:
genernator_VIS2NIR = Generator(m,3,1)
genernator_NIR2VIS = Generator(m,1,3)
discriminator_A_NIR = models.resnet34(pretrained=True)
discriminator_B_VIS = models.resnet34(pretrained=True)

In [11]:
discriminator_A_NIR.fc = nn.Linear(512,1,bias = True)
discriminator_B_VIS.fc = nn.Linear(512,1,bias = True)
#resnet降的倍数太多了，减少一个pool
discriminator_B_VIS.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
discriminator_A_NIR.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)

In [None]:
genernator_VIS2NIR = nn.DataParallel(genernator_VIS2NIR).cuda()
genernator_NIR2VIS = nn.DataParallel(genernator_NIR2VIS).cuda()
discriminator_A_NIR = nn.DataParallel(discriminator_A_NIR).cuda()
discriminator_B_VIS = nn.DataParallel(discriminator_B_VIS).cuda()

### 以下读取数据

In [12]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import yaml
import torch.utils.data as data

In [13]:
class CustomDatasets(Dataset):
    def __init__(self,img_NIR_dir,img_VIS_dir,NIR_list,VIS_list):
        self.img_NIR_dir = img_NIR_dir
        self.img_VIS_dir = img_VIS_dir
        self.NIR_list = NIR_list
        self.VIS_list = VIS_list
    def __len__(self):
        return len(self.NIR_list)
    def __getitem__(self,idx):
        NIR = Image.open(os.path.join(self.img_NIR_dir,self.NIR_list[idx])).convert('L').resize((64,128))
        VIS = Image.open(os.path.join(self.img_VIS_dir,self.VIS_list[idx])).convert('RGB').resize((64,128))
        
        totensor = transforms.ToTensor()
        return totensor(VIS),totensor(NIR)

In [14]:
def createDataset(img_NIR_dir,img_VIS_dir,p = 0.1):
    NIR_list = os.listdir(img_NIR_dir)
    VIS_list = os.listdir(img_VIS_dir)
    l = int(15513*(1-p))
    return CustomDatasets(img_NIR_dir,img_VIS_dir,NIR_list[:l],VIS_list[:l]),CustomDatasets(img_NIR_dir,img_VIS_dir,NIR_list[l:],VIS_list[l:])

In [15]:
trainSet,testSet = createDataset('./data/trainB/','./data/trainA/')
train_loader = data.DataLoader(trainSet,batch_size=BATCH_SIZE,shuffle=True)
test_loader = data.DataLoader(testSet,1,shuffle=True)

### 以下定义训练过程

In [16]:
#如果只是L1范数，则loss会特别大，可以改用mean(abs(map))
def similarity_loss(real,fake):
    loss = 0
    for i,j in zip(real,fake):
        loss += torch.mean(torch.abs(i-j))
    return loss

In [17]:
def score_loss(discrinminator,fake):
    loss = 0
    for i in fake:
        loss += torch.pow(discrinminator(i.expand(-1,3,-1,-1))-1,2) 
    return loss.mean()

In [18]:
def genernator_train(genernator,discriminator,optim,data):
    genernator[0].train()
    genernator[1].train()
    discriminator.eval()
    discriminator.eval()
    
    VIS2NIR_A_fake = genernator[0](data[0])
    NIR2VIS_B_fake = genernator[1](VIS2NIR_A_fake)
    simil_loss_A = similarity_loss(data[0],NIR2VIS_B_fake)
    dis_loss_A = score_loss(discriminator[0],VIS2NIR_A_fake)
    loss = simil_loss_A+dis_loss_A
    optim[0].zero_grad()
    loss.backward()
    optim[0].step()
    del VIS2NIR_A_fake,NIR2VIS_B_fake
    
    NIR2VIS_B_fake = genernator[1](data[1])
    VIS2NIR_A_fake = genernator[0](NIR2VIS_B_fake)
    simil_loss_B = similarity_loss(data[1],NIR2VIS_B_fake)
    dis_loss_B = score_loss(discriminator[1],NIR2VIS_B_fake)
    loss = simil_loss_B+dis_loss_B
    optim[0].zero_grad()
    loss.backward()
    optim.step()
    
    return simil_loss_A,dis_loss_A,simil_loss_B,dis_loss_B

In [19]:
def discriminator_loss(discriminator,fake,real):
    loss = 0
    for i,j in zip(fake,real):
        #print(j.expand(-1,3,-1,-1).shape)
        loss += (torch.pow(discriminator(j.expand(-1,3,-1,-1))-1,2)+torch.pow(discriminator(i.expand(-1,3,-1,-1)),2))
    return loss.mean()

In [20]:
def discriminator_train(genernator,discriminator,optim,data):
    discriminator[0].train()
    discriminator[1].train()
    genernator[0].eval()
    genernator[1].eval()
    
    VIS2NIR_fake = genernator[0](data[0]).detach()
    
    loss_A = discriminator_loss(discriminator[0],VIS2NIR_fake,data[0])
    del VIS2NIR_fake
    optim[0].zero_grad()
    loss_A.backward()
    optim[0].step()
    NIR2VIS_fake = genernator[1](data[1]).detach()
    loss_B = discriminator_loss(discriminator[1],NIR2VIS_fake.data[1])
    del NIR2VIS_fake
    optim[1].zero_grad()
    loss_B.backward()
    optim[1].step()
    
    return loss_A,loss_B

In [21]:
import matplotlib.pyplot as plt

In [22]:
def test(genernator,data,epoch):
    genernator[0].eval()
    genernator[1].eval()
    
    transform = transforms.ToPILImage()
    with torch.no_grad():
        VIS2NIR_fake = genernator[0](data[0])
        NIR2VIS_fake = genernator[1](VIS2NIR_fake)
    fig=plt.figure(figsize=(16, 4))
    columns = 4
    fig.add_subplot(rows, columns, 1)
    plt.imshow(transform(data[0].cpu()))
    fig.add_subplot(rows, columns, 2)
    plt.imshow(transform(VIS2NIR_fake.cpu()))
    fig.add_subplot(rows, columns, 3)
    plt.imshow(transform(NIR2VIS_fake.cpu()))
    fig.add_subplot(rows, columns, 4)
    plt.imshow(transform(data[1].cpu()))
    plt.tight_layout()       
    plt.savefig('./process_image/VIS2NIR_A_%d.jpg'%(epoch+1))
    plt.show()
    with torch.no_grad():
        NIR2VIS_fake = genernator[1](data[1])
        VIS2NIR_fake = genernator[0](data[0])
    fig=plt.figure(figsize=(16, 4))
    columns = 4
    fig.add_subplot(rows, columns, 1)
    plt.imshow(transform(data[1].cpu()))
    fig.add_subplot(rows, columns, 2)
    plt.imshow(transform(NIR2VIS_fake.cpu()))
    fig.add_subplot(rows, columns, 3)
    plt.imshow(transform(VIS2NIR_fake.cpu()))
    fig.add_subplot(rows, columns, 4)
    plt.imshow(transform(data[0].cpu()))
    plt.tight_layout()       
    plt.savefig('./process_image/NIR2VIS_B_%d.jpg'%(epoch+1))
    plt.show()

In [23]:
import torch.optim as optim

In [24]:
optimzer_gen_A_VIS2NIR = optim.RMSprop(genernator_VIS2NIR.parameters(),lr=0.0002)
optimzer_gen_B_NIR2VIS = optim.RMSprop(genernator_NIR2VIS.parameters(),lr=0.0002)
optimzer_dis_A = optim.Adam(discriminator_A_NIR.parameters(),lr = 0.001)
optimzer_dis_B = optim.Adam(discriminator_B_VIS.parameters(),lr=0.001)

In [None]:
for epoch in EPOCHS:
    test([genernator_VIS2NIR,genernator_NIR2VIS],next(iter(test_loader)),epoch)
    for data in train_loader:
        data[0] = data[0].to(device)
        data[1] = data[1].to(device)
        simil_loss_A,dis_loss_A,simil_loss_B,dis_loss_B = genernator_train(
            [genernator_VIS2NIR,genernator_NIR2VIS],
            [discriminator_A_NIR,discriminator_B_VIS],
            [optimzer_gen_A_VIS2NIR,optimzer_gen_B_NIR2VIS],
            data)
        loss_A,loss_B = discriminator_train(
            [genernator_VIS2NIR,genernator_NIR2VIS],
            [discriminator_A_NIR,discriminator_B_VIS],
            [optimzer_gen_A_VIS2NIR,optimzer_gen_B_NIR2VIS],
            data)
    print('epoch: {}/{},loss_A_consistency: {},loss_A_discriminator: {},loss_B_consistency: {},loss_B_discriminator: {}'.format(epoch+1,EPOCHS,simil_loss_A,dis_loss_A.item(),simil_loss_B,dis_loss_B.item()))
    print('discriminator_A_VIS_loss:{},discriminator_B_NIR_loss{}'.format(loss_A.item(),loss_B.item()))