### 以下定义生成器与判别器的网络结构

- 生成器和判别器均使用孪生网络  
- 生成器使用RNN方式，第一个step生成local文件，第二个step生成全局图像。  
- 生成器step1和step2使用同一个骨干网络

In [209]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import copy

In [48]:
#每个local的输出的分支网络（如果需要的话）（如果输入也要分支也可以用这个）
class outBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(outBlock,self).__init__()
        self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels,32,kernel_size=(3,3),stride=1,padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True)
        )
        self.shortcut = nn.Sequential(
        nn.Conv2d(32,64,kernel_size=1,stride=1,padding=0),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True)
        )
        self.block = nn.Sequential(
        nn.Conv2d(32,64,kernel_size = 1,stride = 1,padding = 0),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace = True),
        nn.Conv2d(64,64,kernel_size = 3,stride = 1,padding = 1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True)
        )
        self.sqush = nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1)
        
    def forward(self,x):
        x = self.layer1(x)
        x1 = self.shortcut(x)
        x2 = self.block(x)
        x = torch.cat((x1,x2),1)
        return self.sqush(x)

In [76]:
#注册钩子函数
class saveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [211]:
class unetUpSampleBlock(nn.Module):
    """
    用于创建unet右侧的上采样层，采用转置卷积进行上采样（尺寸×2）
    self.tranConv将上一层进行上采样，尺寸×2
    self.conv，将左侧特征图再做一次卷积减少通道数，所以尺寸不变
    此时两者尺寸正好一致-----建立在图片尺寸为128×128的基础上，否则上采样不能简单的×2
    """
    def __init__(self,in_channels,feature_channels,out_channels,dp=False,ps=0.25):#注意，out_channels 是最终输出通道的一半。
        super(unetUpSampleBlock,self).__init__()
        self.tranConv = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2,bias=False)#输出尺寸正好为输入尺寸的两倍
        self.conv = nn.Conv2d(feature_channels,out_channels,1,bias=False) #这一层将传来的特征图再做一次卷积，将特征图通道数减半
        self.bn = nn.BatchNorm2d(out_channels*2) #将特征图与上采样再通道出相加后再一起归一化
        self.dp = dp
        if dp:
            self.dropout = nn.Dropout(ps,inplace=True)
            
    def forward(self,x,features):
        x1 = self.tranConv(x)
        x2 = self.conv(features)
        x = torch.cat([x1,x2],dim=1)
        x = self.bn(F.relu(x))
        return self.dropout(x) if self.dp else x

In [261]:
class Generator(nn.Module):
    #基于resnet50的UNet网络
    #NIR是可见光模式，3通道
    #主干网络为Unet，输入输出尺寸均为64×64
    def __init__(self,model,in_channels,out_channels):
        super(Generator,self).__init__()
        self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels,64,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace = True)
        )
        self.downsample = nn.Sequential(*list(model.children())[4:-2])
        #print(len(list(model.children())[4:-2]))
        self.features = [saveFeatures(list(self.downsample.children())[i]) for i in range(3)]
        self.up1 = unetUpSampleBlock(2048,1024,512) #feature:self.features[2]
        self.up2 = unetUpSampleBlock(1024,512,256)
        self.up3 = unetUpSampleBlock(512,256,128)
        self.up4 = unetUpSampleBlock(256,64,32) #feature:self.layer1的输出
        self.outlayer = nn.Conv2d(64,out_channels,3,1,1)
        
    def forward(self,batch):
        out_batch = []
        for i in batch:
            x1 = self.layer1(i)
            x = self.downsample(x1)
            x = self.up1(x,self.features[2].features)
            x = self.up2(x,self.features[1].features)
            x = self.up3(x,self.features[0].features)
            x = self.up4(x,x1)
            out_batch.append(self.outlayer(x))
        return out_batch

In [262]:
m = models.resnet50()
tem_paras = copy.deepcopy(m.layer1[0].downsample[0].state_dict())
m.layer1[0].downsample[0] = nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
m.layer1[0].downsample[0].load_state_dict(tem_paras)
tem_paras = copy.deepcopy(m.layer1[0].conv2.state_dict())
m.layer1[0].conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
m.layer1[0].conv2.load_state_dict(tem_paras)
del tem_paras

In [268]:
#m = Generator(m,3,1)  #------生成一个生成器
f = torch.ones(2,3,32,16)
t = torch.ones(2,3,32,16)
f = m([f])

In [270]:
f[0].shape

torch.Size([2, 1, 32, 16])

### 以下定义数据读取

- 分别读取一张图片即上面的头、胸、手、腿  
- 全局图片为.resize((64,128))  
- 头为.resize((32,16))
- 胸部为.resize((64,64))  
- 手臂为.resize((64,64))  
- 腿部为.resize((64,128))  
- 坐标文件存储在images_NIR.yml和images_VIS.yml两个文件上

