# BreakHis Image Classification with 🤗 Vision Transformers and `TensorFlow`

### Quick intro: Vision Transformer (ViT) by Google Brain
The Vision Transformer (ViT) is basically BERT, but applied to images. It attains excellent results compared to state-of-the-art convolutional networks. In order to provide images to the model, each image is split into a sequence of fixed-size patches (typically of resolution 16x16 or 32x32), which are linearly embedded. One also adds a [CLS] token at the beginning of the sequence in order to classify images. Next, one adds absolute position embeddings and provides this sequence to the Transformer encoder.

* [Original paper](https://arxiv.org/abs/2010.11929)
* [Official repo (in JAX)](https://github.com/google-research/vision_transformer)
* [🤗 Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit)
* [Pre-trained model](https://huggingface.co/google/vit-base-patch16-224-in21k)

## Installation

In [1]:
# !pip install transformers datasets tensorflow-addons --upgrade

In [2]:
# !pip show tensorflow

## Setup & Configuration

In this step, we will define global configurations and parameters, which are used across the whole end-to-end fine-tuning process, e.g. `feature extractor` and `model` we will use. 

In this example we are going to fine-tune the [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) a Vision Transformer (ViT) pre-trained on ImageNet-21k (14 million images, 21,843 classes) at resolution 224x224.
There are also [large](https://huggingface.co/google/vit-large-patch16-224-in21k) and [huge](https://huggingface.co/google/vit-huge-patch14-224-in21k) flavors of original ViT.

In [3]:
from transformers import TFAutoModelForImageClassification, AutoImageProcessor, TFConvNextForImageClassification

model_id = "facebook/convnext-base-224-22k"
# model_id = "microsoft/swin-base-patch4-window7-224-in22k"

model_arch = TFConvNextForImageClassification
image_processor = AutoImageProcessor.from_pretrained(model_id)

zoom = 400

image_processor

2023-12-03 18:33:44.523010: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


ConvNextImageProcessor {
  "crop_pct": 0.875,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ConvNextImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 224
  }
}

In [4]:
from datasets import load_dataset
from datetime import datetime
import json
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
from PIL import Image
import shutil

import tensorflow as tf
import tensorflow_addons as tfa
from transformers import create_optimizer, DefaultDataCollator, ViTImageProcessor



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



## Dataset & Pre-processing

- **Data Source:** https://www.kaggle.com/code/nasrulhakim86/breast-cancer-histopathology-images-classification/data
- The Breast Cancer Histopathological Image Classification (BreakHis) is composed of 9,109 microscopic images of breast tumor tissue collected from 82 patients.
- The images are collected using different magnifying factors (40X, 100X, 200X, and 400X). 
- To date, it contains 2,480 benign and 5,429 malignant samples (700X460 pixels, 3-channel RGB, 8-bit depth in each channel, PNG format).
- This database has been built in collaboration with the P&D Laboratory – Pathological Anatomy and Cytopathology, Parana, Brazil (http://www.prevencaoediagnose.com.br). 
- Each image filename stores information about the image itself: method of procedure biopsy, tumor class, tumor type, patient identification, and magnification factor. 
- For example, SOBBTA-14-4659-40-001.png is the image 1, at magnification factor 40X, of a benign tumor of type tubular adenoma, original from the slide 14-4659, which was collected by procedure SOB.

The `BreakHis` is not yet available as a dataset in the `datasets` library. To be able to create a `Dataset` instance we need to write a small little helper function, which will load our `Dataset` from the filesystem and create the instance to use later for training.

This notebook assumes that the dataset is available in directory tree next to this file and its directory name is `breakhis_400x`

In [5]:
cwd = Path().absolute()
input_path = cwd / f'breakhis_{zoom}x'

In [6]:
tf.debugging.disable_traceback_filtering()

def process_example(image):
    inputs = image_processor(image, return_tensors='tf')
    return inputs['pixel_values']


def process_dataset(example):
    example['pixel_values'] = process_example(Image.open(example['file_loc']).convert("RGB"))

    example['label'] = to_categorical(example['label'], num_classes=2)
    return example

def load_data(fold_idx):
    train_csv = str(input_path / f"train_{fold_idx}.csv")
    val_csv = str(input_path / f"val_{fold_idx}.csv")
    dataset = load_dataset(
        'csv', data_files={'train': train_csv, 'val': val_csv})

    dataset = dataset.map(process_dataset, with_indices=False, num_proc=4)

    print(f"Loaded {fold_idx} dataset: {dataset}")

    return dataset


## Fine-tuning the model using `Keras`

Now that our `dataset` is processed, we can download the pretrained model and fine-tune it. But before we can do this we need to convert our Hugging Face `datasets` Dataset into a `tf.data.Dataset`. For this, we will use the `.to_tf_dataset` method and a `data collator` (Data collators are objects that will form a batch by using a list of dataset elements as input).




## Hyperparameter

In [7]:
id2label = {"0": "benign", "1": "malignant"}
label2id = {v: k for k, v in id2label.items()}

num_train_epochs = 150
batch_size = 10
batch_size = 10
num_warmup_steps = 0
fp16 = True

# Train in mixed-precision float16
# Comment this line out if you're using a GPU that will not benefit from this
if fp16:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")


INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 3070 Laptop GPU, compute capability 8.6


2023-12-03 18:33:50.505097: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-12-03 18:33:50.582637: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-12-03 18:33:50.582721: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-12-03 18:33:50.583258: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.


### Download the pretrained transformer model and fine-tune it. 

In [8]:
def get_loss():
    return tf.keras.losses.BinaryCrossentropy(from_logits=True)


def get_metrics():
    return [
        tf.keras.metrics.BinaryAccuracy(name="accuracy"),
        tf.keras.metrics.AUC(name='auc', from_logits=True),
        # tf.keras.metrics.AUC(name='auc_multi', from_logits=True,
                            #  num_labels=2, multi_label=True),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.Precision(name='precision'),
        tfa.metrics.F1Score(name='f1_score', num_classes=2, threshold=0.5),
    ]


def get_callbacks(output_path, fold_idx):
    return [
        EarlyStopping(monitor="val_loss", patience=3),
        CSVLogger(output_path / f'train_metrics_{fold_idx}.csv')
    ]


def get_optimizer(learning_rate, weight_decay_rate, num_warmup_steps, num_train_steps):
    optimizer, _ = create_optimizer(
        init_lr=learning_rate,
        num_train_steps=num_train_steps,
        weight_decay_rate=weight_decay_rate,
        num_warmup_steps=num_warmup_steps,
    )

    return optimizer


num_train_steps_list = []
def train_model(fold_idx, train, val, learning_rate, weight_decay_rate, output_path):
    num_train_steps = len(train) * num_train_epochs
    num_train_steps_list.append(num_train_steps)
    print(f"num_train_steps = {num_train_steps}")
    optimizer = get_optimizer(
        learning_rate, weight_decay_rate, num_warmup_steps, num_train_steps)

    # load pre-trained ViT model
    model = model_arch.from_pretrained(
        model_id,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes = True
    )

    # compile model
    model.compile(optimizer=optimizer, loss=get_loss(), metrics=get_metrics())
    
    print(f"MODEL SUMMARY: {model.summary()}")
    
    history = model.fit(
        train,
        validation_data=val,
        callbacks=get_callbacks(output_path, fold_idx),
        epochs=num_train_epochs,
    )

    return model, history


In [9]:
def remove_extra_dim(example):
    example['pixel_values'] = np.squeeze(example['pixel_values'], axis=0)
    return example

def save_model(idx, model, output_path):
    model.save_pretrained(output_path / f'model_{idx}', from_tf=True)
    
def save_history(idx, history, output_path):
    np.save(output_path / f'train_history_{idx}.npy', history.history)

In [10]:
def intersection(lst1, lst2):
    return list(set(lst1) & set(lst2))


def run_fold(fold_idx, learning_rate, weight_decay_rate, output_path):
    tf.keras.backend.clear_session()
    dataset = load_data(fold_idx)

    # Check patient ids uniqueness
    train_dataset = dataset["train"].map(remove_extra_dim)
    val_dataset = dataset["val"].map(remove_extra_dim)

    # Create datasets and train model
    data_collator = DefaultDataCollator(return_tensors="tf")

    train_dataset_tf = train_dataset.to_tf_dataset(
        columns=['pixel_values'],
        label_cols=['label'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=data_collator
    )

    val_dataset_tf = val_dataset.to_tf_dataset(
        columns=['pixel_values'],
        label_cols=['label'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    print(train_dataset_tf)
    print(val_dataset_tf)

    model, history = train_model(fold_idx, train_dataset_tf, val_dataset_tf, learning_rate, weight_decay_rate, output_path)
    save_model(fold_idx, model, output_path)
    save_history(fold_idx, history, output_path)

    print(f'Fold {fold_idx} finished')


In [11]:
def save_model_info(output_path, fold_idx, learning_rate, weight_decay_rate):
    model_info = {"idx": fold_idx,
                    "model_id": model_id,
                    "zoom": zoom,
                    "n_splits": 5,
                    "num_train_epochs": num_train_epochs,
                    "batch_size": batch_size,
                    "learning_rate": learning_rate,
                    "weight_decay_rate": weight_decay_rate,
                    "num_warmup_steps": num_warmup_steps,
                    "num_train_steps": num_train_steps_list[0]}

    with open(output_path / f'model_info_{fold_idx}.json', 'w') as f:
        json.dump(model_info, f, indent=4)

    print(json.dumps(model_info, indent=4))

In [13]:
# import os
# os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'

experiment_id = "convnextttttt"
fold_idx = 0
learning_rate = 3e-5
# learning_rate = 1e-4
# weight_decay_rate = 0.01
weight_decay_rate = 0.005

output_path = cwd / 'results' / f'{zoom}x_{experiment_id}'

# shutil.rmtree(output_path, ignore_errors=True)
os.makedirs(output_path)

run_fold(fold_idx, learning_rate, weight_decay_rate, output_path)
save_model_info(output_path, fold_idx, learning_rate, weight_decay_rate)

Loaded 0 dataset: DatasetDict({
    train: Dataset({
        features: ['file_loc', 'label', 'label_str', 'patient_id', 'pixel_values'],
        num_rows: 1077
    })
    val: Dataset({
        features: ['file_loc', 'label', 'label_str', 'patient_id', 'pixel_values'],
        num_rows: 354
    })
})


Old behaviour: columns=['a'], labels=['labels'] -> (tf.Tensor, tf.Tensor)  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor)  
New behaviour: columns=['a'],labels=['labels'] -> ({'a': tf.Tensor}, {'labels': tf.Tensor})  
             : columns='a', labels='labels' -> (tf.Tensor, tf.Tensor) 
2023-12-03 18:33:59.530529: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-12-03 18:33:59.530708: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2023-12-03 18:33:59.530814: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been b

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 3, 224, 224), dtype=tf.float32, name=None), TensorSpec(shape=(None, 2), dtype=tf.float32, name=None))>
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 3, 224, 224), dtype=tf.float32, name=None), TensorSpec(shape=(None, 2), dtype=tf.float32, name=None))>
num_train_steps = 16200


