In [None]:
!nvidia-smi -L

In [None]:
!git clone https://github.com/benearnthof/StyleGan.git /content/sg

In [None]:
from google.colab import drive
drive.mount("/content/drive")
import os

In [None]:
os.mkdir("/content/sample/")
os.chdir("/content/sg")

In [None]:
# required for custom cuda extensions
!pip install Ninja

In [None]:
os.chdir("/content/sg")

In [None]:
# it appears the 20000 checkpoint for the retina model got corrupted because the server ran out of storage space during training. 
# unfortunately I only saved checkpoints every 10000 iterations

In [None]:
import torch
from torchvision import utils
from model2 import Generator
from tqdm import tqdm

In [None]:
def gen(npics, G, device, seeds = [1], nsample = 1, styledim = 512, truncation = 1.0, trunc_mean = 4096):
  with torch.no_grad():
    G.eval()
    for i in tqdm(range(npics)):
      torch.manual_seed(seeds[i])
      sample_z = torch.randn(nsample, styledim, device = device)

      sample, _ = G(
          [sample_z], truncation = truncation, truncation_latent = trunc_mean
      )

      utils.save_image(
          sample, 
          f"sample/{str(i).zfill(6)}.png",
          nrow = 1,
          normalize = True, 
          range = (-1, 1),
      )

In [None]:
device = "cuda"
G = Generator(
    size = 128, style_dim = 512, n_mlp = 8
).to(device)

In [None]:
checkpoint = torch.load("/content/drive/MyDrive/style-based-gan-pytorch/checkpoints_corgi_reg_aug/040000.pt")

In [None]:
G.load_state_dict(checkpoint["g_ema"], strict = False)

In [None]:
n = 25000
gen(npics = n, G = G, device = "cuda", seeds = range(n))

In [None]:
from google.colab import files
!zip -r /content/drive/MyDrive/corgiSample25k_reg_aug.zip /content/sg/sample
# files.download("/content/drive/MyDrive/retinaSample25k.zip")

In [None]:
def linterp(z, steps):
  out = []
  for i in range(len(z)-1):
    for index in range(steps):
      t = index/float(steps)
      out.append(z[i+1] * t + z[i] * (1-t))
  return out

In [None]:
def gen_linterp_z(G, device, nsteps = 5, seeds = [0, 2], styledim = 512, truncation = 1.0, trunc_mean = 4096):
  with torch.no_grad():
    G.eval()
    torch.manual_seed(seeds[0])
    start = torch.randn(1, styledim, device = device)
    torch.manual_seed(seeds[1])
    end = torch.randn(1, styledim, device = device)

    zs = linterp([start, end], steps = nsteps)

    for i in tqdm(range(nsteps)):

      sample, _ = G(
          [zs[i]], truncation = truncation, truncation_latent = trunc_mean
      )

      utils.save_image(
          sample, 
          f"sample/{str(i).zfill(4)}.png",
          nrow = nsteps,
          normalize = True, 
          range = (-1, 1),
      )

In [None]:
gen_linterp_z(G = G, device = "cuda", nsteps = 25)

In [None]:
import glob
from PIL import Image

# filepaths
fp_in = "/content/sg/sample/*.png"
fp_out = "/content/sg/sample/linterp.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='GIF', append_images=imgs,
         save_all=True, duration=200, loop=0)

In [None]:
z = start = torch.randn(1, 512, device = device)
w = G.style(z)
z[0,0], w[0,0]

smp, _ = G(w, input_is_latent = True)
utils.save_image(
          smp, 
          f"sample/test.png",
          nrow = 1,
          normalize = True, 
          range = (-1, 1),
      )


In [None]:
# G(ws, input_is_latent = True)
def gen_linterp_w(G, device, nsteps = 5, seeds = [0, 2], styledim = 512, truncation = 1.0, trunc_mean = 4096):
  with torch.no_grad():
    G.eval()
    torch.manual_seed(seeds[0])
    start = torch.randn(1, styledim, device = device)
    torch.manual_seed(seeds[1])
    end = torch.randn(1, styledim, device = device)

    # pass through style network 
    start_w = G.style(start)
    end_w = G.style(end)

    ws = linterp([start_w, end_w], steps = nsteps)

    for i in tqdm(range(nsteps)):

      sample, _ = G(
          [ws[i]], 
          truncation = truncation, 
          truncation_latent = trunc_mean,
          input_is_latent = True
      )

      utils.save_image(
          sample, 
          f"sample_w/{str(i).zfill(4)}.png",
          nrow = nsteps,
          normalize = True, 
          range = (-1, 1),
      )

In [None]:
os.getcwd()

In [None]:
os.mkdir("/content/sg/sample_w/")

In [None]:
gen_linterp_w(G = G, device = "cuda", nsteps = 25)

In [None]:
# generate gif
fp_in = "/content/sg/sample_w/*.png"
fp_out = "/content/sg/sample_w/linterp_w.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='GIF', append_images=imgs,
         save_all=True, duration=200, loop=0)

In [None]:
torch.manual_seed(0)
start = torch.randn(1, 2, device = "cpu")
torch.manual_seed(2)
end = torch.randn(1, 2, device = "cpu")

In [None]:
def spherical_interp(steps, start, end):
  out = []
  for i in range(steps):
    t = i / (steps - 1)
    if t <= 0: 
      out.append(start) 
    elif t >= 1: 
      out.append(end)
    elif torch.allclose(start, end):
      out.append(start)
    omega = torch.arccos(torch.tensordot(start/torch.linalg.norm(start), end/torch.linalg.norm(end)))
    sin_omega = torch.sin(omega)
    out.append(np.sin((1.0 - t) * omega) / sin_omega * start + torch.sin(t * omega) / sin_omega * end)
  return out

