In [1]:
import gym
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from collections import namedtuple, deque

import math
import random
import copy

import cv2

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
class StyleNet(nn.Module):
    def __init__(self):
        super(StyleNet, self).__init__()
        vgg = torchvision.models.vgg11(pretrained=True)
        self.net = nn.Sequential()
        self.net.add_module("feature1", vgg.features[:3])
        self.net.add_module("feature2", vgg.features[3:6])
        self.net.add_module("feature3", vgg.features[6:11])
        self.net.add_module("feature4", vgg.features[11:16])
        self.net.add_module("feature5", vgg.features[16:])
        
    def forward(self, im):
        f1 = self.net.feature1(im)
        f2 = self.net.feature2(f1)
        f3 = self.net.feature3(f2)
        f4 = self.net.feature4(f3)
        f5 = self.net.feature5(f4)
        return f1, f2, f3, f4, f5

In [3]:
net = StyleNet().to(device)
# net = net.requires_grad_(False)

In [4]:
from PIL import Image

In [5]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

content_img = cv2.imread("./content.jpg")
content = cv2.resize(content_img, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)[:,:,::-1] /255
content = (content - mean) / std

#https://lonelyplanet.co.kr/magazine/articles/articleLoad/AI_00002078?page=&keyType=

style_img = cv2.imread("./starry-night.jpg")
style = cv2.resize(style_img, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)[:,:,::-1] /255
style = (style - mean) / std
#https://pixabay.com/illustrations/starry-night-vincent-van-gough-1093721/

style_target = net(torch.FloatTensor(style).permute(2,0,1).unsqueeze(0).to(device))

ss = []
for s in style_target:
    s = s.detach()
    channel = s.shape[1]
    s = s.squeeze().view(channel, -1).contiguous()
    s /= s.shape[1]
    ss.append(s.matmul(s.T))

style_target = ss
    
content_target = net(torch.FloatTensor(content).permute(2,0,1).unsqueeze(0).to(device))    
for c in content_target:
    c = c.detach()
c1,c2,c3,c4,c5 = content_target

In [6]:
# img = torch.ones(torch.FloatTensor(content).permute(2,0,1).unsqueeze(0).shape)
# img.uniform_()
img = torch.FloatTensor(content)
img = torch.autograd.Variable(img, requires_grad=True)

EPOCH = 2000
LR = 0.1
alph = 1e-2
beta = 1e1
opt = optim.Adam([img], LR)

content_layer_index = 4

In [None]:
for epoch in range(EPOCH):
    feature_map = net(torch.FloatTensor(img).permute(2,0,1).unsqueeze(0).to(device))

    style_loss = 0
    for i, (x, s) in enumerate(zip(feature_map,style_target)):
        if i<2:
            continue
        channel = s.shape[1]
        x_ = x.squeeze().view(channel, -1).contiguous()
        x_ = x_/x_.shape[1]
        x_ = x_.matmul(x_.T)
        l = (channel**4)*((x_ - s) ** 2).mean() / 4
        style_loss += l

    style_loss *= beta
    content_loss = alph * ((feature_map[content_layer_index] - content_target[content_layer_index])**2).mean() / 2

    # loss = content_loss + style_loss
    loss =  content_loss + style_loss
    print(epoch, "%.10f %.10f"%(content_loss.cpu().item(), style_loss.cpu().item()))

    opt.zero_grad()
    loss.backward(retain_graph=True)
    opt.step()

0 0.0000000000 347347.2500000000
1 0.0027988318 256391.6875000000
2 0.0062368554 252887.6875000000
3 0.0031354078 231165.3437500000
4 0.0040318649 179812.5000000000
5 0.0069386843 161736.9531250000
6 0.0044806837 143054.7031250000
7 0.0053065503 110124.8281250000
8 0.0070891259 101148.3750000000
9 0.0054587936 84020.8750000000
10 0.0057142102 71924.9843750000
11 0.0065580225 65452.5742187500
12 0.0061211395 56039.3437500000
13 0.0059915110 52481.5703125000
14 0.0063636336 47297.6484375000
15 0.0063012801 43914.4882812500
16 0.0061083566 40915.3867187500
17 0.0062983399 38218.9765625000
18 0.0063438532 36213.5820312500
19 0.0061952374 34129.7890625000
20 0.0063624182 32354.7558593750
21 0.0063106422 30696.0195312500
22 0.0062381616 29391.2109375000
23 0.0063980254 28164.1992187500
24 0.0062955087 26880.4863281250
25 0.0063584135 25853.5000000000
26 0.0063999412 24919.8046875000
27 0.0063210055 24169.9121093750
28 0.0064473092 23351.2460937500
29 0.0063450509 22600.7734375000
30 0.006441

