In [None]:
%matplotlib inline

from beras.gan import GAN, upsample
from beras.models import AbstractModel
from beras.util import downsample, blur
import cairosvg
from beesgrid.generate_grids import BlackWhiteArtist, MASK, MASK_BLACK, \
    MASK_WHITE, GridGenerator, MaskGridArtist
import beesgrid.generate_grids as gen_grids
import h5py
import itertools
import keras
import keras.initializations
from keras.models import Sequential, Graph
from keras.layers.core import Dense, Dropout, Activation, Flatten, Reshape, Layer
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSample2D

import os.path

import base64

import os
import xml.etree.ElementTree as et
import io
import scipy.misc
from beesgrid import NUM_CELLS
import sys
from keras.optimizers import SGD, Adam
from scipy.misc import imsave
import numpy as np
from theano.ifelse import ifelse
import theano
import theano.tensor as T
import theano.tensor.shared_randomstreams as T_random

import time
import math
import matplotlib.pyplot as plt
from beras.layers.attention import RotationTransformer

floatX = theano.config.floatX

In [None]:
bs = 64
x = T.arange(9*bs).reshape((bs, 1, 3, 3))

idx = T.zeros_like(x)
idx = T.set_subtensor(idx[:, :, :, 1], T.ones((bs, 1, 3)))
ns = [idx[i:i+1].nonzero() for i in range(bs)]
print(ns[0][1])
T.stack([x[n].var() for n in ns]).eval()

In [None]:
def mask_loss(mask_image, image):
    axis = [1, 2, 3]

    def get_subtensor_var(image, mean, idx):
        mean = T.patternbroadcast(mean.reshape((-1, 1, 1, 1)), [False, True, True, True])
        tmp_image = T.zeros_like(image)
        tmp_image = T.set_subtensor(tmp_image[idx.nonzero()], image[idx.nonzero()])
        return get_subtensor_mean((tmp_image - mean)**2, idx)

    def get_subtensor_sum(image, idx):
        tmp_image = T.zeros_like(image)
        tmp_image = T.set_subtensor(tmp_image[idx.nonzero()], image[idx.nonzero()])
        return T.sum(tmp_image, axis)

    def get_subtensor_mean(image, idx):
        return get_subtensor_sum(image, idx) / T.sum(idx, axis)

    white_mean = get_subtensor_mean(image, mask_image > MASK["IGNORE"])
    black_mean = get_subtensor_mean(image, mask_image < MASK["IGNORE"])
    min_distance = 0.25 * T.ones_like(black_mean)
    distance = T.minimum(white_mean - black_mean, min_distance)
    loss = (distance - min_distance)**2
    cell_loss = T.zeros_like(loss)

    def cell_loss_fn(mask_color, color_mean):
        cell_idx = T.eq(mask_image, MASK[mask_color])
        cell_mean = get_subtensor_mean(image, cell_idx)
        cell_weight = T.sum(cell_idx, axis)
        return T.switch(T.isnan(cell_mean),
                         T.zeros_like(black_mean),
                         cell_weight * (
                             (color_mean - cell_mean)**2 +
                             7*get_subtensor_var(image, color_mean, cell_idx)
                         ))
    for black_parts in MASK_BLACK:
        cell_loss += cell_loss_fn(black_parts, black_mean)
    for white_parts in MASK_WHITE:
        cell_loss += cell_loss_fn(white_parts, white_mean)

    cell_loss /= T.sum(T.neq(mask_image, MASK["IGNORE"]))
    loss += 2*cell_loss
    return 50*T.mean(loss), 50*loss

In [None]:
batch_size = 128
def mask_loss_new(mask_image, image):
    axis = [1, 2, 3]
    batch_half = batch_size // 2
    
    def channel_nonzeros(idx):
        return [idx[i:i+1].nonzero() for i in range(batch_half)]
    def get_subtensor_var(image, nz_idx):
        return T.stack([image[nz].var() for nz in nz_idx])
    def get_subtensor_sum(image, nz_idx):
        return T.stack([image[nz].sum() for nz in nz_idx])
    def get_subtensor_mean(image, nz_idx):
        return T.stack([image[nz].mean() for nz in nz_idx])

    white_mean = get_subtensor_mean(image, 
                                    channel_nonzeros(mask_image > MASK["IGNORE"]))
    black_mean = get_subtensor_mean(image, 
                                    channel_nonzeros(mask_image < MASK["IGNORE"]))
    min_distance = 0.25 * T.ones_like(black_mean)
    distance = T.minimum(white_mean - black_mean, min_distance)
    loss = (distance - min_distance)**2
    cell_loss = T.zeros_like(loss)

    def cell_loss_fn(mask_color, color_mean):
        cell_idx = T.eq(mask_image, MASK[mask_color])
        cell_nonzeros = channel_nonzeros(cell_idx)
        cell_mean = get_subtensor_mean(image, cell_nonzeros)
        cell_weight = T.sum(cell_idx, axis)
        return T.switch(T.isnan(cell_mean),
                         T.zeros_like(black_mean),
                         cell_weight * (
                             (color_mean - cell_mean)**2 +
                             7*get_subtensor_var(image, cell_nonzeros)
                         ))
    for black_parts in MASK_BLACK:
        cell_loss += cell_loss_fn(black_parts, black_mean)
    for white_parts in MASK_WHITE:
        cell_loss += cell_loss_fn(white_parts, white_mean)

    cell_loss /= T.sum(T.neq(mask_image, MASK["IGNORE"]))
    loss += 2*cell_loss
    return 50*T.mean(loss), 50*loss

In [None]:
m = T.tensor4()
img = T.tensor4()

In [None]:
%time new_loss = theano.function([m, img], mask_loss_new(m, img))

In [None]:
%time old_loss = theano.function([m, img], mask_loss(m, img))

In [None]:
def masks(batch_size):
    batch_size += 64 - (batch_size % 64)
    generator = GridGenerator()
    artist = MaskGridArtist()
    for masks in gen_grids.batches(batch_size, generator, artist=artist,
                                   scales=[1.]):
        yield masks[0].astype(floatX)

mask_idx = next(masks(1))
images = np.random.sample((64, 1, 64, 64)).astype(floatX)

In [None]:
%timeit old_loss(mask_idx, images)

In [None]:
%timeit new_loss(mask_idx, images)