In [None]:
spherical_interp(10, start, end)[0]

In [None]:
def gen_slerp_z(G, device, nsteps = 5, seeds = [0, 2], styledim = 512, truncation = 1.0, trunc_mean = 4096):
  with torch.no_grad():
    G.eval()
    torch.manual_seed(seeds[0])
    start = torch.randn(1, styledim, device = device)
    torch.manual_seed(seeds[1])
    end = torch.randn(1, styledim, device = device)

    zs = spherical_interp(steps = nsteps, start = start.cpu(), end = end.cpu())
    zs = torch.stack(zs)
    zs = zs.to(torch.device('cuda'))

    for i in tqdm(range(nsteps)):

      sample, _ = G(
          [zs[i]], truncation = truncation, truncation_latent = trunc_mean
      )

      utils.save_image(
          sample, 
          f"sample_spherical/{str(i).zfill(4)}.png",
          nrow = nsteps,
          normalize = True, 
          range = (-1, 1),
      )

In [None]:
# G(ws, input_is_latent = True)
def gen_slerp_w(G, device, nsteps = 5, seeds = [0, 2], styledim = 512, truncation = 1.0, trunc_mean = 4096):
  with torch.no_grad():
    G.eval()
    torch.manual_seed(seeds[0])
    start = torch.randn(1, styledim, device = device)
    torch.manual_seed(seeds[1])
    end = torch.randn(1, styledim, device = device)

    # pass through style network 
    start_w = G.style(start)
    end_w = G.style(end)

    ws = spherical_interp(steps = nsteps, start = start_w.cpu(), end = end_w.cpu())
    ws = torch.stack(ws)
    ws = ws.to(torch.device('cuda'))
    for i in tqdm(range(nsteps)):

      sample, _ = G(
          [ws[i]], 
          truncation = truncation, 
          truncation_latent = trunc_mean,
          input_is_latent = True
      )

      utils.save_image(
          sample, 
          f"sample_spherical_w/{str(i).zfill(4)}.png",
          nrow = nsteps,
          normalize = True, 
          range = (-1, 1),
      )

In [None]:
if not os.path.exists("/content/sg/sample_spherical/"):
  os.mkdir("/content/sg/sample_spherical/")
if not os.path.exists("/content/sg/sample_spherical_w/"):
  os.mkdir("/content/sg/sample_spherical_w/")

In [None]:
gen_slerp_z(G = G, device = "cuda", nsteps = 25)

In [None]:
gen_slerp_w(G = G, device = "cuda", nsteps = 25)

In [None]:
# generate gif
fp_in = "/content/sg/sample_spherical/*.png"
fp_out = "/content/sg/sample_spherical/slerp.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='GIF', append_images=imgs,
         save_all=True, duration=200, loop=0)

In [None]:
# generate gif (should probably wrap this in function at this point)
fp_in = "/content/sg/sample_spherical_w/*.png"
fp_out = "/content/sg/sample_spherical_w/slerp_w.gif"

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='GIF', append_images=imgs,
         save_all=True, duration=200, loop=0)

In [None]:
os.mkdir("/content/sg/zippedfiles")

In [None]:
import sys
from PIL import Image

"/content/sg/sample_spherical_w/*.png"

def imgcombine(path):
  fp_in = os.path.join(path, "*.png")
  fp_out = os.path.join(path, "combined.png")
  img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
  widths, heights = zip(*(i.size for i in imgs))

  total_width = sum(widths)
  max_height = max(heights)

  new_im = Image.new('RGB', (total_width, max_height))

  x_offset = 0
  for im in imgs:
    new_im.paste(im, (x_offset,0))
    x_offset += im.size[0]

  new_im.save(fp_out)

In [None]:
imgcombine("/content/sg/sample_spherical_w/")
imgcombine("/content/sg/sample_spherical/")
imgcombine("/content/sg/sample/")
imgcombine("/content/sg/sample_w/")

In [None]:
# imgcombine("/content/sg/sample_spherical_w/")

In [None]:
# zip and download folders 
!zip -r /content/sg/zippedfiles/sample.zip /content/sg/sample/
!zip -r /content/sg/zippedfiles/sample_w.zip /content/sg/sample_w/
!zip -r /content/sg/zippedfiles/sample_spherical.zip /content/sg/sample_spherical//
!zip -r /content/sg/zippedfiles/sample_spherical_w.zip /content/sg/sample_spherical_w/

In [None]:
from google.colab import files
files.download("/content/sg/zippedfiles/sample.zip")
files.download("/content/sg/zippedfiles/sample_w.zip")
files.download("/content/sg/zippedfiles/sample_spherical.zip")
files.download("/content/sg/zippedfiles/sample_spherical_w.zip")

In [None]:
# combine combinations
def imgcombine(path):
  fp_in = os.path.join(path, "*.png")
  fp_out = os.path.join(path, "combined.png")
  img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
  widths, heights = zip(*(i.size for i in imgs))

  total_width = sum(widths)
  max_height = max(heights)

  new_im = Image.new('RGB', (total_width, max_height))

  x_offset = 0
  for im in imgs:
    new_im.paste(im, (x_offset,0))
    x_offset += im.size[0]

  new_im.save(fp_out)