In [1]:
#from absl import logging
import tensorflow as tf
#import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import collections
import numpy as np
import tensorflow.keras.layers as layers
from tqdm import tqdm
from itertools import cycle

In [2]:
"""Slot Attention model for object discovery and set prediction."""

class SlotAttention(layers.Layer):
  """Slot Attention module."""

  def __init__(self, num_iterations, num_slots, slot_size, mlp_hidden_size,
               epsilon=1e-8):
    """Builds the Slot Attention module.
    Args:
      num_iterations: Number of iterations.
      num_slots: Number of slots.
      slot_size: Dimensionality of slot feature vectors.
      mlp_hidden_size: Hidden layer size of MLP.
      epsilon: Offset for attention coefficients before normalization.
    """
    super().__init__()
    self.num_iterations = num_iterations
    self.num_slots = num_slots
    self.slot_size = slot_size
    self.mlp_hidden_size = mlp_hidden_size
    self.epsilon = epsilon

    self.norm_inputs = layers.LayerNormalization()
    self.norm_slots = layers.LayerNormalization()
    self.norm_mlp = layers.LayerNormalization()

    # Parameters for Gaussian init (shared by all slots).   # Intialize slots randomly at first 
    self.slots_mu = self.add_weight(
        initializer="glorot_uniform",
        shape=[1, 1, self.slot_size],   # slot_size: Dimensionality of slot feature vectors.
        dtype=tf.float32,
        name="slots_mu")
    self.slots_log_sigma = self.add_weight(
        initializer="glorot_uniform",
        shape=[1, 1, self.slot_size],
        dtype=tf.float32,
        name="slots_log_sigma")

    # Linear maps for the attention module.
    self.project_q = layers.Dense(self.slot_size, use_bias=False, name="q")
    self.project_k = layers.Dense(self.slot_size, use_bias=False, name="k")
    self.project_v = layers.Dense(self.slot_size, use_bias=False, name="v")

    # Slot update functions.
    self.gru = layers.GRUCell(self.slot_size)
    self.mlp = tf.keras.Sequential([
        layers.Dense(self.mlp_hidden_size, activation="relu"),
        layers.Dense(self.slot_size)
    ], name="mlp")

  def call(self, inputs):
    # `inputs` has shape [batch_size, num_inputs, inputs_size].
    inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
    k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].  # create key vectors (based on inputs)
    v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].  # create value vectors (based on inputs)

    # Initialize the slots. Shape: [batch_size, num_slots, slot_size].
    slots = self.slots_mu + tf.exp(self.slots_log_sigma) * tf.random.normal(
        [tf.shape(inputs)[0], self.num_slots, self.slot_size])  # size: [batch_size, num_slots, slot_size]

    # Multiple rounds of attention.
    for _ in range(self.num_iterations):
      slots_prev = slots
      slots = self.norm_slots(slots)

      # Attention.
      q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].  # create query vectors (based on slots)
      q *= self.slot_size ** -0.5  # Normalization.
      attn_logits = tf.keras.backend.batch_dot(k, q, axes=-1) # Batchwise dot product.
      attn = tf.nn.softmax(attn_logits, axis=-1)
      # `attn` has shape: [batch_size, num_inputs, num_slots]. 
      # attn represents how much attention each slot should pay to the features 

      # Weigted mean.
      attn += self.epsilon
      attn /= tf.reduce_sum(attn, axis=-2, keepdims=True) # summation; sum across the batch_size 
      updates = tf.keras.backend.batch_dot(attn, v, axes=-2)
      # `updates` has shape: [batch_size, num_slots, slot_size].

      # Slot update.
      slots, _ = self.gru(updates, [slots_prev])   # output after gru has shape: [batch_size, num_slots, slot_size]
      slots += self.mlp(self.norm_mlp(slots))      # # output after mlp has shape: [batch_size, num_slots, slot_size]

    return slots


def spatial_broadcast(slots, resolution):
  """Broadcast slot features to a 2D grid and collapse slot dimension."""
  # `slots` has shape: [batch_size, num_slots, slot_size].
  slots = tf.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
  grid = tf.tile(slots, [1, resolution[0], resolution[1], 1])   # this operation creates a new tensor by replicating input multiples times
  # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
  return grid


def spatial_flatten(x):
  return tf.reshape(x, [-1, x.shape[1] * x.shape[2], x.shape[-1]])


def unstack_and_split(x, batch_size, num_channels=3):
  """Unstack batch dimension and split into channels and alpha mask."""
  unstacked = tf.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
  channels, masks = tf.split(unstacked, [num_channels, 1], axis=-1)
  return channels, masks
    

def build_grid(resolution):
  ranges = [np.linspace(0., 1., num=res) for res in resolution]
  grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
  grid = np.stack(grid, axis=-1)
  grid = np.reshape(grid, [resolution[0], resolution[1], -1])
  grid = np.expand_dims(grid, axis=0)
  grid = grid.astype(np.float32)
  return np.concatenate([grid, 1.0 - grid], axis=-1)


