In [1]:
import os
import numpy as np

import tensorflow as tf
import tensorflow_addons as tfa
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
print(tf.__version__)
AUTOTUNE = tf.data.experimental.AUTOTUNE

import deepcell
# Changed from before due to new placement of Track, concat_tracks
from deepcell_tracking.utils import load_trks
from deepcell.data.tracking import Track, concat_tracks
##############
from sklearn.model_selection import train_test_split
from deepcell.utils.data_utils import reshape_movie
from deepcell.utils.transform_utils import erode_edges
from deepcell.data import split_dataset
from deepcell_toolbox.processing import normalize, histogram_normalization

import spektral

2.8.0


In [2]:
import json
def load_img_dict(file):
    f = open(file)
    d = json.load(f)
    d = {int(k1): {int(k2): {int(k3): v for k3, v in d[k1][k2].items()} for k2, d[k1][k2] in d[k1].items()} for k1, d[k1] in d.items()}
    return d
def load_img_idx_dict(file):
    f = open(file)
    d = json.load(f)
    d = {int(k): v for k, v in d.items()}
    return d

In [3]:
train_good_imgs = load_img_dict('../dataset_pruning/train_appearances_dict.json')
train_blank_imgs = load_img_dict('../dataset_pruning/train_blank_dict.json')
train_border_imgs = load_img_dict('../dataset_pruning/train_border_dict.json')
val_good_imgs = load_img_dict('../dataset_pruning/val_appearances_dict.json')
val_blank_imgs = load_img_dict('../dataset_pruning/val_blank_dict.json')
val_border_imgs = load_img_dict('../dataset_pruning/val_border_dict.json')

In [None]:
# Since the images are written to disk one at a time, we can modify the write function to write the image to disk
# only when the image is "good." Can pass in the dictionary of good images as an argument. The funciton uses 
# track.appearances, the NP array, to write the file. Since the dictionary corresponds with the indices of the 
# array, we can use this to determine which images to add.

# Based on this line in the Track class, 'appearances = np.zeros(batch_shape + appearance_shape, dtype='float32')',
# it seems the track.appearances object does not have any batch/cells/frames pattern, but we should check.

In [18]:
import argparse
import os

import numpy as np
import tensorflow as tf

from deepcell_tracking.trk_io import load_trks
from deepcell_tracking.utils import get_max_cells
from deepcell.data.tracking import Track
# Might want to import this just to get the functions it uses
from deepcell.utils.tfrecord_utils import write_tracking_dataset_to_tfr

def get_arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data-path',
                        default='/training/tracking-nuclear',
                        help='Path to the training data.')

    parser.add_argument('--appearance-dim', type=int, default=64)
    parser.add_argument('--distance-threshold', type=int, default=64)
    parser.add_argument('--crop-mode', type=str, default='fixed')

    return parser

In [19]:
import os
"create_tracking_example" in dir(os)

False

In [45]:
import csv
import os

from tensorflow.data import Dataset
from tensorflow.io import serialize_tensor
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.backend import is_sparse

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        # BytesList won't unpack a string from an EagerTensor.
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def create_tracking_example(track_dict):
    """Create a tf.train.Example for a single item
    of a tracking dataset
    Args:
        track_dict (dict): A dictionary with a single
            item of a tracking dataset
    """

    data = {}

    # Define the dictionary of our single example
    for key in track_dict:
        # WE DON'T NEED TO BOTHER WITH CHECKING FOR SPARSE KEYS (REMOVED THAT)
        data[key] = _bytes_feature(serialize_tensor(track_dict[key]))

        shapes = track_dict[key].shape

        # I DON'T REALLY GET WHAT THIS IS DOING
        for i in range(len(shapes)):
            shape_string = '{}_shape_{}'.format(key, i)
            data[shape_string] = _int64_feature(shapes[i])

    # Create an Example, wrapping the single features
    example = tf.train.Example(features=tf.train.Features(feature=data))

    return example