2023-12-03 18:34:04.424004: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8800
2023-12-03 18:34:04.942202: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x806b9d10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-12-03 18:34:04.942256: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA GeForce RTX 3070 Laptop GPU, Compute Capability 8.6
2023-12-03 18:34:05.561471: I ./tensorflow/compiler/jit/device_compiler.h:180] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




All model checkpoint layers were used when initializing TFConvNextForImageClassification.

Some weights of TFConvNextForImageClassification were not initialized from the model checkpoint at facebook/convnext-base-224-22k and are newly initialized because the shapes did not match:
- classifier/kernel:0: found shape (1024, 21841) in the checkpoint and (1024, 2) in the model instantiated
- classifier/bias:0: found shape (21841,) in the checkpoint and (2,) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model: "tf_conv_next_for_image_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 convnext (TFConvNextMainLay  multiple                 87566464  
 er)                                                             
                                                                 
 classifier (Dense)          multiple                  2050      
                                                                 
Total params: 87,568,514
Trainable params: 87,568,514
Non-trainable params: 0
_________________________________________________________________
MODEL SUMMARY: None
Epoch 1/150

2023-12-03 18:35:26.115118: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at xla_ops.cc:362 : UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.4 = (f16[1,7,7,1024]{3,2,1,0}, u8[0]{0}) custom-call(f16[1,7,7,7168]{3,2,1,0} %bitcast.40, f16[1024,7,7,8]{3,2,1,0} %transpose.3), window={size=7x7 pad=3_3x3_3}, dim_labels=b01f_o01i->b01f, feature_group_count=1024, custom_call_target="__cudnn$convForward", metadata={op_type="Conv2DBackpropFilter" op_name="gradients/Conv2D_grad/Conv2DBackpropFilter" source_file="/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/optimizers/legacy/optimizer_v2.py" source_line=519}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNKNOWN: CUDNN_STATUS_BAD_PARAM
in tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(3588): 'op' CUDNN_BACKEND_OPERATION: cudnnFinalize Failed

