In [5]:
# imports
import io
import base64
import pathlib
import datetime
import random
import numpy as np
import json
import torch
import torch.nn as nn
import PIL

from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
from model import NCA_model

In [6]:
# utility functions

# creates a circle mask given a size, radius and position
def create_circle_mask(size, radius, pos):
    pos = pos * size
    Y, X = np.ogrid[:size, :size]
    dist_from_center = np.sqrt((X - pos[0])**2 + (Y-pos[1])**2)
    mask = dist_from_center >= radius
    return mask

# Loads an image from a specified path and converts to torch.Tensor
def load_image(path, size):
    img = Image.open(path)
    img = img.resize((size, size), Image.LANCZOS)
    img = np.float32(img) / 255.0
    img[..., :3] *= img[..., 3:]
    return torch.from_numpy(img).permute(2, 0, 1)[None, ...]

# converts an RGBA image to a RGB image
def to_rgb(img_rgba):
    rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
    return torch.clamp(1.0 - a + rgb, 0, 1)

# Create a starting tensor for training
# Only the active pixels are goin to be in the middle
def make_seed(size, n_channels):
    x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
    x[:, 3:, size // 2, size // 2] = 1
    return x

In [7]:
# image and video manipulation functions

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

In [8]:
# training parameters
img = 'imgs\pup.png'
name = 'pup64'
save_model = True

size = 64
pad = 16
n_channels = 16
n_train_iter = 5000
batch_size = 8
pool_size = 1024
n_damage = 3
device = 'cuda'

eval_freq = 500
eval_iter = 300


log_dir = 'logs'
model_dir = 'models'

In [16]:
# misc
full_size = size + (2 * pad)
device = torch.device(device)

# create log
log_path = pathlib.Path(log_dir)
log_path.mkdir(parents=True, exist_ok=True)

# target image
loaded_img = load_image(img, size)
target_img_ = nn.functional.pad(loaded_img, (pad, pad, pad, pad), 'constant', 0)
target_img = target_img_.to(device)
target_img = target_img.repeat(batch_size, 1, 1, 1)

imshow(loaded_img)

TypeError: Cannot handle this data type: (1, 1, 64, 64), |u1

In [None]:


# model and optimizer
model = NCA_model(_n_channels=args.n_channels, _device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

# pool init
seed = make_seed(args.size, args.n_channels).to(device)
seed = nn.functional.pad(seed, (p, p, p, p), 'constant', 0)
pool = seed.clone().repeat(args.pool_size, 1, 1, 1)

# training loop
for it in tqdm(range(args.n_batches)):
    batch_ixs = np.random.choice(
        args.pool_size, args.batch_size, replace=False
    ).tolist()
    
    # get training batch
    x = pool[batch_ixs]
    
    # damage examples in batch
    if args.damage > 0:
        radius = random.uniform(args.size*0.1, args.size*0.4)
        u = random.uniform(0, 1) * args.size + p
        v = random.uniform(0, 1) * args.size + p
        mask = create_erase_mask(full_size, radius, [u, v])
        x[-args.damage:] *= torch.tensor(mask).to(device)
    
    # forward pass
    for i in range(np.random.randint(64, 96)):
        x = model(x)
    
    loss_batch = ((target_img - x[:, :4, ...]) ** 2).mean(dim=[1, 2, 3])
    loss = loss_batch.mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    writer.add_scalar('train/loss', loss, it)
    
    # find best in batch
    argmax_batch = loss_batch.argmax().item()
    argmax_pool = batch_ixs[argmax_batch]
    remaining_batch = [i for i in range(args.batch_size) if i != argmax_batch]
    remaining_pool = [i for i in batch_ixs if i != argmax_pool]
    
    pool[argmax_pool] = seed.clone()
    pool[remaining_pool] = x[remaining_batch].detach()
    
    if it % args.eval_frequency == 0:
        x_eval = seed.clone()
        eval_video = torch.empty(1, args.eval_iterations, 3, *x_eval.shape[2:])
        for it_eval in range(args.eval_iterations):
            x_eval = model(x_eval)
            x_eval_out = to_rgb(x_eval[:, :4].detach().cpu())
            eval_video[0, it_eval] = x_eval_out
            
        writer.add_video('eval', eval_video, it, fps=60)
        
# save model
if args.save_model:
    model_path = pathlib.Path(args.modeldir)
    model_path.mkdir(parents=True, exist_ok=True)
    if args.name == None:
        ts = str(datetime.datetime.now()).replace(' ', '_').replace(':', '-').replace('.', '-')
        args.name = 'model_' + ts
    torch.save(model, args.modeldir + '\\' + args.name + '.pt')
    
    # save model arguments
    dict = vars(args)
    json_object = json.dumps(dict, indent=4)
    with open(args.modeldir + '\\' + args.name + '_params.json', 'w') as outfile:
        outfile.write(json_object)