In [101]:
import tensorflow as tf
import tensorflow3d as t3d
import numpy as np
from dataclasses import dataclass, asdict, field
from tqdm import trange
import datetime

In [102]:
def build_dataset(path, batch=True, batch_size=32, cache=True, ordered=False, shuffle=False, test=False):
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(f"{path}/*.tfrecords"))
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
        
    dataset = dataset.with_options(ignore_order)
    
    feature_description = {
        'ground_truth': tf.io.FixedLenFeature([], tf.string),
        'num': tf.io.FixedLenFeature([], tf.int64),
        'point_cloud1': tf.io.FixedLenFeature([], tf.string),
        'point_cloud2': tf.io.FixedLenFeature([], tf.string),
        'color1': tf.io.FixedLenFeature([], tf.string),
        'color2': tf.io.FixedLenFeature([], tf.string),
        'mask1': tf.io.FixedLenFeature([], tf.string),
    }

    def _parse_image_function(example_proto):
        content = tf.io.parse_single_example(example_proto, feature_description)
        pc1 = tf.io.parse_tensor(content['point_cloud1'], tf.float32)
        pc2 = tf.io.parse_tensor(content['point_cloud2'], tf.float32)
        flow = tf.io.parse_tensor(content['ground_truth'], tf.float32)
        m1 = tf.io.parse_tensor(content['mask1'], tf.bool)
        return (pc1, pc2, flow, m1)
    
    dataset = dataset.map(_parse_image_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if shuffle:
        dataset = dataset.shuffle(1024)
        
    if batch:
        dataset = dataset.batch(batch_size, drop_remainder=False)
        
    if cache:
        dataset = dataset.cache()
        
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

In [103]:
!pwd

/custom-op/tensorflow3d


In [104]:
flownet3d = t3d.models.FlowNet3D(name='flownet3d').build(input_shape1=(None, 3), input_shape2=(None, 3))
flownet3d.summary()

Model: "functional_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, None, 3)]    0                                            
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, None, 3)]    0                                            
__________________________________________________________________________________________________
set_conv_30 (SetConv)           ((None, 1024, 3), (N 3808        input_11[0][0]                   
__________________________________________________________________________________________________
set_conv_32 (SetConv)           ((None, 1024, 3), (N 3808        input_12[0][0]                   
______________________________________________________________________________________

In [105]:
@dataclass
class Settings:
    name: str
    lr: float=0.001
    lr_decay: float=0.8
    patience: int=12
    epochs: int=1000
    batch_size: int=32
    def dict(self):
        return {k: str(v) for k, v in asdict(self).items()}

In [106]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [107]:
def compute_loss(pred, gt, mask):
    """
    Compute Variational Energy Loss, and EPE
    """
    
    return tf.keras.metrics.mean_squared_error(pred, gt)
    
    

def train(data, model, settings):
    """
    Training Loop
    """

    # Fetch config from settings
    batch_size = settings.batch_size
    pbar = trange(settings.epochs)
    optimiser = tf.keras.optimizers.Adam(settings.lr)
    patience = settings.patience
    lr_decay = settings.lr_decay
    train_dataset = build_dataset('./../assets/train', batch_size=batch_size)
    val_dataset = build_dataset('./../assets/test', batch_size=batch_size)
    
    losses = []
    counter = 0
    history = {'train': {'mse': []}, 'val': {'mse': []}, 'lr':[]}
    for i in pbar:
        iter_losses = []
        history['lr'].append(optimiser.lr.numpy())
        for batch in data[0]:
            loss = train_step(batch, model, optimiser, train=True)
            iter_losses.append(loss)
    
        epoch_loss = tf.reduce_mean(iter_losses).numpy()
        tf.summary.scalar('train_loss', epoch_loss, step=i)
        history['train']['mse'].append(epoch_loss)
        if(min(history['train']['mse']) < epoch_loss):
            counter += 1
            
        else:
            counter = 0
            model.save('ckpt.h5')
            
        if(patience is not None and counter == patience):
            counter = 0
            optimiser.lr.assign(optimiser.lr.numpy()*lr_decay)
            print(f"Learning rate decayed to: {optimiser.lr.numpy()}, Minimum (mse) was: ({min(history['train']['mse'])})")
       
        val_losses = []
        if(data[1] is not None):
            for batch in data[1]:
                loss = train_step(batch, model, optimiser, train=False)
                iter_losses.append(loss)
                
        val_epoch_loss = tf.reduce_mean(iter_losses).numpy()
        tf.summary.scalar('val_loss', val_epoch_loss, step=i)
        history['val']['mse'].append(val_epoch_loss)
        pbar.set_description(f"train_mse:  {epoch_loss}, val_mse: {val_epoch_loss}, patience_count: {counter}")
        
    return model, history


In [108]:
@tf.function
def train_step(batch, model, optimiser, train=True):
    """ 
    Single Forward and Backpropagation Step
    """
    
    with tf.GradientTape(persistent=False) as tape:
        pred = model((batch[0], batch[1]))
        loss = compute_loss(pred, batch[2], batch[3])
    
    if(train):
        grad = tape.gradient(loss, model.trainable_variables)
        optimiser.apply_gradients(zip(grad, model.trainable_variables), experimental_aggregate_gradients=False)
        
    return tf.reduce_mean(loss)

In [109]:
settings = Settings(
    name='model.h5',
    lr=0.001, #0.0008
    lr_decay=0.8,
    patience=12,
    epochs=151,
)

In [110]:
model, history = train((train_dataset, val_dataset), flownet3d, settings)
model.save(settings.name)

  0%|          | 0/151 [01:12<?, ?it/s]


KeyboardInterrupt: 

In [122]:
%reload_ext tensorboard
%tensorboard --logdir logs/gradient_tape --host 127.0.0.1

Reusing TensorBoard on port 6007 (pid 11927), started 0:00:55 ago. (Use '!kill 11927' to kill it.)

In [123]:
!kill 11927

/bin/sh: 1: kill: No such process