In [51]:
def write_tracking_dataset_to_tfr(track,
                                  filename,
                                  good_imgs,
                                  target_max_cells=168,
                                  verbose=True):

    filename_tfr = filename + '.tfrecord'
    filename_csv = filename + '.csv'

    count = 0

    writer = tf.io.TFRecordWriter(filename_tfr)

    # Get features to add
    # WE PROBABLY ONLY CARE ABOUT APP
    app = track.appearances

    # Pad cells - we need to do this to use validation data
    # during training

    # TARGET MAX CELLS WILL BE THE MAXIMUM 'CELLS' DIMENSION BETWEEN TRAIN AND VAL,
    # AND THE OTHER ONE WILL BE PADDED TO ACHIEVE THAT. I PROBABLY DON'T NEED TO
    # PAD ANYTHING, SINCE MY DATA IS IN THE FORMAT (num_imgs, dim, dim, 1), NOT
    # (batches, frames, cells, dim, dim, 1).
    
    # Iterate over all batches
    # THIS SHOULD PROBABLY BE THE CELLS THEMSELVES IN MY CASE
    for b in range(app.shape[0]):
        for f in range(app.shape[1]):
            for c in range(app.shape[2]):
                if good_imgs[b][c][f] != -1:
                    img = app[b, f, c]
                    track_dict = {'app': img}

                    example = create_tracking_example(track_dict)

                    if example is not None:
                        writer.write(example.SerializeToString())
                        count += 1

    writer.close()

    if verbose:
        print(f'Wrote {count} elements to TFRecord')

    # WE'LL WORRY ABOUT THE CSV WRITER FOR METADATA AT THE END, IF NECESSARY
    # OKAY WE MIGHT NEED IT TO PARSE THE DATA
    # Save dataset metadata
    # THIS SHOULD BE OKAY--WE JUST HAVE ONE KEY RATHER THAN A BUNCH, NOW
    dataset_keys = track_dict.keys()
    dataset_dims = [len(track_dict[k].shape) for k in dataset_keys]

    with open(filename_csv, 'w') as f:
        writer = csv.writer(f)
        rows = [[k, dims] for k, dims in zip(dataset_keys, dataset_dims)]
        writer.writerows(rows)
        
        # SHOULDN'T NEED ROWS FOR adj_shape AND temp_adj_shape, SINCE WE'RE NOT
        # WRITING THESE

    return count

In [14]:
# Probably no reaon to use argument parser--only there because they wanted to run it from the command line
args = get_arg_parser().parse_args([])

train_trks = load_trks(os.path.join(args.data_path, 'train.trks'))
val_trks = load_trks(os.path.join('/training/tracking-nuclear', 'val.trks'))

# max_cells = max([get_max_cells(train_trks['y']), get_max_cells(val_trks['y'])])

In [48]:
with tf.device('/cpu:0'):
    val_tracks = Track(tracked_data=val_trks,
                   appearance_dim=args.appearance_dim,
                   distance_threshold=args.distance_threshold,
                   crop_mode=args.crop_mode)

    write_tracking_dataset_to_tfr(val_tracks, filename='val', good_imgs=val_good_imgs)

100%|███████████████████████████████████████████| 27/27 [02:45<00:00,  6.12s/it]
100%|███████████████████████████████████████████| 27/27 [04:08<00:00,  9.19s/it]


Wrote 89436 elements to TFRecord


In [52]:
with tf.device('/cpu:0'):
    write_tracking_dataset_to_tfr(val_tracks, filename='val', good_imgs=val_good_imgs)

Wrote 89436 elements to TFRecord


In [49]:
with tf.device('/cpu:0'):
    train_tracks = Track(tracked_data=train_trks,
                   appearance_dim=args.appearance_dim,
                   distance_threshold=args.distance_threshold,
                   crop_mode=args.crop_mode)

    write_tracking_dataset_to_tfr(train_tracks, filename='train', good_imgs=train_good_imgs)

100%|███████████████████████████████████████████| 91/91 [12:54<00:00,  8.52s/it]
100%|███████████████████████████████████████████| 91/91 [20:18<00:00, 13.38s/it]


Wrote 383800 elements to TFRecord


In [53]:
with tf.device('/cpu:0'):
    write_tracking_dataset_to_tfr(train_tracks, filename='train', good_imgs=train_good_imgs)

Wrote 383800 elements to TFRecord


In [50]:
type(val_tracks)

deepcell.data.tracking.Track

In [54]:
# Maybe when we parse, we can return a tuple (not dictionary) with two images (the same) of shape (1, 64, 64, 1)
# We probably don't need the CSV file of dimensions, can just put dimensions like the examples do in the Examples
def parse_tracking_example(example, dataset_ndims,
                           dtype=tf.float32):
    """Parse a tracking example
    Args:
        example (tf.train.Example): The tracking example to be parsed
        dataset_ndims (dict): Dictionary of dataset metadata
        dtype (tf dtype): Dtype of training data
    """
    # WE MIGHT NEED THE METADATA NOW, TO PARSE IT

    # WHAT IS THE DIFFERENCE BETWEEN X AND y?
    X_names = ['app']

    full_name_dict = {'app': 'appearances'}

    # Recreate the example structure
    data = {}
    shape_strings_dict = {}
    shapes_dict = {}

    for key in dataset_ndims:
        if 'shape' in key:
            new_key = '_'.join(key.split('_')[0:-1])
            shapes_dict[new_key] = dataset_ndims[key]

    for key in shapes_dict:
        dataset_ndims.pop('{}_shape'.format(key))

    for key in dataset_ndims:
        # NO SUCH THING AS sparse_names ANYMORE
        data[key] = tf.io.FixedLenFeature([], tf.string)

        shape_strings = ['{}_shape_{}'.format(key, i)
                         for i in range(dataset_ndims[key])]
        shape_strings_dict[key] = shape_strings

        for ss in shape_strings:
            data[ss] = tf.io.FixedLenFeature([], tf.int64)

    # Get data
    content = tf.io.parse_single_example(example, data)

    X_dict = {}

    for key in dataset_ndims:

        # Get the feature and reshape
        # AGAIN NO NEED TO CHECK FOR SPARSENESS
        shape = [content[ss] for ss in shape_strings_dict[key]]
        value = content[key]
        value = tf.io.parse_tensor(value, out_type=dtype)
        value = tf.reshape(value, shape=shape)

        X_dict[full_name_dict[key]] = value

    return X_dict

In [55]:
def get_dataset(filename, parse_fn=None, **kwargs):
    """Get a TFRecord Dataset
    Args:
        filename (str): The base filename of the dataset to be
            loaded. The filetype (e.g., .tfrecord) should not
            be included
        parse_fn (python function): The function for parsing
            tf.train.Example examples in the the dataset
    """

    # Define tfrecord and csv file
    filename_tfrecord = filename + '.tfrecord'
    filename_csv = filename + '.csv'

    # Load the csv
    dataset_ndims = {}
    shapes = {}

    with open(filename_csv) as f:
        reader = csv.reader(f)
        for row in reader:
            if 'shape' in row[0]:
                dataset_ndims[row[0]] = [int(i) for i in row[1:]]
            else:
                dataset_ndims[row[0]] = int(row[1])

    # Create the dataset
    dataset = tf.data.TFRecordDataset(filename_tfrecord)

    # Pass each feature through the mapping function
    def parse_func(example):
        return parse_fn(example,
                        dataset_ndims=dataset_ndims,
                        **kwargs)

    dataset = dataset.map(parse_func)

    return dataset

In [56]:
def get_dataset(filename, parse_fn=None, **kwargs):
    """Get a TFRecord Dataset
    Args:
        filename (str): The base filename of the dataset to be
            loaded. The filetype (e.g., .tfrecord) should not
            be included
        parse_fn (python function): The function for parsing
            tf.train.Example examples in the the dataset
    """

    # Define tfrecord and csv file
    filename_tfrecord = filename + '.tfrecord'
    filename_csv = filename + '.csv'

    # Load the csv
    dataset_ndims = {}
    shapes = {}

    with open(filename_csv) as f:
        reader = csv.reader(f)
        for row in reader:
            if 'shape' in row[0]:
                dataset_ndims[row[0]] = [int(i) for i in row[1:]]
            else:
                dataset_ndims[row[0]] = int(row[1])

    # Create the dataset
    # MIGHT NEED TO DO from_tensor_slices AND REPEAT THE DATA
    dataset = tf.data.TFRecordDataset(filename_tfrecord)

    # Pass each feature through the mapping function
    def parse_func(example):
        return parse_fn(example,
                        dataset_ndims=dataset_ndims,
                        **kwargs)

    dataset = dataset.map(parse_func)

    return dataset

In [57]:
train_dataset = get_dataset('train', parse_fn=parse_tracking_example)

In [58]:
val_dataset = get_dataset('val', parse_fn=parse_tracking_example)

In [68]:
for sample in train_dataset.take(3):
    print(sample['input_1'].shape)

(64, 64, 1)
(64, 64, 1)
(64, 64, 1)


In [59]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras
tfkl = tf.keras.layers

In [60]:
class VAE:
    
    def __init__(self, dim_z, kl_weight, learning_rate):
        # change dim from (28, 28, 1)
        self.dim_x = (64, 64, 1)
        self.dim_z = dim_z
        self.kl_weight = kl_weight
        self.learning_rate = learning_rate

    # Sequential API encoder
    def encoder_z(self):
        # define prior distribution for the code, which is an isotropic Gaussian
        prior = tfd.Independent(tfd.Normal(loc=tf.zeros(self.dim_z), scale=1.), 
                                reinterpreted_batch_ndims=1)
        # build layers argument for tfk.Sequential()
        input_shape = self.dim_x
        layers = [tfkl.InputLayer(input_shape=input_shape)]
        layers.append(tfkl.Conv2D(filters=32, kernel_size=3, strides=(2,2), 
                                  padding='valid', activation='relu'))
        layers.append(tfkl.Conv2D(filters=64, kernel_size=3, strides=(2,2), 
                                  padding='valid', activation='relu'))
        layers.append(tfkl.Flatten())
        # the following two lines set the output to be a probabilistic distribution
        layers.append(tfkl.Dense(tfpl.IndependentNormal.params_size(self.dim_z), 
                                 activation=None, name='z_params'))
        layers.append(tfpl.IndependentNormal(self.dim_z, 
            convert_to_tensor_fn=tfd.Distribution.sample, 
            activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=self.kl_weight), 
            name='z_layer'))
        return tfk.Sequential(layers, name='encoder')
    
    # Sequential API decoder
    def decoder_x(self):
        layers = [tfkl.InputLayer(input_shape=self.dim_z)]
        # probably 7 before since 28/2/2 = 7, so changing to 32/2/2 = 8
        layers.append(tfkl.Dense(16*16*32, activation=None))
        layers.append(tfkl.Reshape((16,16,32)))
        layers.append(tfkl.Conv2DTranspose(filters=64, kernel_size=3, strides=2, 
                                           padding='same', activation='relu'))
        layers.append(tfkl.Conv2DTranspose(filters=32, kernel_size=3, strides=2, 
                                           padding='same', activation='relu'))
        layers.append(tfkl.Conv2DTranspose(filters=1, kernel_size=3, strides=1, 
                                           padding='same'))
        layers.append(tfkl.Flatten())
        # note that here we don't need 
        # `tfkl.Dense(tfpl.IndependentBernoulli.params_size(self.dim_x))` because 
        # we've restored the desired input shape with the last Conv2DTranspose layer
        layers.append(tfkl.Dense(tfpl.IndependentNormal.params_size(self.dim_x), 
                                 activation=None, name='x_params'))
        layers.append(tfpl.IndependentNormal(self.dim_x,
            name='x_layer'))
        return tfk.Sequential(layers, name='decoder')
    
    def build_vae_keras_model(self):
        x_input = tfk.Input(shape=self.dim_x)
        encoder = self.encoder_z()
        decoder = self.decoder_x()
        z = encoder(x_input)

        # compile VAE model
        model = tfk.Model(inputs=x_input, outputs=decoder(z))
        model.compile(loss=negative_log_likelihood, 
                      optimizer=tfk.optimizers.Adam(self.learning_rate))
        return model

# the negative of log-likelihood for probabilistic output
negative_log_likelihood = lambda x, rv_x: -rv_x.log_prob(x)

In [61]:
vae = VAE(1024, 1, 1e-3)
AE = vae.build_vae_keras_model()

In [62]:
from tensorflow_addons.optimizers import RectifiedAdam as RAdam
from deepcell import train_utils

# steps_per_epoch = 3838
# validation_steps = 895
n_epochs = 1
model_path = '../models/first_64_64'

train_callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(
      
        monitor='val_loss', factor=0.5, verbose=1,
        patience=3, min_lr=1e-7),
    tf.keras.callbacks.ModelCheckpoint(
        model_path, monitor='val_loss',
        save_best_only=True, verbose=1,
        save_weights_only=True)
]

loss_history = AE.fit(
    train_dataset,
#     steps_per_epoch=steps_per_epoch,
    validation_data=val_dataset,
#     validation_steps=validation_steps,
    epochs=n_epochs,
    verbose=1,
    callbacks=train_callbacks
)

ValueError: in user code:

    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 859, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/input_spec.py", line 183, in assert_input_compatibility
        raise ValueError(f'Missing data for input "{name}". '

    ValueError: Missing data for input "input_7". You passed a data dictionary with keys ['input_1']. Expected the following keys: ['input_7']


In [None]:
# Firstly, what is the form of the .tfrecord file?
# I think we need an image, recon pair (two keys) for each feature, like they did image, label in the example

In [69]:
type(val_tracks.appearances)

numpy.ndarray

In [71]:
val_tracks.appearances.shape

(27, 71, 277, 64, 64, 1)