<a href="https://colab.research.google.com/github/lifeisbeautifu1/deep-learning/blob/main/High_Resolution_Image_Inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Source: https://arxiv.org/pdf/1611.09969.pdf

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ContentLoss(nn.Module):
    """
    content loss layer
    """
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()
        self.loss = None

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

    def update(self, target):
        """
        update target of content loss
        :param target:
        :return:
        """
        self.target = target.detach()

In [None]:
class TVLoss(nn.Module):
    def __init__(self, weight = 1):
        super(TVLoss,self).__init__()
        self.weight = weight
        self.loss = None

    def forward(self, input):
        image = input.squeeze().permute([1, 2, 0])
        r = (image[:, :, 0] + 2.12) / 4.37
        g = (image[:, :, 1] + 2.04) / 4.46
        b = (image[:, :, 2] + 1.80) / 4.44

        temp = torch.cat([r.unsqueeze(2), g.unsqueeze(2), b.unsqueeze(2)], dim=2)
        gx = torch.cat((temp[1:, :, :], temp[-1, :, :].unsqueeze(0)), dim=0)
        gx = gx - temp

        gy = torch.cat((temp[:, 1:, :], temp[:, -1, :].unsqueeze(1)), dim=1)
        gy = gy - temp

        self.loss = torch.mean(torch.pow(gx, 2)) + torch.mean(torch.pow(gy, 2))
        return input

In [None]:
class StyleLoss(nn.Module):
    def __init__(self, target, patch_size, mrf_style_stride, mrf_synthesis_stride, gpu_chunck_size, device):
        super(StyleLoss, self).__init__()
        self.patch_size = patch_size
        self.mrf_style_stride = mrf_style_stride
        self.mrf_synthesis_stride = mrf_synthesis_stride
        self.gpu_chunck_size = gpu_chunck_size
        self.device = device
        self.loss = None

        self.style_patches = self.patches_sampling(target.detach(), patch_size = self.patch_size, stride = self.mrf_style_stride)
        self.style_patches_norm = self.cal_patches_norm()
        self.style_patches_norm = self.style_patches_norm.view(-1, 1, 1)

    def update(self,target):
        self.style_patches = self.patches_sampling(target.detach(), patch_size = self.patch_size, stride = self.mrf_style_stride)
        self.style_patches_norm = self.cal_patches_norm()
        self.style_patches_norm = self.style_patches_norm.view(-1, 1, 1)

    def forward(self, input):
        sysnthesis_patches = self.content_patches_sampling(input, patch_size = self.patch_size, stride = self.mrf_synthesis_stride)
        max_response=[]
        for i in range(0, self.style_patches.shape[0], self.gpu_chunck_size):
            i_start = i
            i_end = min(i + self.gpu_chunck_size, self.style_patches.shape[0])
            weight = self.style_patches[i_start:i_end, :, :, :]
            response = F.conv2d(input, weight, stride=self.mrf_synthesis_stride)
            max_response.append(response.squeeze(dim=0))
        max_response = torch.cat(max_response, dim=0)

        max_response = max_response.div(self.style_patches_norm)
        max_response = torch.argmax(max_response, dim=0)
        max_response = torch.reshape(max_response, (1, -1)).squeeze()

        loss=0

        for i in range(0, len(max_response), self.gpu_chunck_size):
            i_start = i
            i_end = min(i + self.gpu_chunck_size, len(max_response))
            tp_ind = tuple(range(i_start, i_end))
            sp_ind = max_response[i_start:i_end]
            loss += torch.sum(torch.mean(torch.pow(sysnthesis_patches[tp_ind,:,:,:] - self.style_patches[sp_ind,:,:,:], 2), dim=[1,2,3]))

        self.loss = loss / len(max_response)

        return input

    def patches_sampling(self,image,patch_size,stride):
        h,w=image.shape[2:4]
        patches=[]

        for i in range(0, h - patch_size + 1, stride):
            for j in range(0, w - patch_size + 1, stride):
                centerX = i + self.patch_size / 2
                centerY = j + self.patch_size / 2
                bool = (centerX > h / 4) and (centerX < (h * 3 / 4)) and (centerY > w / 4) and(centerY < (w * 3 / 4))
                if(not bool):
                    patches.append(image[:,:,i:i + patch_size,j:j + patch_size])
        patches=torch.cat(patches, dim=0).to(self.device)
        return patches



    def content_patches_sampling(self, image, patch_size, stride):
        h, w = image.shape[2:4]
        patches = []
        for i in range(0, h - patch_size + 1, stride):
            for j in range(0, w - patch_size + 1, stride):
                patches.append(image[:, :, i:i + patch_size, j:j + patch_size])
        patches = torch.cat(patches, dim=0).to(self.device)
        return patches


    def cal_patches_norm(self):
        norm_array = torch.zeros(self.style_patches.shape[0])
        for i in range(self.style_patches.shape[0]):
            norm_array[i] = torch.pow(torch.sum(torch.pow(self.style_patches[i], 2)), 0.5)
        return norm_array.to(self.device)

