In [None]:
pip install tensorflow keras keras-unet-collection scikit-image scikit-learn opencv-python matplotlib pandas

In [3]:
# ===========================================================
# Comparativo: U-Net variants (keras_unet_collection)
#   - unet_2d, unet_plus_2d, r2_unet_2d, resunet_a_2d
#   - att_unet_2d (Attention U-Net)
#   - vnet_2d (2D adaptado p/ 1 canal Sigmoid)
# Dataset: TC crânio (grayscale), máscara binária (0/1)
# ===========================================================
import os, glob, cv2, random
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, roc_auc_score, average_precision_score,
    precision_score, recall_score, f1_score
)
import pandas as pd
from keras_unet_collection import models

# ----------------------------
# Configs gerais
# ----------------------------
SEED = 42
np.random.seed(SEED); tf.random.set_seed(SEED)

gpus = tf.config.list_physical_devices("GPU")
for g in gpus:
    try: tf.config.experimental.set_memory_growth(g, True)
    except: pass

IMG_SIZE     = (256, 256)
INPUT_SHAPE  = (256, 256, 1)
BATCH_SIZE   = 2
EPOCHS       = 50
LR           = 1e-4

PATH_IMG  = r"C:\Users\Daniel\Documents\Brain_Stroke_CT_Dataset\Ischemia\PNG"
PATH_MASK = r"C:\Users\Daniel\Documents\Brain_Stroke_CT_Dataset\Ischemia\MASKS"

os.makedirs("Models", exist_ok=True)

# ----------------------------
# Utils: parear, carregar, tf.data
# ----------------------------
def build_pairs(path_img, path_mask):
    exts = ("*.png","*.jpg","*.jpeg","*.bmp","*.tif","*.tiff")
    all_imgs = []
    for e in exts: all_imgs += glob.glob(os.path.join(path_img, e))
    all_imgs = sorted(all_imgs)
    img_files, mask_files = [], []
    for f in all_imgs:
        stem = os.path.splitext(os.path.basename(f))[0]
        found = None
        for e in exts:
            cand = glob.glob(os.path.join(path_mask, stem + e[1:]))
            if cand: found = cand[0]; break
        if not found:
            cand = glob.glob(os.path.join(path_mask, stem + ".*"))
            if cand: found = cand[0]
        if found:
            img_files.append(f); mask_files.append(found)
    return img_files, mask_files

def load_and_preprocess(img_paths, mask_paths, img_size):
    X, Y = [], []
    for ip, mp in zip(img_paths, mask_paths):
        im = cv2.imread(ip, cv2.IMREAD_GRAYSCALE)
        ms = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if im is None or ms is None: continue
        im = cv2.resize(im, img_size, interpolation=cv2.INTER_AREA)
        ms = cv2.resize(ms, img_size, interpolation=cv2.INTER_NEAREST)
        im = (im.astype(np.float32)/255.0)[..., None]
        m_bin = (ms > 0).astype(np.float32)[..., None]
        X.append(im); Y.append(m_bin)
    return np.array(X, np.float32), np.array(Y, np.float32)

def tf_augment(img, mask):
    if tf.random.uniform(()) > 0.5:
        img  = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        img  = tf.image.flip_up_down(img)
        mask = tf.image.flip_up_down(mask)
    k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
    img  = tf.image.rot90(img, k)
    mask = tf.image.rot90(mask, k)
    return img, mask

