In [1]:
from absl import logging
import tensorflow as tf
#import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import collections
import numpy as np
import tensorflow.keras.layers as layers
from tqdm import tqdm

In [2]:
"""Slot Attention model for object discovery and set prediction."""

class SlotAttention(layers.Layer):
  """Slot Attention module."""

  def __init__(self, num_iterations, num_slots, slot_size, mlp_hidden_size,
               epsilon=1e-8):
    """Builds the Slot Attention module.
    Args:
      num_iterations: Number of iterations.
      num_slots: Number of slots.
      slot_size: Dimensionality of slot feature vectors.
      mlp_hidden_size: Hidden layer size of MLP.
      epsilon: Offset for attention coefficients before normalization.
    """
    super().__init__()
    self.num_iterations = num_iterations
    self.num_slots = num_slots
    self.slot_size = slot_size
    self.mlp_hidden_size = mlp_hidden_size
    self.epsilon = epsilon

    self.norm_inputs = layers.LayerNormalization()
    self.norm_slots = layers.LayerNormalization()
    self.norm_mlp = layers.LayerNormalization()

    # Parameters for Gaussian init (shared by all slots).   # Intialize slots randomly at first 
    self.slots_mu = self.add_weight(
        initializer="glorot_uniform",
        shape=[1, 1, self.slot_size],   # slot_size: Dimensionality of slot feature vectors.
        dtype=tf.float32,
        name="slots_mu")
    self.slots_log_sigma = self.add_weight(
        initializer="glorot_uniform",
        shape=[1, 1, self.slot_size],
        dtype=tf.float32,
        name="slots_log_sigma")

    # Linear maps for the attention module.
    self.project_q = layers.Dense(self.slot_size, use_bias=False, name="q")
    self.project_k = layers.Dense(self.slot_size, use_bias=False, name="k")
    self.project_v = layers.Dense(self.slot_size, use_bias=False, name="v")

    # Slot update functions.
    self.gru = layers.GRUCell(self.slot_size)
    self.mlp = tf.keras.Sequential([
        layers.Dense(self.mlp_hidden_size, activation="relu"),
        layers.Dense(self.slot_size)
    ], name="mlp")

  def call(self, inputs):
    # `inputs` has shape [batch_size, num_inputs, inputs_size].
    inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
    k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].  # create key vectors (based on inputs)
    v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].  # create value vectors (based on inputs)

    # Initialize the slots. Shape: [batch_size, num_slots, slot_size].
    slots = self.slots_mu + tf.exp(self.slots_log_sigma) * tf.random.normal(
        [tf.shape(inputs)[0], self.num_slots, self.slot_size])  # size: [batch_size, num_slots, slot_size]

    # Multiple rounds of attention.
    for _ in range(self.num_iterations):
      slots_prev = slots
      slots = self.norm_slots(slots)

      # Attention.
      q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].  # create query vectors (based on slots)
      q *= self.slot_size ** -0.5  # Normalization.
      attn_logits = tf.keras.backend.batch_dot(k, q, axes=-1) # Batchwise dot product.
      attn = tf.nn.softmax(attn_logits, axis=-1)
      # `attn` has shape: [batch_size, num_inputs, num_slots]. 
      # attn represents how much attention each slot should pay to the features 

      # Weigted mean.
      attn += self.epsilon
      attn /= tf.reduce_sum(attn, axis=-2, keepdims=True) # summation; sum across the batch_size 
      updates = tf.keras.backend.batch_dot(attn, v, axes=-2)
      # `updates` has shape: [batch_size, num_slots, slot_size].

      # Slot update.
      slots, _ = self.gru(updates, [slots_prev])   # output after gru has shape: [batch_size, num_slots, slot_size]
      slots += self.mlp(self.norm_mlp(slots))      # # output after mlp has shape: [batch_size, num_slots, slot_size]

    return slots


