<a href="https://colab.research.google.com/github/ggyppsyy/colab_experiments/blob/master/Jax_DC_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#Sources
#https://towardsdatascience.com/understanding-batch-normalization-with-examples-in-numpy-and-tensorflow-with-interactive-code-7f59bb126642
#https://wiseodd.github.io/techblog/2016/07/16/convnet-conv-layer/
#https://github.com/huyouare/CS231n/blob/master/assignment2/cs231n/im2col.py
#https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb?authuser=1&hl=en#scrollTo=7APc6tD7TiuZ
#https://github.com/pytorch/examples/blob/master/dcgan/main.py
#https://arxiv.org/abs/1511.06434

In [2]:
!pip install --upgrade -q git+https://github.com/google/jax.git
!pip install --upgrade -q jaxlib

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

  Building wheel for jax (setup.py) ... [?25l[?25hdone


In [0]:
import numpy as onp
from torch.utils import data
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import time

In [0]:
def leaky_relu(x):
    y1 = ((x > 0) * x)                                                 
    y2 = ((x <= 0) * x * 0.01)                                         
    return y1 + y2 

def sigmoid(x):
    return 1 / (1. + np.exp(-1 * x))

In [0]:
def channel_normalization(b):
    batch_mean = np.mean(b)
    print(batch_mean)
    batch_var = np.sum((b-batch_mean) ** 2, axis=0) / b.shape[0]
    print(batch_var)
    return (b-batch_mean) / ( (batch_var + 1e-8) ** 0.5 )

def batch_normalization(batch):
    print(batch.shape)
    norm_batch = []
    batch = batch.transpose((1, 0, 2, 3)).astype(onp.float)
    print(batch.shape)
    for i in range(batch.shape[0]):
        norm_batch.append(channel_normalization(batch[i]))
    norm_batch = np.stack(norm_batch)
    print(batch.shape)
    return norm_batch.transpose((1, 0, 2, 3))

In [0]:
def tile(array, height):
    arrays = array
    if (type(height)==int):
        height = [height]
    height = np.array(height)
    if (len(height.shape)==0):
        height = np.expand_dims(height, 0)
    height = list(height)
    shape_h = len(height)
    for p in range(shape_h-1,-1,-1):
        height_p = int(height[p])
        for i in range(height_p-1):
            arrays = np.concatenate([arrays, array], axis=0)
        arrays = np.expand_dims(arrays, axis=0)    
        array = arrays
    return np.squeeze(np.array(arrays),axis=0)

def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = int((H + 2 * padding - field_height) / stride + 1)
    out_width = int((W + 2 * padding - field_width) / stride + 1)
    
    
    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = tile(np.arange(field_width), field_height * C)
    j1 = stride * tile(np.arange(out_width), out_height)
    i = i0.reshape((-1, 1)) + i1.reshape((1, -1))
    j = j0.reshape((-1, 1)) + j1.reshape((1, -1))

    k = np.repeat(np.arange(C), field_height * field_width).reshape((-1, 1))

    return (k, i, j)


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant',constant_values=0.0)

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                                 stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    print((field_height, field_width, C))
    cols = cols.transpose((1, 2, 0)).reshape((field_height * field_width * C, -1))
    return cols

In [0]:
def conv_forward(X, W, b, stride=1, padding=1):
    #X = np.transpose(X, (0, 3, 1, 2))    
    #cache = W, b, stride, padding
    n_filters, d_filter, h_filter, w_filter = W.shape
    n_x, d_x, h_x, w_x = X.shape
    print((n_filters, d_filter, h_filter, w_filter))
    print((n_x, d_x, h_x, w_x))
    assert (h_x - h_filter + 2 * padding) % stride == 0
    assert (w_x - w_filter + 2 * padding) % stride == 0
    h_out = int((h_x - h_filter + 2 * padding) / stride + 1)
    w_out = int((w_x - w_filter + 2 * padding) / stride + 1)
    
    X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride)
    W_col = W.reshape((n_filters, -1))
    
    print(X_col.shape)
    print(W_col.shape)
    print(b.shape)
    out = W_col @ X_col + b

    out = out.reshape((n_filters, h_out, w_out, n_x))
    out = out.transpose((3, 0, 1, 2))

    cache = (X, W, b, stride, padding, X_col)

    return [out, cache]

In [0]:
def create_layer(h, w, c, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (c, h, w)), scale * random.normal(b_key, (1,))
    
    
#simplifiy
def create_conv_layer(channels, num_filters, height, width, key, scale=1e-2):
    weights = []
    biases = []
    for n in range(num_filters):
        W,b = create_layer(height, width, channels, key, scale=1e-2)
        weights.append(np.array(W))
        biases.append(np.array(b))
    print("%d filter layer done!" % num_filters)
    return np.stack(weights),np.stack(biases)


def _old_create_layer(h, w, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (h, w)), scale * random.normal(b_key, (1,))
    
    
#simplifiy
def _old_create_conv_layer(channels, num_filters, height, width, key, scale=1e-2):
    weights = []
    biases = []
    for n in range(num_filters):
        c_w = []
        c_b = []
        for c in range(channels):
            W,b = create_layer(height, width, key, scale=1e-2)
            c_w.append(W)
            c_b.append(b)
        weights.append(np.stack(c_w))
        biases.append(np.stack(c_b))
    print("%d filter layer done!" % num_filters)
    return np.stack(weights),np.stack(biases)

In [0]:
def create_discriminator_params(image_shape=(256,256,3),num_filters=8,filter_size=4):
    factors = 0
    assert(image_shape[1]%4==0)
    assert(image_shape[1]==image_shape[2])
    x = int(image_shape[1] / 4)
    for i in range(1, x + 1):
        if x % i == 0: factors+=1
    num_layers = factors-1
    print("Number of layers: %d" % num_layers)
    assert(num_layers>=0)
    params = []
    c = image_shape[0]
    params.append(create_conv_layer(c, 
                                    num_filters, 
                                    filter_size, 
                                    filter_size, 
                                    random.PRNGKey(0)))
    
    for l in range(1, num_layers):
        params.append(create_conv_layer(num_filters*2**(l-1), 
                                        num_filters*2**l, 
                                        filter_size,   
                                        filter_size, 
                                        random.PRNGKey(0)))
        
    params.append(create_conv_layer(num_filters*2**(num_layers-1), 
                                    1, 
                                    filter_size, 
                                    filter_size, 
                                    random.PRNGKey(0)))
    
    print(len(params))
    for i in range(len(params)):
        for j in range(len(params[i])):
            print("%d.%d.%d" % (i,j,len(params[i][j])))
            print(params[i][j].shape)
    return params
    
def predict(params, image):
    activations = image
    for w, b in params[:-1]:
        outputs,_ = conv_forward(activations,w,b,stride=2)
        outputs = batch_normalization(outputs)
        activations = leaky_relu(outputs)
    final_w, final_b = params[-1]
    final_output,_ = conv_forward(activations,final_w,final_b,padding=0)
    return sigmoid(np.squeeze(final_output))
    
shape = [None,256,256,3]
batched_predict = vmap(predict, in_axes=(None, 0))
    
def loss(params, images, targets):
    preds = predict(params, images)
    print(preds.shape)
    print(targets.shape)
    return -np.sum(preds * targets)
    
def accuracy(images, targets):
    predicted_class = np.round(np.ravel(batched_predict(images)))
    return np.mean(predicted_class == target_class)
    
@jit
def update(params, x, y, step_size=0.0001):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In [10]:
img_s = (1,64,64)
p = transforms.Compose([transforms.Resize((img_s[1],img_s[2])),transforms.ToTensor()])

train_dataset = MNIST('/tmp/mnist/', train=True, download=True, transform=p)
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
train_dataset_array = next(iter(train_loader))[0].numpy().reshape((len(train_dataset.data), img_s[1],img_s[2],img_s[0])).astype(np.float32)
train_dataset_labels = onp.array(next(iter(train_loader))[1])

'''test_dataset = MNIST('/tmp/mnist/', train=False, download=True, transform=p)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
test_dataset_array = next(iter(train_loader))[0].numpy().reshape((len(test_dataset.data), img_s[0],img_s[1],img_s[2])).astype(np.float32)
test_dataset_labels = onp.array(next(iter(train_loader))[1])'''

train_dataset_array = np.transpose(train_dataset_array, (0, 3, 1, 2))
#test_dataset_array = np.transpose(test_dataset_array, (0, 3, 1, 2))

split = int(-0.2*train_dataset_array.shape[0])
test_images = train_dataset_array[split:]
test_labels = train_dataset_labels[split:]
train_images = train_dataset_array[:split]
train_labels = train_dataset_labels[:split]

print(test_images.shape)
print(train_images.shape)

def simple_data_generator(images, labels, batch_size):
    batch = onp.random.randint(images.shape[0], size=batch_size)
    batch_data = images[batch]
    batch_labels = labels[batch]
    return batch_data,batch_labels



(12000, 1, 64, 64)
(48000, 1, 64, 64)


In [11]:
num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
params = create_discriminator_params(image_shape=img_s)

Number of layers: 4
8 filter layer done!
16 filter layer done!
32 filter layer done!
64 filter layer done!
1 filter layer done!
5
0.0.8
(8, 1, 4, 4)
0.1.8
(8, 1)
1.0.16
(16, 8, 4, 4)
1.1.16
(16, 1)
2.0.32
(32, 16, 4, 4)
2.1.32
(32, 1)
3.0.64
(64, 32, 4, 4)
3.1.64
(64, 1)
4.0.1
(1, 64, 4, 4)
4.1.1
(1, 1)


In [0]:
print("lets-a-go!")
for epoch in range(num_epochs):
    start_time = time.time()
    for step in range(steps_per_epoch):
        x, y = simple_data_generator(train_images, train_labels, batch_size)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = discrim.accuracy(simple_data_generator(train_images,train_labels,batch_size))
    test_acc = discrim.accuracy(simple_data_generator(test_images,test_labels,batch_size))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

lets-a-go!
(8, 1, 4, 4)
(64, 1, 64, 64)
(4, 4, 1)
(16, 65536)
(8, 16)
(8, 1)
(64, 8, 32, 32)
(8, 64, 32, 32)
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[32,32])>with<JVPTrace(level=1/1)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=1/1)>
Traced<ShapedArr

