# SCALAE

Interactive implementation of https://github.com/LendelTheGreat/SCALAE

Run all and scroll to the bottom!

In [None]:
#@title Code and model downloading
!pip install yacs
!git clone https://github.com/LendelTheGreat/SCALAE.git
!gdown --id 1pqjjx8zRSPsTzPXVQiFf3PgmGnJpLSQn

In [None]:
#@title Model setup
import os
import sys
import numpy as np
import torch

from IPython.display import HTML, Image, clear_output
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import PIL
import PIL.ImageOps
import io

SCALAE_PATH = '/content/SCALAE'
if SCALAE_PATH not in sys.path:
    sys.path.append(SCALAE_PATH)
from checkpointer import Checkpointer
from defaults import get_cfg_defaults
from model import Model


print('Loading config...')
config_file=f'{SCALAE_PATH}/configs/popgan.yaml'
cfg = get_cfg_defaults()
cfg.merge_from_file(config_file)
cfg.freeze()

print('Initializing model...')
torch.cuda.set_device(0)
layer_count = cfg.MODEL.LAYER_COUNT
model = Model(
    startf=cfg.MODEL.START_CHANNEL_COUNT,
    layer_count=cfg.MODEL.LAYER_COUNT,
    maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
    latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
    truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
    truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
    mapping_layers=cfg.MODEL.MAPPING_LAYERS,
    channels=cfg.MODEL.CHANNELS,
    generator=cfg.MODEL.GENERATOR,
    encoder=cfg.MODEL.ENCODER)
model.cuda()
model.eval()
model.requires_grad_(False)

print('Loading checkpoint...')
model_dict = {
    'discriminator_s': model.encoder,
    'generator_s': model.decoder,
    'mapping_tl_s': model.mapping_tl,
    'mapping_fl_s': model.mapping_fl,
    'dlatent_avg': model.dlatent_avg
}
checkpointer = Checkpointer(cfg, model_dict)
extra_checkpoint_data = checkpointer.load(file_name='/content/scalae_pop2sat_1024_e200.pth')

print('READY!')

In [None]:
#@title Helper functions
def sample_z(seed):
  with torch.no_grad():
    rng = np.random.RandomState(seed)
    z = rng.standard_normal(cfg.MODEL.LATENT_SPACE_SIZE)
    z = torch.from_numpy(z).float().cuda().unsqueeze(0)
  return z

def z2w(z):
  with torch.no_grad():
    w = model.mapping_fl(z)
  return w

def sample_noise(seed):
  rng = np.random.RandomState(seed)
  noise = [0] + [torch.from_numpy(rng.standard_normal((1, 1, 2 ** i, 2 ** i)).astype(np.float32)) for i in range(1, 11)]
  return noise

def decode(w, pop, noise):
  return torch.clamp(model.decoder(w, pop, layer_count - 1, 1, noise=noise), -1., 1.)

def numpy_to_tensor(img):
  if len(img.shape) < 3:
    img = np.expand_dims(img, 2)
  img = img.transpose((2, 0, 1))
  x = torch.tensor(np.asarray(img, dtype=np.float32), requires_grad=False).cuda() / 0.5 - 1.
  return x

def tensor_to_numpy(img):
  img = img * 0.5 + 0.5
  img = img.cpu().squeeze().numpy()
  if len(img.shape) == 3:
    img = np.moveaxis(img, 0, 2)
  return img

def generate(w, np_pop, noise):
  pop = numpy_to_tensor(np_pop).unsqueeze(0)
  img = torch.clamp(model.decoder(w, pop, layer_count - 1, 1, noise=noise), -1., 1.)
  img = tensor_to_numpy(img)
  return img

def numpy_to_PIL(img, img_size=256):
  img = np.clip(img, 0, 1)
  img = (img *255).astype(np.uint8)
  img = PIL.Image.fromarray(img)
  img = img.resize((img_size, img_size))
  return img

canvas_html = """
<div class="slidecontainer">
  <label>Brush strengt (corresponds to population density)</label>
  <input style="width:%dpx" type="range" min="0" max="10" value="10" class="slider" id="pop_slider">
</div>
<div class="slidecontainer">
  <label>Brush width</label>
  <input style="width:%dpx" type="range" min="5" max="100" value="20" class="slider" id="line_slider">
</div>
<div>
  <canvas width=%d height=%d></canvas>
  <img src='data:image/jpeg;base64,%s'/>
</div>
<button style="height:20px;width:%dpx" id="1">Generate Fake Satellite Image</button>
<label></label>
<script>
var label =  document.querySelector('label')
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')
var line_slider = document.getElementById("line_slider");
ctx.lineWidth = line_slider.value
line_slider.oninput = function() {
  ctx.lineWidth = this.value;
}
var pop_slider = document.getElementById("pop_slider");
ctx.strokeStyle = "#" + (pop_slider.value * 25).toString(16) + (pop_slider.value * 25).toString(16) + (pop_slider.value * 25).toString(16);
pop_slider.oninput = function() {
  ctx.strokeStyle = "#" + (this.value * 25).toString(16) + (this.value * 25).toString(16) + (this.value * 25).toString(16);
}
pop_image = new Image();
pop_image.src = 'data:image/png;base64,%s';
pop_image.onload = function(){
  ctx.drawImage(pop_image, 0, 0);
}
var button_generate = document.getElementById('1')
var mouse = {x: 0, y: 0}
canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}
var data = new Promise(resolve=>{
  button_generate.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

def setup_html(img_size, pop_file, gen_file):
  pop_str = b64encode(imgfile_to_string(pop_file)).decode("utf-8")
  gen_str = b64encode(imgfile_to_string(gen_file)).decode("utf-8")
  return HTML(canvas_html % (img_size, img_size, img_size, img_size, gen_str, img_size, pop_str))

def imgfile_to_string(file):
  an_image = PIL.Image.open(file)
  output = io.BytesIO()
  an_image.save(output, format="png")
  image_as_string = output.getvalue()
  return image_as_string

def run(img_size=256, seed=0):
  pop = np.zeros((1024, 1024))
  numpy_to_PIL(pop, img_size).save("pop.png")

  z = sample_z(seed)
  w = z2w(z)
  noise = sample_noise(seed)
  
  img = generate(w, pop, noise)
  img = numpy_to_PIL(img, img_size)
  img.save("img.png")

  while True:
    clear_output()
    display(setup_html(img_size, 'pop.png', "img.png"))

    data = eval_js("data").split(',')[1]

    binary = b64decode(data)
    with open("pop.png", 'wb') as f:
      f.write(binary)

    pop = PIL.Image.open("pop.png")
    pop = pop.resize((1024, 1024))
    pop = np.array(pop) / 255
    pop = pop[:, :, 0]

    img = generate(w, pop, noise)
    img = numpy_to_PIL(img, img_size)
    img.save("img.png")

# Instructions

On the left you can paint in the black canvas. This is the population density input to the model. On the right you see the generated model output.

Click the `Generate Fake Satellite Image` button at the bottom to generate an image.

You can adjust the painting brush with the sliders on top.

In [None]:
run(img_size=512, seed=42)