In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models import vgg19, VGG19_Weights
from PIL import Image
import matplotlib.pyplot as plt
from torch.optim import Adam
from torchvision.utils import save_image

def style_transfer(content_img_path, style_img_path, stylized_name, num_steps=300, content_weight=1e5, style_weight=1e10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Function to load an image from the path and prepare it for processing
    def load_image(img_path, size=512, scale=None):
        image = Image.open(img_path).convert('RGB')
        if scale:
            size = int(scale * min(image.size))
        loader = transforms.Compose([
            transforms.Resize((size, size)),  # scale imported image
            transforms.ToTensor(),  # transform it into a torch tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)

    # Load content and style images
    content_img = load_image(content_img_path)
    style_img = load_image(style_img_path, scale=0.5)

    # Load the pre-trained VGG19 model
    vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
    for param in vgg.parameters():
        param.requires_grad = False

    # Move the model to GPU, if available
    vgg.to(device)

    # Function to extract features from the layers
    def get_features(image, model, layers=None):
        if layers is None:
            layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1', '21': 'conv4_2', '28': 'conv5_1'}
        features = {}
        x = image
        for name, layer in model._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
        return features

    # Function to calculate the Gram matrix of an image
    def gram_matrix(tensor):
        _, d, h, w = tensor.size()
        tensor = tensor.view(d, h * w)
        gram = torch.mm(tensor, tensor.t())
        return gram

    # Get content and style features only once before forming the target image
    content_features = get_features(content_img, vgg)
    style_features = get_features(style_img, vgg)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

    # Create a 'target' image and clone the content image
    target = content_img.clone().requires_grad_(True).to(device)

    # Optimizer
    optimizer = Adam([target], lr=0.003)

    # Style transfer process
    for i in range(1, num_steps + 1):
        target_features = get_features(target, vgg)
        content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)

        style_loss = 0
        for layer in style_grams:
            target_feature = target_features[layer]
            target_gram = gram_matrix(target_feature)
            _, d, h, w = target_feature.shape
            style_gram = style_grams[layer]
            layer_style_loss = torch.mean((target_gram - style_gram)**2)
            style_loss += layer_style_loss / (d * h * w)

        total_loss = content_weight * content_loss + style_weight * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print('Step {}, Total loss: {}'.format(i, total_loss.item()))

    # Convert the tensor image to PIL image for saving
    unloader = transforms.ToPILImage()
    final_img = target.cpu().clone().squeeze(0)
    final_img = unloader(final_img)

    save_image(final_img, stylized_name)

    return final_img

# Use the function
final_img = style_transfer('raw_data/images/content/astronaut.png', 'raw_data/images/style/The_Scream_S.jpg', 'raw_data/images/stylized/astronaut_stylized.jpg')
final_img.show()