def spatial_broadcast(slots, resolution):
  """Broadcast slot features to a 2D grid and collapse slot dimension."""
  # `slots` has shape: [batch_size, num_slots, slot_size].
  slots = tf.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
  grid = tf.tile(slots, [1, resolution[0], resolution[1], 1])   # this operation creates a new tensor by replicating input multiples times
  # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
  return grid


def spatial_flatten(x):
  return tf.reshape(x, [-1, x.shape[1] * x.shape[2], x.shape[-1]])


def unstack_and_split(x, batch_size, num_channels=3):
  """Unstack batch dimension and split into channels and alpha mask."""
  unstacked = tf.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
  channels, masks = tf.split(unstacked, [num_channels, 1], axis=-1)
  return channels, masks
    

def build_grid(resolution):
  ranges = [np.linspace(0., 1., num=res) for res in resolution]
  grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
  grid = np.stack(grid, axis=-1)
  grid = np.reshape(grid, [resolution[0], resolution[1], -1])
  grid = np.expand_dims(grid, axis=0)
  grid = grid.astype(np.float32)
  return np.concatenate([grid, 1.0 - grid], axis=-1)


class SoftPositionEmbed(layers.Layer):
  """Adds soft positional embedding with learnable projection."""

  def __init__(self, hidden_size, resolution):
    """Builds the soft position embedding layer.
    Args:
      hidden_size: Size of input feature dimension.
      resolution: Tuple of integers specifying width and height of grid.
    """
    super().__init__()
    self.dense = layers.Dense(hidden_size, use_bias=True)
    self.grid = build_grid(resolution)

  def call(self, inputs):
    return inputs + self.dense(self.grid)

In [3]:
resolution = (256,256)
num_slots = 7
num_iterations = 3

encoder_cnn = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
], name="encoder_cnn")

decoder_initial_size = (8, 8)
decoder_cnn = tf.keras.Sequential([
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),  
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(4, 3, strides=(1, 1), padding="SAME", activation=None)
], name="decoder_cnn")

encoder_pos = SoftPositionEmbed(64, resolution)
decoder_pos = SoftPositionEmbed(64, decoder_initial_size)

layer_norm = tf.keras.layers.LayerNormalization()
mlp = tf.keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(64)
], name="encoded_feedforward")

slot_attention = SlotAttention(num_iterations=num_iterations, num_slots=num_slots, slot_size=64, mlp_hidden_size=128)

Metal device set to: Apple M1 Pro


2022-10-18 14:35:42.441400: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-10-18 14:35:42.441786: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
# `image` has shape: [batch_size, width, height, num_channels].

# Convolutional encoder with position embedding.
inputs = tf.keras.Input(shape=(256,256,3,))
x = encoder_cnn(inputs)  # CNN Backbone.
x = encoder_pos(x)  # Add positional embeddings to x
x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
x = mlp(layer_norm(x))  # Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size(64)].

# Slot Attention module.
slots = slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].

# Spatial broadcast decoder.
x = spatial_broadcast(slots, decoder_initial_size)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
x = decoder_pos(x)
x = decoder_cnn(x)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].

# Undo combination of slot and batch dimension; split alpha masks.
recons, masks = unstack_and_split(x, batch_size=64)
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].

# Normalize alpha masks over slots.
masks = tf.nn.softmax(masks, axis=1)
recon_combined = tf.reduce_sum(recons * masks, axis=1)  # Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].

outputs = recon_combined, recons, masks, slots

slot_attention_ae = tf.keras.Model(inputs = inputs, outputs = outputs, name="Slot_Attention_AutoEnconder")
slot_attention_ae.summary()

Model: "Slot_Attention_AutoEnconder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 encoder_cnn (Sequential)       (None, 256, 256, 64  312256      ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 soft_position_embed (SoftPosit  (None, 256, 256, 64  320        ['encoder_cnn[0][0]']            
 ionEmbed)                      )                                       

