In [7]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn,optim
import PIL
import numpy as np
%matplotlib auto


Using matplotlib backend: Qt5Agg


In [8]:
content_path=r'F:\study\ml\LM\image\13\rainier.jpg'
style_path=r'F:\study\ml\LM\image\13\autumn-oak.jpg'

In [9]:
content_img=PIL.Image.open(content_path)
style_img=PIL.Image.open(style_path)

In [10]:
type(content_img)

PIL.JpegImagePlugin.JpegImageFile

In [11]:
np.array(content_img).shape,np.array(style_img).shape

((1365, 2048, 3), (1200, 1717, 3))

In [12]:
plt.subplots(12)
plt.subplot(121)
plt.imshow(content_img)
plt.subplot(122)
plt.imshow(style_img)
plt.show()

In [13]:
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

In [14]:
def preprocess(img,img_shape):
    transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)
    ])
    print("img  :" ,img)
    print("transforms img shape : ",transforms(img).shape)
    return transforms(img).unsqueeze(0)

In [15]:
preprocess(content_img,(300,450))

img  : <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2048x1365 at 0x1BA43C6DF70>
transforms img shape :  torch.Size([3, 300, 450])


tensor([[[[ 1.9235,  1.9235,  1.9407,  ...,  0.0741,  0.0741,  0.1083],
          [ 1.9235,  1.9235,  1.9407,  ...,  0.1426,  0.1426,  0.1768],
          [ 1.9235,  1.9235,  1.9407,  ...,  0.2282,  0.2111,  0.2111],
          ...,
          [-1.8782, -1.8439, -1.8268,  ..., -1.6384, -0.8164, -0.8335],
          [-1.8782, -1.8439, -1.8268,  ..., -1.6384, -1.2445, -1.1247],
          [-1.8782, -1.8782, -1.8610,  ..., -1.8439, -1.7069, -1.5528]],

         [[ 1.9909,  1.9909,  1.9734,  ...,  1.2206,  1.2206,  1.2556],
          [ 1.9909,  1.9909,  1.9734,  ...,  1.2731,  1.2731,  1.3081],
          [ 1.9909,  1.9909,  1.9559,  ...,  1.3431,  1.3256,  1.3256],
          ...,
          [-1.8256, -1.7906, -1.7731,  ..., -1.6506, -1.0203, -1.1253],
          [-1.8256, -1.7906, -1.7731,  ..., -1.5630, -1.3704, -1.2654],
          [-1.8256, -1.8256, -1.8081,  ..., -1.7731, -1.6856, -1.5630]],

         [[ 2.2740,  2.2740,  2.2740,  ...,  1.9603,  1.9603,  1.9951],
          [ 2.2740,  2.2740,  

In [16]:
def postprocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))

In [17]:
pretrained_net=torchvision.models.vgg19(pretrained=True)

In [18]:
style_layers,content_layers=[0,5,10,19,28],[25]

In [19]:
net=nn.Sequential(*[pretrained_net.features[i] for i in 
                   range(max(content_layers+style_layers)+1)])

In [38]:
net

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [20]:
def extract_features(X,content_layers,style_layers):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents,styles

In [21]:
def get_contents(img_shape,device):
    content_X=preprocess(content_img,img_shape).to(device)
    contents_Y,_=extract_features(content_X,content_layers,style_layers)
    return content_X,contents_Y

In [22]:
def get_styles(img_shape,device):
    style_X=preprocess(style_img,img_shape).to(device)
    _,styles_Y=extract_features(style_X,content_layers,style_layers)
    return style_X,styles_Y

In [23]:
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()

In [24]:
def gram(X):
    num_channels,n=X.shape[1],X.numel()//x.shape[1]
    X=X.reshape((num_channels,n))
    return torch.matmul(X,X.T)/(num_channels*n)

In [25]:
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()

In [26]:
def tv_loss(Y_hat):
    return 0.5*(torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+
              torch.absbs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())

In [27]:
content_weight,style_weight,tv_weight=1,1e3,10

In [28]:
def computer_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram):
    contents_l=[content_loss(Y_hat,Y) *content_weight for Y_hat,Y in 
               zip(contents_Y_hat,contents_Y)]
    styles_l=[style_loss(Y_hat,Y) * style_weight for Y_hat,Y in 
             zip(styles_Y_hat,styles_Y_gram)]
    tv_l=tv_loss(X)*tv_weight
    l=sum(10*styles_l+contents_l+[tv_l])
    return contents_l,styles_l,tv_l,l

In [30]:
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super().__init__(**kwargs)
        self.weight=nn.Parameter(torch.rand(*img_shape))
    
    def forward(self):
        return self.weight

In [31]:
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]
    return gen_img(),styles_Y_gram,trainer

In [36]:
device,image_shape='cpu',(300,450)
content_X,contents_Y=get_contents(image_shape,device)
_,styles_Y=get_styles(image_shape,device)

img  : <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2048x1365 at 0x1BA43C6DF70>
transforms img shape :  torch.Size([3, 300, 450])
img  : <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1717x1200 at 0x1BA3FA1C5B0>
transforms img shape :  torch.Size([3, 300, 450])


In [37]:
len(styles_Y),styles_Y[0].shape

(5, torch.Size([1, 64, 300, 450]))