In [0]:
!pip install pretrainedmodels
!pip install git+https://github.com/leftthomas/CapsuleLayer.git@master
!pip install git+https://github.com/szagoruyko/pytorchviz

In [0]:
# 挂载谷歌云盘
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

In [0]:
# 更改云工作目录
!mkdir -p drive
!google-drive-ocamlfuse drive
import os
os.chdir("drive/AI/styleTransfer/pytorch_perceptual")

fuse: mountpoint is not empty
fuse: if you are sure this is safe, use the 'nonempty' mount option


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import Optimizer
import torchvision
import torchvision.transforms as Transforms
import torchvision.datasets as Datasets

from PIL import Image
import numpy as np
import datetime
import math
import pretrainedmodels
from capsule_layer import CapsuleConv2d
from capsule_layer import CapsuleConvTranspose2d
from capsule_layer.optim import MultiStepRI

# Set random number seed
SEED = 66
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

In [0]:
# "Fixing Weight Decay Regularization in Adam" https://arxiv.org/abs/1711.05101
class AdamW(Optimizer):
    """Implements Adam algorithm.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            ICLR 2018 https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        #super(AdamW, self).__init__(params, defaults)
        super().__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                if group['weight_decay'] != 0:
                    decayed_weights = torch.mul(p.data, group['weight_decay'])
                    p.data.addcdiv_(-step_size, exp_avg, denom)
                    p.data.sub_(decayed_weights)
                else:
                    p.data.addcdiv_(-step_size, exp_avg, denom)

        return loss

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, stride=1, norm_type='instance', norm_first=True, norm_last=False):
        super(ResidualBlock, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.norm_first = norm_first
        self.norm_last = norm_last
        
        if norm_type is 'batch':
            self.norm = nn.BatchNorm2d
        else:
            self.norm = nn.InstanceNorm2d
        
        if self.norm_first:
            self.bn_first = self.norm(input_channels, affine=True)
            self.relu_first = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(input_channels, int(output_channels/4), 1, 1, bias = False)
        self.bn1 = self.norm(int(output_channels/4), affine=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(int(output_channels/4), int(output_channels/4), 3, stride, padding = 1, bias = False)
        self.bn2 = self.norm(int(output_channels/4), affine=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(int(output_channels/4), output_channels, 1, 1, bias = False)
        if (self.input_channels != self.output_channels) or (self.stride !=1 ):
            self.conv4 = nn.Conv2d(input_channels, output_channels , 1, stride, bias = False)
        if self.norm_last:
            self.bn_last = self.norm(input_channels, affine=True)
            self.relu_last = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        if self.norm_first:
            out = self.bn_first(x)
            out1 = self.relu_first(out)
        else:
            out1 = residual
        out = self.conv1(out1)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        if (self.input_channels != self.output_channels) or (self.stride !=1 ):
            residual = self.conv4(out1)
        out += residual
        if self.norm_last:
            out = self.bn_last(out)
            out = self.relu_last(out)
        return out
      
# "Residual Attention Network for Image Classification" https://arxiv.org/pdf/1704.06904.pdf
class AttentionModule_stage1(nn.Module):
    def __init__(self, in_channels, out_channels, norm_type='instance'):
        super(AttentionModule_stage1, self).__init__()
        self.first_residual_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance', norm_first=False)

        self.trunk_branches = nn.Sequential(
            ResidualBlock(in_channels, out_channels, norm_type='instance'),
            ResidualBlock(in_channels, out_channels, norm_type='instance')
         )

        if norm_type is 'batch':
            self.norm = nn.BatchNorm2d
        else:
            self.norm = nn.InstanceNorm2d
        
        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.softmax1_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance')
        self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels, norm_type='instance')
        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.softmax2_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance')
        self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels, norm_type='instance')
        self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax3_blocks = nn.Sequential(
            ResidualBlock(in_channels, out_channels, norm_type='instance'),
            ResidualBlock(in_channels, out_channels, norm_type='instance')
        )

        self.softmax4_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance')

        self.softmax5_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance')

        self.softmax6_blocks = nn.Sequential(
            self.norm(out_channels, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            self.norm(out_channels, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.Sigmoid()
        )

        self.last_blocks = ResidualBlock(in_channels, out_channels, norm_type='instance', norm_last=True)

    def forward(self, x):
        x = self.first_residual_blocks(x)
        out_trunk = self.trunk_branches(x)
        out_mpool1 = self.mpool1(x)
        out_softmax1 = self.softmax1_blocks(out_mpool1)
        out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
        out_mpool2 = self.mpool2(out_softmax1)
        out_softmax2 = self.softmax2_blocks(out_mpool2)
        out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
        out_mpool3 = self.mpool3(out_softmax2)
        out_softmax3 = self.softmax3_blocks(out_mpool3)
        out_interp3 = F.interpolate(out_softmax3, (out_softmax2.shape[2], out_softmax2.shape[3])) + out_softmax2
        out = out_interp3 + out_skip2_connection
        out_softmax4 = self.softmax4_blocks(out)
        out_interp2 = F.interpolate(out_softmax4, (out_softmax1.shape[2], out_softmax1.shape[3])) + out_softmax1
        out = out_interp2 + out_skip1_connection
        out_softmax5 = self.softmax5_blocks(out)
        out_interp1 = F.interpolate(out_softmax5, (out_trunk.shape[2], out_trunk.shape[3])) + out_trunk
        out_softmax6 = self.softmax6_blocks(out_interp1)
        out = (1 + out_softmax6) * out_trunk
        out_last = self.last_blocks(out)

        return out_last

# "Squeeze-and-Excitation Networks" https://arxiv.org/abs/1709.01507
class SE_Resnet50(nn.Module):
    def __init__(self):
        super(SE_Resnet50, self).__init__()
        print('Preparing pretrained SE Resnet 50 ...')
        self.se_resnet50 = pretrainedmodels.__dict__['se_resnet50'](num_classes=1000, pretrained='imagenet').eval()
        self.layer0 = self.se_resnet50.layer0
        self.layer1_1 = self.se_resnet50.layer1[0]
        self.layer1_2 = self.se_resnet50.layer1[1]
        self.layer1_3 = self.se_resnet50.layer1[2]
        self.layer2_1 = self.se_resnet50.layer2[0]
        
    def forward(self, x):
        x = self.layer0(x)
        out_1 = self.layer1_1(x)
        out_2 = self.layer1_2(out_1)
        out_3 = self.layer1_3(out_2)
        out_4 = self.layer2_1(out_3)

        return [out_1, out_2, out_3, out_4]

# Load image with size of parameter size
def load_img(path, scale=None, size=None):
    img = Image.open(path).convert('RGB')

    transform_list = []
    if size is not None:
        transform_list += [Transforms.Resize((int(size[0]), int(size[1])))]
    elif scale is not None:
        transform_list += [Transforms.Resize((int(img.size[1] * scale), int(img.size[0] * scale)))]

    transform_list += [Transforms.ToTensor()]
    transform = Transforms.Compose(transform_list)
    
    img = transform(img)
    img = img.unsqueeze(dim=0)
    
    return img

# Make image with shape of [content | result | style] and save it
def save_img(img_name, content, style, result):
    _, H, W = content.size()
    size = (H, W)
    # print(content.shape,result.shape,style.shape)
    img = torch.stack([content, result, style], dim=0)
    torchvision.utils.save_image(img, img_name, nrow=3)

# Load pretrained weight
def load_weight(model, path):
    model.load_state_dict(torch.load(path))
    return model

# Normalization with mean and std
def norm(var):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    
    mean = Variable(torch.zeros(var.size()).type(dtype))
    std = Variable(torch.zeros(var.size()).type(dtype))
    
    mean[:, 0, :, :] = 0.485
    mean[:, 1, :, :] = 0.456
    mean[:, 2, :, :] = 0.406
    
    std[:, 0, :, :] = 0.229
    std[:, 1, :, :] = 0.224
    std[:, 2, :, :] = 0.225

    normed = var.sub(mean).div(std)
    return normed

# Denormalization with mean and std
def denorm(var):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    mean = Variable(torch.zeros(var.size()).type(dtype))
    std = Variable(torch.zeros(var.size()).type(dtype))
    
    mean[:, 0, :, :] = 0.485
    mean[:, 1, :, :] = 0.456
    mean[:, 2, :, :] = 0.406
    
    std[:, 0, :, :] = 0.229
    std[:, 1, :, :] = 0.224
    std[:, 2, :, :] = 0.225

    normed = var.mul(std).add(mean)
    return normed

# Get gram matrix
def gram(var_list):
    gram_list = []
    
    for i in range(len(var_list)):
        var = var_list[i]
        N, C, H, W = var.size()
        var = var.view(N, C, H*W)
        g = torch.bmm(var, var.transpose(2, 1)) / (C * H * W)
        gram_list.append(g)
        
    return gram_list

# Get doubleGram matrix
def firstGram(tensor1,tensor2):
    g=[]
    for i in range(tensor1.shape[0]):
        g.append((tensor2[i]*tensor1))
    return torch.cat(g,0)

def secondGram(tensor1):
    g=[]
    for i in range(tensor1.shape[0]):
        for j in range(tensor1.shape[0]):
            g.append(firstGram(tensor1[i],tensor1[j]))
    return torch.cat(g,0)

def doubleGram(tensor1):
    g=[]
    for i in range(tensor1.shape[0]):
        g.append(secondGram(tensor1[i]).unsqueeze(0))
    return torch.cat(g,0)

def gramP(var_list):
    gram_list = []
    
    for i in range(len(var_list)):
        var = var_list[i]
        N, C, H, W = var.size()
        var = var.view(N, C, H*W)
        g = doubleGram(var) / (C * H * W * H * W)
        gram_list.append(g)
        
    return gram_list

# Creat data loader
def data_loader(root, batch_size=4, size=256):
    transform = Transforms.Compose([Transforms.Resize(size), 
                                    Transforms.CenterCrop(size),
                                    Transforms.ToTensor()
                                   ])
    dataset = Datasets.ImageFolder(root, transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    length = len(dataset)
    
    return loader, length

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel=3, stride=2, pad=1, norm_type='batch', act_type='relu'):
        super(ConvBlock, self).__init__()
        layers = []
        layers += [nn.ReflectionPad2d(pad), 
                   nn.Conv2d(in_c, out_c, kernel_size=kernel, stride=stride, padding=0)]
        
        if norm_type is 'batch':
            layers += [nn.BatchNorm2d(out_c, affine=True)]
        elif norm_type is 'instance':
            layers += [nn.InstanceNorm2d(out_c, affine=True)]
        elif norm_type is None:
            pass
        
        if act_type is 'relu': 
            layers += [nn.ReLU()]
        elif act_type is 'tanh':
            layers += [nn.Tanh()]
        elif act_type is None:
            pass
        
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.block(x)
        return out
    
class ResBlock(nn.Module):
    def __init__(self, channels, norm_type):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(ConvBlock(channels, channels, kernel=3, stride=1, pad=1, norm_type=norm_type,
                                             act_type='relu'),
                                   ConvBlock(channels, channels, kernel=3, stride=1, pad=1, norm_type=norm_type,
                                             act_type=None))
    
    def forward(self, x):
        out = self.block(x) + x
        return out

class ConvTransBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel=3, stride=2, pad=1, out_pad=1, norm_type='batch'):
        super(ConvTransBlock, self).__init__()
        layers = []
        
        # Conv transpose layer
        layers += [nn.ConvTranspose2d(in_c, out_c, kernel_size=kernel, stride=stride, padding=pad,
                                      output_padding=out_pad)]
        
        # Normalization layer
        if norm_type is 'batch':
            layers += [nn.BatchNorm2d(out_c, affine=True)]
        elif norm_type is 'instance':
            layers += [nn.InstanceNorm2d(out_c, affine=True)]

        # Activiation layer
        layers += [nn.ReLU()]
        
        self.conv_trans_block = nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv_trans_block(x)
        return out
    
class ImageTransformNet(nn.Module):
    def __init__(self, res_block_num=5, norm_type='batch'):
        super(ImageTransformNet, self).__init__()
        # Downsampling blocks
        self.downsamples = nn.Sequential(ConvBlock(3, 32, kernel=9, stride=1, pad=4, norm_type=norm_type,
                                                   act_type='relu'),
                                         ConvBlock(32, 64, kernel=3, stride=2, pad=1, norm_type=norm_type,
                                                   act_type='relu'),
                                         ConvBlock(64, 128, kernel=3, stride=2, pad=1, norm_type=norm_type,
                                                   act_type='relu'))
        
        # Residual blocks
        res = []
        for _ in range(res_block_num): 
            res += [ResBlock(128, norm_type)]
        self.residuals = nn.Sequential(*res)    

        # Upsampling blocks
        self.upsamples = nn.Sequential(ConvTransBlock(128, 64, kernel=3, stride=2, pad=1, out_pad=1,
                                                      norm_type=norm_type),
                                       ConvTransBlock(64, 32, kernel=3, stride=2, pad=1, out_pad=1,
                                                      norm_type=norm_type),
                                       ConvBlock(32, 3, kernel=9, stride=1, pad=4, norm_type=None,
                                                 act_type='tanh'))

    def forward(self, x):
        out = self.downsamples(x)
        out = self.residuals(out)
        out = self.upsamples(out)
        out = (out + 1) / 2
        return out

class CapsBlock(nn.Module):
    def __init__(self, in_c, out_c, in_length, out_length, kernel=3, stride=2, pad=1):
        super(CapsBlock, self).__init__()
        layers = []
        layers += [nn.ReflectionPad2d(pad), 
                   CapsuleConv2d(in_c, out_c, in_length=in_length, out_length=out_length, kernel_size=kernel, 
                                 stride=stride, padding=0, routing_type='dynamic', bias=True)]
        
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.block(x)
        return out
    
class CapsResBlock(nn.Module):
    def __init__(self, channels, vlength=8):
        super(CapsResBlock, self).__init__()
        self.block = nn.Sequential(CapsBlock(channels, channels, vlength, vlength, kernel=3, stride=1, pad=1),
                                   CapsBlock(channels, channels, vlength, vlength, kernel=3, stride=1, pad=1))
    
    def forward(self, x):
        out = self.block(x) + x
        return out
    
class CapsTransBlock(nn.Module):
    def __init__(self, in_c, out_c, in_length, out_length, kernel=3, stride=2, pad=1, out_pad=1):
        super(CapsTransBlock, self).__init__()
        layers = []
        
        # Conv transpose layer
        layers += [CapsuleConvTranspose2d(in_c, out_c, in_length=in_length, out_length=out_length, kernel_size=kernel, 
                                          stride=stride, padding=pad, output_padding=out_pad, bias=True)]
        
        self.conv_trans_block = nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv_trans_block(x)
        return out

class ImageTransformCapsNet(nn.Module):
    def __init__(self, res_block_num=5, norm_type='batch'):
        super(ImageTransformCapsNet, self).__init__()
        # Downsampling blocks
        self.downsamples = nn.Sequential(ConvBlock(3, 32, kernel=9, stride=1, pad=4, norm_type=norm_type,
                                                   act_type='relu'),
                                         ConvBlock(32, 64, kernel=3, stride=2, pad=1, norm_type=norm_type,
                                                   act_type='relu'),
                                         ConvBlock(64, 128, kernel=3, stride=2, pad=1, norm_type=norm_type,
                                                   act_type='relu'))
        
        # Residual attention blocks
        res = []
        for _ in range(res_block_num): 
            res += [ResBlock(128, norm_type)]
        self.residuals = nn.Sequential(*res)    
        
        self.attention = AttentionModule_stage1(128, 128, norm_type=norm_type)
    
        # Upsampling blocks
        self.upsamples = nn.Sequential(CapsTransBlock(128, 64, 64, 16, kernel=3, stride=2, pad=1, out_pad=1),
                                       CapsTransBlock(64, 32, 16, 32, kernel=3, stride=2, pad=1, out_pad=1,),
                                       ConvBlock(32, 3, kernel=9, stride=1, pad=4, norm_type=None, act_type='tanh'),
                                       # CapsBlock(32, 3, 8, 1, kernel=9, stride=1, pad=4),
                                       )

    def forward(self, x):
        out = self.downsamples(x)
        out = self.residuals(out)
        out = self.upsamples(out)
        out = (out + 1) / 2
        return out

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        print('Preparing pretrained VGG 16 ...')
        self.vgg_16 = torchvision.models.vgg16(pretrained=True).features
        
        self.relu_1_2 = nn.Sequential(*list(self.vgg_16.children())[0:4])
        self.relu_2_2 = nn.Sequential(*list(self.vgg_16.children())[4:9])
        self.relu_3_3 = nn.Sequential(*list(self.vgg_16.children())[9:16])
        self.relu_4_3 = nn.Sequential(*list(self.vgg_16.children())[16:23])
    
    def forward(self, x):
        out_1_2 = self.relu_1_2(x)
        out_2_2 = self.relu_2_2(out_1_2)
        out_3_3 = self.relu_3_3(out_2_2)
        out_4_3 = self.relu_4_3(out_3_3)

        return [out_1_2, out_2_2, out_3_3, out_4_3]
    
class Solver():
    def __init__(self, trn_dir, style_path, result_dir, weight_dir, process_dir=None, process_image=None, process_scale=None, process_number=0, record_number=0, num_epoch=2, 
                 batch_size=4, record_name=None, content_loss_pos=2, lr=1e-3, lambda_c=1, lambda_s=5e+5, show_every=1000, save_every=5000, pretrain=None, lossNet='vgg',
                 test_number=0, test_dir=None, transNet='capsnet', opti='adamw', norm_type='batch', gram='gramP'):
        
        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor
            
        self.style_path = style_path
        self.result_dir = result_dir
        self.weight_dir = weight_dir
        self.process_dir = process_dir
        self.process_image = process_image
        self.process_number = process_number
        self.process_scale = process_scale
        self.record_number = record_number
        self.record_name = record_name
        self.test_number = test_number
        self.test_dir = test_dir
        self.transNet = transNet
        self.norm_type = norm_type
        
        # Models
        if self.transNet == 'capsnet':
            self.trans_net = ImageTransformCapsNet(norm_type=norm_type).type(self.dtype)
            self.router = MultiStepRI(self.trans_net, milestones=[3], addition=3, verbose=True)
        else:
            self.trans_net = ImageTransformNet(norm_type=norm_type).type(self.dtype)
        
        if pretrain is not None:
            self.trans_net.load_state_dict(torch.load(pretrain))
        
        if lossNet == 'senet50':
            self.lossnet = SE_Resnet50().type(self.dtype)
            self.size = 224
        else:
            self.lossnet = VGG16().type(self.dtype)
            self.size = 256
        
        # Dataloader
        self.dloader, total_num = data_loader(root=trn_dir, batch_size = batch_size, size = self.size)
        self.total_iter = int(total_num / batch_size) + 1
        if self.test_dir is not None:
            self.test_dloader, test_total_num = data_loader(root=test_dir, batch_size = batch_size, size = self.size)
            self.test_total_iter = int(test_total_num / batch_size) + 1
        
        # Loss function and optimizer
        self.mse_loss = nn.MSELoss()
        if opti == 'adam':
            self.optimizer = optim.Adam(self.trans_net.parameters(), lr=lr, weight_decay=1e-5)
        elif opti == 'sgd':
            self.optimizer = optim.SGD(self.trans_net.parameters(), lr=lr, weight_decay=1e-5)
        else:
            self.optimizer = AdamW(self.trans_net.parameters(), lr=lr, weight_decay=1e-5)
        
        # Hyperparameters
        self.content_loss_pos = content_loss_pos
        self.lambda_c = lambda_c
        self.lambda_s = lambda_s
        self.show_every = show_every
        self.save_every = save_every
        self.num_epoch = num_epoch
        
    def train(self):
        # Process on style image. Only need to be done once.
        style_img = load_img(self.style_path, size=(self.size, self.size)).type(self.dtype)
        _style_img = style_img.clone()
        style_img = Variable(style_img)
        style_img = norm(style_img)

        style_relu = self.lossnet(style_img)
        if gram=='gramP':
            gram_target = gramP(style_relu)
        else:
            gram_target = gram(style_relu)
        
        # Write records
        count = 0
        train_record = 'iter,train_content_loss,train_style_loss\n'
        test_record = 'iter,test_content_loss,test_style_loss\n'
        test_interval = int(self.num_epoch * len(self.dloader) / self.test_number if self.test_number > 0 else -1)
        process_interval = int(self.num_epoch * len(self.dloader) / self.process_number if self.process_number > 0 else -1)
        record_interval = int(self.num_epoch * len(self.dloader) / self.record_number if self.record_number > 0 else -1)
        print('Start traing [test_interval %d, process_interval %d, record_interval %d]' % (test_interval, process_interval, record_interval))

        for epoch in range(self.num_epoch):
            for iters, (trn_img, _) in enumerate(self.dloader):
                # Forward training images to ImageTransformNet
                trn_img = Variable(trn_img.type(self.dtype))
                trn_img = norm(trn_img)
                out_img = self.trans_net(trn_img)
                
                content_img = Variable(trn_img.data.clone())
                out_img = norm(out_img)

                # Forward training img and content img to VGG16
                relu_target = self.lossnet(content_img)
                relu_out = self.lossnet(out_img)

                # Get 4 activations from lossNet
                feature_y = relu_out[self.content_loss_pos]
                feature_t = Variable(relu_target[self.content_loss_pos].data, requires_grad=False)
                
                # Content loss
                content_loss = self.lambda_c * self.mse_loss(feature_y, feature_t)

                # Gram matrix
                if gram=='gramP':
                    gram_out = gramP(relu_out)
                else:
                    gram_out = gram(relu_out)

                # Style matrix
                style_loss = 0
                for i in range(len(gram_target)):
                    gram_y = gram_out[i]
                    gram_t = Variable(gram_target[i].expand_as(gram_out[i]).data, requires_grad=False)
                    style_loss += self.lambda_s * self.mse_loss(gram_y, gram_t)

                loss = content_loss + style_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                count += 1

                if iters % self.show_every == 0:
                    print('[Epoch : (%d / %d), Iters : (%d / %d), Count : %d] => Content : %f, Style : %f' \
                          %(epoch + 1, self.num_epoch, iters, self.total_iter, count, content_loss.item(), style_loss.item()))
                    
                    _, style_name = os.path.split(self.style_path)
                    style_name, _ = os.path.splitext(style_name)
                    result_dir = os.path.join(self.result_dir, style_name)
                    
                    if os.path.exists(result_dir) is not True:
                        os.makedirs(result_dir)
                    
                    file_name = str(epoch) + '_' + str(iters) + '.png'
                    file_name = os.path.join(result_dir, file_name)
                    
                    # Denorm the img to get correct img
                    content_img = denorm(content_img)
                    out_img = denorm(out_img)
                    
                    save_img(file_name, content_img.data[0], _style_img[0], out_img.data[0])
                    
                if count % record_interval == 1:
                    train_record += str(count) + ',' + str(content_loss.item()) + ',' + str(style_loss.item()) + '\n'
                    
                if count % process_interval == 1:
                    _, style_name = os.path.split(self.style_path)
                    style_name, _ = os.path.splitext(style_name)
                    
                    content_img = load_img(path=self.process_image, scale=self.process_scale)
                    content_img = Variable(content_img.type(self.dtype))
                    content_img = norm(content_img)
                    result_img = self.trans_net(content_img)
                    
                    if os.path.exists(self.process_dir) is not True:
                        os.makedirs(self.process_dir)
                    torchvision.utils.save_image(result_img.data, os.path.join(self.process_dir, style_name + '_' + str(count) + '.png'), nrow=1)
                    
                if count % test_interval == 0 and self.test_dir is not None:
                    print('[Epoch : (%d / %d), Iters : (%d / %d), Count : %d] => Content : %f, Style : %f' \
                          %(epoch + 1, self.num_epoch, iters, self.total_iter, count, content_loss.item(), style_loss.item()))
                    closs = 0
                    sloss = 0
                    self.trans_net.eval()
                    for iters, (test_img, _) in enumerate(self.test_dloader):
                        # Forward training images to ImageTransformNet
                        test_img = Variable(test_img.type(self.dtype))
                        test_img = norm(test_img)
                        out_img = self.trans_net(test_img)

                        content_img = Variable(test_img.data.clone())
                        out_img = norm(out_img)

                        # Forward test img and content img
                        relu_target = self.lossnet(content_img)
                        relu_out = self.lossnet(out_img)

                        # Get 4 activations from lossNet
                        feature_y = relu_out[self.content_loss_pos]
                        feature_t = Variable(relu_target[self.content_loss_pos].data, requires_grad=False)

                        # Content loss
                        content_loss = self.lambda_c * self.mse_loss(feature_y, feature_t)

                        # Gram matrix
                        if gram=='gramP':
                            gram_out = gramP(relu_out)
                        else:
                            gram_out = gram(relu_out)

                        # Style matrix
                        style_loss = 0
                        for i in range(len(gram_target)):
                            gram_y = gram_out[i]
                            gram_t = Variable(gram_target[i].expand_as(gram_out[i]).data, requires_grad=False)
                            style_loss += self.lambda_s * self.mse_loss(gram_y, gram_t)

                        closs += content_loss.item()
                        sloss += style_loss.item()

                    if self.transNet == 'capsnet':
                        self.router.step()
                    self.trans_net.train()
                    test_record += str(count) + ',' + str(closs/len(self.test_dloader)) + ',' + str(sloss/len(self.test_dloader)) + '\n'
                    print('[Epoch : (%d / %d) => Test_Content : %f, Test_Style : %f' \
                          %(epoch + 1, self.num_epoch, closs/len(self.test_dloader), sloss/len(self.test_dloader)))
                    
        weight_name = style_name + '.weight'
        weight_path = os.path.join(self.weight_dir, weight_name)
        torch.save(self.trans_net.state_dict(), weight_path)
        
        if self.record_name is not None:
            files = open(self.record_name + '_train.txt',"w") # "w"
            files.write(train_record)
            files.close()
            if self.test_dir is not None:
                files = open(self.record_name + '_test.txt',"w") # "w"
                files.write(test_record)
                files.close()
        
def test(weight_path, content_path, output_path, scale=None, transNet='capsnet', norm_type = 'batch'):
    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    print('Loading the model...')
    if transNet=='capsnet':
        trans_net = ImageTransformCapsNet(norm_type=norm_type).type(dtype)
    else:
        trans_net = ImageTransformNet(norm_type=norm_type).type(dtype)
    trans_net = load_weight(model=trans_net, path=weight_path)

    print('Loading the model is done!')
    # content_img = (1, 3, 256, 256)
    content_img = load_img(path=content_path, scale=scale)
    content_img = Variable(content_img.type(dtype))
    content_img = norm(content_img)

    # result_img = (1, 3, 256, 256)
    result_img = trans_net(content_img)

    # content_img = (1, 3, 256, 256)
    content_img = denorm(content_img)

    out_dir, _ = os.path.split(output_path)
    if os.path.exists(out_dir) is not True and out_dir != '':
        os.makedirs(out_dir)

    torchvision.utils.save_image(result_img.data, output_path, nrow=1)
    print('Saved image : ' + output_path)

In [0]:
# train
s = Solver(trn_dir = '../Perceptual/pytorch_v/data',
           style_path = 'style/abs.jpg', 
           record_name = 'abstract_1_caps_record',
           result_dir = 'check', 
           weight_dir = './',
           num_epoch = 3,
           batch_size = 5,
           content_loss_pos = 1,
           lr = 1e-3,
           lambda_c = 1,
           lambda_s = 5e4, #5e4 1e6
           show_every = 20,
           save_every = 5000,
           pretrain = None,
           lossNet = 'vgg', # vgg senet50， 
           process_dir = 'process', 
           process_image = 'content/ybh.jpg', 
           process_scale = 0.3, 
           process_number  = 20, 
           record_number = 600,
           test_dir = '../Perceptual/pytorch_v/valid',
           test_number = 5,
           transNet = 'capsnet', # capsnet cnn
           opti = 'adamw', # adam adamw sgd
           norm_type = 'instance', # batch instance
           gram = 'gram' # gram gramP(Double Gram)
          )

s.train()

Preparing pretrained VGG 16 ...


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.torch/models/vgg16-397923af.pth
100%|██████████| 553433881/553433881 [00:05<00:00, 93298593.56it/s]


Start traing [test_interval 597, process_interval 149, record_interval 4]
[Epoch : (1 / 3), Iters : (0 / 995), Count : 1] => Content : 8.577991, Style : 81.841286
[Epoch : (1 / 3), Iters : (20 / 995), Count : 21] => Content : 6.189469, Style : 72.253922
[Epoch : (1 / 3), Iters : (40 / 995), Count : 41] => Content : 12.414554, Style : 37.415215
[Epoch : (1 / 3), Iters : (60 / 995), Count : 61] => Content : 10.282390, Style : 31.595238
[Epoch : (1 / 3), Iters : (80 / 995), Count : 81] => Content : 12.468610, Style : 24.234875
[Epoch : (1 / 3), Iters : (100 / 995), Count : 101] => Content : 13.713984, Style : 19.657730
[Epoch : (1 / 3), Iters : (120 / 995), Count : 121] => Content : 11.664084, Style : 17.933031
[Epoch : (1 / 3), Iters : (140 / 995), Count : 141] => Content : 15.917862, Style : 13.623446
[Epoch : (1 / 3), Iters : (160 / 995), Count : 161] => Content : 12.860255, Style : 13.660542
[Epoch : (1 / 3), Iters : (180 / 995), Count : 181] => Content : 13.364314, Style : 12.096566


RuntimeError: ignored

In [0]:
# test
content_name = 'tp.jpg'
test(
    weight_path='new_weight/udnie.weight' ,
    content_path='content/' + content_name, 
    output_path='fantasy_' + content_name.split('.')[0] + '.png',
    scale=0.9,
    transNet='capsnet',
    norm_type='instance', # batch instance
)

Loading the model...
Loading the model is done!
Saved image : fantasy_tp.png


In [0]:
# draw network
print(SE_Resnet50(),'\n\n',VGG16(),'\n\n',ImageTransformCapsNet(norm_type='instance'),'\n\n',ImageTransformNet(norm_type='instance'))
from torchviz import make_dot, make_dot_from_trace
x=torch.ones((1,3,224,224))
model=ImageTransformCapsNet(norm_type='instance')
make_dot(model(x), params=dict(list(model.named_parameters())))