In [None]:
class CNNMRF(nn.Module):
    def __init__(self, style_image, content_image, device, content_weight, style_weight, tv_weight, gpu_chunck_size=256,
                 mrf_style_stride = 2, mrf_synthesis_stride = 2):
        super(CNNMRF, self).__init__()
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.tv_weight = tv_weight
        self.patch_size = 3
        self.device = device
        self.gpu_chunck_size = gpu_chunck_size
        self.mrf_style_stride = mrf_style_stride
        self.mrf_synthesis_stride = mrf_synthesis_stride
        self.style_layers = [12, 21]
        self.content_layers = [22]
        self.model, self.content_losses, self.style_losses, self.tv_loss = self.get_model_and_losses(
            style_image=style_image, content_image=content_image)

    def forward(self, input):
        # input is synthesis picture
        out = self.model(input)
        style_score = 0
        content_score = 0
        tv_score = self.tv_loss.loss

        # calculate style loss
        for sl in self.style_losses:
            style_score += sl.loss
        # calculate content loss
        for cl in self.content_losses:
            content_score += cl.loss

        loss = self.style_weight * style_score + self.content_weight * content_score + self.tv_weight + tv_score
        return loss

    def get_model_and_losses(self, style_image, content_image):
        style_image.to(self.device)
        content_image.to(self.device)
        vgg = models.vgg19(pretrained=True).to(self.device)
        model = nn.Sequential().to(self.device)
        content_losses = []
        style_losses = []
        # add tv loss layer
        tv_loss = TVLoss().to(self.device)
        model.add_module('tv_loss', tv_loss)

        next_content_idx = 0
        next_style_idx = 0

        for i in range(len(vgg.features)):
            if next_content_idx >= len(self.content_layers) and next_style_idx >= len(self.style_layers):
                break
            # add layer of vgg19
            layer = vgg.features[i]
            layer = layer.to(self.device)
            name = str(i)
            model.add_module(name, layer)

            # add content loss layer
            content_image.to(self.device)
            if i in self.content_layers:
                target = model(content_image).detach()
                content_loss = ContentLoss(target)
                model.add_module("content_loss_{}".format(next_content_idx), content_loss)
                content_losses.append(content_loss)
                next_content_idx += 1

            # add style loss layer
            if i in self.style_layers:
                target_feature = model(style_image).detach()
                style_loss = StyleLoss(target_feature, patch_size=self.patch_size,
                                                 mrf_style_stride=self.mrf_style_stride,
                                                 mrf_synthesis_stride=self.mrf_synthesis_stride,
                                                 gpu_chunck_size=self.gpu_chunck_size, device=self.device)

                model.add_module("style_loss_{}".format(next_style_idx), style_loss)
                style_losses.append(style_loss)
                next_style_idx += 1

        return model, content_losses, style_losses, tv_loss

    def update_style_and_content_image(self, style_image, content_image):
        style_image.to(self.device)
        content_image.to(self.device)
        x = style_image.clone().to(self.device)
        next_style_idx = 0
        i = 0
        for layer in self.model:
            if isinstance(layer, TVLoss) or isinstance(layer, ContentLoss) or isinstance(layer, StyleLoss):
                continue
            if next_style_idx >= len(self.style_losses):
                break
            x = layer(x)
            if i in self.style_layers:
                # extract feature of style image in vgg19 as style loss target
                self.style_losses[next_style_idx].update(x)
                next_style_idx += 1
            i += 1

        # update the target of content loss layer
        x = content_image.clone().to(self.device)
        next_content_idx = 0
        i = 0
        for layer in self.model:
            if isinstance(layer, TVLoss) or isinstance(layer, ContentLoss) or isinstance(layer, StyleLoss):
                continue
            if next_content_idx >= len(self.content_losses):
                break
            x = layer(x)
            if i in self.content_layers:
                # extract feature of content image in vgg19 as content loss target
                self.content_losses[next_content_idx].update(x)
                next_content_idx += 1
            i += 1

