<a href="https://colab.research.google.com/github/jianlgler/IST_labiagi/blob/main/PyTorch_AdaIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

from ipywidgets import interact, interactive

import ipywidgets as widgets

import os
import sys

from google.colab import drive

import matplotlib.pyplot as plt

In [74]:
drive.mount("/content/data")

!cp -r data/MyDrive/Utils Utils

Drive already mounted at /content/data; to attempt to forcibly remount, call drive.mount("/content/data", force_remount=True).


In [75]:
from Utils import net, utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [76]:
# Path stuff
path = "./data/MyDrive/images"
sys.path.append(path)

style_dir_raw = "style"
content_dir_raw = "content"

style_dir = os.path.join(path, style_dir_raw)
content_dir = os.path.join(path, content_dir_raw)

In [77]:
def style_transfer(vgg, decoder, content, style, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    
    content_f = vgg(content)
    style_f = vgg(style)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = utils.ada_in(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        feat = utils.ada_in(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)

In [83]:
transform = transforms.Compose([transforms.Resize(512), transforms.ToTensor()])

# setting up the net
decoder = net.decoder
vgg = net.vgg

decoder.eval()
vgg.eval()

decoder.load_state_dict(torch.load("./data/MyDrive/decoder.pth"))
vgg.load_state_dict(torch.load("./data/MyDrive/vgg_normalised.pth"))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder.to(device)

print("Done!")
# Parameters
#do_interpolation = False

#interpolation_weights = 1,1,1,1


Done!


In [84]:
def stylize(content, c_path, style, s_path, preserve_color, alpha):
  # Pre-operations

  content_img = transform(Image.open(os.path.join(c_path, content)))
  style_img = transform(Image.open(os.path.join(s_path, style)))

  if(preserve_color):
      style_img = utils.preserve_color(style_img, content_img)

  style_img = style_img.to(device).unsqueeze(0)
  content_img = content_img.to(device).unsqueeze(0)
  ###########################################################
  with torch.no_grad():
    output = style_transfer(vgg, decoder, content_img, style_img,
                            alpha)
  ###########################################################
  color_add = ""
  if preserve_color:
      color_add = "preserved"
  output = output.cpu()
  output_name = '{:s}_stylized_{:s}_{:s}_{:s}'.format(os.path.splitext(content)[0], 
                                                os.path.splitext(style)[0], str(alpha), color_add)
  output_name += ".jpg"

  print(output_name)

  output_dir = os.path.join(path, "output")
  
  save_image(output, os.path.join(output_dir, output_name))
  utils.display(output_name, output_dir)
  

In [None]:
# Set up a subplot grid that has height 2 and width 1,
# and set the first such subplot as active.
interact(utils.display, x = os.listdir(content_dir), path = content_dir)
interact(utils.display, x = os.listdir(style_dir), path = style_dir)

In [None]:
interact(stylize, content=os.listdir(content_dir), 
         c_path=content_dir, style=os.listdir(style_dir), s_path=style_dir,
         preserve_color=False, alpha=widgets.FloatSlider(min=0, max=1.0, step=0.01, value=1))