In [None]:
#importações necessárias 
from torch_snippets import *
from torchvision import transforms as T
from torch.nn import functional as F
from torchvision.models import vgg19  #modelo pré-treinado import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'  

In [None]:
#Transformações para preparar as imagens
preprocess = T.Compose([
    T.ToTensor(),  # Converte para tensor
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),  #Normalização padrão ImageNet
    T.Lambda(lambda x: x.mul_(255))  #Multiplica por 255
])

#transformações para reverter o pré-processamento 
postprocess = T.Compose([
    T.Lambda(lambda x: x.mul_(1./255)),  # Divide por 255
    T.Normalize(mean=[-0.485/0.229, -0.456/0.224,-0.406/0.225], std=[1/0.229,1/0.224,1/0.255]),  # Desnormaliza
])

In [None]:
#Classe para calcular a Matriz de Gram 
class GramMatrix(nn.Module):
    def forward(self,input):
        b,c,h,w = input.size()
        feat = input.view(b,c,h*w)  #Achata as dimensões espaciais
        G = feat@feat.transpose(1,2)  
        G.div_(h*w)  #Normaliza pelo tamanho
        return G


class GramMSELoss(nn.Module):
    def forward(self,input,target):
        out = F.mse_loss(GramMatrix()(input),target)  #MSE entre matrizes de Gram
        return(out)

#Classe para extrair features de camadas específicas 
class vgg19_modified(nn.Module):
    def __init__(self):
        super().__init__()
        features = list(vgg19(pretrained=True).features)  #Carrega VGG19 pré-treinada
        self.features = nn.ModuleList(features).eval()  
    
    def forward(self, x, layers=[]):
        if not layers:
            return x
        order = np.argsort(layers)
        _results, results = [], []
        
        #passa pela rede e extrai features das camadas especificadas
        for ix, model in enumerate(self.features):
            x = model(x)
            if ix in layers: _results.append(x)
        for o in order: results.append(_results[o])
        return results if layers is not [] else x

In [None]:
#Instancia o modelo VGG19 modificado e move para o dispositivo
vgg = vgg19_modified().to(device)

In [None]:
!wget https://easydrawingguides.com/wp-content/uploads/2016/10/how-to-draw-an-elephant-featured-image-1200.png
!wget https://www.neh.gov/sites/default/files/2022-09/Fall_2022_web-images_Picasso_32.jpg

--2025-07-29 16:45:34--  https://easydrawingguides.com/wp-content/uploads/2016/10/how-to-draw-an-elephant-featured-image-1200.png
2606:4700::6810:966c, 2606:4700::6810:976c, 104.16.150.108, ...
Conectando-se a easydrawingguides.com (easydrawingguides.com)|2606:4700::6810:966c|:443... conectado.
A requisição HTTP foi enviada, aguardando resposta... 200 OK
Tamanho: 56936 (56K) [image/png]
Salvando em: ‘how-to-draw-an-elephant-featured-image-1200.png’


2025-07-29 16:45:35 (132 MB/s) - ‘how-to-draw-an-elephant-featured-image-1200.png’ salvo [56936/56936]

--2025-07-29 16:45:35--  https://www.neh.gov/sites/default/files/2022-09/Fall_2022_web-images_Picasso_32.jpg
23.21.228.79ww.neh.gov (www.neh.gov)... 
conectado.-se a www.neh.gov (www.neh.gov)|23.21.228.79|:443... 
A requisição HTTP foi enviada, aguardando resposta... 200 OK
Tamanho: 5309491 (5,1M) [image/jpeg]
Salvando em: ‘Fall_2022_web-images_Picasso_32.jpg’


2025-07-29 16:45:37 (4,83 MB/s) - ‘Fall_2022_web-images_Picasso_32.jpg’ salv

In [None]:
#Carrega as imagens, redimensiona para 512x512 
imgs = [Image.open(path).resize((512,512)).convert('RGB') for path in ['picasso.jpg',
                                                                       'elephant.png']]
style_image, content_image = [preprocess(img).to(device)[None] for img in imgs]

In [None]:
#Cria uma cópia da imagem de conteúdo 
opt_img = content_image.data.clone()
opt_img.requires_grad = True  # Habilita gradientes para otimização

In [None]:
#Define quais camadas da VGG19 usar para extrair features
style_layers = [0, 5, 10, 19, 28]  
content_layers = [21]  
loss_layers = style_layers + content_layers  # Todas as camadas usadas

In [None]:
#define as funções de perda
loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers)
loss_fns = [loss_fn.to(device) for loss_fn in loss_fns]  

In [None]:
#define os pesos para balancear as perdas de estilo e conteúdo
style_weights = [1000/n**2 for n in [64,128,256,512,512]]  
content_weights = [1]  
weights = style_weights + content_weights  

In [None]:
#extrai as features das imagens de referência que servirão como alvos
style_target = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)] 
content_targets = [A.detach() for A in vgg(content_image, content_layers)]
targets = style_target + content_targets  #combina todos os alvos

In [None]:
import torch.optim as optim
#Configuração do otimizador 
max_iters = 500  
optimizer = optim.LBFGS([opt_img]) 

In [None]:
#loop principal de otimização
iters = 0
while iters < max_iters:
    def closure():
        global iters
        iters += 1
        optimizer.zero_grad()
        out = vgg(opt_img, loss_layers)
        # Calcula perdas ponderadas para cada camada
        layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)]
        loss = sum(layer_losses)  #soma todas as perdas
        loss.backward()  #calcula gradientes
        return loss
    optimizer.step(closure)  