In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn,optim
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [2]:
content_path=r'E:\Study\ml\dataset\LM\13\rainier.jpg'
style_path=r'E:\Study\ml\dataset\LM\13\autumn-oak.jpg'

In [3]:
content_img=plt.imread(content_path)
style_img=plt.imread(style_path)

In [4]:
plt.subplots(12)
plt.subplot(121)
plt.imshow(content_img)
plt.subplot(122)
plt.imshow(style_img)
plt.show()


In [5]:
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

In [6]:
def preprocess(img,image_shape):
    transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std)
    ])
    return transforms(img).unsqueeze(0)

In [7]:
def postprocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))

In [8]:
pretrained_net=torchvision.models.vgg19(pretrained=True)



In [9]:
style_layers,content_layers=[0,5,10,19,28],[25]

In [10]:
net=nn.Sequential(*[pretrained_net.features[i] for i in 
                   range(max(content_layers+style_layers)+1)])

In [15]:
def extract_features(X,content_layers,style_layers):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents,styles

In [16]:
def get_contents(image_shape,device):
    content_X=preprocess(content_img,image_shape).to(device)
    contents_Y,_=extract_features(content_X,content_layers,style_layers)
    return content_X,contents_Y
    

In [17]:
def get_styles(image_shape,device):
    style_X=preprocess(style_img,image_shape).to(device)
    _,styles_Y=extract_features(style_X,content_layers,style_layers)
    return style_X,styles_Y

#### content loss

In [18]:
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()

#### style loss

In [19]:
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()

In [20]:
def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,-1,:]).mean()+
                 torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())

In [21]:
content_weight,style_weight,tv_weight=1,1e3,10

In [22]:
def compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram):
    contents_l=[content_loss(Y_hat,Y) * content_weight for Y_hat,Y in
               zip(contents_Y_hat,contents_Y)]
    styles_l=[style_loss(Y_hat,Y) * style_weight for Y_hat,Y in 
             zip(styles_Y_hat,styles_Y_gram)]
    tv_l=tv_loss(X) * tv_weight
    l=sum(10*styles_l+contents_l+[tv_l])
    return contents_l,styles_l,tv_l,l

In [23]:
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super(SynthesizedImage,self).__init__(**kwargs)
        self.weight=nn.Parameter(torch.rand(*img_shape))
    
    def forward(self):
        return self.weight

In [24]:
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]
    return gen_img(),styles_Y_gram,trainer

In [26]:
def train(X,contents_Y,styles_Y,device,lr,num_epochs,lr_decay_epoch):
    X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
    scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat,styles_Y_hat=extract_features(
        X,content_layers,style_layers)
        contents_l,styles_l,tv_l,l=compute_loss(
        X,contents_Y_hat,styles_Y_hat,contens_Y,style_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if(epoch+1) %10 ==0:
            printf("contens loss: %.2f",float(sum(contents_l)))
            printf("styles loss: %.2f",float(sum(styles_l)))
            printf("tv loss: %.2f",float(sum(tv_l)))
    return X