In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import vgg19
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Ruta a la imagen que desea transformar
target_image_path = '/content/drive/MyDrive/Transfer style con Pytorch/miFoto.jpg'
# Ruta de acceso a la imagen de estilo
style_reference_image_path = '/content/drive/MyDrive/Transfer style con Pytorch/styles/Van gogh.jpg'
# Dimensiones de la imagen generada
width, height = Image.open(target_image_path).size
img_height = 400
img_width = int(width * img_height / height)

In [3]:
# Preprocesamiento y desprocesamiento
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)
    return image

def deprocess_image(tensor):
    tensor = tensor.clone().detach().cpu()
    tensor = tensor.numpy().squeeze()
    tensor = tensor.transpose(1, 2, 0)
    tensor[:, :, 0] += 0.485 * 255
    tensor[:, :, 1] += 0.456 * 255
    tensor[:, :, 2] += 0.406 * 255
    tensor = tensor[:, :, ::-1]
    tensor = np.clip(tensor, 0, 255).astype('uint8')
    return tensor

# Cargar imágenes
target_image = preprocess_image(target_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = target_image.clone().requires_grad_(True)

# Modelo VGG19
cnn = vgg19(pretrained=True).features.eval()

# Función de pérdida de contenido
def content_loss(base, combination):
    return torch.mean((combination - base) ** 2)

# Función de pérdida de estilo
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    G = torch.mm(features, features.t())
    return G.div(b * c * h * w)

def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    return torch.mean((S - C) ** 2)

# Función de pérdida de variación total
def total_variation_loss(x):
    b, c, h, w = x.size()
    a = torch.mean(torch.abs(x[:, :, :-1, :-1] - x[:, :, 1:, :-1]))
    b = torch.mean(torch.abs(x[:, :, :-1, :-1] - x[:, :, :-1, 1:]))
    return a + b

# Capa utilizada para la pérdida de contenido
content_layer = '21'
# Capas utilizadas para la pérdida de estilo
style_layers = ['0', '5', '10', '19', '28']
# Ponderaciones en la media ponderada de los componentes de la pérdida
total_variation_weight = 1e-4
style_weight = 1.0
content_weight = 0.025

# Extraer características
def get_features(image, model, layers=None):
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[name] = x
    return features

target_features = get_features(target_image, cnn, layers={content_layer})
style_features = get_features(style_reference_image, cnn, layers=style_layers)

# Optimizador
optimizer = optim.LBFGS([combination_image])

# Función de actualización
def closure():
    combination_features = get_features(combination_image, cnn, layers={content_layer} | set(style_layers))

    loss = content_weight * content_loss(target_features[content_layer], combination_features[content_layer])
    for layer in style_layers:
        loss += (style_weight / len(style_layers)) * style_loss(style_features[layer], combination_features[layer])
    loss += total_variation_weight * total_variation_loss(combination_image)

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    return loss

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:08<00:00, 68.0MB/s]


In [4]:
# Iteraciones de optimización
iterations = 5
for i in range(iterations):
    optimizer.step(closure)
    print(f'Iteration {i+1} completed')

# Guardar la imagen generada
final_img = deprocess_image(combination_image)
plt.imsave('result.png', final_img)

Iteration 1 completed
Iteration 2 completed
Iteration 3 completed
Iteration 4 completed
Iteration 5 completed


In [5]:
# Guardar Modelo
import torch
from torchvision.models import vgg19
import torch.optim as optim

# Definir el modelo y el optimizador
cnn = vgg19(pretrained=True).features.eval()
optimizer = optim.LBFGS([combination_image])

# Guardar el estado del modelo y el optimizador
checkpoint = {
    'model_state_dict': cnn.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'target_image': target_image,
    'style_reference_image': style_reference_image,
    'combination_image': combination_image,
}

torch.save(checkpoint, 'style_transfer_checkpoint.pth')


In [6]:
#Cargar Modelo
import torch
from torchvision.models import vgg19
import torch.optim as optim

# Cargar el punto de control
checkpoint = torch.load('style_transfer_checkpoint.pth')

# Definir el modelo y el optimizador
cnn = vgg19(pretrained=True).features.eval()
optimizer = optim.LBFGS([checkpoint['combination_image']])

# Restaurar los estados
cnn.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Restaurar las imágenes
target_image = checkpoint['target_image']
style_reference_image = checkpoint['style_reference_image']
combination_image = checkpoint['combination_image']

# No olvides poner el modelo en modo de evaluación si solo vas a inferir
cnn.eval()


Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo