<a href="https://colab.research.google.com/github/ggyppsyy/colab_experiments/blob/master/Jax_DC_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.12-cp36-none-linux_x86_64.whl
#!pip install --upgrade -q jax

In [0]:
import numpy as np

In [0]:
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = (H + 2 * padding - field_height) / stride + 1
    out_width = (W + 2 * padding - field_width) / stride + 1

    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width), out_height)
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

    return (k, i, j)


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                                 stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols

In [0]:
def conv_forward(X, W, b, stride=1, padding=1):
    #X = np.transpose(X, (0, 3, 1, 2))    
    #cache = W, b, stride, padding
    n_filters, d_filter, h_filter, w_filter = W.shape
    n_x, d_x, h_x, w_x = X.shape
    assert (h_x - h_filter + 2 * padding) % stride == 0
    assert (w_x - w_filter + 2 * padding) % stride == 0
    h_out = int((h_x - h_filter + 2 * padding) / stride + 1)
    w_out = int((w_x - w_filter + 2 * padding) / stride + 1)
    
    X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride)
    W_col = W.reshape(n_filters, -1)
    
    out = W_col @ X_col + b
    out = out.reshape(n_filters, h_out, w_out, n_x)
    out = out.transpose(3, 0, 1, 2)

    cache = (X, W, b, stride, padding, X_col)

    return out, cache

In [4]:
n_filters = 2
field_height=2
field_width=2
padding=1
stride=1
N, C, H, W = (3,1,3,3)
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)

print(out_height)
print(out_width)

y = np.repeat(np.ones(n_filters*field_height*field_width), C).reshape((n_filters,field_height,field_width,C)).astype(np.int)
y = y.reshape(n_filters, -1)
x = np.repeat(np.ones(N*H*W), C).reshape((N,H,W,C)).astype(np.int)
x = np.transpose(x, (0, 3, 1, 2))
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
out = y @ cols

print(i0)
print(i1)
print(j0)
print(j1)
print("-------")
print(k)
print(i)
print(j)
print(C)
print(cols)
print(cols.shape)
print(y.shape)
print(out.shape)
print(out)
print("-------")
out = out.reshape(n_filters, out_height, out_width, N)
print(out)
print("-------")
out = out.transpose(3, 0, 1, 2)
print(out)

4
4
[0 0 1 1]
[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
[0 1 0 1]
[0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3]
-------
[[0]
 [0]
 [0]
 [0]]
[[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
 [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
 [1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4]
 [1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4]]
[[0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3]
 [1 2 3 4 1 2 3 4 1 2 3 4 1 2 3 4]
 [0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3]
 [1 2 3 4 1 2 3 4 1 2 3 4 1 2 3 4]]
1
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1
  0 0 0 1 1 1 1 1 1 1 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0
  1 1 1 1 1 1 1 1 1 0 0 0]
 [0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1
  0 0 0 0 0 0 0 0 0 0 0 0]
 [1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0]]
(4, 48)
(2, 4)
(2, 48)
[[1 1 1 2 2 2 2 2 2 1 1 1 2 2 2 4 4 4 4 4 4 2 2 2 2 2 2 4 4 4 4 4 4 2 2 2
  1 1 1 2 2 2 2 2 2 1 1 1]
 [1 1 1 2 2 2 2 2 2 1 1 1 2 2 2 4 4 4 4 4 4 2 2 2 2 2 2 4 4 4 4 4 

In [5]:
x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.transpose(x, (2, 0, 1))
print(y)
print(y.shape)

x = np.arange(12).reshape((2,2,3))
print(x)
print(x.shape)
y = np.moveaxis(x, 2, 0)
print(y)
print(y.shape)

[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
(2, 2, 3)
[[[ 0  3]
  [ 6  9]]

 [[ 1  4]
  [ 7 10]]

 [[ 2  5]
  [ 8 11]]]
(3, 2, 2)
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
(2, 2, 3)
[[[ 0  3]
  [ 6  9]]

 [[ 1  4]
  [ 7 10]]

 [[ 2  5]
  [ 8 11]]]
(3, 2, 2)
