In [1]:
import tensorflow as tf
import numpy as np
import cv2

import matplotlib.pyplot as plt

from cell_division.nets.transfer_learning import CNN
from auxiliary.data.dataset_cell import CellDataset
from auxiliary import values as v
from auxiliary.utils.colors import bcolors as c

# from focal_loss import SparseCategoricalFocalLoss
from tensorflow.keras.losses import CategoricalCrossentropy
from cell_division.nets.custom_layers import w_cel_loss, focal_loss

from sklearn.utils.class_weight import compute_class_weight

# GPU config
from auxiliary.utils.timer import LoadingBar
from auxiliary.gpu.gpu_tf import (
    increase_gpu_memory, 
    set_gpu_allocator, 
    clear_session
)

increase_gpu_memory()
set_gpu_allocator()

In [2]:
img_dir = v.data_path + 'CellDivision/images/'
label_train_dir = v.data_path + 'CellDivision/train.csv'
label_test_dir = v.data_path + 'CellDivision/test.csv'
label_val_dir = v.data_path + 'CellDivision/val.csv'

INPUT_SHAPE = (50, 50, 3)
BATCH_SIZE = 64

Dataset (Generators)

Generatos do not load directly the images into memory, but they load the images on the fly. This is useful when the dataset is too large to fit into memory.

In [3]:
train_generator = CellDataset(
    img_dir, 
    label_train_dir, 
    batch_size=BATCH_SIZE, 
    resize=INPUT_SHAPE[:2]
)

val_generator = CellDataset(
    img_dir, 
    label_val_dir, 
    batch_size=BATCH_SIZE, 
    resize=INPUT_SHAPE[:2]
)

Transfer Learning 

In [4]:
base_models = {
    'DenseNet121': tf.keras.applications.DenseNet121,
    'EfficientNetV2L': tf.keras.applications.EfficientNetV2L,
    'EfficientNetV2M': tf.keras.applications.EfficientNetV2M,
    'VGG16': tf.keras.applications.VGG16,
    'ResNet50': tf.keras.applications.ResNet50,
    'InceptionV3': tf.keras.applications.InceptionV3,
    'MobileNetV2': tf.keras.applications.MobileNetV2,
    'NASNetMobile': tf.keras.applications.NASNetMobile,
}


In [5]:
param_grid = {
    'base_model': list(base_models.keys()),
    'lr': [1e-3, 1e-2],
    'fine_tune': [True, False],
    'loss': [focal_loss(), w_cel_loss()],
    'top': ['CAM', 'Standard'],
    # 'class_weight': [None, 'balanced']
}

In [7]:
from tensorboard.errors import InvalidArgumentError

bar = LoadingBar(
    len(param_grid['base_model']) * len(param_grid['lr']) * len(param_grid['fine_tune']) * len(param_grid['loss']) * len(param_grid['top'])
)

results = {}

for base_model in param_grid['base_model']:
    for lr in param_grid['lr']:
        for fine_tune in param_grid['fine_tune']:
            for loss in param_grid['loss']:
                for top in param_grid['top']:
                    print(f'{c.OKGREEN}Model: {base_model} - LR: {lr} - Fine Tune: {fine_tune} - Loss: {loss} - Top: {top}{c.ENDC}')
                    
                    try:
                        model = CNN(
                            base=base_models[base_model],
                            n_classes=3,
                            input_shape=INPUT_SHAPE,
                            fine_tune=fine_tune
                        )
                        model.build_top(activation='softmax', b_type=top)
                        model.compile(lr=lr, loss=loss)
                        model.fit(
                            train_generator,
                            val_generator,
                            epochs=100,
                            batch_size=BATCH_SIZE,
                            save=False,
                            verbose=1
                        )
    
                        results[(base_model, lr, fine_tune, loss, top)] = model.model.history
                    except Exception as e:
                        print(f'{c.FAIL}Error: {e}{c.ENDC}')
                        results[(base_model, lr, fine_tune, loss, top)] = None

                    clear_session()
                    bar.update()

bar.end()

[92mModel: DenseNet121 - LR: 0.001 - Fine Tune: True - Loss: <function focal_loss.<locals>.focal_loss_fixed at 0x762a9229c280> - Top: CAM[0m
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 5: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 8: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 11: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 14: early stopping
[                                                  ] 0.38%[92mModel: DenseNet121 - LR: 0.001 - Fine Tune: True - Loss: <function focal_loss.<locals>.focal_loss_fixed at 0x762a9229c280> - Top: Standard[0m
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 6: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 9: ReduceLROnPlateau reducing 

InvalidArgumentError: Graph execution error:

Detected at node 'assert_greater_equal/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
      app.start()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
      self.io_loop.start()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
      self.asyncio_loop.run_forever()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
      self._run_once()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
      handle._run()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
      await result
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
      result = self._run_cell(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
      result = runner(coro)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
      coro.send(None)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_46368/11548542.py", line 22, in <module>
      model.fit(
    File "/home/imarcoss/ht_morphogenesis/cell_division/nets/transfer_learning.py", line 93, in fit
      )
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1420, in fit
      val_logs = self.evaluate(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1716, in evaluate
      tmp_logs = self.test_function(iterator)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1525, in test_function
      return step_function(self, iterator)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1514, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1507, in run_step
      outputs = model.test_step(data)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1474, in test_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 957, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 459, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/metrics.py", line 178, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/metrics.py", line 2347, in update_state
      return metrics_utils.update_confusion_matrix_variables(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 602, in update_confusion_matrix_variables
      tf.compat.v1.assert_greater_equal(
Node: 'assert_greater_equal/Assert/AssertGuard/Assert'
Detected at node 'assert_greater_equal/Assert/AssertGuard/Assert' defined at (most recent call last):
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
      app.start()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
      self.io_loop.start()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
      self.asyncio_loop.run_forever()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
      self._run_once()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
      handle._run()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
      await result
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
      result = self._run_cell(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
      result = runner(coro)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
      coro.send(None)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_46368/11548542.py", line 22, in <module>
      model.fit(
    File "/home/imarcoss/ht_morphogenesis/cell_division/nets/transfer_learning.py", line 93, in fit
      )
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1420, in fit
      val_logs = self.evaluate(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1716, in evaluate
      tmp_logs = self.test_function(iterator)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1525, in test_function
      return step_function(self, iterator)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1514, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1507, in run_step
      outputs = model.test_step(data)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 1474, in test_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/training.py", line 957, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 459, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/metrics.py", line 178, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/metrics.py", line 2347, in update_state
      return metrics_utils.update_confusion_matrix_variables(
    File "/home/imarcoss/mambaforge/envs/py310ml/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 602, in update_confusion_matrix_variables
      tf.compat.v1.assert_greater_equal(
Node: 'assert_greater_equal/Assert/AssertGuard/Assert'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model/prediction_layer/Softmax:0) = ] [[nan nan nan]...] [y (Cast_3/x:0) = ] [0]
	 [[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
	 [[assert_less_equal/Assert/AssertGuard/pivot_f/_13/_33]]
  (1) INVALID_ARGUMENT:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model/prediction_layer/Softmax:0) = ] [[nan nan nan]...] [y (Cast_3/x:0) = ] [0]
	 [[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_2645114]

In [None]:
import json
print(json.dumps(results, indent=4))
# Save
with open('../cell_division/results/grid_search_cnn.json', 'w') as f:
    json.dump(results, f)