In [None]:
class ContentNet(nn.Module):
    def __init__(self, option):
        super(ContentNet, self).__init__()
        self.net = nn.Sequential(
            # input: 3 * 128 * 128?
            # first layer: input -> 64 * 64 * 64
            nn.Conv2d(option.channal, option.filter1, 4, 2, 1, bias=False),
            nn.ELU(alpha=0.2, inplace=True),
            # second layer: 64 * 64 * 64 -> 64 * 32 * 32
            nn.Conv2d(option.filter1, option.filter1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.filter1),
            nn.ELU(alpha=0.2, inplace=True),
            # third layer
            nn.Conv2d(option.filter1, option.filter1 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.filter1 * 2),
            nn.ELU(alpha=0.2, inplace=True),
            # fourth layer
            nn.Conv2d(option.filter1 * 2, option.filter1 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.filter1 * 4),
            nn.ELU(alpha=0.2, inplace=True),
            # fifth layer
            nn.Conv2d(option.filter1 * 4, option.filter1 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.filter1 * 8),
            nn.ELU(alpha=0.2, inplace=True),
            # bottleneck of encoder
            nn.Conv2d(option.filter1 * 8, option.bottleneck, 4, bias=False),
            nn.BatchNorm2d(option.bottleneck),
            nn.ELU(alpha=0.2, inplace=True),
            # decoder first layer
            nn.ConvTranspose2d(option.bottleneck, option.deFilter * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(option.deFilter * 8),
            nn.ELU(alpha=0.2, inplace=True),
            # decoder second layer
            nn.ConvTranspose2d(option.deFilter * 8, option.deFilter * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.deFilter * 4),
            nn.ELU(alpha=0.2, inplace=True),
            # thrid layer
            nn.ConvTranspose2d(option.deFilter * 4, option.deFilter * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.deFilter * 2),
            nn.ELU(alpha=0.2, inplace=True),
            # fourth layer
            nn.ConvTranspose2d(option.deFilter * 2, option.deFilter, 4, 2, 1, bias=False),
            nn.BatchNorm2d(option.deFilter),
            nn.ELU(alpha=0.2, inplace=True),
            # output layer
            nn.ConvTranspose2d(option.deFilter, option.channal, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.net(input)
        return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self, opt):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # input 3 x 64 x 64
            nn.Conv2d(opt.channal, opt.DFilter, 4, 2, 1, bias=False),
            nn.ELU(0.2, inplace=True),
            # state size. 64 x 32 x 32
            nn.Conv2d(opt.DFilter, opt.DFilter * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.DFilter * 2),
            nn.ELU(0.2, inplace=True),
            # state size. 128 x 16 x 16
            nn.Conv2d(opt.DFilter * 2, opt.DFilter * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.DFilter * 4),
            nn.ELU(0.2, inplace=True),
            # state size. 256 x 8 x 8
            nn.Conv2d(opt.DFilter * 4, opt.DFilter * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.DFilter * 8),
            nn.ELU(0.2, inplace=True),
            # state size. 512 x 4 x 4
            nn.Conv2d(opt.DFilter * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.net(input)
        return output.view(-1, 1)

# Train content

In [None]:
import os
import argparse
import random

import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

parser = argparse.ArgumentParser()
opt = parser.parse_args(args=[])
opt.bottleneck = 4000
opt.channal = 3
opt.DFilter = 64
opt.filter1 = 64
opt.deFilter = 64
opt.imageSize = 128
opt.dataroot = "/content/drive/MyDrive/data"
opt.batchSize = 32
opt.contentNet = "/content/drive/MyDrive/model/contentNet_cifar10.pth"
opt.discriNet = "/content/drive/MyDrive/model/discriNet_cifar10.pth"
opt.overlapPred = 4
opt.lr = 0.0002
opt.beta1 = 0.5
opt.niter = 90
opt.wtl2 = 0.998

try:
    os.makedirs("result/train/cropped")
    os.makedirs("result/train/real")
    os.makedirs("result/train/recon")
    # os.makedirs("model")
except OSError:
    pass

# random seed set
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cudnn.benchmark = True

transform = transforms.Compose([transforms.Resize(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, )


# custom weights initialization
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


resume_epoch = 80

contentNet = ContentNet(opt)
contentNet.apply(weights_init)
if opt.contentNet != '' and os.path.exists(opt.contentNet):
    contentNet.load_state_dict(torch.load(opt.contentNet, map_location=lambda storage, location: storage)['state_dict'])
    print("Loaded weights for ContentNet successfully")
    resume_epoch = torch.load(opt.contentNet)['epoch']
print(contentNet)

discriNet = Discriminator(opt)
discriNet.apply(weights_init)
if opt.discriNet != '' and os.path.exists(opt.discriNet):
    discriNet.load_state_dict(torch.load(opt.discriNet, map_location=lambda storage, location: storage)['state_dict'])
    print("Loaded weights for Discrimanator successfully")
    resume_epoch = torch.load(opt.discriNet)['epoch']
print(discriNet)

criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

input_real = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
label = torch.FloatTensor(opt.batchSize)

real_label = 1
fake_label = 0

real_center = torch.FloatTensor([opt.batchSize, 3, opt.imageSize / 2, opt.imageSize / 2])

# move to cuda
contentNet.cuda()
discriNet.cuda()
criterion.cuda()
criterionMSE.cuda()
input_real, input_cropped, label = input_real.cuda(), input_cropped.cuda(), label.cuda()
real_center = real_center.cuda()

input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)

real_center = Variable(real_center)

# setup optimizer
optimizerD = optim.Adam(discriNet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(contentNet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

for epoch in range(resume_epoch, opt.niter):
    for i, data in enumerate(dataloader, 0):
        real_raw, _ = data
        real_center_raw = real_raw[:, :, int(opt.imageSize / 4):int(opt.imageSize / 4) + int(opt.imageSize / 2),
                          int(opt.imageSize / 4):int(opt.imageSize / 4) + int(opt.imageSize / 2)]
        batch_size = real_raw.size(0)

        input_real.resize_(real_raw.size()).copy_(real_raw)
        input_cropped.resize_(real_raw.size()).copy_(real_raw)
        real_center.resize_(real_center_raw.size()).copy_(real_center_raw)

        input_cropped.data[:, 0,
        int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
        int(opt.imageSize / 4 + opt.overlapPred):int(
            opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 117.0 / 255.0 - 1.0
        input_cropped.data[:, 1,
        int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
        int(opt.imageSize / 4 + opt.overlapPred):int(
            opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 104.0 / 255.0 - 1.0
        input_cropped.data[:, 2,
        int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
        int(opt.imageSize / 4 + opt.overlapPred):int(
            opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 123.0 / 255.0 - 1.0

        discriNet.zero_grad()
        label.resize_(batch_size).fill_(real_label)

        output = discriNet(real_center)
        err_real_D = criterion(output, torch.unsqueeze(label, 1))
        err_real_D.backward()
        D_x = output.data.mean()

        # train with fake
        fake = contentNet(input_cropped)
        label.data.fill_(fake_label)
        output = discriNet(fake.detach())
        err_fake_D = criterion(output, torch.unsqueeze(label, 1))
        err_fake_D.backward()
        D_G_1 = output.data.mean()
        errD = err_real_D + err_fake_D
        optimizerD.step()

        # 优化生成器  maximize log(D(G(z)))
        contentNet.zero_grad()
        # 目标是将其完全变为真
        label.data.fill_(real_label)
        output = discriNet(fake)
        err_G_D = criterion(output, torch.unsqueeze(label, 1))

        errG_l2 = criterionMSE(fake, real_center)
        # ？？？？
        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(float(opt.wtl2) * 10)
        wtl2Matrix.data[:, :, int(opt.overlapPred):int(opt.imageSize / 2 - opt.overlapPred),
        int(opt.overlapPred):int(opt.imageSize / 2 - opt.overlapPred)] = float(opt.wtl2)

        errG_l2 = (fake - real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix
        errG_l2 = errG_l2.mean()

        errG = (1 - float(opt.wtl2)) * err_G_D + float(opt.wtl2) * errG_l2

        errG.backward()

        D_G_z2 = output.data.mean()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
              % (epoch, opt.niter, i, len(dataloader),
                 errD.item(), err_G_D.item(), errG_l2.item(), D_x, D_G_1,))

        if i % 100 == 0:
            vutils.save_image(real_raw,
                              'result/train/real/real_samples_epoch_%03d_%03d.png' % (epoch, i))
            vutils.save_image(input_cropped.data,
                              'result/train/cropped/cropped_samples_epoch_%03d_%03d.png' % (epoch, i))
            recon_image = input_cropped.clone()
            recon_image.data[:, :, int(opt.imageSize / 4):int(opt.imageSize / 4 + opt.imageSize / 2),
            int(opt.imageSize / 4):int(opt.imageSize / 4 + opt.imageSize / 2)] = fake.data
            vutils.save_image(recon_image.data,
                              'result/train/recon/recon_center_samples_epoch_%03d_%03d.png' % (epoch, i))

        if i % 100 == 0:
            # do checkpointing
            torch.save({'epoch': epoch,
                        'state_dict': contentNet.state_dict()},
                       './drive/MyDrive/model/contentNet_cifar10.pth')
            torch.save({'epoch': epoch,
                        'state_dict': discriNet.state_dict()},
                       './drive/MyDrive/model/discriNet_cifar10.pth')
            print("第{}已保存".format(i))

Loaded weights for ContentNet successfully
ContentNet(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ELU(alpha=0.2, inplace=True)
    (2): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ELU(alpha=0.2, inplace=True)
    (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ELU(alpha=0.2, inplace=True)
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ELU(alpha=0.2, inplace=True)
    (11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=Tru

In [None]:
import cv2
import torchvision

def get_synthesis_image(synthesis, denorm, device):
    cpu_device = torch.device('cpu')
    image = synthesis.clone().squeeze().to(cpu_device)

    image = denorm(image)
    return image.to(device).clamp_(0, 1)

def unsample_synthesis(height, width, synthesis, device):
    synthesis = F.interpolate(synthesis, size=[height, width], mode='bilinear')
    synthesis = synthesis.clone().detach().requires_grad_(True).to(device)
    return synthesis

def main(config, cropped, synthesis_in, dir):
    # cropped is 512*512 with a hole in center
    # synthesis is 64*64
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    "-------------------transform and denorm transform-----------------"

    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406]
    # and std=[0.229, 0.224, 0.225].
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
    denorm_transform = transforms.Normalize(mean=(-2.12, -2.04, -1.80), std=(4.37, 4.46, 4.44))

    "resize image in several level for training"
    size = 256
    synthesis_in = F.interpolate(synthesis_in, [size, size], mode="bilinear")

    cropped.to(device)
    synthesis_in.to(device)

    pyramid_content_image = []
    pyramid_style_image = []
    for i in range(config.num_res):
        cropped_sub = F.interpolate(cropped, scale_factor = 1 / pow(2, config.num_res - 1 - i), mode='bilinear').to(device)

        synthesis_in_sub = F.interpolate(synthesis_in, scale_factor = 1 / pow(2, config.num_res - 1 - i), mode='bilinear').to(device)

        pyramid_content_image.append(cropped_sub)
        pyramid_style_image.append(synthesis_in_sub)

    "-----------------start training-------"
    global iter
    iter = 0
    synthesis = None

    # create cnnmrf model
    cnnmrf = CNNMRF(style_image=pyramid_style_image[0], content_image=pyramid_content_image[0], device=device,
                    content_weight=config.content_weight, style_weight=config.style_weight, tv_weight=config.tv_weight,
                    gpu_chunck_size=config.gpu_chunck_size, mrf_synthesis_stride=config.mrf_synthesis_stride,
                    mrf_style_stride=config.mrf_style_stride).to(device)

    # Sets the module in training mode.
    cnnmrf.train()
    for i in range(0, config.num_res):
        if i == 0:
            # in lowest level init the synthesis from content resized image
            synthesis = pyramid_content_image[0].clone().requires_grad_(True).to(device)
        else:
            # in high level init the synthesis from unsampling the upper level synthesis
            synthesis = unsample_synthesis(pyramid_content_image[i].shape[2], pyramid_content_image[i].shape[3],
                                           synthesis, device)
            cnnmrf.update_style_and_content_image(style_image=pyramid_style_image[i],
                                                  content_image=pyramid_content_image[i])

        # max_iter (int): maximal number of iterations per optimization step
        optimizer = optim.LBFGS([synthesis], lr=1, max_iter=config.max_iter)
        "--------------------"

        def closure():
            global iter
            optimizer.zero_grad()
            loss = cnnmrf(synthesis)
            loss.backward(retain_graph=True)
            # print loss
            if (iter + 1) % 10 == 0:
                print('res-%d-iteration-%d: %f' % (i + 1, iter + 1, loss.item()))
            # save image
            if (iter + 1) % config.sample_step == 0 or iter + 1 == config.max_iter:
                image = get_synthesis_image(synthesis, denorm_transform, device)
                image = F.interpolate(image.unsqueeze(0), size=pyramid_content_image[i].shape[2:4], mode='bilinear')
                torchvision.utils.save_image(image.squeeze(), 'res-%d-result-%d.jpg' % (i + 1, iter + 1))
                print('save image: res-%d-result-%d.jpg' % (i + 1, iter + 1))
            iter += 1
            if iter == config.max_iter:
                iter = 0
            return loss

        "----------------------"
        optimizer.step(closure)

    image = get_synthesis_image(synthesis, denorm_transform, device)
    image = F.interpolate(image.unsqueeze(0), size=pyramid_content_image[2].shape[2:4], mode='bilinear')
    return image

def texture(cropped, synthesis, dir):
    parser = argparse.ArgumentParser()
    config = parser.parse_args(args=[])
    config.max_iter = 50 # (!)
    config.sample_step = 50
    config.content_weight = 1
    config.style_weight = 0.6
    config.tv_weight = 0.35
    config.num_res = 3
    config.gpu_chunck_size = 256
    config.mrf_style_stride = 2
    config.mrf_synthesis_stride = 2
    setting = str(config)
    setting.replace(', ', '\n')

    with open(dir + '/setting.txt', 'a') as file_handle:
        file_handle.write(setting)  # 写入
        file_handle.write('\n')

    return main(config, cropped, synthesis, dir)

In [None]:
from PIL import Image

parser = argparse.ArgumentParser()
opt = parser.parse_args(args=[])
opt.bottleneck = 4000
opt.channal = 3
opt.deFilter = 64
opt.DFilter = 64
opt.filter1 = 64
opt.imageSize = 128
opt.imageSize_raw = 512
opt.dataroot = "/content/drive/MyDrive/data"
opt.batchSize = 32
opt.contentNet = "/content/drive/MyDrive/models/contentNet_cifar10.pth"
opt.discriNet = "/content/drive/MyDrive/models/discriNet_cifar10.pth"
opt.content_path = "/content/drive/MyDrive/lincoln.png"
opt.overlapPred = 4
opt.lr = 0.0002
opt.beta1 = 0.5
opt.niter = 25
opt.wtl2 = 0.998

pic = opt.content_path
dir='pic_result'

if not os.path.exists(dir):
    os.mkdir(dir)

# cut edge
transform = transforms.Compose([transforms.Resize(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform1 = transforms.Compose([transforms.Resize(opt.imageSize_raw),
                                 transforms.CenterCrop(opt.imageSize_raw),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

denorm_transform = transforms.Normalize(mean=(-1, -1, -1), std=(2, 2, 2))
denorm_transform1 = transforms.Normalize(mean=(-2.12, -2.04, -1.80), std=(4.37, 4.46, 4.44))
transform3 = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

content_image_ori = cv2.imread(pic)
content_image_ori = cv2.cvtColor(content_image_ori, cv2.COLOR_BGR2RGB)
content_image_ori_PIL = Image.fromarray(content_image_ori)
content_image = transform(content_image_ori_PIL).unsqueeze(0)
content_images = content_image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
content_512 = transform1(content_image_ori_PIL).unsqueeze(0).to(device)
vutils.save_image(denorm_transform1(content_512[0]), dir + '/real.jpg')

content_512[:, :, int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2),
int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2)] = 0.0
content_512[:, :, int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2),
int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2)] = torch.mean(content_512)

vutils.save_image(denorm_transform1(content_512[0]), dir + '/cropped.jpg')

photos = ["/content/drive/MyDrive/data/humans/1 (1).jpeg"]

for path in photos:
    content_image_ori1 = cv2.imread(path)
    content_image_ori1 = cv2.cvtColor(content_image_ori1, cv2.COLOR_BGR2RGB)
    content_image_ori_PIL1 = Image.fromarray(content_image_ori1)
    content_image1 = transform(content_image_ori_PIL1).unsqueeze(0)
    content_images = torch.cat((content_images, content_image1), 0)

input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
input_cropped = Variable(input_cropped)

input_cropped.resize_(content_images.size()).copy_(content_images)

input_cropped.data[:, 0,
int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
int(opt.imageSize / 4 + opt.overlapPred):int(
    opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 117.0 / 255.0 - 1.0
input_cropped.data[:, 1,
int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
int(opt.imageSize / 4 + opt.overlapPred):int(
    opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 104.0 / 255.0 - 1.0
input_cropped.data[:, 2,
int(opt.imageSize / 4 + opt.overlapPred):int(opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred),
int(opt.imageSize / 4 + opt.overlapPred):int(
    opt.imageSize / 4 + opt.imageSize / 2 - opt.overlapPred)] = 2 * 123.0 / 255.0 - 1.0

contentNet = ContentNet(opt)
if opt.contentNet != '' and os.path.exists(opt.contentNet):
    print("Loading weights...")
    contentNet.load_state_dict(
        torch.load(opt.contentNet, map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.contentNet)['epoch']
print(contentNet)

synthesis = contentNet(input_cropped)

recon_image = input_cropped.clone()
recon_image.data[:, :, int(opt.imageSize / 4):int(opt.imageSize / 4 + opt.imageSize / 2),
int(opt.imageSize / 4):int(opt.imageSize / 4 + opt.imageSize / 2)] = synthesis.data

vutils.save_image(denorm_transform(recon_image.data[0]), dir + '/input.jpg')
vutils.save_image(denorm_transform(synthesis.data[0]), dir + '/output.jpg')

content_result = denorm_transform(synthesis.data[0])
content_result = transform3(content_result)

content_result = content_result.unsqueeze(0)
content_result.cuda()

result = texture(content_512, content_result, dir)
# result = F.interpolate(result, size=[256, 256], mode="bilinear")
content_512 = denorm_transform1(content_512.data[0])

for i in range(3):
    content_512.data[i, int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2), \
    int(opt.imageSize_raw / 4):int(opt.imageSize_raw / 4 + opt.imageSize_raw / 2)] = result.data[0][i]

vutils.save_image(content_512, dir + '/result.jpg')

Loading weights...
ContentNet(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ELU(alpha=0.2, inplace=True)
    (2): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ELU(alpha=0.2, inplace=True)
    (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ELU(alpha=0.2, inplace=True)
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ELU(alpha=0.2, inplace=True)
    (11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T