In [173]:
import jax.scipy as jsp
import jax.numpy as jnp
from jax import random
from jax import lax
import jax
import numpy as np
from functools import partial
from jax import make_jaxpr

import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']
plt.style.use('fivethirtyeight')

from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Compose, Resize, RandomHorizontalFlip, Normalize, RandomCrop

import warnings
warnings.filterwarnings("ignore")

## Datasets

In [63]:


# CIFAR-10
class CIFAR10:
    def __init__(self, batch_size=64) -> None:
        # Get datasets
        train_transforms = Compose([RandomCrop(32, padding=4), RandomHorizontalFlip(), ToTensor(), Normalize(mean=(0.4246, 0.4149, 0.3839), std=(0.2828,0.2779, 0.2844))])
        valid_transforms = Compose([ToTensor(), Normalize(mean=(0.4942, 0.4851, 0.4504), std=(0.2467,0.2429, 0.2616))])
        self.train = datasets.CIFAR10(root='data', train=True, transform=train_transforms)
        self.val = datasets.CIFAR10(root='data', train=False, transform=valid_transforms)

        # Create data loaders
        self.train_dl = DataLoader(self.train, batch_size = batch_size, shuffle=True)
        self.val_dl = DataLoader(self.val, batch_size = batch_size, shuffle=True)


In [65]:
# Check out dataset
def plot_data_check(data, labels, title):
    plt.figure(figsize=(10, 10))
    x, y = next(iter(data.train_dl))
    print (x)

    for i in range(64):

        plt.subplot(8, 8, i+1)
        img = x.permute(0, 2, 3, 1)
        plt.imshow(img[i])
        plt.title(labels[y[i].item()], fontsize=9)
        plt.axis('off')

# CIFAR-10
#data = CIFAR10()
data = CIFAR10(batch_size=64)

labels = ['plane', 'car', 'bird', 'cat', 'deer',
          'dog', 'frog', 'horse', 'ship', 'truck']
len(data.train), len(data.val), len(data.train_dl), len(data.val_dl)
#plot_data_check(data, labels, "MNIST")

data = CIFAR10(batch_size=4096*4*2*2)


for x, y in data.train_dl:
    print (x[:,0,:,:].mean())
    print (x[:,1,:,:].mean())
    print (x[:,2,:,:].mean())
    print (x[:,0,:,:].std())
    print (x[:,1,:,:].std())
    print (x[:,2,:,:].std())

for x, y in data.val_dl:
    print (x[:,0,:,:].mean())
    print (x[:,1,:,:].mean())
    print (x[:,2,:,:].mean())
    print (x[:,0,:,:].std())
    print (x[:,1,:,:].std())
    print (x[:,2,:,:].std())



tensor(-0.0001)
tensor(-0.0003)
tensor(-0.0003)
tensor(0.9999)
tensor(0.9997)
tensor(1.0001)
tensor(5.7946e-05)
tensor(0.0001)
tensor(3.4803e-05)
tensor(0.9998)
tensor(1.0000)
tensor(1.0000)


In [64]:
data = CIFAR10(batch_size=64)
x, y = next(iter(data.train_dl))

In [47]:
x.shape

(64,)

## Jax Classes

In [295]:
class Module:
    def __init__(self) -> None:
        self.state = {}
        self.params = {}

    def init(self, key, x):
        raise NotImplementedError

    def forward(self, params, x):
        raise NotImplementedError

    def __call__(self, params, x):
        return self.forward(params, x)


class Linear(Module):
    def __init__(self, out_features, bias=True) -> None:
        super().__init__()
        self.out_features = out_features
        self.bias = bias

    def init(self, key, x):
        # Get keys
        self.in_features = x.shape[-1]
        key, w_key, b_key = jax.random.split(key, num=3)
        self.params['weights'] = jax.random.normal(
            w_key, ( self.in_features, self.out_features))/self.in_features**0.5

        if self.bias:
            self.params['bias'] = jax.random.normal(
                b_key, (1, self.out_features))/self.in_features**0.5

        return self.params, key

    @partial(jax.jit, static_argnames=['self'])
    def forward(self, params, x):

        out = x @ params['weights']
        if self.bias:
            out += params['bias']
            return out

        return out

    @partial(jax.jit, static_argnames=['self'])
    def loss_fn(self, params, x, y):
        logits = self(params, x)
        return -jax.nn.log_softmax(logits)[jnp.arange(logits.shape[0]), y].mean()