In [13]:
import io
import tifffile
import quilt3 as q3
import matplotlib.pyplot as plt
import tensorflow as tf
import collections
import numpy as np
from tqdm import tqdm
from aicsimageio import AICSImage #=> this package was really difficult to install, maybe using an automated yaml would be good
from PIL import Image
import os
from urllib.parse import urlparse, unquote

pipeline = q3.Package.browse(
    "aics/pipeline_integrated_single_cell",
    registry="s3://allencell"
)

def convert_to_padded_tensor(img):
    image_tensor = tf.convert_to_tensor(img.data[0][0])
    padded_tensor = tf.image.resize_with_crop_or_pad(image_tensor, 256, 256)
    return padded_tensor

batch = None

imgs = []

for ind, file_name in enumerate(pipeline["cell_images_2d"]):
    if (ind < 64):
        entry = pipeline["cell_images_2d"][file_name].fetch(f"./AllenCell/{file_name}")
        uri_file_path = entry.get()
        file_path = unquote(urlparse(uri_file_path).path) # => this is stupid because you literally define the path in the line above         img = AICSImage(file_path)
        img = AICSImage(file_path)
        #print(tf.convert_to_tensor(img.data[0][0]).shape)
        
        #imgs.append(img.data[0][0])
        tensor = convert_to_padded_tensor(img)
        imgs.append(tensor[0])
        
        #if batch == None:
            #batch = tensor
        #else:
            #batch = tf.concat([batch, tensor], axis=0)
        
#print(batch.shape)

dataset = tf.data.Dataset.from_tensor_slices(imgs)

Loading manifest: 100%|███████████████████| 179067/179067 [00:02<00:00, 60.1k/s]
100%|██████████████████████████████████████| 9.34k/9.34k [00:02<00:00, 3.94kB/s]
100%|██████████████████████████████████████| 11.9k/11.9k [00:03<00:00, 3.83kB/s]
100%|██████████████████████████████████████| 11.5k/11.5k [00:02<00:00, 4.96kB/s]
100%|██████████████████████████████████████| 6.92k/6.92k [00:02<00:00, 2.96kB/s]
100%|██████████████████████████████████████| 11.6k/11.6k [00:02<00:00, 4.81kB/s]
100%|██████████████████████████████████████| 13.1k/13.1k [00:02<00:00, 5.15kB/s]
100%|██████████████████████████████████████| 10.2k/10.2k [00:02<00:00, 3.98kB/s]
100%|██████████████████████████████████████| 10.8k/10.8k [00:02<00:00, 4.70kB/s]
100%|██████████████████████████████████████| 4.26k/4.26k [00:02<00:00, 1.66kB/s]
100%|██████████████████████████████████████| 7.30k/7.30k [00:02<00:00, 3.17kB/s]
100%|██████████████████████████████████████| 8.37k/8.37k [00:02<00:00, 3.51kB/s]
100%|███████████████████████

In [12]:
len(imgs)

64

In [15]:
dataset = dataset.batch(64)
list(dataset.as_numpy_iterator())