def make_ds(X, Y, batch_size, augment=False, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if shuffle: ds = ds.shuffle(len(X), seed=SEED, reshuffle_each_iteration=True)
    if augment: ds = ds.map(tf_augment, num_parallel_calls=tf.data.AUTOTUNE)
    return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# ----------------------------
# Métricas/Loss binárias
# ----------------------------
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(tf.cast(y_true, tf.float32))
    y_pred_f = K.flatten(tf.cast(y_pred, tf.float32))
    inter = K.sum(y_true_f * y_pred_f)
    return (2.*inter + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(tf.cast(y_true, tf.float32))
    y_pred_f = K.flatten(tf.cast(y_pred, tf.float32))
    inter = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - inter
    return (inter + smooth) / (union + smooth)

def dice_loss(y_true, y_pred): return 1.0 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred, bce_w=1.0, dice_w=1.0):
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dl  = dice_loss(y_true, y_pred)
    return bce_w*bce + dice_w*dl

# ----------------------------
# Construtores dos 6 modelos
# (Atenção: nomes de parâmetros podem variar por versão da lib.
#  Se aparecer erro de assinatura, me mande a mensagem que eu ajusto.)
# ----------------------------
def build_unet_2d():
    return models.unet_2d(
        input_size=INPUT_SHAPE,
        filter_num=[64, 128, 256, 512],
        n_labels=1,
        stack_num_down=2,
        stack_num_up=2,
        activation='ReLU',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,
        unpool='bilinear',
        name='unet2d'
    )

def build_unet_plus_2d():
    return models.unet_plus_2d(
        input_size=INPUT_SHAPE,
        filter_num=[64, 128, 256, 512],
        n_labels=1,
        stack_num_down=2,
        stack_num_up=2,
        activation='ReLU',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,
        unpool='bilinear',
        deep_supervision=False,
        name='unetpp'
    )

def build_r2_unet_2d():
    return models.r2_unet_2d(
        input_size=INPUT_SHAPE,
        filter_num=[64, 128, 256, 512],
        n_labels=1,
        stack_num_down=2,
        stack_num_up=2,
        activation='ReLU',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,
        unpool='bilinear',
        recur_num=2,   # se der erro, troque para re_num=2
        name='r2unet'
    )

def build_resunet_a_2d():
    return models.resunet_a_2d(
        input_size=INPUT_SHAPE,
        filter_num=[64, 128, 256],      # depth=3
        dilation_num=[1, 2, 4],         # << ADICIONADO
        n_labels=1,
        activation='ReLU',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,
        unpool='bilinear',
        name='resuneta'
    )

def build_att_unet_2d():
    return models.att_unet_2d(
        input_size=INPUT_SHAPE,
        filter_num=[64, 128, 256, 512],
        n_labels=1,
        stack_num_down=2,
        stack_num_up=2,
        activation='ReLU',
        atten_activation='ReLU',
        attention='add',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,
        unpool='bilinear',
        name='attunet'
    )

def build_vnet_2d():
    # V-Net original usa PReLU; aqui mantemos PReLU, mas com saída 1 canal Sigmoid para comparação justa.
    return models.vnet_2d(
        input_size=INPUT_SHAPE,
        filter_num=[16, 32, 64, 128, 256],
        n_labels=1,                # << 1 canal (binário)
        res_num_ini=1, res_num_max=3,
        activation='PReLU',
        output_activation='Sigmoid',
        batch_norm=True,
        pool=False,                # down por conv stride
        unpool=False,              # up por conv transpose
        name='vnet'
    )
    
def build_swin_unet_2d():
    return models.swin_unet_2d(
        input_size=INPUT_SHAPE,
        filter_num_begin=64,       # número inicial de filtros
        n_labels=1,                # binário (1 classe)
        depth=4,                   # número de estágios do encoder/decoder
        stack_num_down=2,          # convs por bloco no encoder
        stack_num_up=2,            # convs por bloco no decoder
        patch_size=(2, 2),         # tamanho do patch de entrada
        num_heads=[4, 8, 16, 32],  # cabeças de atenção por estágio
        window_size=(7, 7),        # tamanho da janela de self-attention
        num_mlp=512,               # tamanho do MLP na atenção
        output_activation='Sigmoid',  # saída binária
        shift_window=True,         # janela deslocada (como no paper Swin)
        name='swin_unet'
    )

BUILDERS = {
    "resunet_a_2d" : build_resunet_a_2d,
    "vnet_2d"      : build_vnet_2d,
    "att_unet_2d"  : build_att_unet_2d,
    "swin_unet_2d" : build_swin_unet_2d,
    "unet_2d"      : build_unet_2d,
    "unet_plus_2d" : build_unet_plus_2d,
    "r2_unet_2d"   : build_r2_unet_2d,
}

# ----------------------------
# Treino + avaliação + plots
# ----------------------------
def train_and_eval(model_name, X_train, y_train, X_val, y_val, X_test, y_test):
    print(f"\n==================== {model_name} ====================")
    model = BUILDERS[model_name]()
    model.compile(optimizer=Adam(learning_rate=LR),
                  loss=bce_dice_loss,
                  metrics=['accuracy', dice_coef, iou_coef])
    model.summary()

    ckpt_path = os.path.join("Models", f"{model_name}_best.keras")
    cb = [
        ModelCheckpoint(ckpt_path, monitor="val_loss", save_best_only=True, mode="min"),
        EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=4, verbose=1, min_lr=1e-6),
    ]

    train_ds = make_ds(X_train, y_train, BATCH_SIZE, augment=True,  shuffle=True)
    val_ds   = make_ds(X_val,   y_val,   BATCH_SIZE, augment=False, shuffle=False)

    history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=cb, verbose=1)

    # Curvas
    plt.figure(figsize=(12,4))
    plt.suptitle(model_name)
    plt.subplot(1,3,1); plt.plot(history.history['loss'], label='train'); plt.plot(history.history['val_loss'], label='val'); plt.title('Loss'); plt.legend()
    plt.subplot(1,3,2); plt.plot(history.history['dice_coef'], label='train'); plt.plot(history.history['val_dice_coef'], label='val'); plt.title('Dice'); plt.legend()
    plt.subplot(1,3,3); plt.plot(history.history['iou_coef'], label='train'); plt.plot(history.history['val_iou_coef'], label='val'); plt.title('IoU'); plt.legend()
    plt.tight_layout(); plt.show()

    # Teste
    print("📊 Avaliação no TESTE...")
    y_prob = model.predict(X_test, batch_size=BATCH_SIZE, verbose=0)     # (N,H,W,1)
    y_pred = (y_prob > 0.5).astype(np.uint8)
    y_true = (y_test > 0.5).astype(np.uint8)

    yt  = y_true.flatten().astype(np.uint8)
    yp  = y_pred.flatten().astype(np.uint8)
    ypf = y_prob.flatten().astype(np.float32)

    tn, fp, fn, tp = confusion_matrix(yt, yp).ravel()
    accuracy    = (tp + tn) / (tp + tn + fp + fn + 1e-7)
    precision   = precision_score(yt, yp, zero_division=0)
    sensitivity = recall_score(yt, yp, zero_division=0)
    specificity = tn / (tn + fp + 1e-7)
    f1          = f1_score(yt, yp, zero_division=0)
    try:
        auc = roc_auc_score(yt, ypf)
        ap  = average_precision_score(yt, ypf)
    except Exception:
        auc, ap = float('nan'), float('nan')

    inter = np.sum((yt == 1) & (yp == 1))
    dice = (2*inter + 1e-7) / (np.sum(yt) + np.sum(yp) + 1e-7)
    union = (np.sum(yt) + np.sum(yp) - inter)
    iou  = (inter + 1e-7) / (union + 1e-7)

    print(f"\n🔎 {model_name} - Métricas no TESTE:")
    print(f"Dice        = {dice:.4f}")
    print(f"IoU         = {iou:.4f}")
    print(f"Accuracy    = {accuracy:.4f}")
    print(f"Precision   = {precision:.4f}")
    print(f"Sensitivity = {sensitivity:.4f}")
    print(f"Specificity = {specificity:.4f}")
    print(f"F1-Score    = {f1:.4f}")
    print(f"AUC         = {auc:.4f}")
    print(f"AP          = {ap:.4f}")

    # Overlays
    def show_examples(num=4):
        num = min(num, len(X_test))
        for i in range(num):
            img = X_test[i, ..., 0]
            gt  = y_true[i, ..., 0]
            pr  = y_pred[i, ..., 0]
            plt.figure(figsize=(12,4))
            plt.subplot(1,3,1); plt.imshow(img, cmap='gray'); plt.title('Imagem'); plt.axis('off')
            plt.subplot(1,3,2); plt.imshow(gt,  cmap='gray');  plt.title('GT'); plt.axis('off')
            plt.subplot(1,3,3); plt.imshow(img, cmap='gray'); plt.imshow(pr, cmap='jet', alpha=0.5); plt.title('Pred'); plt.axis('off')
            plt.tight_layout(); plt.show()
    print("\n🖼️ Exemplos de predição (overlays):")
    show_examples(4)

    # Salvar
    model.save(os.path.join("Models", f"{model_name}_final.keras"), save_traces=True)
    model.save_weights(os.path.join("Models", f"{model_name}_weights.h5"))

    return {
        "name": model_name,
        "Dice": dice, "IoU": iou, "Accuracy": accuracy, "Precision": precision,
        "Sensitivity": sensitivity, "Specificity": specificity, "F1": f1, "AUC": auc, "AP": ap
    }

