Skip to content

Commit

Permalink
Add generate_gif for PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
zsyzzsoft committed Mar 23, 2021
1 parent c8ae41b commit 0ef06b2
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
4 changes: 2 additions & 2 deletions DiffAugment-stylegan2-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ The following commands are an example of generating images with our pre-trained
```bash
python generate.py --outdir=out --seeds=1-16 --network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl

python style_mixing.py --outdir=out --rows=1-3 --cols=5-12 --network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl
python generate_gif.py --output=obama.gif --seed=0 --num-rows=1 --num-cols=8 --network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl
```

<img src="../imgs/style-mixing-grid.jpg" width="1000px"/>
<img src="../imgs/obama.gif" width="1000px"/>

## Other Usages

Expand Down
91 changes: 91 additions & 0 deletions DiffAugment-stylegan2-pytorch/generate_gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Generate GIF using pretrained network pickle."""

import os

import click
import dnnlib
import numpy as np
from PIL import Image
import torch

import legacy

#----------------------------------------------------------------------------

@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seed', help='Random seed', default=0, type=int)
@click.option('--num-rows', help='Number of rows', default=1, type=int)
@click.option('--num-cols', help='Number of columns', default=8, type=int)
@click.option('--resolution', help='Resolution of the output images', default=128, type=int)
@click.option('--num-phases', help='Number of phases', default=5, type=int)
@click.option('--transition-frames', help='Number of transition frames per phase', default=20, type=int)
@click.option('--static-frames', help='Number of static frames per phase', default=5, type=int)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--output', type=str, required=True)
def generate_gif(
network_pkl: str,
seed: int,
num_rows: int,
num_cols: int,
resolution: int,
num_phases: int,
transition_frames: int,
static_frames: int,
truncation_psi: float,
noise_mode: str,
output: str
):
"""Generate gif using pretrained network pickle.
Examples:
\b
python generate_gif.py --output=obama.gif --seed=0 --num-rows=1 --num-cols=8 \\
--network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

os.makedirs(os.path.dirname(output), exist_ok=True)

np.random.seed(seed)

output_seq = []
batch_size = num_rows * num_cols
latent_size = G.z_dim
latents = [np.random.randn(batch_size, latent_size) for _ in range(num_phases)]

def to_image_grid(outputs):
outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
outputs = np.concatenate(outputs, axis=1)
outputs = np.concatenate(outputs, axis=1)
return Image.fromarray(outputs).resize((resolution * num_cols, resolution * num_rows), Image.ANTIALIAS)

def generate(dlatents):
images = G.synthesis(dlatents, noise_mode=noise_mode)
images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
return to_image_grid(images)

for i in range(num_phases):
dlatents0 = G.mapping(torch.from_numpy(latents[i - 1]).to(device), None)
dlatents1 = G.mapping(torch.from_numpy(latents[i]).to(device), None)
for j in range(transition_frames):
dlatents = (dlatents0 * (transition_frames - j) + dlatents1 * j) / transition_frames
output_seq.append(generate(dlatents))
output_seq.extend([generate(dlatents1)] * static_frames)

if not output.endswith('.gif'):
output += '.gif'
output_seq[0].save(output, save_all=True, append_images=output_seq[1:], optimize=False, duration=50, loop=0)


#----------------------------------------------------------------------------

if __name__ == "__main__":
generate_gif() # pylint: disable=no-value-for-parameter

#----------------------------------------------------------------------------
Binary file modified imgs/interp.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/obama.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed imgs/style-mixing-grid.jpg
Binary file not shown.

0 comments on commit 0ef06b2

Please sign in to comment.