-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAINT: create utils, fix validation (#78)
* MAINT: add imageio to requirements * MAINT: remove char array from train and train_disc * MAINT: refactor rasterize * BUG: fix bugs with utils * STY: yapf * MAINT: move model loading/saving into utils, not fridge * MAINT: use all data while training gan * BLD: write infer.py * BUG: remove fridge * BLD: generate after every epoch * BUG: troubleshooting * BUG: more troubleshooting * BUG: final troubleshooting * MAINT: infer every 5000 epochs * DOC: docstring for save_images * EXP: infer every 2000 backprops * BUG: use save_images properly * EXP: clip gradients at 3 * EXP: make generator smaller * BUG: remove extra underscore * MAINT: move output to top-level dir
- Loading branch information
1 parent
579090d
commit c3447f6
Showing
9 changed files
with
288 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
absl-py==0.7.0 | ||
imageio==2.5.0 | ||
matplotlib==3.0.2 | ||
numpy==1.16.0 | ||
protobuf==3.6.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from absl import flags, app | ||
import torch | ||
from utils import CHARACTERS, rasterize, save_images, load_model | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("checkpoint", None, "Path to checkpoint file.") | ||
flags.DEFINE_integer("numfonts", 16, | ||
"Number of fonts to generate at inference time.") | ||
flags.DEFINE_integer('styledim', 100, | ||
"Dimensionality of the latent style space.") | ||
flags.DEFINE_integer('resolution', 64, | ||
"Resolution of rasters at inference time.") | ||
flags.DEFINE_string('device', 'cuda:0', "Device to train and infer on.") | ||
|
||
|
||
def infer(gen, num_fonts, path, style_dim=100, resolution=64, device='cuda'): | ||
""" | ||
Runs generator at inference time to generate a batch of fonts. | ||
Parameters | ||
---------- | ||
gen : torch.nn.Module | ||
PyTorch generator object. | ||
num_fonts : int | ||
Number of fonts to generate. | ||
style_dim : int | ||
Dimensionality of style space (i.e. output of style network). Defaults | ||
to 100. | ||
resolution : int | ||
Resolution of rasters. Defaults to 64. | ||
device : one of 'cuda' or 'cpu' | ||
Device to infer on. Defaults to 'cuda'. | ||
Returns | ||
------- | ||
""" | ||
# Create random dense style vector | ||
# style_vector = torch.rand([num_fonts * len(CHARACTERS), style_dim], device=FLAGS.device) | ||
style_vector = torch.rand([16, style_dim], device=FLAGS.device) | ||
|
||
# Create random one-hot character vector | ||
# char_vector = torch.eye(len(CHARACTERS), device=FLAGS.device).repeat([num_fonts, 1]) | ||
char_vector = torch.zeros([16, len(CHARACTERS)], device=FLAGS.device) | ||
char_vector[range(16), range(16)] = 1 | ||
|
||
# Generate a batch of fake characters | ||
fake_chars = gen(char_vector, style_vector) | ||
fake_chars = torch.reshape(fake_chars, [num_fonts, -1, 2]) | ||
rasters = rasterize(fake_chars) | ||
|
||
# save_images(rasters, [num_fonts, len(CHARACTERS)], path + '_raster.png') | ||
save_images(rasters, [4, 4], path + '_raster.png') | ||
torch.save(fake_chars, path + '_bezier.pt') | ||
|
||
|
||
def main(argv): | ||
if FLAGS.checkpoint is None: | ||
raise ValueError('No checkpoint file supplied.') | ||
|
||
gen = load_model(FLAGS.checkpoint) | ||
infer(gen, FLAGS.numfonts, FLAGS.path, FLAGS.styledim, FLAGS.resolution, | ||
FLAGS.device) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.