# ----------------------------
# Carregar dados (uma vez)
# ----------------------------
imgs, masks = build_pairs(PATH_IMG, PATH_MASK)
print(f"🔎 Pares encontrados: {len(imgs)}")
train_img, temp_img, train_msk, temp_msk = train_test_split(imgs, masks, test_size=0.30, random_state=SEED, shuffle=True)
val_img,   test_img,  val_msk,   test_msk = train_test_split(temp_img, temp_msk, test_size=0.50, random_state=SEED, shuffle=True)

X_train, y_train = load_and_preprocess(train_img, train_msk, IMG_SIZE)
X_val,   y_val   = load_and_preprocess(val_img,   val_msk,   IMG_SIZE)
X_test,  y_test  = load_and_preprocess(test_img,  test_msk,  IMG_SIZE)

print("Shapes ->", X_train.shape, y_train.shape, "|", X_val.shape, y_val.shape, "|", X_test.shape, y_test.shape)

# ----------------------------
# Execute os 6 modelos
# ----------------------------
to_run = ["resunet_a_2d", "vnet_2d", "att_unet_2d", "swin_unet_2d", "unet_2d", "unet_plus_2d", "r2_unet_2d", "resunet_a_2d"]
results = []
for name in to_run:
    results.append(train_and_eval(name, X_train, y_train, X_val, y_val, X_test, y_test))

