# Abstractor Dev

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model
import sklearn.metrics
from sklearn.model_selection import train_test_split

import sys; sys.path.append('../..')
import utils

2023-07-12 15:45:07.157056: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-12 15:45:07.217831: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
#%env "WANDB_NOTEBOOK_NAME" "abstractor_SCAN.ipynb"
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
def create_callbacks(monitor='loss'):
    callbacks = [wandb.keras.WandbMetricsLogger(log_freq='epoch'),]
    return callbacks

## Dataset

In [4]:
import urllib.request 
from collections import Counter
data_url = 'https://raw.githubusercontent.com/brendenlake/SCAN/master/tasks.txt'
data = urllib.request.urlopen(data_url)

input_voc = Counter()
input_voc.update(['<pad>'])
output_voc = Counter()
output_voc.update(['<PAD>'])
input_len = output_len = input_max = output_max = 0

records=0
for line in data: 
    command = line.decode("utf-8") 
    _, input, output = command.replace('OUT:', 'IN:').split('IN:')
    input_voc.update(input.split())
    output_voc.update(output.split())
    input_max = max([input_max, len(input.split())])
    output_max = max([output_max, len(output.split())])
    input_len = input_len + len(input.split())
    output_len = output_len + len(output.split())
    records = records + 1

print('Number of records: %d' % records)
print('Maximum lengths: %d, %d' % (input_max, output_max))
print('Average lengths: %.2f, %.2f' % (input_len/records, output_len/records))
print(input_voc)
print(output_voc)

words = [w for w in input_voc]
commands = [c for c in output_voc]
word2index = {words[i]:i for i in range(len(words))}
command2index = {commands[i]:i for i in range(len(commands))}

print('Input vocabulary size: %d' % len(words))
print(words)
print('Output vocabulary size: %d' % len(commands))
print(commands)

Number of records: 20910
Maximum lengths: 9, 48
Average lengths: 7.25, 14.32
Counter({'right': 18405, 'left': 18405, 'thrice': 13906, 'twice': 13906, 'opposite': 12270, 'around': 12270, 'after': 10404, 'and': 10404, 'walk': 8589, 'run': 8589, 'look': 8589, 'jump': 8589, 'turn': 7362, '<pad>': 1})
Counter({'I_TURN_RIGHT': 85890, 'I_TURN_LEFT': 85890, 'I_RUN': 31902, 'I_WALK': 31902, 'I_LOOK': 31902, 'I_JUMP': 31902, '<PAD>': 1})
Input vocabulary size: 14
['<pad>', 'walk', 'opposite', 'right', 'thrice', 'after', 'run', 'turn', 'around', 'twice', 'look', 'left', 'and', 'jump']
Output vocabulary size: 7
['<PAD>', 'I_TURN_RIGHT', 'I_RUN', 'I_WALK', 'I_TURN_LEFT', 'I_LOOK', 'I_JUMP']


In [5]:
n = records
length_max = 30

BEGIN_INPUT = len(input_voc) # token for 'beginning of input'
END_INPUT = len(input_voc)+1 # token for 'end of input'
BEGIN_COMMAND = len(output_voc) # token for 'beginning of command'
END_COMMAND = len(output_voc)+1 # token for 'end of command'

sources = np.array(n*(length_max+2)*[0]).reshape(n, length_max+2)
targets = np.array(n*(length_max+2)*[0]).reshape(n, length_max+2)
data = urllib.request.urlopen(data_url)

rec = 0
longones = 0
for line in data: 
    command = line.decode("utf-8") 
    _, input, output = command.replace('OUT:', 'IN:').split('IN:')
    input_words = input.split()
    output_commands = output.split()
    source = [word2index[w] for w in input_words]
    source.insert(0, BEGIN_INPUT)
    source.append(END_INPUT)
    target = [command2index[c] for c in output_commands]
    if len(target) > length_max:
        longones = longones + 1
        continue
    target.insert(0, BEGIN_COMMAND)
    target.append(END_COMMAND)
    for j in range(len(source)):
        sources[rec, j] = source[j]
    for j in range(len(target)):
        targets[rec, j] = target[j]
    rec = rec + 1