In [0]:
class Generator():
    def __init__(self):
        self.step_size = 0.0001
        self.image_shape = (224,224,3)
        self.params = []
        
    def predict():
        pass
    
    def batched_predict():
        pass
    
    def loss(params, images, targets):
        preds = batched_predict(params, images)
        return -np.sum(preds * targets)
    
    @jit
    def update(params, x, y):
        grads = grad(loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

In [0]:
import numpy as np
n_filters = 2
field_height=2
field_width=2
padding=1
stride=1
N, C, H, W = (3,6,3,3)
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)

print(out_height)
print(out_width)

y = np.repeat(np.ones(n_filters*field_height*field_width), C).reshape((n_filters,field_height,field_width,C)).astype(np.int)
y = y.reshape(n_filters, -1)
x = np.repeat(np.ones(N*H*W), C).reshape((N,H,W,C)).astype(np.int)
x = np.transpose(x, (0, 3, 1, 2))
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
out = y @ cols

print(i0)
print(i1)
print(j0)
print(j1)
print("-------")
print(k)
print(i)
print(j)
print(C)
print(cols)
print(cols.shape)
print(y.shape)
print(out.shape)
print(out)
print("-------")
out = out.reshape(n_filters, out_height, out_width, N)
print(out)
print("-------")
out = out.transpose(3, 0, 1, 2)
print(out)

