In [None]:
!git clone https://github.com/LendelTheGreat/SCALAE.git

In [None]:
!gdown --id 1pqjjx8zRSPsTzPXVQiFf3PgmGnJpLSQn

In [None]:
import os
import numpy as np
import torch

from IPython.display import HTML, Image, clear_output
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import PIL
import PIL.ImageOps
import io

SCALAE_PATH = os.path.join(os.path.expanduser('~'), 'stylegan2-pytorch')
if SCALAE_PATH not in sys.path:
    sys.path.append(SCALAE_PATH)
from checkpointer import Checkpointer
from defaults import get_cfg_defaults
from model import Model

In [None]:
print('Loading config...')
config_file='configs/popgan.yaml'
cfg = get_cfg_defaults()
cfg.merge_from_file(config_file)
cfg.freeze()

print('Initializing model...')
torch.cuda.set_device(0)
layer_count = cfg.MODEL.LAYER_COUNT
model = Model(
    startf=cfg.MODEL.START_CHANNEL_COUNT,
    layer_count=cfg.MODEL.LAYER_COUNT,
    maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
    latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
    truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
    truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
    mapping_layers=cfg.MODEL.MAPPING_LAYERS,
    channels=cfg.MODEL.CHANNELS,
    generator=cfg.MODEL.GENERATOR,
    encoder=cfg.MODEL.ENCODER)
model.cuda()
model.eval()
model.requires_grad_(False)

print('Loading checkpoint...')
model_dict = {
    'discriminator_s': model.encoder,
    'generator_s': model.decoder,
    'mapping_tl_s': model.mapping_tl,
    'mapping_fl_s': model.mapping_fl,
    'dlatent_avg': model.dlatent_avg
}
checkpointer = Checkpointer(cfg, model_dict)
extra_checkpoint_data = checkpointer.load(file_name='/content/scalae_pop2sat_1024_e200.pth')

print('READY!')

In [None]:
def sample_z(seed):
    with torch.no_grad():
        rng = np.random.RandomState(seed)
        z = rng.standard_normal(cfg.MODEL.LATENT_SPACE_SIZE)
        z = torch.from_numpy(z).float().cuda().unsqueeze(0)
    return z

def z2w(z):
    with torch.no_grad():
        w = model.mapping_fl(z)
    return w

def sample_noise(seed):
    rng = np.random.RandomState(seed)
    noise = [0] + [torch.from_numpy(rng.standard_normal((1, 1, 2 ** i, 2 ** i))) for i in range(1, 11)]
    return noise

def decode(w, pop, noise):
    return torch.clamp(model.decoder(w, pop, layer_count - 1, 1, noise=noise), -1., 1.)

def numpy_to_tensor(img):
    if len(img.shape) < 3:
        img = np.expand_dims(img, 2)
    img = img.transpose((2, 0, 1))
    x = torch.tensor(np.asarray(img, dtype=np.float32), requires_grad=False).cuda() / 0.5 - 1.
    return x

def tensor_to_numpy(img):
    img = img * 0.5 + 0.5
    img = img.cpu().squeeze().numpy()
    if len(img.shape) == 3:
        img = np.moveaxis(img, 0, 2)
    return img

def generate(w, np_pop, noise):
    pop = numpy_to_tensor(np_pop).unsqueeze(0)
    img = torch.clamp(model.decoder(w, pop, layer_count - 1, 1, noise=noise), -1., 1.)
    img = tensor_to_numpy(img)
    return img

def numpy_to_PIL(img, img_size=256):
    img = np.clip(img, 0, 1)
    img = (img *255).astype(np.uint8)
    img = PIL.Image.fromarray(img)
    img = img.resize((img_size, img_size))
    return img

In [None]:
canvas_html = """
<canvas width=%d height=%d></canvas>
<button>Generate Fake Satellite Image</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')
ctx.lineWidth = %d
ctx.strokeStyle = '#ffffff';
base_image = new Image();
base_image.src = 'data:image/png;base64,%s';
base_image.onload = function(){
  ctx.drawImage(base_image, 0, 0);
}
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}
canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}
var data = new Promise(resolve=>{
  button.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

def imgfile_to_string(file):
    an_image = PIL.Image.open(file)
    output = io.BytesIO()
    an_image.save(output, format="png")
    image_as_string = output.getvalue()
    return image_as_string

def run(img_size=256, seed=0):
    line_width = 1

    pop = np.zeros((img_size, img_size))
    numpy_to_PIL(pop, img_size).save("pop.png")

    z = sample_z(seed)
    w = z2w(z)
    noise = sample_noise(seed)

    img = generate(w, pop, noise)
    img = numpy_to_PIL(img, img_size)
    img.save("img.png")

    while True:
        clear_output()
        pop_str = b64encode(imgfile_to_string('pop.png')).decode("utf-8")
        display(HTML(canvas_html % (img_size, img_size, line_width, pop_str)))
        display(Image('img.png'))

        data = eval_js("data").split(',')[1]

        binary = b64decode(data)
        with open("pop.png", 'wb') as f:
            f.write(binary)

        pop = PIL.Image.open("pop.png")
        pop = np.array(pop) / 255
        pop = pop[:, :, 0]

        img = generate(w, pop, noise_seed)
        img = numpy_to_PIL(img, img_size)
        img.save("img.png")


In [None]:
run()