print('processed %d commands' % rec)
print('processed %d (%.2f%%) over length %d' % (longones, 100*longones/(rec+longones), length_max))

sources = sources[:rec,:]
targets = targets[:rec,:]


processed 19758 commands
processed 1152 (5.51%) over length 30


In [6]:
def display_sentence(input_seq):
    sentence = ''
    for i in range(len(input_seq)):
        if input_seq[i] == BEGIN_INPUT:
            sentence = sentence + 'BEGIN_INPUT'
        elif input_seq[i] == END_INPUT:
            sentence = sentence + ' END_INPUT'
            break
        else: 
            sentence = sentence + ' ' + words[input_seq[i]]
    return sentence

def display_command(output_seq):
    command = ''
    for i in range(len(output_seq)):
        if output_seq[i] == BEGIN_COMMAND:
            command = command + 'BEGIN_COMMAND'
        elif output_seq[i] == END_COMMAND:
            command = command + ' END_COMMAND'
            break
        else: 
            command = command + ' ' + commands[output_seq[i]]
    return command



In [7]:
print(display_sentence(sources[42,]))
print(display_command(targets[42,]))
print(sources[42,])
print(targets[42,])

BEGIN_INPUT run opposite left thrice and jump thrice END_INPUT
BEGIN_COMMAND I_TURN_LEFT I_TURN_LEFT I_RUN I_TURN_LEFT I_TURN_LEFT I_RUN I_TURN_LEFT I_TURN_LEFT I_RUN I_JUMP I_JUMP I_JUMP END_COMMAND
[14  6  2 11  4 12 13  4 15  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0]
[7 4 4 2 4 4 2 4 4 2 6 6 6 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [8]:
sentences_train, sentences_test, commands_train, commands_test = train_test_split(sources, targets, test_size=0.25)

source_train = sentences_train[:,:-1]
target_train = commands_train[:,:-1]
labels_train = commands_train[:,1:]

source_test = sentences_test[:,:-1]
target_test = commands_test[:,:-1]
labels_test = commands_test[:,1:]

In [9]:
print(display_sentence(source_test[100,]))
print(display_command(target_test[100,]))
print(display_command(labels_test[100,]))

BEGIN_INPUT look opposite left thrice after turn opposite right END_INPUT
BEGIN_COMMAND I_TURN_RIGHT I_TURN_RIGHT I_TURN_LEFT I_TURN_LEFT I_LOOK I_TURN_LEFT I_TURN_LEFT I_LOOK I_TURN_LEFT I_TURN_LEFT I_LOOK END_COMMAND
 I_TURN_RIGHT I_TURN_RIGHT I_TURN_LEFT I_TURN_LEFT I_LOOK I_TURN_LEFT I_TURN_LEFT I_LOOK I_TURN_LEFT I_TURN_LEFT I_LOOK END_COMMAND


## Transformer

In [10]:
import tensorflow_models as tfm


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [11]:
dropout_rate = 0.1
num_layers = 2
num_attention_heads = 4
intermediate_ff_size = 64
embedding_dim = 32

source_embedder = tf.keras.layers.Embedding(len(word2index), embedding_dim, name='source_embedder')
target_embedder = tf.keras.layers.Embedding(len(command2index), embedding_dim, name='target_embedder')

encoder = tfm.nlp.models.TransformerEncoder(
    num_layers=num_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_ff_size,
    activation='relu',
    dropout_rate=dropout_rate,
    attention_dropout_rate=0.0,
    use_bias=False,
    norm_first=False,
    norm_epsilon=1e-06,
    intermediate_dropout=0.0,
    name='encoder'
)

decoder = tfm.nlp.models.TransformerDecoder(
    num_layers=num_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_ff_size,
    activation='relu',
    dropout_rate=dropout_rate,
    attention_dropout_rate=0.0,
    use_bias=False,
    norm_first=False,
    norm_epsilon=1e-06,
    intermediate_dropout=0.0,
    name='decoder'
)

final_layer = layers.Dense(len(command2index), name='final_layer')

input_seq = layers.Input(shape=source_train.shape[1:])
target_seq = layers.Input(shape=target_train.shape[1:])

x = source_embedder(input_seq)
y = target_embedder(target_seq)

E = encoder(x)
D = decoder(y, E)

pred = final_layer(D)

transformer = tf.keras.Model(inputs=[input_seq, target_seq], outputs=pred, name='transformer')


2023-07-12 15:43:57.350874: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38217 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:5e:00.0, compute capability: 8.0


In [12]:
transformer.summary()

Model: "transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 31)]         0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 31)]         0           []                               
                                                                                                  
 source_embedder (Embedding)    (None, 31, 32)       448         ['input_1[0][0]']                
                                                                                                  
 target_embedder (Embedding)    (None, 31, 32)       224         ['input_2[0][0]']                
                                                                                        

In [13]:
loss = tf.keras.losses.SparseCategoricalCrossentropy()
create_optimizer = lambda: tf.keras.optimizers.Adam()

transformer.compile(loss=loss, optimizer=create_optimizer())

In [14]:
transformer.fit((source_train, target_train), labels_train)

2023-07-12 15:42:52.183546: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:637] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-07-12 15:42:55.346584: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x2ac1c02687b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-07-12 15:42:55.346651: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB, Compute Capability 8.0
2023-07-12 15:42:55.413421: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-07-12 15:42:56.151647: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.1.1 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade

InternalError: Graph execution error:

Detected at node 'StatefulPartitionedCall_67' defined at (most recent call last):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/tmp.9nY6gh1B1h/ipykernel_21791/756575719.py", line 1, in <module>
      transformer.fit((source_train, target_train), labels_train)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1685, in fit
      tmp_logs = self.train_function(iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1284, in train_function
      return step_function(self, iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1268, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1249, in run_step
      outputs = model.train_step(data)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1054, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 543, in minimize
      self.apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1174, in apply_gradients
      return super().apply_gradients(grads_and_vars, name=name)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 650, in apply_gradients
      iteration = self._internal_apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1200, in _internal_apply_gradients
      return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1250, in _distributed_apply_gradients_fn
      distribution.extended.update(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1245, in apply_grad_to_update_var
      return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'StatefulPartitionedCall_67'
RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:618) dnn != nullptr 
	 [[{{node StatefulPartitionedCall_67}}]] [Op:__inference_train_function_9928]

In [10]:
model = tf.keras.Sequential([layers.Dense(1)])
X = np.random.random((59, 1))
y = 2*X + 1
y = np.squeeze(y)

model(X);

model.compile(loss='mse', optimizer='adam')
model.fit(X, y)

2023-07-12 15:45:26.799388: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38217 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:5e:00.0, compute capability: 8.0
2023-07-12 15:45:28.601630: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x2ba8a00096b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-07-12 15:45:28.601679: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB, Compute Capability 8.0
2023-07-12 15:45:28.612395: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-07-12 15:45:29.317095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.1.1 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major

InternalError: Graph execution error:

Detected at node 'StatefulPartitionedCall' defined at (most recent call last):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/tmp.9nY6gh1B1h/ipykernel_22775/735454358.py", line 9, in <module>
      model.fit(X, y)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1685, in fit
      tmp_logs = self.train_function(iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1284, in train_function
      return step_function(self, iterator)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1268, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1249, in run_step
      outputs = model.train_step(data)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/engine/training.py", line 1054, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 543, in minimize
      self.apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1174, in apply_gradients
      return super().apply_gradients(grads_and_vars, name=name)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 650, in apply_gradients
      iteration = self._internal_apply_gradients(grads_and_vars)
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1200, in _internal_apply_gradients
      return tf.__internal__.distribute.interim.maybe_merge_call(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1250, in _distributed_apply_gradients_fn
      distribution.extended.update(
    File "/gpfs/gibbs/project/lafferty/ma2393/conda_envs/relml/lib/python3.8/site-packages/keras/optimizers/optimizer.py", line 1245, in apply_grad_to_update_var
      return self._update_step_xla(grad, var, id(self._var_key(var)))
Node: 'StatefulPartitionedCall'
RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:618) dnn != nullptr 
	 [[{{node StatefulPartitionedCall}}]] [Op:__inference_train_function_517]