<a href="https://colab.research.google.com/github/ggyppsyy/colab_experiments/blob/master/Jax_DC_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#https://towardsdatascience.com/understanding-batch-normalization-with-examples-in-numpy-and-tensorflow-with-interactive-code-7f59bb126642
#https://wiseodd.github.io/techblog/2016/07/16/convnet-conv-layer/
#https://github.com/huyouare/CS231n/blob/master/assignment2/cs231n/im2col.py
#https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb?authuser=1&hl=en#scrollTo=7APc6tD7TiuZ
#https://github.com/pytorch/examples/blob/master/dcgan/main.py
#https://arxiv.org/abs/1511.06434

In [0]:
#!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.12-cp36-none-linux_x86_64.whl
#!pip install --upgrade -q jax

In [0]:
import numpy as np

In [0]:
def leaky_relu(x):
    y1 = ((x > 0) * x)                                                 
    y2 = ((x <= 0) * x * 0.01)                                         
    return y1 + y2 

def sigmoid(x):
    return 1 / (1. + np.exp(-x))

In [0]:
def channel_normalization(b):
    batch_mean = np.mean(b)
    print(batch_mean)
    batch_var = np.sum((b-batch_mean) ** 2, axis=0) / b.shape[0]
    print(batch_var)
    return (b-batch_mean) / ( (batch_var + 1e-8) ** 0.5 )

def batch_normalization(batch):
    print(batch.shape)
    batch = batch.transpose(1, 0, 2, 3).astype(np.float)
    print(batch.shape)
    for i in range(batch.shape[0]):
        batch[i] = channel_normalization(batch[i])
    print(batch.shape)
    return batch.transpose(1, 0, 2, 3)

In [0]:
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = (H + 2 * padding - field_height) / stride + 1
    out_width = (W + 2 * padding - field_width) / stride + 1

    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width), out_height)
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

    return (k, i, j)


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                                 stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols

In [0]:
def conv_forward(X, W, b, stride=1, padding=1):
    #X = np.transpose(X, (0, 3, 1, 2))    
    #cache = W, b, stride, padding
    n_filters, d_filter, h_filter, w_filter = W.shape
    n_x, d_x, h_x, w_x = X.shape
    assert (h_x - h_filter + 2 * padding) % stride == 0
    assert (w_x - w_filter + 2 * padding) % stride == 0
    h_out = int((h_x - h_filter + 2 * padding) / stride + 1)
    w_out = int((w_x - w_filter + 2 * padding) / stride + 1)
    
    X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride)
    W_col = W.reshape(n_filters, -1)
    
    out = W_col @ X_col + b
    out = out.reshape(n_filters, h_out, w_out, n_x)
    out = out.transpose(3, 0, 1, 2)

    cache = (X, W, b, stride, padding, X_col)

    return out, cache

In [0]:
def create_layer(h, w, key, channels=3, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (h, w)), scale * random.normal(b_key, (1,))

def create_conv_layer(num_filters, height, width, key, channels=3, scale=1e-2):
    filters = []
    for n in range(num_filters):
        filters.append(create_layer(height, width, key, channels=3, scale=1e-2))
    return np.stack(filters)

In [0]:
class Discriminator():
    def __init__(self):
        self.step_size = 0.0001
        self.image_shape = (224,224,3)
        self.params = []
        
        
    def predict():
        pass
    
    def batched_predict():
        pass
    
    def loss(params, images, targets):
        preds = batched_predict(params, images)
        return -np.sum(preds * targets)
    
    @jit
    def update(params, x, y):
        grads = grad(loss)(self.params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(self.params, grads)]

In [0]:
class Generator():
    def __init__(self):
        self.step_size = 0.0001
        
        
    def predict():
        pass
    
    def batched_predict():
        pass
    
    def loss(params, images, targets):
        preds = batched_predict(params, images)
        return -np.sum(preds * targets)
    
    @jit
    def update(params, x, y):
        grads = grad(loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

In [0]:
n_filters = 2
field_height=2
field_width=2
padding=1
stride=1
N, C, H, W = (3,1,3,3)
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)

print(out_height)
print(out_width)

y = np.repeat(np.ones(n_filters*field_height*field_width), C).reshape((n_filters,field_height,field_width,C)).astype(np.int)
y = y.reshape(n_filters, -1)
x = np.repeat(np.ones(N*H*W), C).reshape((N,H,W,C)).astype(np.int)
x = np.transpose(x, (0, 3, 1, 2))
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
out = y @ cols

print(i0)
print(i1)
print(j0)
print(j1)
print("-------")
print(k)
print(i)
print(j)
print(C)
print(cols)
print(cols.shape)
print(y.shape)
print(out.shape)
print(out)
print("-------")
out = out.reshape(n_filters, out_height, out_width, N)
print(out)
print("-------")
out = out.transpose(3, 0, 1, 2)
print(out)

In [0]:
x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.transpose(x, (2, 0, 1))
print(y)
print(y.shape)

x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.moveaxis(x, 2, 0)
print(y)
print(y.shape)

[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
(2, 2, 3)
[[[ 0  3]
  [ 6  9]]

 [[ 1  4]
  [ 7 10]]

 [[ 2  5]
  [ 8 11]]]
(3, 2, 2)
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
(2, 2, 3)
[[[ 0  3]
  [ 6  9]]

 [[ 1  4]
  [ 7 10]]

 [[ 2  5]
  [ 8 11]]]
(3, 2, 2)


In [39]:
i = np.random.randint(5, size=(2, 1, 3, 3))
print(i)
#print(i.transpose(1, 0, 2, 3))
p = batch_normalization(i)
print(p.shape)
print(p)
#print(p.transpose(1, 0, 2, 3))


[[[[2 2 3]
   [4 3 4]
   [4 0 2]]]


 [[[1 4 0]
   [2 4 0]
   [2 2 4]]]]
(2, 1, 3, 3)
(1, 2, 3, 3)
2.388888888888889
[[1.04012346 1.37345679 3.04012346]
 [1.37345679 1.4845679  4.15123457]
 [1.37345679 2.92901235 1.37345679]]
(1, 2, 3, 3)
(2, 1, 3, 3)
[[[[-0.3813143  -0.33183182  0.35048914]
   [ 1.37473184  0.50155683  0.79074573]
   [ 1.37473184 -1.39583906 -0.33183182]]]


 [[[-1.36183677  1.37473184 -1.37009392]
   [-0.33183182  1.32228618 -1.17248505]
   [-0.33183182 -0.22722962  1.37473184]]]]
