# Style Transfer

In [None]:
!gdown 16pB7xg-XYJAtL6g4BqNy-ervrKnEnnz0

Downloading...
From: https://drive.google.com/uc?id=16pB7xg-XYJAtL6g4BqNy-ervrKnEnnz0
To: /content/images.zip
100% 7.75M/7.75M [00:00<00:00, 25.1MB/s]


In [2]:
!unzip /content/images.zip -d ./

Archive:  /content/images.zip
   creating: ./images/output/
  inflating: ./images/content/content.jpg  
  inflating: ./images/style/style.jpg  


In [3]:
import sys, os, distutils.core

!git clone 'https://github.com/facebookresearch/detectron2'
dist = distutils.core.run_setup("./detectron2/setup.py")
!python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])}
sys.path.insert(0, os.path.abspath('./detectron2'))

Cloning into 'detectron2'...
remote: Enumerating objects: 15806, done.[K
remote: Counting objects: 100% (63/63), done.[K
remote: Compressing objects: 100% (54/54), done.[K
remote: Total 15806 (delta 22), reused 38 (delta 9), pack-reused 15743 (from 1)[K
Receiving objects: 100% (15806/15806), 6.38 MiB | 5.63 MiB/s, done.
Resolving deltas: 100% (11516/11516), done.
Ignoring dataclasses: markers 'python_version < "3.7"' don't match your environment
Collecting yacs>=0.1.8
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting fvcore<0.1.6,>=0.1.5
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath<0.1.10,>=0.1.7
  Downloading iopath-0.1.9-py3-none-any.whl.metadata (370 bytes)
Collecting omegaconf<2.4,>=2.1
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)


## Setup

In [4]:
import os
import json
from cycler import cycler
import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog

import numpy as np
import json
from cv2 import imread, imwrite
import matplotlib.pyplot as plt

In [5]:
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

## Helper functions

In [6]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def preprocess(img, size=512):
    transform = T.Compose([
        T.ToTensor(),
        T.Resize(size),
        T.Normalize(mean=IMAGENET_MEAN.tolist(),
                    std=IMAGENET_STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img).to(dtype=dtype, device=device)

def deprocess(img):
    transform = T.Compose([
        T.Lambda(lambda x: x[0]),
        T.Normalize(mean=[0, 0, 0], std=[1.0 / s for s in IMAGENET_STD.tolist()]),
        T.Normalize(mean=[-m for m in IMAGENET_MEAN.tolist()], std=[1, 1, 1]),
        T.Lambda(clamp),
        T.Lambda(to_np_uint8),
    ])
    return transform(img)

def clamp(x):
    return x.data.clamp_(0.0, 1.0)

def to_np_uint8(x):
    x = np.round(x.cpu().detach().numpy() * 255.0).astype('uint8')
    x = np.transpose(x, [1, 2, 0])
    return x

def rescale(x):
    low, high = x.min(), x.max()
    x_rescaled = (x - low) / (high - low)
    return x_rescaled

## Device

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

print('Current setting for torch:', device)

Current setting for torch: cuda


In [8]:
torch.__version__

'2.5.1+cu121'

## load Model

In [9]:
backbone = torchvision.models.vgg19(pretrained=True)

train_nodes, eval_nodes = get_graph_node_names(backbone)

# https://pytorch.org/vision/stable/models.html
from pprint import pprint
print(eval_nodes)
print(backbone.features)

return_nodes = []
target_name = 'features'

for node in eval_nodes:
    if target_name in node:
        return_nodes.append(node)

print(f'There are the {len(return_nodes)} intermediate [{target_name}] layers.')

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 75.4MB/s]


['x', 'features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'features.13', 'features.14', 'features.15', 'features.16', 'features.17', 'features.18', 'features.19', 'features.20', 'features.21', 'features.22', 'features.23', 'features.24', 'features.25', 'features.26', 'features.27', 'features.28', 'features.29', 'features.30', 'features.31', 'features.32', 'features.33', 'features.34', 'features.35', 'features.36', 'avgpool', 'flatten', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3', 'classifier.4', 'classifier.5', 'classifier.6']
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, k

In [10]:
# Load the pre-trained model.
backbone = create_feature_extractor(backbone, return_nodes=return_nodes)
backbone = backbone.to(device=device, dtype=dtype)

# turn off the computational graph for the gradidne to save calculation
for param in backbone.parameters():
    param.requires_grad = False

def extract_features(x, backbone):
    x = x.to(device=device, dtype=dtype)
    return backbone(x)

In [11]:
node_name_to_idx = {n: i for i, n in enumerate(return_nodes)}
node_idx_to_name = {i: n for i, n in enumerate(return_nodes)}

## load Panoptic segmentation model

In [12]:
def conduct_segmentation(style_image):
    # Inference with a panoptic segmentation model
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
    segmentator = DefaultPredictor(cfg)

    im = imread(style_image)
    panoptic_seg, segments_info = segmentator(im)["panoptic_seg"]

    for seg_in in segments_info:
        if seg_in['isthing']:
            seg_in['class_name'] = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[seg_in['category_id']]
        else:
            seg_in['class_name'] = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).stuff_classes[seg_in['category_id']]

    return {'seg_out': panoptic_seg.to("cpu"), 'seg_info': segments_info}

## Computing Loss

### Content loss

In [13]:
def content_loss(feats, content_layer_indices, content_targets, content_weights):
    loss = torch.Tensor([0.0]).to(device=device, dtype=dtype)

    for (layer_id, target, weight) in zip(content_layer_indices, content_targets, content_weights):
        loss += weight * torch.sum((feats[node_idx_to_name[layer_id]] - target)**2) / feats[node_idx_to_name[layer_id]].numel()

    return loss

### Style loss

In [14]:
def gram_matrix(features, normalize=True):
    N, C, H, W = features.shape
    feat = features.transpose(0, 1).reshape(C, -1)
    gram = feat.mm(feat.t())

    if normalize == True:
        gram /= N * C * H * W

    return gram

def gram_matrix_weighted(features, weights, normalize=True):
    N, C, H, W = features.shape

    wei = weights.repeat(N, C, 1, 1).to(dtype=torch.float32)
    features = features * wei

    feat = features.transpose(0, 1).reshape(C, -1)
    gram = feat.mm(feat.t())

    if normalize == True:
        gram /= (wei ** 2).sum()

    return gram

def style_loss(feats, style_layer_indices, style_targets, style_weights):
    loss = torch.Tensor([0.0]).to(device=device, dtype=dtype)

    for (layer_id, target_gram, weight) in zip(style_layer_indices, style_targets, style_weights):
        gram = gram_matrix(feats[node_idx_to_name[layer_id]])
        loss += weight * torch.sum((target_gram - gram)**2)

    return loss

### Total-variation regularization

In [15]:
def tv_loss(img, tv_weight):
    h_filter = torch.Tensor([
        [[[1], [-1]], [[0], [ 0]], [[0], [ 0]]],
        [[[0], [ 0]], [[1], [-1]], [[0], [ 0]]],
        [[[0], [ 0]], [[0], [ 0]], [[1], [-1]]],
    ]).to(device=device, dtype=dtype)
    h_loss = torch.mean(torch.nn.functional.conv2d(img, h_filter, None, stride=1, padding=0)**2)

    w_filter = h_filter.transpose(2, 3)
    w_loss = torch.mean(torch.nn.functional.conv2d(img, w_filter, None, stride=1, padding=0)**2)
    return tv_weight * (h_loss + w_loss)

## Style transfer function

In [16]:
def style_transfer(content_image, style_image,
                   image_size, style_size,
                   content_layer_indices, content_weights,
                   style_layer_indices, style_weights,
                   tv_weight, init='random', start_lr=3.0, clamp_every=2700,
                   n_iters=5000, decay_every=900, decay_ratio=0.3, print_every=1000, save_folder=None, visualize=True):
    # Extract features for the content image
    content_img = imread(content_image)[:, :, ::-1].copy()

    content_img = preprocess(content_img, size=image_size)
    feats = extract_features(content_img, backbone)
    content_targets = []
    for idx in content_layer_indices:
        content_targets.append(feats[node_idx_to_name[idx]].clone())

    # Extract features for the style image
    style_img = imread(style_image)[:, :, ::-1].copy()
    style_img = preprocess(style_img, size=style_size)
    feats = extract_features(style_img, backbone)

    # Exploit panoptic segmentation result
    panoptic_out = conduct_segmentation(style_image)
    seg_out = panoptic_out['seg_out'].reshape(1, 1, *panoptic_out['seg_out'].shape)
    weights = torch.ones_like(seg_out).to(dtype=torch.float32)

    for seg_info in panoptic_out['seg_info']:
        if seg_info['class_name'] in ['person', 'sky', 'horse', 'backpack', 'umbrella', ]:
            weights[seg_out == seg_info['id']] = 0
    weights[seg_out == 0] = 0.3

    if visualize:
        f, axarr = plt.subplots(1, 1, figsize=(5.0, 5.0), tight_layout=True)
        axarr.axis('off')
        axarr.imshow((weights.reshape(weights.shape[2], weights.shape[3], 1).repeat(1, 1, 3) * 255).to(dtype=torch.int32).cpu().numpy())
        plt.show()

    style_targets = []
    for idx in style_layer_indices:
        _, _, H, W = feats[node_idx_to_name[idx]].shape
        resize = torchvision.transforms.Resize((H, W), torchvision.transforms.InterpolationMode.NEAREST)
        style_targets.append(gram_matrix_weighted(feats[node_idx_to_name[idx]].clone(), resize(weights.clone()).to(device)))
        # style_targets.append(gram_matrix(feats[node_idx_to_name[idx]].clone()))

    # Initialize output image to content image or nois
    if init == 'random':
        img = torch.Tensor(content_img.size()).uniform_(0, 1)
    elif init in ['content', 'contents']:
        img = content_img.clone()
    else:
        raise ValueError("style_transfer(init) takes as input among ['random', 'content']")

    img = img.to(dtype=dtype, device=device)

    # Turn on the computation graph for gradient calculation
    img.requires_grad_()

    # Note that we are optimizing the pixel values of the image by passing
    # in the img Torch tensor, whose requires_grad flag is set to True
    optimizer = torch.optim.Adam([img], lr=start_lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, decay_every, decay_ratio)

    if visualize:
        f, axarr = plt.subplots(1, 2, figsize=(10.0, 5.0), tight_layout=True)
        axarr[0].axis('off')
        axarr[1].axis('off')
        axarr[0].set_title('Content Source Img.')
        axarr[1].set_title('Style Source Img.')
        axarr[0].imshow(deprocess(content_img))
        axarr[1].imshow(deprocess(style_img))
        plt.show()

    for t in range(n_iters):
        # clean graph
        optimizer.zero_grad()

        # Compute loss
        feats = extract_features(img, backbone)
        c_loss = content_loss(feats, content_layer_indices, content_targets, content_weights)
        s_loss = style_loss(feats, style_layer_indices, style_targets, style_weights)
        t_loss = tv_loss(img, tv_weight)
        loss = c_loss + s_loss + t_loss

        # Perform gradient descents on our image values
        loss.backward()
        optimizer.step()
        scheduler.step()

        # display
        if t % print_every == 0:
            img_print = deprocess(img)

            if visualize:
                f, axarr = plt.subplots(1, 1, figsize=(5.0, 5.0), tight_layout=True)
                print('Iteration {}'.format(t + 1))
                axarr.axis('off')
                axarr.imshow(img_print)
                plt.show()

            if save_folder:
                imwrite(os.path.join(save_folder, f'i{t:05d}.jpg'), img_print[:, :, ::-1].copy())

        # clip too large values
        if (t) % round(clamp_every):
            img.data.clamp_(-5.0, 5.0)

    img_print = deprocess(img)

    if visualize:
        f, axarr = plt.subplots(1, 1, figsize=(5.0, 5.0), tight_layout=True)
        axarr.axis('off')
        axarr.imshow(img_print)
        plt.show()

    if save_folder:
        imwrite(os.path.join(save_folder, f'{t:05d}.jpg'), img_print[:, :, ::-1].copy())

- `style_transfer()` 함수의 입력 인자 설명

| params                | 설명                                                                         |   |   |   |
|-----------------------|------------------------------------------------------------------------------|---|---|---|
| content_image         | Content image 파일 경로 및 이름                                              |   |   |   |
| style_image           | Style image 파일 경로 및 이름                                                |   |   |   |
| image_size            | 생성할 Image size                                                            |   |   |   |
| style_size            | Style을 계산할 Image size (Style에서 가져오고 싶은 주요 Pattern 크기에 영향) |   |   |   |
| content_layer_indices | Content loss를 적용할 계층 번호 [0, len(return_nodes) - 1] 범위의 정수       |   |   |   |
| content_weights       | content_layer_indices에 언급한 계층별로 적용할 Weight 강도                   |   |   |   |
| style_layer_indices   | Style loss를 적용할 계층 번호 [0, len(return_nodes) - 1] 범위의 정수         |   |   |   |
| style_weights         | style_layer_indices에 언급한 계층별로 적용할 Weight 강도                     |   |   |   |
| tv_weight             | Total variation 강도                                                         |   |   |   |
| init                  | 초기화 방법 (random 또는 content)                                            |   |   |   |
| start_lr              | 초기 Learning rate                                                           |   |   |   |
| n_iters               | 반복 횟수                                                                    |   |   |   |
| decay_every           | Learning rate 감소 주기                                                      |   |   |   |
| decay_ratio           | Learning rate 감소율                                                         |   |   |   |
| clamp_every           | 영상 범위 제한 적용 주기 (픽셀 표현 범위를 벗어나는 값 억제)                 |   |   |   |
| print_every           | 출력 주기                                                                    |   |   |   |

## Run neural style transfer!

In [18]:
stating_index = 0  # TODO: Modify this one if wanting resume the creating

content_img_cycler = cycler(content_image=glob.glob(r'/content/images/content/*.jpg'))
style_img_cycler = cycler(style_image=glob.glob(r'/content/images/style/*.jpg'))
style_cycler = cycler(style_weights=[(1e+3, 1e+3, 1e+3, 1e+3, 1e+3), ])
content_cycler = cycler(content_weights=[(3e+2, 3e+2, 0, 0, 0)])
                                        #  (1e+3, 1e+3, 0, 0, 0),
                                        #  (3e+3, 3e+3, 0, 0, 0),
                                        #  (0, 3e+2, 3e+2, 0, 0),
                                        #  (0, 1e+3, 1e+3, 0, 0),
                                        #  (0, 3e+3, 3e+3, 0, 0)])
tv_cycler = cycler(tv_weight=[1e+3])
lr_cycler = cycler(start_lr=[5.0])


total_cycler = content_img_cycler * style_img_cycler * style_cycler * content_cycler * tv_cycler * lr_cycler

for i, c in tqdm(enumerate(total_cycler), total=len(total_cycler)):
    if  i < stating_index:
        continue

    params = {
        'image_size' : 192,
        'style_size' : 192,
        'content_layer_indices' : (0, 5, 10, 14, 19),     # conv_1, 3, 5 7, 9
        'style_layer_indices' : (0, 5, 10, 14, 19),       # conv_1, 3, 5 7, 9
        'init': 'content',
        'n_iters': 10000,
        'decay_every': 2300,
        'decay_ratio': 0.3,
        'clamp_every': 6800,
        'print_every': 2000,
        **c
    }
    save_folder_base = r'/content/images/output'
    folder_num = 1
    while True:
        save_folder = os.path.join(save_folder_base, f'{folder_num:05d}')
        if os.path.exists(save_folder):
            folder_num += 1
        else:
            break

    os.makedirs(save_folder, exist_ok=True)
    with open(os.path.join(save_folder, 'setting.json'), 'w') as fp:
        json.dump({'index': i, **params}, fp)

    style_transfer(**params, save_folder=save_folder, visualize=False)

  0%|          | 0/1 [00:00<?, ?it/s]

[11/18 14:09:35 d2.checkpoint.detection_checkpoint]: [DetectionCheckpointer] Loading from https://dl.fbaipublicfiles.com/detectron2/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x/139514519/model_final_cafdb1.pkl ...



model_final_cafdb1.pkl: 0.00B [00:00, ?B/s][A
model_final_cafdb1.pkl:   0%|          | 8.19k/261M [00:00<1:55:20, 37.7kB/s][A
model_final_cafdb1.pkl:   0%|          | 647k/261M [00:00<01:42, 2.53MB/s]   [A
model_final_cafdb1.pkl:   3%|▎         | 8.64M/261M [00:00<00:08, 30.1MB/s][A
model_final_cafdb1.pkl:   9%|▉         | 24.0M/261M [00:00<00:03, 71.8MB/s][A
model_final_cafdb1.pkl:  18%|█▊        | 47.2M/261M [00:00<00:02, 105MB/s] [A
model_final_cafdb1.pkl:  24%|██▍       | 62.9M/261M [00:00<00:01, 117MB/s][A
model_final_cafdb1.pkl:  29%|██▊       | 74.5M/261M [00:00<00:01, 116MB/s][A
model_final_cafdb1.pkl:  32%|███▏      | 84.1M/261M [00:00<00:01, 111MB/s][A
model_final_cafdb1.pkl:  36%|███▌      | 94.4M/261M [00:01<00:01, 108MB/s][A
model_final_cafdb1.pkl:  40%|███▉      | 104M/261M [00:01<00:01, 101MB/s] [A
model_final_cafdb1.pkl:  44%|████▎     | 114M/261M [00:01<00:01, 102MB/s][A
model_final_cafdb1.pkl:  48%|████▊     | 126M/261M [00:01<00:01, 106MB/s][A
model_fin