In [3]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn,optim
import PIL
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [4]:
content_path=r'F:\study\ml\LM\image\13\rainier.jpg'
style_path=r'F:\study\ml\LM\image\13\autumn-oak.jpg'

In [5]:
content_img=PIL.Image.open(content_path)
style_img=PIL.Image.open(style_path)

In [6]:
plt.subplots(12)
plt.subplot(121)
plt.imshow(content_img)
plt.subplot(122)
plt.imshow(style_img)
plt.show()

In [7]:
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

In [8]:
def preprocess(img,img_shape):
    transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)])
    return transforms(img).unsqueeze(0)

In [9]:
preprocess(content_img,(300,450))

tensor([[[[ 1.9235,  1.9235,  1.9407,  ...,  0.0741,  0.0741,  0.1083],
          [ 1.9235,  1.9235,  1.9407,  ...,  0.1426,  0.1426,  0.1768],
          [ 1.9235,  1.9235,  1.9407,  ...,  0.2282,  0.2111,  0.2111],
          ...,
          [-1.8782, -1.8439, -1.8268,  ..., -1.6384, -0.8164, -0.8335],
          [-1.8782, -1.8439, -1.8268,  ..., -1.6384, -1.2445, -1.1247],
          [-1.8782, -1.8782, -1.8610,  ..., -1.8439, -1.7069, -1.5528]],

         [[ 1.9909,  1.9909,  1.9734,  ...,  1.2206,  1.2206,  1.2556],
          [ 1.9909,  1.9909,  1.9734,  ...,  1.2731,  1.2731,  1.3081],
          [ 1.9909,  1.9909,  1.9559,  ...,  1.3431,  1.3256,  1.3256],
          ...,
          [-1.8256, -1.7906, -1.7731,  ..., -1.6506, -1.0203, -1.1253],
          [-1.8256, -1.7906, -1.7731,  ..., -1.5630, -1.3704, -1.2654],
          [-1.8256, -1.8256, -1.8081,  ..., -1.7731, -1.6856, -1.5630]],

         [[ 2.2740,  2.2740,  2.2740,  ...,  1.9603,  1.9603,  1.9951],
          [ 2.2740,  2.2740,  

In [54]:
def postprocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))

In [11]:
pretrained_net=torchvision.models.vgg19(pretrained=True)

In [12]:
style_layers,content_layers=[0,5,10,19,28],[25]

In [13]:
net=nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers+
                                                                 style_layers)+1)])

In [14]:
def extract_features(X,content_layers,style_layers):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents,styles

In [15]:
def get_contents(img_shape,device):
    content_X=preprocess(content_img,img_shape).to(device)
    contents_Y,_=extract_features(content_X,content_layers,style_layers)
    return content_X,contents_Y

In [16]:
def get_styles(img_shape,device):
    style_X=preprocess(style_img,img_shape).to(device)
    _,styles_Y=extract_features(style_X,content_layers,style_layers)
    return style_X,styles_Y

In [17]:
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()

In [18]:
def gram(X):
    num_channels,n=X.shape[1],X.numel()//X.shape[1]
    X=X.reshape((num_channels,n))
    return torch.matmul(X,X.T)/(num_channels*n)

In [19]:
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()

In [20]:
def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+
                 torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())

In [21]:
content_weight,style_weight,tv_weight=1,1e3,10

In [32]:
def compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram):
    contents_l=[content_loss(Y_hat,Y) * content_weight for Y_hat,Y in 
               zip(contents_Y_hat,contents_Y)]
    styles_l=[style_loss(Y_hat,Y) * style_weight for Y_hat,Y in 
             zip(styles_Y_hat,styles_Y_gram)]
    tv_l=tv_loss(X) * tv_weight
    l=sum(10*styles_l+contents_l+[tv_l])
    return contents_l,styles_l,tv_l,l

In [23]:
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super().__init__(**kwargs)
        self.weight=nn.Parameter(torch.rand(*img_shape))
        
    def forward(self):
        return self.weight

In [41]:
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]
#     for i in styles_Y_gram:
#         print(" styles_Y_gram shape : ",i.shape)
    return gen_img(),styles_Y_gram,trainer

In [42]:
device,image_shape='cpu',(300,450)
content_X,contents_Y=get_contents(image_shape,device)
_,styles_Y=get_styles(image_shape,device)

In [43]:
for i in styles_Y:
    print("style Y shape : ",i.shape)

style Y shape :  torch.Size([1, 64, 300, 450])
style Y shape :  torch.Size([1, 128, 150, 225])
style Y shape :  torch.Size([1, 256, 75, 112])
style Y shape :  torch.Size([1, 512, 37, 56])
style Y shape :  torch.Size([1, 512, 18, 28])


In [44]:
X,styles_Y_gram,trainer=get_inits(content_X,'cpu',0.3,styles_Y)

In [45]:
contents_Y_hat,styles_Y_hat=extract_features(X,content_layers,style_layers)

In [46]:
for i in styles_Y_hat:
    print("styles_Y_hat shape : ",i.shape)

styles_Y_hat shape :  torch.Size([1, 64, 300, 450])
styles_Y_hat shape :  torch.Size([1, 128, 150, 225])
styles_Y_hat shape :  torch.Size([1, 256, 75, 112])
styles_Y_hat shape :  torch.Size([1, 512, 37, 56])
styles_Y_hat shape :  torch.Size([1, 512, 18, 28])


In [51]:
def train(X,contents_Y,styles_Y,device ,lr,num_epoch,lr_decay_epoch):
    X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
    scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
    for epoch in range(num_epoch):
        trainer.zero_grad()
        contents_Y_hat,styles_Y_hat=extract_features(X,content_layers,style_layers)
        contents_l,styles_l,tv_l,l=compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if(epoch+1) %10 ==0:
            print("-------------- epoch -----------------",epoch)
            print("content loss :",float(sum(contents_l)))
            print("style loss   :",float(sum(styles_l)))
            print("tv loss      :",float(tv_l))
    return X
        

In [52]:
device,image_shape='cpu',(300,450)
net=net.to(device)
content_X,contents_Y=get_contents(image_shape,device)
_,styles_Y=get_styles(image_shape,device)
output=train(content_X,contents_Y,styles_Y,device,0.3,500,50)

-------------- epoch ----------------- 9
content loss : 1.886697769165039
style loss   : 1.7113851308822632
tv loss      : 7.836203098297119
-------------- epoch ----------------- 19
content loss : 1.6912285089492798
style loss   : 0.7593602538108826
tv loss      : 7.591408729553223
-------------- epoch ----------------- 29
content loss : 1.5549570322036743
style loss   : 0.2870299220085144
tv loss      : 6.251469612121582
-------------- epoch ----------------- 39
content loss : 1.450787901878357
style loss   : 0.21341782808303833
tv loss      : 5.155290603637695
-------------- epoch ----------------- 49
content loss : 1.3914440870285034
style loss   : 0.12216442078351974
tv loss      : 4.535834312438965
-------------- epoch ----------------- 59
content loss : 1.3003970384597778
style loss   : 0.07521913200616837
tv loss      : 3.760624408721924
-------------- epoch ----------------- 69
content loss : 1.1860100030899048
style loss   : 0.057811036705970764
tv loss      : 3.1427314281463

In [55]:
output1=postprocess(output)
plt.imshow(output1)
plt.show()