Skip to content

Commit

Permalink
Use PIL for resizing float32 images
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Feb 6, 2017
1 parent a42278a commit 3721a9a
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import argparse
from collections import namedtuple, OrderedDict
from concurrent.futures import ThreadPoolExecutor
from fractions import Fraction
from functools import partial
import io
Expand Down Expand Up @@ -88,13 +89,21 @@ def normalize(arr):
return arr


def resize(arr, size, order=1):
"""Resamples a CxHxW NumPy float array to a different HxW shape."""
h, w = size
arr = np.float32(arr)
resized_arr = zoom(arr, (1, h/arr.shape[1], w/arr.shape[2]), order=order, mode='wrap')
assert resized_arr.shape[1:] == size
return resized_arr
def resize(a, hw, method=Image.LANCZOS):
"""Resamples an image array in CxHxW format to a new HxW size. The interpolation is performed
in floating point and the result dtype is numpy.float32."""
def _resize(a, b):
b[:] = Image.fromarray(a).resize((hw[1], hw[0]), method)

a = np.float32(a)
ch = a.shape[0]
b = np.zeros((ch, hw[0], hw[1]), np.float32)

with ThreadPoolExecutor(max_workers=os.cpu_count()) as ex:
futs = [ex.submit(_resize, a[i], b[i]) for i in range(ch)]
_ = [fut.result() for fut in futs]

return b


def roll_by_1(arr, shift, axis):
Expand Down Expand Up @@ -268,7 +277,7 @@ def set_params(self, last_iterate):
self.params = last_iterate
hw = self.params.shape[-2:]
self.g1 = resize(self.g1, hw)
self.g2 = resize(self.g2, hw, order=0)
self.g2 = np.maximum(0, resize(self.g2, hw, method=Image.BILINEAR))
self.p1 = resize(self.p1, hw)

def restore_state(self, optimizer):
Expand Down Expand Up @@ -536,7 +545,8 @@ def set_params(self, last_iterate):
self.grad = None
xy = self.params.shape[-2:]
self.g1 = resize(self.g1, xy)
self.g2 = np.maximum(resize(self.g2, xy), EPS) * (self.g2.size / last_iterate.size)
self.g2 = np.maximum(resize(self.g2, xy, method=Image.BILINEAR), EPS)
self.g2 *= self.g2.size / last_iterate.size
self.p1 = np.zeros_like(last_iterate)
self.sk = []
self.yk = []
Expand Down

0 comments on commit 3721a9a

Please sign in to comment.