Skip to content

Commit

Permalink
MAINT: create utils, fix validation (#78)
Browse files Browse the repository at this point in the history
* MAINT: add imageio to requirements

* MAINT: remove char array from train and train_disc

* MAINT: refactor rasterize

* BUG: fix bugs with utils

* STY: yapf

* MAINT: move model loading/saving into utils, not fridge

* MAINT: use all data while training gan

* BLD: write infer.py

* BUG: remove fridge

* BLD: generate after every epoch

* BUG: troubleshooting

* BUG: more troubleshooting

* BUG: final troubleshooting

* MAINT: infer every 5000 epochs

* DOC: docstring for save_images

* EXP: infer every 2000 backprops

* BUG: use save_images properly

* EXP: clip gradients at 3

* EXP: make generator smaller

* BUG: remove extra underscore

* MAINT: move output to top-level dir
  • Loading branch information
eigenfoo authored and ArianaFreitag committed Mar 2, 2019
1 parent 579090d commit c3447f6
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 193 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ SHELL = bash
init:
find .git/hooks -type l -exec rm {} \;
find .githooks -type f -exec ln -sf ../../{} .git/hooks/ \;
mkdir output/
mkdir output/checkpoints/
mkdir output/fonts/

venv:
( \
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
absl-py==0.7.0
imageio==2.5.0
matplotlib==3.0.2
numpy==1.16.0
protobuf==3.6.1
Expand Down
54 changes: 2 additions & 52 deletions src/eggtart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.optim as optim
from coordconv import CoordConv
from resnet import resnet_small
from utils import rasterize


class Eggtart(nn.Module):
Expand All @@ -15,64 +16,13 @@ def __init__(self, device, resolution=64, sigma=0.01):
self.sigma = sigma
self.device = device

# Padding constants chosen by looking at empirical distribution of
# coordinates of Bezier control point from real fonts
left_pad = 0.25 * resolution
right_pad = 1.25 * resolution
up_pad = 0.8 * resolution
down_pad = 0.4 * resolution
mesh_lr = np.linspace(
-left_pad, resolution + right_pad, num=resolution, endpoint=False)
mesh_ud = np.linspace(
-down_pad, resolution + up_pad, num=resolution, endpoint=False)
XX, YY = np.meshgrid(mesh_lr, mesh_ud)
YY = np.flip(YY)
XX_expanded = XX[:, :, np.newaxis]
YY_expanded = YY[:, :, np.newaxis]
self.x_meshgrid = torch.Tensor(XX_expanded / resolution).to(device)
self.y_meshgrid = torch.Tensor(YY_expanded / resolution).to(device)

def rasterize(self, x):
'''
Simple rasterization: drop a single Gaussian at every control point.
Parameters
----------
x : [batch_size, num_control_points, 2]
Control points of glyphs.
Notes
-----
The num_contours and num_beziers dimensions have been collapsed into one
num_control_points dimension.
Also, we can pad with sufficiently large coordinates (e.g. 999) to
indicate that there are no more control points: this places a Gaussian
off-raster, which minimally affects the raster.
'''
batch_size = x.size()[0]
num_samples = x.size()[1]
x_samples = x[:, :, 0].unsqueeze(1).unsqueeze(1)
y_samples = x[:, :, 1].unsqueeze(1).unsqueeze(1)

x_meshgrid_expanded = self.x_meshgrid.expand(
batch_size, self.resolution, self.resolution, num_samples)
y_meshgrid_expanded = self.y_meshgrid.expand(
batch_size, self.resolution, self.resolution, num_samples)

raster = torch.exp(
(-(x_samples - x_meshgrid_expanded)**2 -
(y_samples - y_meshgrid_expanded)**2) / (2 * self.sigma**2))
raster = raster.sum(dim=3)
return raster

def forward(self, x):
# x.shape = [batch_size, 20, 30, 3, 2]
# `forward` must return shape [batch_size, 70]
x = x.squeeze()
batch_size = x.shape[0]
x = x.view(batch_size, -1, 2)
x = self.rasterize(x)
x = rasterize(x, device=self.device)
x = x.unsqueeze(1)
x = self.resnet(x)
return x
Expand Down
42 changes: 0 additions & 42 deletions src/fridge.py

This file was deleted.

67 changes: 67 additions & 0 deletions src/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from absl import flags, app
import torch
from utils import CHARACTERS, rasterize, save_images, load_model

FLAGS = flags.FLAGS

flags.DEFINE_string("checkpoint", None, "Path to checkpoint file.")
flags.DEFINE_integer("numfonts", 16,
"Number of fonts to generate at inference time.")
flags.DEFINE_integer('styledim', 100,
"Dimensionality of the latent style space.")
flags.DEFINE_integer('resolution', 64,
"Resolution of rasters at inference time.")
flags.DEFINE_string('device', 'cuda:0', "Device to train and infer on.")


def infer(gen, num_fonts, path, style_dim=100, resolution=64, device='cuda'):
"""
Runs generator at inference time to generate a batch of fonts.
Parameters
----------
gen : torch.nn.Module
PyTorch generator object.
num_fonts : int
Number of fonts to generate.
style_dim : int
Dimensionality of style space (i.e. output of style network). Defaults
to 100.
resolution : int
Resolution of rasters. Defaults to 64.
device : one of 'cuda' or 'cpu'
Device to infer on. Defaults to 'cuda'.
Returns
-------
"""
# Create random dense style vector
# style_vector = torch.rand([num_fonts * len(CHARACTERS), style_dim], device=FLAGS.device)
style_vector = torch.rand([16, style_dim], device=FLAGS.device)

# Create random one-hot character vector
# char_vector = torch.eye(len(CHARACTERS), device=FLAGS.device).repeat([num_fonts, 1])
char_vector = torch.zeros([16, len(CHARACTERS)], device=FLAGS.device)
char_vector[range(16), range(16)] = 1

# Generate a batch of fake characters
fake_chars = gen(char_vector, style_vector)
fake_chars = torch.reshape(fake_chars, [num_fonts, -1, 2])
rasters = rasterize(fake_chars)

# save_images(rasters, [num_fonts, len(CHARACTERS)], path + '_raster.png')
save_images(rasters, [4, 4], path + '_raster.png')
torch.save(fake_chars, path + '_bezier.pt')


def main(argv):
if FLAGS.checkpoint is None:
raise ValueError('No checkpoint file supplied.')

gen = load_model(FLAGS.checkpoint)
infer(gen, FLAGS.numfonts, FLAGS.path, FLAGS.styledim, FLAGS.resolution,
FLAGS.device)


if __name__ == '__main__':
app.run(main)
2 changes: 1 addition & 1 deletion src/matzah.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def forward(self, styleVec, classVec):

def matzah_optimizer(net):
'''Returns optimizer and number of epochs, in that order.'''
return optim.Adam(net.parameters(), lr=0.00001)
return optim.Adam(net.parameters(), lr=0.002)

0 comments on commit c3447f6

Please sign in to comment.