In [296]:
batch_size = 512
data = CIFAR10(batch_size=batch_size)
max_epochs = 25

model = Linear(10)
x, y = next(iter(data.train_dl))
x, y = jnp.array(x).reshape(x.shape[0], -1), jnp.array(y)
key = random.PRNGKey(1701)
params, key = model.init(key, x)

for epoch in range(1,max_epochs+1):
    for step, batch in enumerate(data.train_dl):
        x, y = batch
        x, y = jnp.array(x).reshape(x.shape[0], -1), jnp.array(y)

        loss, grads =  jax.value_and_grad(model.loss_fn)(params, x, y)

        params = jax.tree_map(lambda p, g: p1 * g, params, grads)

    if epoch % 1 == 0:
        print (epoch, loss)


    



1 2.2513404
2 2.0467536
3 2.1063917
4 2.036028
5 1.966125
6 1.9665849
7 1.9464259
8 1.8994087
9 2.0194066
10 2.0090256
11 1.9160752
12 1.9348012
13 1.9261287
14 1.9602569
15 1.9714416
16 1.9760994
17 1.8773366
18 1.934239
19 1.9697044
20 1.9714535
21 1.8590908
22 1.940966
23 1.9363087
24 1.899105
25 1.8722675


Array(20.937744, dtype=float32)

In [250]:
logits.shape

(16, 10)

In [196]:
batch_size = 64
data = CIFAR10(batch_size=batch_size)

lin1 = Linear(10, bias=True)
x, y = next(iter(data.train_dl))
x = jnp.array(x).reshape(batch_size, -1)

key = random.PRNGKey(1701)
params, key = lin1.init(key, x)
#x = jnp.array(x)

#x.shape

In [198]:
out = lin1(params,  x)
out[0]

Array([  9.400408 ,  26.034424 ,  31.0158   , -18.16154  ,  61.345306 ,
       -26.92185  , -48.455643 ,  53.08465  ,  -4.6999598,  30.587025 ],      dtype=float32)

In [102]:
x.reshape(64,-1).shape

(64, 3072)

In [199]:
print(make_jaxpr(lin1.forward)(params, x))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,10][39m b[35m:f32[3072,10][39m c[35m:f32[64,3072][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[64,10][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:f32[1,10][39m f[35m:f32[3072,10][39m g[35m:f32[64,3072][39m. [34m[22m[1mlet
          [39m[22m[22mh[35m:f32[64,10][39m = dot_general[
            dimension_numbers=(((1,), (0,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] g f
          i[35m:f32[64,10][39m = add h e
        [34m[22m[1min [39m[22m[22m(i,) }
      name=forward
    ] a b c
  [34m[22m[1min [39m[22m[22m(d,) }


In [200]:
print(make_jaxpr(lin1.forward)(params,  x))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1,10][39m b[35m:f32[3072,10][39m c[35m:f32[64,3072][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[64,10][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:f32[1,10][39m f[35m:f32[3072,10][39m g[35m:f32[64,3072][39m. [34m[22m[1mlet
          [39m[22m[22mh[35m:f32[64,10][39m = dot_general[
            dimension_numbers=(((1,), (0,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] g f
          i[35m:f32[64,10][39m = add h e
        [34m[22m[1min [39m[22m[22m(i,) }
      name=forward
    ] a b c
  [34m[22m[1min [39m[22m[22m(d,) }


In [184]:
for i in range(100000):
    out = lin1(params, state, x)


In [188]:
for i in range(100000):
    out = lin1(params, state, x)

In [208]:
x, y = next(iter(data.train_dl))


In [210]:
x.shape[0]

64