248 0.0066797775 5596.1396484375
249 0.0066591105 5490.5395507812
250 0.0067246007 5447.8564453125
251 0.0066527114 5508.9824218750
252 0.0067666802 5425.4804687500
253 0.0066063581 5521.9130859375
254 0.0067822337 5399.8085937500
255 0.0066075134 5537.6118164062
256 0.0067973910 5469.9169921875
257 0.0065611508 5534.5078125000
258 0.0068336600 5505.5278320312
259 0.0065981694 5406.6870117188
260 0.0067640669 5299.2924804688
261 0.0066766357 5216.0615234375
262 0.0066931047 5138.2734375000
263 0.0067289243 5170.2529296875
264 0.0066326833 5196.3339843750
265 0.0067927381 5232.4863281250
266 0.0065797055 5337.8330078125
267 0.0068380348 5360.2763671875
268 0.0065507363 5549.7656250000
269 0.0068483469 5352.9418945312
270 0.0066046142 5210.8520507812
271 0.0067481613 5066.0351562500
272 0.0067404965 5035.7836914062
273 0.0066287993 5095.5610351562
274 0.0068012807 5124.6328125000
275 0.0066153109 5199.6684570312
276 0.0068102130 5176.2939453125
277 0.0066170529 5183.8066406250
278 0.0067

496 0.0067929230 3747.1411132812
497 0.0067492360 3720.1508789062
498 0.0067919684 3696.3378906250
499 0.0067599951 3727.4157714844
500 0.0068096127 3719.0063476562
501 0.0067102173 3721.1130371094
502 0.0068809269 3848.7846679688
503 0.0065948614 4263.3925781250
504 0.0070199808 4670.2509765625
505 0.0064293691 5940.9516601562
506 0.0070624510 5259.6757812500
507 0.0065791546 4853.8173828125
508 0.0067903348 4493.6567382812
509 0.0068931356 5164.4414062500
510 0.0066356361 5169.8364257812
511 0.0067626536 4192.5834960938
512 0.0068748188 4834.4746093750
513 0.0066856048 4524.2900390625
514 0.0067470209 4276.9770507812
515 0.0067850924 4631.9926757812
516 0.0068252729 4206.8999023438
517 0.0065719243 4715.9433593750
518 0.0070024417 5021.8027343750
519 0.0064287810 7296.6162109375
520 0.0070914412 6866.9287109375
521 0.0064276690 6971.1450195312
522 0.0068840776 4786.0488281250
523 0.0069489651 5971.1665039062
524 0.0064688283 6073.4497070312
525 0.0070058452 6625.3691406250
526 0.0065

743 0.0060772686 35229.2343750000
744 0.0066245538 30444.6523437500
745 0.0068998002 26023.0429687500
746 0.0064622699 26886.4648437500
747 0.0068794345 24243.4414062500
748 0.0065924097 22123.9726562500
749 0.0065506515 21500.0175781250
750 0.0067340462 20536.0585937500
751 0.0064262073 19811.5878906250
752 0.0070028245 20607.9492187500
753 0.0060924441 22942.7734375000
754 0.0070401095 20162.8496093750
755 0.0064124158 18024.6894531250
756 0.0066686198 16315.9570312500
757 0.0069103166 16476.7949218750
758 0.0063560209 16533.3281250000
759 0.0067478744 14715.7763671875
760 0.0067988900 14168.5527343750
761 0.0064888881 14213.1142578125
762 0.0068314993 13439.3457031250
763 0.0066038808 12747.4238281250
764 0.0066686179 12149.2187500000
765 0.0067526218 11859.7060546875
766 0.0065604532 11714.6376953125
767 0.0068128346 11304.5781250000
768 0.0065195304 11216.3906250000
769 0.0067950306 10610.0439453125
770 0.0066193198 10252.0361328125
771 0.0067062951 9995.5019531250
772 0.006666249

990 0.0067500290 4175.2592773438
991 0.0068320832 4144.2065429688
992 0.0067502330 4100.3774414062
993 0.0068455911 4065.0566406250
994 0.0067832307 4019.7138671875
995 0.0068018129 3998.1838378906
996 0.0067970892 3963.4108886719
997 0.0068054805 3947.9433593750
998 0.0068069929 3916.2749023438
999 0.0067827837 3906.6840820312
1000 0.0068239104 3882.2490234375
1001 0.0067755328 3882.7551269531
1002 0.0068520759 3884.9736328125
1003 0.0067172442 3956.7812500000
1004 0.0069313035 4085.5693359375
1005 0.0066178548 4449.5209960938
1006 0.0070019397 4475.8427734375
1007 0.0066205161 4436.1879882812
1008 0.0068704002 3865.5080566406
1009 0.0069017485 3940.7734375000
1010 0.0066398531 4265.4687500000
1011 0.0068974942 3920.2355957031
1012 0.0068486375 3791.7260742188
1013 0.0066854986 4004.9614257812
1014 0.0068879561 3848.0593261719
1015 0.0068223113 3735.9877929688
1016 0.0067220200 3838.8293457031
1017 0.0068766689 3790.5371093750
1018 0.0068022786 3700.0454101562
1019 0.0067648622 3725.8

