In [1]:
import os

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)
assert len(physical_devices) == 3
# to allow other tensorflow processes to use the gpu
# https://stackoverflow.com/a/60699372/7989988
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.config.experimental.set_memory_growth(physical_devices[1], True)
tf.config.experimental.set_memory_growth(physical_devices[2], True)

import numpy as np
from tensorflow import keras
from tensorflow.keras import Model, Input, layers
from IPython.display import display
import tensorflow_datasets as tfds
import time
import matplotlib.pyplot as plt
import enlighten
import tensorflow_probability as tfp
from dotmap import DotMap


import models
import training
import datasets
import vizualization
import schedules

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]


In [2]:
config = DotMap({
    'ds': 'mnist_4',
    'distributed': False,
    'minibatch_size': 32,
    'n_steps': 20000,
    'test_size': 300,
    'test_minibatch_size': 100,
    'test_interval': 100,
    'test_n_shuf': [1, 8, 16, 32],
    'test_n_seq': [1, 16, 32, 48],
    'test_autoregressive': False,
    'noise_fraction': 0,
    'display_images': False,
    'display_image_interval': 2000,
    'dont_display_until_loss': 0.45,
    'bg_color': [1.0, 0.4, 0.6],
    'lr_schedule': None,
    'lr_warmup_steps': 300,
    'max_lr': 0.0001,
    'min_lr': 0.0001,
    'grad_accum_steps': None, #['exponential', 1, 4],
    'max_accum_steps': 4,
    'use_wandb': True,
    'wandb_log_interval': 10,
    'loss_window_size': 500,
})

# need to change for multiworkerstrategy
if config.distributed:
    config.num_devices = len(physical_devices)
else:
    config.num_devices = 1
config.global_batch_size = config.minibatch_size * config.num_devices

In [3]:
ds_default = DotMap({
    'buffer_size': 60000,
    'name': 'mnist',
    'n_colors': 4,
    'image_size': (28, 28),
})
    
ds_configs = DotMap({
    'mnist_4': {
        'buffer_size': 60000,
        'name': 'mnist',
        'n_colors': 4,
        'image_size': (28, 28),
    },
    'mnist_4_7x7': {
        'buffer_size': 60000,
        'name': 'mnist',
        'n_colors': 4,
        'n_color_dims': 1,
        'rescale': (7, 7),
    },
    'mnist_binary_7x7': {
        'buffer_size': 60000,
        'name': 'mnist',
        'n_colors': 2,
        'rescale': (7, 7),
    },
})

config.dataset = ds_configs[config.ds]

dataset, metadata = tfds.load(config.dataset.name, with_info=True, as_supervised=True)

ds_train_original = dataset['train']
ds_test_original = dataset['test']

centroids = datasets.find_centroids(config, ds_train_original)
gamma_dist, gamma_name = datasets.gamma_distribution_7x7()
ds = datasets.Datasets(config, ds_train_original, ds_test_original, centroids, gamma_dist)
viz = vizualization.Viz(config, ds, centroids)
ds_train, ds_test = ds.make_datasets()

# model config
config.model = DotMap({
    'n_colors': config.dataset.n_colors,
    'n_enc_a_layers': 3,
    'n_enc_b_layers': 3,
    'ffl_dim': 128,
    'embd_dim': 512,
    'n_dec_layers': 3,
    'dec_dim': 600,
    'n_heads': 4,
    'dropout_rate': 0.1,
    'use_idxs_input': True,
    'architecture': 'anp',
    'position_embedding': 'pos_enc',
    'activation': 'swish',
})

if config.distributed == False:
    strategy = tf.distribute.get_strategy()
else:
    strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = models.transformer(config.model)
    # Adam params taken from the linked notebook
    optimizer = keras.optimizers.Adam(learning_rate=config.max_lr)

ds_train_dist = strategy.experimental_distribute_dataset(ds_train)

config.training_mode = 'combination'


centroids.shape (4,)


ValueError: in user code:

    /home/clarkemaxw/msc/cgt-mnist/datasets.py:97 quantize  *
        d = squared_euclidean_distance(sequence, self.centroids) # (height * width, centroids)
    /home/clarkemaxw/msc/cgt-mnist/datasets.py:77 squared_euclidean_distance  *
        ab = tf.linalg.matmul(a, b)
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
        return target(*args, **kwargs)
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py:3654 matmul
        return gen_math_ops.mat_mul(
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py:5712 mat_mul
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:748 _apply_op_helper
        op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:599 _create_op_internal
        return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3561 _create_op_internal
        ret = Operation(
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:2041 __init__
        self._c_op = _create_c_op(self._graph, node_def, inputs,
    /home/clarkemaxw/.cache/pypoetry/virtualenvs/msc-r6Gz9mJU-py3.8/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1883 _create_c_op
        raise ValueError(str(e))

    ValueError: Shape must be rank 2 but is rank 1 for '{{node MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false](args_0, transpose)' with input shapes: [784,1], [4].


In [None]:

model_name = 'models/cuda10-noise-0'
model = keras.models.load_model(model_name)

evalr = training.Evaluator(config, model, optimizer, viz, ds, ds_train_dist, ds_test)

evalr.process_batch()
