In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
from deepsvg.svglib.svg import SVG

from deepsvg import utils
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.utils import to_gif
from deepsvg.svglib.geom import Bbox
from deepsvg.svgtensor_dataset import SVGTensorDataset, load_dataset
from deepsvg.utils.utils import batchify, linear

import torch

# DeepSVG interpolation between pairs of icons

In [4]:
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") 

Load the pretrained model and dataset

In [5]:
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
from configs.deepsvg.hierarchical_ordered import Config

cfg = Config()
model = cfg.make_model().to(device)
utils.load_model(pretrained_path, model, device)
model.eval();

In [6]:
dataset, _ = load_dataset(cfg)

Number of train SVGs: 23391
First SVG in train:5272 - baby - toys
Number of test SVGs: 2599
First SVG in train:41146 - network - clouds


In [7]:
def easein_easeout(t):
    return t*t / (2. * (t*t - t) + 1.);

def interpolate(z1, z2, n=25, filename=None, ease=True, do_display=True):
    alphas = torch.linspace(0., 1., n)
    if ease:
        alphas = easein_easeout(alphas)
    z_list = [(1-a) * z1 + a * z2 for a in alphas]
    
    img_list = [decode(z, do_display=False, return_png=True) for z in z_list]
    to_gif(img_list + img_list[::-1], file_path=filename, frame_duration=1/12)

In [8]:
def encode_icon(idx):
    data = dataset.get(id=idx, random_aug=False)
    model_args = batchify((data[key] for key in cfg.model_args), device)
    with torch.no_grad():
        z = model(*model_args, encode_mode=True)
    return z

def decode(z, do_display=True, return_svg=False, return_png=False):
    commands_y, args_y = model.greedy_sample(z=z)
    tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
    svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")
    
    if return_svg:
        return svg_path_sample
    
    return svg_path_sample.draw(do_display=do_display, return_png=return_png)

In [9]:
def interpolate_icons(idx1=None, idx2=None, n=25, *args, **kwargs):
    z1, z2 = encode_icon(idx1), encode_icon(idx2)
    interpolate(z1, z2, n=n, *args, **kwargs)

## Random interpolation
Display an interpolation between random icons of the dataset!

In [10]:
id1, id2 = dataset.random_id(), dataset.random_id()
interpolate_icons(id1, id2)

## Intepolations shown in the paper
Or display the interpolations shown in the paper!

In [11]:
interpolate_icons("76528", "31079")

In [12]:
interpolate_icons("20623", "42458")

In [13]:
interpolate_icons("20567", "30915")

In [14]:
interpolate_icons("20567", "42345")

In [15]:
interpolate_icons("36873", "16710")

In [16]:
interpolate_icons("1942", "19527")

In [17]:
interpolate_icons("30053", "25925")

In [18]:
interpolate_icons("76522", "82571")

In [19]:
interpolate_icons("42297", "17591")