To ignore this failure and try to use a fallback algorithm (whi

UnknownError: Graph execution error:

Detected at node 'AdamWeightDecay/gradients/PartitionedCall' defined at (most recent call last):
    File "<frozen runpy>", line 198, in _run_module_as_main
    File "<frozen runpy>", line 88, in _run_code
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/traitlets/config/application.py", line 1053, in launch_instance
      app.start()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
      self.io_loop.start()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
      self._run_once()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
      handle._run()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
      await self.process_one()
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one
      await dispatch(*args)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
      await result
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
      reply_content = await reply_content
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_1364/3407642219.py", line 16, in <module>
      run_fold(fold_idx, learning_rate, weight_decay_rate, output_path)
    File "/tmp/ipykernel_1364/2688380364.py", line 34, in run_fold
      model, history = train_model(fold_idx, train_dataset_tf, val_dataset_tf, learning_rate, weight_decay_rate, output_path)
    File "/tmp/ipykernel_1364/4123974316.py", line 57, in train_model
      history = model.fit(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler
      return fn(*args, **kwargs)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/engine/training.py", line 1685, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/engine/training.py", line 1284, in train_function
      return step_function(self, iterator)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/engine/training.py", line 1268, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/engine/training.py", line 1249, in run_step
      outputs = model.train_step(data)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/transformers/modeling_tf_utils.py", line 1675, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/optimizers/legacy/optimizer_v2.py", line 585, in minimize
      grads_and_vars = self._compute_gradients(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/mixed_precision/loss_scale_optimizer.py", line 744, in _compute_gradients
      grads_and_vars = self._optimizer._compute_gradients(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/optimizers/legacy/optimizer_v2.py", line 643, in _compute_gradients
      grads_and_vars = self._get_gradients(
    File "/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/optimizers/legacy/optimizer_v2.py", line 519, in _get_gradients
      grads = tape.gradient(loss, var_list, grad_loss)
Node: 'AdamWeightDecay/gradients/PartitionedCall'
Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.4 = (f16[1,7,7,1024]{3,2,1,0}, u8[0]{0}) custom-call(f16[1,7,7,7168]{3,2,1,0} %bitcast.40, f16[1024,7,7,8]{3,2,1,0} %transpose.3), window={size=7x7 pad=3_3x3_3}, dim_labels=b01f_o01i->b01f, feature_group_count=1024, custom_call_target="__cudnn$convForward", metadata={op_type="Conv2DBackpropFilter" op_name="gradients/Conv2D_grad/Conv2DBackpropFilter" source_file="/home/miki/miniconda3/envs/tf3/lib/python3.11/site-packages/keras/optimizers/legacy/optimizer_v2.py" source_line=519}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNKNOWN: CUDNN_STATUS_BAD_PARAM
in tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(3588): 'op' CUDNN_BACKEND_OPERATION: cudnnFinalize Failed

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.
	 [[{{node AdamWeightDecay/gradients/PartitionedCall}}]] [Op:__inference_train_function_46170]

In [None]:
# import argparse
# from pathlib import Path
# import os

# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--experiment_id', default='224x224frames_convnext', type=str)
#     parser.add_argument('--fold_idx', default=0, type=int)
#     parser.add_argument('--weight_decay_rate', default=0.005, type=float)
#     parser.add_argument('--learning_rate', default=3e-5, type=float)
#     args = parser.parse_args()

#     experiment_id = args.experiment_id
#     fold_idx = args.fold_idx
#     weight_decay_rate = args.weight_decay_rate
#     learning_rate = args.learning_rate
    
#     cwd = Path.cwd()
#     output_path = cwd / 'results' / f'{experiment_id}'
    
#     os.makedirs(output_path, exist_ok=True)
    
#     run_fold(fold_idx, learning_rate, weight_decay_rate, output_path)
#     save_model_info(output_path, fold_idx, learning_rate, weight_decay_rate)

# if __name__ == '__main__':
#     main()