# Resumo comparativo
df = pd.DataFrame(results).set_index("name").sort_values("Dice", ascending=False)
print("\n🏁 Resultado comparativo (ordenado por Dice):")
print(df.round(4))


🔎 Pares encontrados: 1130
Shapes -> (791, 256, 256, 1) (791, 256, 256, 1) | (169, 256, 256, 1) (169, 256, 256, 1) | (170, 256, 256, 1) (170, 256, 256, 1)

Received dilation rates: [1, 2, 4]
Received dilation rates are not defined on a per downsampling level basis.
Automated determinations are applied with the following details:
	depth-0, dilation_rate = [1, 2, 4]
	depth-1, dilation_rate = [1, 2, 4]
	depth-2, dilation_rate = [1]
Model: "resuneta_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 256, 256, 1  0           []                               
                                )]                                                                
                                                                                                  
 resuneta_input_mapping (Conv2D  (None, 256, 256,

ResourceExhaustedError: Graph execution error:

Detected at node 'resuneta_model/resuneta_aspp_out_sepconv_r12_0_depthwise/depthwise' defined at (most recent call last):
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
      self.io_loop.start()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\tornado\platform\asyncio.py", line 205, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\asyncio\base_events.py", line 570, in run_forever
      self._run_once()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\asyncio\base_events.py", line 1859, in _run_once
      handle._run()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\asyncio\events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
      await result
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
      result = self._run_cell(
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
      result = runner(coro)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\Daniel\AppData\Local\Temp\ipykernel_65324\1746007110.py", line 360, in <module>
      results.append(train_and_eval(name, X_train, y_train, X_val, y_val, X_test, y_test))
    File "C:\Users\Daniel\AppData\Local\Temp\ipykernel_65324\1746007110.py", line 267, in train_and_eval
      history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=cb, verbose=1)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\layers\convolutional\depthwise_conv2d.py", line 161, in call
      outputs = backend.depthwise_conv2d(
    File "C:\Users\Daniel\miniconda3\envs\tf-gpu\lib\site-packages\keras\backend.py", line 6315, in depthwise_conv2d
      x = tf.compat.v1.nn.depthwise_conv2d(
Node: 'resuneta_model/resuneta_aspp_out_sepconv_r12_0_depthwise/depthwise'
OOM when allocating tensor with shape[288,22,22,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node resuneta_model/resuneta_aspp_out_sepconv_r12_0_depthwise/depthwise}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_243533]