<a href="https://colab.research.google.com/github/mpnsk/ivy_seminar/blob/main/my_ivy_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ivy
!pip install dm-haiku

In [None]:
import ivy

In [None]:
my_backend = "torch"
# my_backend = "tensorflow"
# my_backend = "jax"

# ivy.set_default_device("cpu")


doCompile = True

if my_backend == "jax":
  import jax
  jax.config.update('jax_enable_x64', True)

ivy.set_backend(my_backend)

# as of the pytorch mnist example https://github.com/pytorch/examples/blob/main/mnist/main.py
class Net(ivy.Module):
    def __init__(self):
        self.conv1 = ivy.Conv2D(1, 32, [3, 3], 1, "VALID")
        self.conv2 = ivy.Conv2D(32, 64, [3, 3], 1, "VALID")
        self.conv1_drop = ivy.Dropout(0.25)
        self.conv2_drop = ivy.Dropout(0.5)
        self.fc1 = ivy.Linear(9216, 128)
        self.fc2 = ivy.Linear(128, 10)
        super().__init__()

    def _forward(self, x):
        x = self.conv1(x)
        ivy.relu(x, out=x)
        x = self.conv2(x)
        ivy.relu(x, out=x)
        x = ivy.max_pool2d(x, 2, 2, 'VALID')
        x = self.conv1_drop(x)
        x = ivy.flatten(x, start_dim=1)
        x = self.fc1(x)
        ivy.relu(x, out=x)
        x = self.conv2_drop(x)
        x = self.fc2(x)
        x = ivy.softmax(x)
        return x


def loss_fn(v, x, y):
    entropy = ivy.sparse_cross_entropy(y, model(x))
    entropy = entropy.mean()
    return entropy


if doCompile:
  print("starting compilation")
  model = ivy.compile(Net(), to=my_backend, args=(ivy.random_normal(shape=(1,28,28,1)),))
else:
  model = Net()
print("done")

starting compilation
done


In [None]:
print('loading mnist')
from keras.datasets import mnist
(train_X_np, train_y_np), (test_X, test_y) = mnist.load_data()
del test_X # I'm not there yet, lol
del test_y

# no need to look at all 60,000 samples as of now
limit = 320
train_X_np = train_X_np[:limit, :, :]
train_y_np = train_y_np[:limit]

# batch parameters
batch_size = 32
num_batches = limit // batch_size

print('starting to batch')
## convert to ivy
train_X_ivy = ivy.array(train_X_np)
train_y_ivy = ivy.array(train_y_np)

# normalize from grey scale to 0.0-1.0
train_X_ivy = train_X_ivy / 255

# expand data from nhw to nhwc
train_X_ivy = ivy.expand_dims(train_X_ivy, axis=-1)

# turn big array into list of batched arrays
train_X_batches = [train_X_ivy[i*batch_size:(i+1)*batch_size] for i in range(num_batches)]
train_y_batches = [train_y_ivy[i*batch_size:(i+1)*batch_size] for i in range(num_batches)]
print('done')

loading mnist
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
starting to batch
done


In [None]:
for idx in range(num_batches):
    loss, grads = ivy.execute_with_gradients(lambda v: loss_fn(v, train_X_batches[idx], train_y_batches[idx]), model.v)
    print(f'{idx=}, {loss=}')
    if grads.all().cont_all_false():
        print("ERROR!")
        print(f'{grads.conv1.w=}')
        print(f'{grads.conv1.b=}')
        print(f'{grads.conv2.w=}')
        print(f'{grads.conv2.b=}')
        print(f'{grads.fc1.w=}')
        print(f'{grads.fc1.b=}')
        print(f'{grads.fc2.w=}')
        print(f'{grads.fc2.b=}')
        break
    model.v = model.v - grads * 0.005


random_index = 3

x = train_X_ivy[random_index]
# expand hwc to nhwc
x = ivy.expand_dims(x, axis=0)
prediction = model(x)
print(f'{prediction=}')
print(f'{ivy.argmax(prediction)=}')
print(f'{train_y_ivy[random_index]=}')

idx=0, loss=ivy.array(3.5656419)
idx=1, loss=ivy.array(4.7978497)
idx=2, loss=ivy.array(3.5605125)
idx=3, loss=ivy.array(3.1029253)
idx=4, loss=ivy.array(3.1441267)
idx=5, loss=ivy.array(2.9305246)
idx=6, loss=ivy.array(2.9057436)
idx=7, loss=ivy.array(1.7976189)
idx=8, loss=ivy.array(1.1552026)
idx=9, loss=ivy.array(1.3361251)
prediction=tensor([[0.0169, 0.1191, 0.0124, 0.0245, 0.0169, 0.0825, 0.0679, 0.0223, 0.6305,
         0.0070]], grad_fn=<SoftmaxBackward0>)
ivy.argmax(prediction)=ivy.array(8)
train_y_ivy[random_index]=ivy.array(1)