In [273]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import yaml

In [271]:
def process(x1,y1,x2,y2):
    return x1,y1,x2,y2

In [274]:
class CustomDatasets(Dataset):
    def __init__(self,img_NIR_all,img_VIS_all,img_train_NIR_list,img_train_VIS_list,img_NIR_dir,img_VIS_dir):
        self.img_NIR_all = img_NIR_all
        self.img_VIS_all = img_VIS_all
        self.img_train_NIR_list = img_train_NIR_list
        self.img_train_VIS_list = img_train_VIS_list
        self.img_NIR_dir = img_NIR_dir
        self.img_VIS_dir = img_VIS_dir
        self.NIR_key = list(img_NIR_all.keys())
        self.VIS_key = list(img_VIS_all.keys())
        
    def __len__(self):
        return len(self.img_train_NIR_list)
    
    def __getitem__(self,idx):
        img_NIR_info = self.img_NIR_all[self.NIR_key[idx]]
        img_VIS_info = self.img_VIS_all[self.VIS_key[idx]]
        
        batch = {}
        
        name = self.NIR_key[idx].split('.')
        name = name[0][:-2]+'.'+name[1]
        batch['img_NIR'] = Image.open(os.path.join(self.img_NIR_dir,name)).convert('L').resize((64,128))
        #如果想要打乱NIR图像与VIS图像之间的关系的话只需重新随机选择一个idx即可
        name = self.VIS_key[idx].split('.')
        name = name[0][:-2]+'.'+name[1]
        batch['img_VIS'] = Image.open(os.path.join(self.img_VIS_dir,name)).convert('RGB').resize((64,128))
        
        batch['id_NIR'] = int(self.NIR_key[idx].split('_')[0])
        batch['id_VIS'] = int(self.VIS_key[idx].split('_')[0])
        
        batch['head_NIR'] = batch['img_NIR'].crop(process(**img_NIR_info['head'])).resize((32,16))
        batch['head_VIS'] = batch['img_VIS'].crop(process(**img_NIR_info['head'])).resize((32,16))
        
        batch['chest_NIR'] = batch['img_NIR'].crop(process(**img_NIR_info['chest'])).resize((64,64))
        batch['chest_VIS'] = batch['img_VIS'].crop(process(**img_NIR_info['chest'])).resize((64,64))
        
        batch['thigh_NIR'] = batch['img_NIR'].crop(process(**img_NIR_info['thigh'])).resize((64,64))
        batch['thigh_VIS'] = batch['img_VIS'].crop(process(**img_NIR_info['thigh'])).resize((64,64))
        
        batch['leg_NIR'] = batch['img_NIR'].crop(process(**img_NIR_info['leg'])).resize((64,128))
        batch['leg_VIS'] = batch['img_VIS'].crop(process(**img_NIR_info['leg'])).resize((64,128))
        
        totensor = transforms.ToTensor()
        for i in batch.keys():
            if i == 'id_NIR' or i == 'id_VIS':
                continue
            batch[i] = totensor(batch[i])
        return batch

In [275]:
def createDatasets(yaml_NIR,yaml_VIS,img_NIR_dir,img_VIS_dir,p_test=0.1):
    with open(yaml_NIR,'r') as rf:
        img_NIR_all = yaml.safe_load(rf.read())
    with open(yaml_VIS,'r') as rf:
        img_VIS_all = yaml.safe_load(rf.read())
        
    #假设img_NIR_all和img_VIS_all长度一致
    length = min(len(img_NIR_all),len(img_VIS_all))
    
    img_test_NIR_list = list(img_NIR_all.keys())[:int(length*p_test)]
    img_test_VIS_list = list(img_VIS_all.keys())[:int(length*p_test)]
    img_train_NIR_list = list(img_NIR_all.keys())[int(length*p_test):length]
    img_train_VIS_list = list(img_VIS_all.keys())[int(length*p_test):length]
    #return img_NIR_all,img_VIS_all,img_train_NIR_list,img_train_VIS_list,img_NIR_dir,img_VIS_dir
    return CustomDatasets(img_NIR_all,img_VIS_all,img_train_NIR_list,img_train_VIS_list,img_NIR_dir,img_VIS_dir),CustomDatasets(img_NIR_all,img_VIS_all,img_test_NIR_list,img_test_VIS_list,img_NIR_dir,img_VIS_dir)

In [None]:
def concat_patch(batch):
    

### 以下定义训练及测试

- 训练过程：先生成5各VIS图，再把它们合并再生成更清楚的全局VIS图

In [None]:
def train_genernator(genernator_A,genernator_B,merge_A,mearge_B,discriminator_A,data_batch):
    fake_VIS_1 = genernator_A([data_batch['img_NIR'],data_batch['head_NIR'],data_batch['chest_NIR'],data_batch['thigh_NIR'],data_batch['leg_NIR']])
    