In [5]:
# https://www.tensorflow.org/install/source#gpu
! nvcc  --version

import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import math
import copy
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

from src.utils import apply_circular_mask, cartesian_to_polar_grid, polar_transform, polar_transform_inv
from src.model import toeplitz_extractor
from src.visu import add_grad_coloring

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
Num GPUs Available:  4


In [2]:
import wandb
from datetime import datetime

dataset_name = 'imagenet'
image_size = 224
batch_size = 128
n_epochs = 10
n_channels = 3
n_classes = 1000

# see polar_transform
radius = (np.array([image_size, image_size])**2).sum()**0.5/2
len_beam = int(round(radius))
n_beams = int(round(radius*2*np.pi))
print(len_beam, n_beams)

config = {
}

wandb.init(project="RadialBeams", config=config, group=dataset_name, name=datetime.now().strftime("%m%d-%H%M"))

23 142


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohschm[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
%matplotlib inline
fig, axs = plt.subplots(3,3)

batch_idx = 1

data = next(iter(test_dataset))
image = data['image'][batch_idx]

rot_image = tfa.image.rotate(image, 90., interpolation='bilinear')
rot_image = tf.cast(rot_image, tf.float32)

axs[0, 0].imshow(image)
axs[1, 0].imshow(rot_image)
axs[2, 0].imshow(image)

polar = data['polar'][batch_idx]
rot_polar = polar_transform(rot_image)
shifted_polar = tf.roll(polar, shift=50, axis=1)

axs[0, 1].imshow(polar)
axs[1, 1].imshow(rot_polar)
axs[2, 1].imshow(shifted_polar)

polar_inv = polar_transform_inv(polar.numpy(), output=(image_size, image_size))
rot_polar_inv = polar_transform_inv(rot_polar.numpy(), output=(image_size, image_size))
shifted_polar_inv = polar_transform_inv(shifted_polar.numpy(), output=(image_size, image_size))

axs[0, 2].imshow(polar_inv)
axs[1, 2].imshow(rot_polar_inv)
axs[2, 2].imshow(rot_polar_inv)

plt.show()

In [None]:
from tensorflow.keras import layers, models

def cyclic_beam_padding(x, padding=1):
    return tf.concat([x[:, :, x.shape[2]-padding:], x, x[:, :, :padding]], axis=-2)
    
class PolarRegressor(tf.keras.Model):
    def __init__(self):
        super().__init__()
        # padding such that kernels of size 3 on the edges use cyclic padded values
        # padding = kernel_size - 1 using "same" padding
        self.padding = 16
        self.encoder = models.Sequential([
            layers.InputLayer(input_shape=(len_beam, n_beams+2*self.padding, n_channels)),

            layers.Conv2D(32, (3, 3), activation='gelu', padding='same'),
            layers.Conv2D(64, (3, 3), activation='gelu', padding='same'),
            layers.LayerNormalization(),
            
            layers.Conv2D(128, (3, 3), activation='gelu', padding='same'),
            layers.Conv2D(128, (3, 3), activation='gelu', padding='same'),
            layers.LayerNormalization(),

            layers.Conv2D(64, (3, 3), activation='gelu', padding='same'),
            layers.Conv2D(32, (3, 3), activation='gelu', padding='same'),
            layers.LayerNormalization(),
        ])
        self.encoder2 = models.Sequential([
            layers.InputLayer(input_shape=(n_beams, 2*368)),
            layers.Dense(64, activation='gelu'),
            layers.Dense(1),
            layers.Flatten()
        ])
    
    def call(self, x):
        if self.padding > 0:
            x = cyclic_beam_padding(x, padding=self.padding)
        z = self.encoder(x)
        z = tf.transpose(z, (0, 2, 1, 3))
        z = tf.reshape(z, (z.shape[0], z.shape[1], -1))
        if self.padding > 0:
            z = z[:, self.padding//2:n_beams+self.padding//2]
        z = self.encoder2(z)
        # z = tf.nn.softmax(z, axis=-1)
        return z / tf.norm(z, axis=-1, keepdims=True)

model = PolarRegressor()
model.build(input_shape=(batch_size, len_beam, n_beams, n_channels))
model.summary()
model(tf.zeros((batch_size, len_beam, n_beams, n_channels))).shape

In [None]:
def roll_batch(tensor, roll_values, axis=2):
  return tf.stack([tf.roll(tensor[i], roll_values[i], axis=axis-1) for i in range(tensor.shape[0])])

true = np.zeros((n_beams, image.shape[0]))
true[n_beams//2] = 1
true = tf.convert_to_tensor(true, dtype=tf.float32)

optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
for e in  tqdm(range(n_epochs)):
    # training loop
    for s, sample in enumerate(train_dataset):
        image, label = sample['polar'], sample['label']
        with tf.GradientTape() as tape:
            k_distribution = model(image)
            # cosine of hyper-unit-vectors -> normalization not needed 
            loss = 1. - (k_distribution @ true) 
            grad = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))
        wandb.log({"training loss": np.mean(loss.numpy())}, step=e*len(train_dataset)+s)
        # equivariance test loop
        image = tf.roll(image, shift=100, axis=2) # see example above shift to right
        k_true = tf.cast(tf.fill((image.shape[0], ), 100), tf.int64)
        test_loss = tf.abs(k_true + n_beams//2-tf.argmax(model(image), axis=-1))
        wandb.log({"equivariance test": np.mean(test_loss.numpy())}, step=e*len(train_dataset)+s)
    # pseudo-equivariant test loop
    test_loss = []
    for s, sample in enumerate(test_dataset):
        image, label = sample['polar'], sample['label']
        k = tf.random.uniform(shape=(image.shape[0],), minval=-n_beams//2, maxval=n_beams//2, dtype=tf.int64)
        image = roll_batch(image, k, axis=2)
        k_distribution = model(image)
        test_loss.append(tf.abs(k + n_beams//2-tf.argmax(k_distribution, axis=-1)).numpy())
    wandb.log({"testing loss": np.concatenate(test_loss, axis=0).mean()})
    print(np.concatenate(test_loss, axis=0).mean())
# todo implement load and save of models 

In [None]:
fig, axs = plt.subplots(1,4)
for s, sample in enumerate(train_dataset):
    image, label = sample['polar'], sample['label']
    axs[0].imshow(polar_transform_inv(image[0].numpy()))
    image = tf.roll(image, shift=50, axis=2)
    axs[1].imshow(polar_transform_inv(image[0].numpy()))
    k_distribution = model(image)
    k = tf.argmax(k_distribution, axis=-1)
    image_undo = tf.roll(image[0], shift=93+k[0], axis=1)
    axs[2].imshow(polar_transform_inv(image_undo.numpy()))
    axs[2].set_title(k[0])
    image_undo = tf.roll(image[0], shift=93-k[0], axis=1)
    axs[3].imshow(polar_transform_inv(image_undo.numpy()))
    plt.show()
    break