class SoftPositionEmbed(layers.Layer):
  """Adds soft positional embedding with learnable projection."""

  def __init__(self, hidden_size, resolution):
    """Builds the soft position embedding layer.
    Args:
      hidden_size: Size of input feature dimension.
      resolution: Tuple of integers specifying width and height of grid.
    """
    super().__init__()
    self.dense = layers.Dense(hidden_size, use_bias=True)
    self.grid = build_grid(resolution)

  def call(self, inputs):
    return inputs + self.dense(self.grid)

In [3]:
resolution = (256,256)
num_slots = 7
num_iterations = 3

encoder_cnn = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu"),
    tf.keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="relu")
], name="encoder_cnn")

decoder_initial_size = (8, 8)
decoder_cnn = tf.keras.Sequential([
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),  
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), padding="SAME", activation="relu"),
    tf.keras.layers.Conv2DTranspose(4, 3, strides=(1, 1), padding="SAME", activation=None)
], name="decoder_cnn")

encoder_pos = SoftPositionEmbed(64, resolution)
decoder_pos = SoftPositionEmbed(64, decoder_initial_size)

layer_norm = tf.keras.layers.LayerNormalization()
mlp = tf.keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(64)
], name="encoded_feedforward")

slot_attention = SlotAttention(num_iterations=num_iterations, num_slots=num_slots, slot_size=64, mlp_hidden_size=128)

2022-11-11 17:34:41.555745: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-11 17:34:41.563314: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [4]:
# Convolutional encoder with position embedding.
inputs = tf.keras.Input(shape=(256,256,3,))
x = encoder_cnn(inputs)  # CNN Backbone.
x = encoder_pos(x)  # Add positional embeddings to x
x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
x = mlp(layer_norm(x))  # Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size(64)].

# Slot Attention module.
slots = slot_attention(x)
# `slots` has shape: [batch_size, num_slots, slot_size].

# Spatial broadcast decoder.
x = spatial_broadcast(slots, decoder_initial_size)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
x = decoder_pos(x)
x = decoder_cnn(x)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].

# Undo combination of slot and batch dimension; split alpha masks.
recons, masks = unstack_and_split(x, batch_size=64)
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].

# Normalize alpha masks over slots.
masks = tf.nn.softmax(masks, axis=1)
recon_combined = tf.reduce_sum(recons * masks, axis=1)  # Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].

outputs = recon_combined, recons, masks, slots

slot_attention_ae = tf.keras.Model(inputs = inputs, outputs = outputs, name="Slot_Attention_AutoEnconder")
slot_attention_ae.summary()

Model: "Slot_Attention_AutoEnconder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 encoder_cnn (Sequential)       (None, 256, 256, 64  312256      ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 soft_position_embed (SoftPosit  (None, 256, 256, 64  320        ['encoder_cnn[0][0]']            
 ionEmbed)                      )                                       

In [5]:
import io
import tifffile
import quilt3 as q3
import matplotlib.pyplot as plt
import tensorflow as tf
import collections
import numpy as np
from tqdm import tqdm
from aicsimageio import AICSImage #=> this package was really difficult to install, maybe using an automated yaml would be good
from PIL import Image
import os
from urllib.parse import urlparse, unquote
from os import listdir
from os.path import join
import math

def fetch_data():
    package = q3.Package.browse(
        "aics/pipeline_integrated_single_cell",
        registry="s3://allencell"
    )
    
    package["cell_images_2d"].fetch("./AllenCell/cell_images_2d/")
    
    return 


def allen_cell_dataset(download_data = False, batch_size = 64):
    # 70 train: 35k 
    # 20 validation: 10k
    # 10 test: 5k
    if download_data:
        fetch_data()
        
    def convert_to_padded_tensor(img):
        image_tensor = tf.convert_to_tensor(img.data[0][0], dtype=tf.float32)
        padded_tensor = tf.image.resize_with_crop_or_pad(image_tensor, 256, 256)
        return padded_tensor
    
    
    imgs = []
    file_names = [join("./AllenCell/cell_images_2d/", f) for f in listdir("./AllenCell/cell_images_2d/") if join("./AllenCell/cell_images_2d/", f).endswith(".png")]
    
    if len(file_names) == 0:
        raise Exception("No .png Files in the AllenCell directory.")
        
    for ind, file_name in enumerate(tqdm(file_names, desc="loading data")):
        #if ind < 10000:
        img = AICSImage(file_name)
        tensor = convert_to_padded_tensor(img)
        imgs.append(tensor[0])
        
    print(f"num images: {len(imgs)}")
        
    ## split into train validation test
    dataset = tf.data.Dataset.from_tensor_slices(imgs)
    
    num_sample = len(dataset)
    print(f"length of dataset: {len(dataset)}")
    dataset = dataset.shuffle(buffer_size = len(dataset))
    
    num_train = math.ceil(num_sample * 0.7)
    print(f"num train: {num_train}")
    num_val = math.floor(num_sample * 0.2)
    print(f"num val: {num_val}")
    num_test = math.floor(num_sample * 0.1)
    print(f"num test: {num_test}")
    
    train = dataset.take(num_train)
    print(f"train unbatched: {len(train)}")
    train = train.batch(batch_size)
    print(f" batched: {len(train)}")
    test_val = dataset.skip(num_train)
    
    test = test_val.take(num_test)
    print(f"test unbatched: {len(test)}")
    test = test.batch(batch_size)
    print(f" batched: {len(test)}")
    
    val = test_val.skip(num_test)
    print(f"val unbatched: {len(val)}")
    val = val.batch(batch_size)
    print(f" batched: {len(val)}")

    return train, test, val