[array([[[[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
 
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
 
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
 
         ...,
 
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
 
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
 
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]]],
 
 
        [[[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          ...,
          [0, 0, 0],
          [0, 0, 0

In [16]:
for b in dataset:
    slot_attention_ae(b)

In [4]:
"""Training loop for object discovery with Slot Attention."""

# We use `tf.function` compilation to speed up execution. For debugging,
# consider commenting out the `@tf.function` decorator.


def l2_loss(prediction, target):
  return tf.reduce_mean(tf.math.squared_difference(prediction, target))


@tf.function
def train_step(batch, model, optimizer):
  """Perform a single training step."""

  # Get the prediction of the models and compute the loss.
  with tf.GradientTape() as tape:
    preds = model(batch["image"], training=True)
    recon_combined, recons, masks, slots = preds
    loss_value = l2_loss(recon_combined, batch["image"])
    del recons, masks, slots  # Unused.

  # Get and apply gradients.
  gradients = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))   

  return loss_value

In [5]:
def visualize_loss(losses): 
    """
    Uses Matplotlib to visualize the losses of our model.
    :param losses: list of loss data stored from train. Can use the model's loss_list 
    field 

    NOTE: DO NOT EDIT

    :return: doesn't return anything, a plot should pop-up 
    """
    x = [i for i in range(len(losses))]
    plt.plot(x, losses)
    plt.title('Loss per epoch')
    plt.xlabel('Training Epoch')
    plt.ylabel('Loss')
    plt.show() 

In [None]:
# Hyperparameters of the model.
batch_size = 64
num_slots = 7
num_iterations = 3
base_learning_rate = 0.0004
num_train_steps = 5000
warmup_steps = 5
decay_rate = 0.5
decay_steps = 100000
tf.random.set_seed(0)
resolution = (128, 128)

# Build dataset iterators, optimizers and model.
data_iterator = build_clevr_iterator(
    batch_size, split="train", resolution=resolution, shuffle=True,
    max_n_objects=6, get_properties=False, apply_crop=True)

optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)

model = build_model(resolution, batch_size, num_slots,
                    num_iterations, model_type="object_discovery")
  
# Prepare checkpoint manager.
global_step = tf.Variable(
    0, trainable=False, name="global_step", dtype=tf.int64)

losses = []

for _ in tqdm(range(num_train_steps), desc='Training Epochs'):
    batch = next(data_iterator)

    # Learning rate warm-up.
    if global_step < warmup_steps:
      learning_rate = base_learning_rate * tf.cast(
          global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
    else:
      learning_rate = base_learning_rate
    
    learning_rate = learning_rate * (decay_rate ** (
        tf.cast(global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
    optimizer.lr = learning_rate.numpy()

    loss_value = train_step(batch, model, optimizer)
    losses.append(loss_value)

    # Update the global step. We update it before logging the loss and saving
    # the model so that the last checkpoint is saved at the last iteration.
    global_step.assign_add(1)
    
visualize_loss(losses)

2022-09-16 15:47:29.353231: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-09-16 15:47:29.353342: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1 Pro


2022-09-16 15:47:29.580398: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Training Epochs:   0%|                                 | 0/5000 [00:00<?, ?it/s]2022-09-16 15:47:31.000186: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
Training Epochs:   0%|                     | 24/5000 [03:07<11:04:06,  8.01s/it]

In [None]:
def renormalize(x):
  """Renormalize from [-1, 1] to [0, 1]."""
  return x / 2. + 0.5

def get_prediction(model, batch, idx=0):
  recon_combined, recons, masks, slots = model(batch["image"])
  image = renormalize(batch["image"])[idx]
  recon_combined = renormalize(recon_combined)[idx]
  recons = renormalize(recons)[idx]
  masks = masks[idx]
  return image, recon_combined, recons, masks, slots

In [None]:
batch_size = 64
resolution = (128,128)
data_iterator = build_clevr_iterator(
    batch_size, split="validation", resolution=resolution, shuffle=True,
    max_n_objects=6, get_properties=False, apply_crop=True)

batch = next(data_iterator)

In [None]:
# Visualize.
plt.imshow(renormalize(batch["image"])[0])

In [None]:
image, recon_combined, recons, masks, slots = get_prediction(model, batch)

In [None]:
# Visualize.
num_slots = len(masks)
fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2))
ax[0].imshow(image)
ax[0].set_title('Image')
ax[1].imshow(recon_combined)
ax[1].set_title('Recon.')
for i in range(num_slots):
  ax[i + 2].imshow(recons[i] * masks[i] + (1 - masks[i]))
  ax[i + 2].set_title('Slot %s' % str(i + 1))
for i in range(len(ax)):
  ax[i].grid(False)
  ax[i].axis('off')

In [None]:
plt.imshow(recon_combined)

In [None]:
plt.imshow(masks[1])