In [0]:
import numpy as np
n = [8,7,6]
s = 0.5
print(np.tile(n,s))
print(" ")

def tile(array, height):
    arrays = array
    if (type(height)==int):
        height = [height]
    elif (type(height)==np.ndarray):
        height = height.tolist()
    assert(type(height)==list)
    shape_h = len(height)
    for p in range(shape_h-1,-1,-1):
        for i in range(height[p]-1):
            arrays = np.concatenate([arrays, array], axis=0)
        arrays = np.expand_dims(arrays, axis=0)    
        array = arrays
    return np.squeeze(np.array(arrays),axis=0)

print(tile(n,s).astype(int))

In [0]:
x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.transpose(x, (2, 0, 1))
print(y)
print(y.shape)

x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.moveaxis(x, 2, 0)
print(y)
print(y.shape)

In [0]:
i = np.random.randint(5, size=(2, 1, 3, 3))
print(i)
#print(i.transpose(1, 0, 2, 3))
p = batch_normalization(i)
print(p.shape)
print(p)
#print(p.transpose(1, 0, 2, 3))


In [0]:
height = np.array([5,0,0,0]).astype(np.int64)
height = 5
if (type(height)==int):
    print("here")
    height = [5]
y = np.array(height)
print(type(y))
print(len(y.shape))
y = np.expand_dims(y, 0)
print(y)