In [6]:
#fetch_data()
#do not run this again, all data has been fetched

In [7]:
"""Training loop for object discovery with Slot Attention."""

# We use `tf.function` compilation to speed up execution. For debugging,
# consider commenting out the `@tf.function` decorator.


def l2_loss(prediction, target):
  return tf.reduce_mean(tf.math.squared_difference(prediction, target))


@tf.function
def train_step(batch, model, optimizer):
  """Perform a single training step."""

  # Get the prediction of the models and compute the loss.
  with tf.GradientTape() as tape:
    preds = model(batch, training=True)
    recon_combined, recons, masks, slots = preds
    loss_value = l2_loss(recon_combined, batch)
    del recons, masks, slots  # Unused.

  # Get and apply gradients.
  gradients = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))   

  return loss_value

In [8]:
#train_step(next(train_iterator), slot_attention_ae, tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08))

In [9]:
def visualize_loss(losses): 
    """
    Uses Matplotlib to visualize the losses of our model.
    :param losses: list of loss data stored from train. Can use the model's loss_list 
    field 

    NOTE: DO NOT EDIT

    :return: doesn't return anything, a plot should pop-up 
    """
    x = [i for i in range(len(losses))]
    plt.plot(x, losses)
    plt.title('Loss per epoch')
    plt.xlabel('Training Epoch')
    plt.ylabel('Loss')
    plt.show() 

In [None]:
#build dataset iterators
train_iterator, test_iterator, val_iterator = allen_cell_dataset(False, batch_size)
train_iterator = cycle(list(train_iterator))
test_iterator = cycle(list(test_iterator))
val_iterator = cycle(list(val_iterator))

In [None]:
# Hyperparameters of the model.
batch_size = 64
num_slots = 7
num_iterations = 3
base_learning_rate = 0.0004
num_train_steps = 500
warmup_steps = 5
decay_rate = 0.5
decay_steps = 100000
#tf.random.set_seed(0)
resolution = (256, 256)

#checkpoint_path = "./training/cp.ckpt"
#checkpoint_dir = os.path.dirname(checkpoint_path)


# Build optimizers and model
optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)

#model = build_model(resolution, batch_size, num_slots, num_iterations, model_type="object_discovery")

# Prepare checkpoint manager.
global_step = tf.Variable(0, trainable=False, name="global_step", dtype=tf.int64)

losses = []
val_losses = []

for _ in tqdm(range(num_train_steps), desc='Training Epochs'):
    batch = next(train_iterator)
    val_batch = next(val_iterator)

    # Learning rate warm-up.
    if global_step < warmup_steps:
        learning_rate = base_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
    else:
        learning_rate = base_learning_rate

    learning_rate = learning_rate * (decay_rate ** (tf.cast(global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
    optimizer.lr = learning_rate.numpy()

    loss_value = train_step(batch, slot_attention_ae, optimizer)
    losses.append(loss_value)
    
    val_losses.append(slot_attention_ae(batch, training=False))

    # Update the global step. We update it before logging the loss and saving
    # the model so that the last checkpoint is saved at the last iteration.
    global_step.assign_add(1)

visualize_loss(losses)

loading data: 100%|███████████████████████████| 49325/49325 [06:48<00:00, 120.84it/s]


num images: 49325


In [None]:
def renormalize(x):
  """Renormalize from [-1, 1] to [0, 1]."""
  return x / 2. + 0.5

def get_prediction(model, batch, idx=0):
  recon_combined, recons, masks, slots = model(batch)
  image = renormalize(batch)[idx]
  recon_combined = renormalize(recon_combined)[idx]
  recons = renormalize(recons)[idx]
  masks = masks[idx]
  return image, recon_combined, recons, masks, slots

In [None]:
batch_size = 64
resolution = (256,256)

batch = next(test_iterator)

In [None]:
# Visualize.
plt.imshow(renormalize(batch)[0])

In [None]:
image, recon_combined, recons, masks, slots = get_prediction(model, batch)

In [None]:
# Visualize.
num_slots = len(masks)
fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2))
ax[0].imshow(image)
ax[0].set_title('Image')
ax[1].imshow(recon_combined)
ax[1].set_title('Recon.')
for i in range(num_slots):
  ax[i + 2].imshow(recons[i] * masks[i] + (1 - masks[i]))
  ax[i + 2].set_title('Slot %s' % str(i + 1))
for i in range(len(ax)):
  ax[i].grid(False)
  ax[i].axis('off')