1231 0.0068926942 3308.1469726562
1232 0.0069572534 3741.5219726562
1233 0.0066361972 4128.2465820312
1234 0.0069194506 3509.4921875000
1235 0.0068277875 4021.5698242188
1236 0.0067892512 3394.9047851562
1237 0.0068462281 3540.8752441406
1238 0.0067767096 3805.1616210938
1239 0.0069018309 3317.3579101562
1240 0.0067401067 3751.2766113281
1241 0.0068657417 3237.1252441406
1242 0.0068614837 3434.6489257812
1243 0.0067679533 3337.4670410156
1244 0.0068854159 3337.2993164062
1245 0.0067622769 3481.2521972656
1246 0.0069154599 3310.2690429688
1247 0.0067165145 3481.9941406250
1248 0.0069100815 3222.1850585938
1249 0.0068065012 3273.9775390625
1250 0.0068282047 3121.0493164062
1251 0.0068547567 3181.3203125000
1252 0.0067877495 3162.6787109375
1253 0.0068976069 3183.9755859375
1254 0.0067456984 3201.3964843750
1255 0.0069037364 3151.0793457031
1256 0.0067902897 3128.8359375000
1257 0.0068568718 3047.6843261719
1258 0.0068241172 3045.5285644531
1259 0.0068115005 3018.0195312500
1260 0.0068794

1471 0.0068566343 2617.2675781250
1472 0.0068551549 2620.1030273438
1473 0.0068442067 2606.7778320312
1474 0.0068592909 2595.3259277344
1475 0.0068513993 2594.5942382812
1476 0.0068559838 2591.5673828125
1477 0.0068427059 2582.5410156250
1478 0.0068759564 2591.5678710938
1479 0.0068197786 2604.6123046875
1480 0.0069048414 2612.8830566406
1481 0.0067804810 2703.2304687500
1482 0.0069655911 2837.7551269531
1483 0.0066833841 3131.8774414062
1484 0.0070482623 3176.7753906250
1485 0.0066550439 3265.8427734375
1486 0.0069578909 2867.1845703125
1487 0.0068715643 2924.5837402344
1488 0.0067347507 3157.8291015625
1489 0.0069716368 3036.9746093750
1490 0.0067566107 2955.0507812500
1491 0.0068718791 2897.2231445312
1492 0.0068728980 2886.2844238281
1493 0.0068076318 2775.9736328125
1494 0.0068784915 2793.6181640625
1495 0.0068326211 2850.1713867188
1496 0.0068547428 2684.7128906250
1497 0.0068472954 2733.3508300781
1498 0.0068501616 2773.5981445312
1499 0.0068416712 2656.8930664062
1500 0.0068622

1711 0.0068156063 3672.5781250000
1712 0.0068970714 3674.8171386719
1713 0.0067613381 3643.5639648438
1714 0.0069788336 3812.7519531250
1715 0.0066282735 4091.2585449219
1716 0.0070667248 4214.6757812500
1717 0.0066155195 4598.4189453125
1718 0.0069919792 3696.2512207031
1719 0.0068517169 3493.4057617188
1720 0.0067380858 3547.5444335938
1721 0.0070022563 3716.7744140625
1722 0.0067234165 3545.2607421875
1723 0.0068864333 3337.9877929688
1724 0.0068913051 3259.0971679688
1725 0.0067627761 3330.3210449219
1726 0.0069231857 3233.9985351562
1727 0.0068078088 3209.2998046875
1728 0.0068685249 3092.1015625000
1729 0.0068710842 3060.3198242188
1730 0.0068207155 3041.8310546875
1731 0.0069151972 3046.2878417969
1732 0.0067970911 3012.9208984375
1733 0.0069082105 2965.4235839844
1734 0.0068229614 2934.6684570312
1735 0.0068741320 2882.0861816406
1736 0.0068441485 2864.2802734375
1737 0.0068525481 2823.9794921875
1738 0.0068709622 2825.1840820312
1739 0.0068333731 2794.1118164062
1740 0.0068969

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(img.detach().cpu().numpy() * std + mean)

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(content*std + mean)