In [1]:
from torchvision import transforms
from torchvision.datasets import Omniglot
import torch
from torch.utils.data import DataLoader

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

In [None]:
!pip install flax
import flax
import flax.linen as nn
import optax

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,),)])
background_set = Omniglot(root='./data', background=True, download=True, transform=transform)
evaluation_set = Omniglot(root='./data', background=False, download=True, transform=transform)

In [13]:
batch_size = 10
background_set_loader = DataLoader(background_set, batch_size = batch_size)
evaluation_set_loader = DataLoader(evaluation_set, batch_size = batch_size)

In [6]:
class SiameseNet(nn.Module):
  def setup(self):
    self.conv1 = nn.Conv(features=64, kernel_size=(10,10), name="Conv2D1")
    self.conv2 = nn.Conv(features=128, kernel_size=(7,7), name="Conv2D2")
    self.conv3 = nn.Conv(features=128, kernel_size=(4,4), name="Conv2D3")
    self.conv4 = nn.Conv(features=256, kernel_size=(4,4), name="Conv2D4")
    self.encoder = nn.Dense(4096, name="Encoder")
    self.liner = nn.Dense(1, name="1DLayer")

  def __call__(self,input1,input2):
    x1 = nn.relu(self.conv1(input1))
    x1 = nn.max_pool(x1, window_shape=(2, 2), strides=(2,2))
    x1 = nn.relu(self.conv2(x1))
    x1 = nn.max_pool(x1, window_shape=(2, 2), strides=(2,2))
    x1 = nn.relu(self.conv3(x1))
    x1 = nn.max_pool(x1, window_shape=(2, 2), strides=(2,2))
    x1 = nn.relu(self.conv4(x1))
    x1 = x1.reshape((x1.shape[0], -1)) 
    x1 = nn.sigmoid(self.encoder(x1)) 
 
    x2 = nn.relu(self.conv1(input2))
    x2 = nn.max_pool(x2, window_shape=(2, 2), strides=(2,2))
    x2 = nn.relu(self.conv2(x2))
    x2 = nn.max_pool(x2, window_shape=(2, 2), strides=(2,2))
    x2 = nn.relu(self.conv3(x2))
    x2 = nn.max_pool(x2, window_shape=(2, 2), strides=(2,2))
    x2 = nn.relu(self.conv4(x2))
    x2 = x2.reshape((x2.shape[0], -1)) 
    x2 = nn.sigmoid(self.encoder(x2)) 

    l1_dist = jnp.abs(x1 - x2) 
    return self.liner(l1_dist)

In [14]:
model = SiameseNet()

In [15]:
seed = 42
key = jax.random.PRNGKey(seed)
shape = [batch_size, 105, 105]
key, subkey = jax.random.split(key)
x1 = jax.random.normal(subkey, shape=shape)
 
key, subkey = jax.random.split(key)
x2 = jax.random.normal(subkey, shape=shape)
 
key, subkey = jax.random.split(key)

In [16]:
variables = model.init(subkey, x1, x2)

In [17]:
def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = model.apply(params, batch, batch)
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)
  return loss_value.mean()

In [18]:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(background_set_loader):
    batch = jnp.array(batch.reshape(batch_size,105,105))
    labels = jnp.array(labels)
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

In [None]:
optimizer = optax.adam(learning_rate=0.00001)
variables = fit(variables, optimizer)

step 0, loss: 6.931471824645996
step 100, loss: 6.300234794616699
step 200, loss: 4.072996139526367
step 300, loss: 0.21341514587402344
step 400, loss: -5.263023376464844
