# BreakHis Image Classification with 🤗 Vision Transformers and `TensorFlow`

### Quick intro: Vision Transformer (ViT) by Google Brain
The Vision Transformer (ViT) is basically BERT, but applied to images. It attains excellent results compared to state-of-the-art convolutional networks. In order to provide images to the model, each image is split into a sequence of fixed-size patches (typically of resolution 16x16 or 32x32), which are linearly embedded. One also adds a [CLS] token at the beginning of the sequence in order to classify images. Next, one adds absolute position embeddings and provides this sequence to the Transformer encoder.

* [Original paper](https://arxiv.org/abs/2010.11929)
* [Official repo (in JAX)](https://github.com/google-research/vision_transformer)
* [🤗 Vision Transformer](https://huggingface.co/docs/transformers/model_doc/vit)
* [Pre-trained model](https://huggingface.co/google/vit-base-patch16-224-in21k)

## Installation

In [None]:
# !pip install transformers datasets "tensorflow==2.6.0" tensorflow-addons --upgrade

## Setup & Configuration

In this step, we will define global configurations and parameters, which are used across the whole end-to-end fine-tuning process, e.g. `feature extractor` and `model` we will use. 

In this example we are going to fine-tune the [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) a Vision Transformer (ViT) pre-trained on ImageNet-21k (14 million images, 21,843 classes) at resolution 224x224.
There are also [large](https://huggingface.co/google/vit-large-patch16-224-in21k) and [huge](https://huggingface.co/google/vit-huge-patch14-224-in21k) flavors of original ViT.

In [None]:
from transformers import TFViTForImageClassification, TFResNetForImageClassification, TFConvNextForImageClassification, TFAutoModelForImageClassification, AutoConfig

# model_arch = TFViTForImageClassification
model_id = "google/vit-base-patch16-224-in21k"
# model_id = "google/vit-large-patch16-224-in21k" # OOM

# model_arch = TFResNetForImageClassification
model_arch = TFAutoModelForImageClassification
# model_id = "microsoft/resnet-101"

# model_arch = TFConvNextForImageClassification
# model_id = "facebook/convnext-base-224"
# model_id = "facebook/convnext-large-224"

zoom = 400

In [None]:
from datasets import load_dataset
from datetime import datetime
import json
from keras.utils import to_categorical
from keras.callbacks import CSVLogger, EarlyStopping
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
from PIL import Image
import shutil

import tensorflow as tf
import tensorflow_addons as tfa
from transformers import create_optimizer, DefaultDataCollator, ViTImageProcessor


## Dataset & Pre-processing

- **Data Source:** https://www.kaggle.com/code/nasrulhakim86/breast-cancer-histopathology-images-classification/data
- The Breast Cancer Histopathological Image Classification (BreakHis) is composed of 9,109 microscopic images of breast tumor tissue collected from 82 patients.
- The images are collected using different magnifying factors (40X, 100X, 200X, and 400X). 
- To date, it contains 2,480 benign and 5,429 malignant samples (700X460 pixels, 3-channel RGB, 8-bit depth in each channel, PNG format).
- This database has been built in collaboration with the P&D Laboratory – Pathological Anatomy and Cytopathology, Parana, Brazil (http://www.prevencaoediagnose.com.br). 
- Each image filename stores information about the image itself: method of procedure biopsy, tumor class, tumor type, patient identification, and magnification factor. 
- For example, SOBBTA-14-4659-40-001.png is the image 1, at magnification factor 40X, of a benign tumor of type tubular adenoma, original from the slide 14-4659, which was collected by procedure SOB.

The `BreakHis` is not yet available as a dataset in the `datasets` library. To be able to create a `Dataset` instance we need to write a small little helper function, which will load our `Dataset` from the filesystem and create the instance to use later for training.

This notebook assumes that the dataset is available in directory tree next to this file and its directory name is `breakhis_400x`

In [None]:
cwd = Path().absolute()
input_path = cwd / f'breakhis_{zoom}x'


In [None]:
tf.debugging.disable_traceback_filtering()


image_processor = ViTImageProcessor.from_pretrained(model_id)

def process_example(image):
    inputs = image_processor(image, return_tensors='tf')
    return inputs['pixel_values']


def process_dataset(example):
    example['pixel_values'] = process_example(Image.open(example['file_loc']).convert("RGB"))

    example['label'] = to_categorical(example['label'], num_classes=2)
    return example

def load_data(fold_idx):
    train_csv = str(input_path / f"train_{fold_idx}.csv")
    val_csv = str(input_path / f"val_{fold_idx}.csv")
    dataset = load_dataset(
        'csv', data_files={'train': train_csv, 'val': val_csv})

    dataset = dataset.map(process_dataset, with_indices=False, num_proc=1)

    print(f"Loaded {fold_idx} dataset: {dataset}")

    return dataset


## Fine-tuning the model using `Keras`

Now that our `dataset` is processed, we can download the pretrained model and fine-tune it. But before we can do this we need to convert our Hugging Face `datasets` Dataset into a `tf.data.Dataset`. For this, we will use the `.to_tf_dataset` method and a `data collator` (Data collators are objects that will form a batch by using a list of dataset elements as input).




## Hyperparameter

In [None]:
id2label = {"0": "benign", "1": "malignant"}
label2id = {v: k for k, v in id2label.items()}

num_train_epochs = 100
batch_size = 40
batch_size = 40
num_warmup_steps = 0
fp16 = True

# Train in mixed-precision float16
# Comment this line out if you're using a GPU that will not benefit from this
if fp16:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")


### Download the pretrained transformer model and fine-tune it. 

In [None]:
def get_loss():
    return tf.keras.losses.BinaryCrossentropy(from_logits=True)


def get_metrics():
    return [
        tf.keras.metrics.BinaryAccuracy(name="accuracy"),
        tf.keras.metrics.AUC(name='auc', from_logits=True),
        # tf.keras.metrics.AUC(name='auc_multi', from_logits=True,
                            #  num_labels=2, multi_label=True),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.Precision(name='precision'),
        tfa.metrics.F1Score(name='f1_score', num_classes=2, threshold=0.5),
    ]


def get_callbacks(output_path, fold_idx):
    return [
        EarlyStopping(monitor="val_loss", patience=3),
        CSVLogger(output_path / f'train_metrics_{fold_idx}.csv')
    ]


def get_optimizer(learning_rate, weight_decay_rate, num_warmup_steps, num_train_steps):
    optimizer, _ = create_optimizer(
        init_lr=learning_rate,
        num_train_steps=num_train_steps,
        weight_decay_rate=weight_decay_rate,
        num_warmup_steps=num_warmup_steps,
    )

    return optimizer


num_train_steps_list = []
def train_model(fold_idx, train, val, learning_rate, weight_decay_rate, output_path):
    num_train_steps = len(train) * num_train_epochs
    num_train_steps_list.append(num_train_steps)
    print(f"num_train_steps = {num_train_steps}")
    optimizer = get_optimizer(
        learning_rate, weight_decay_rate, num_warmup_steps, num_train_steps)

    # load pre-trained ViT model
    model = model_arch.from_pretrained(
        model_id,
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
    )

    # compile model
    model.compile(optimizer=optimizer, loss=get_loss(), metrics=get_metrics())
    history = model.fit(
        train,
        validation_data=val,
        callbacks=get_callbacks(output_path, fold_idx),
        epochs=num_train_epochs,
    )

    return model, history


In [None]:
def remove_extra_dim(example):
    example['pixel_values'] = np.squeeze(example['pixel_values'], axis=0)
    return example

def save_model(idx, model, output_path):
    model.save_pretrained(output_path / f'model_{idx}', from_tf=True)
    
def save_history(idx, history, output_path):
    np.save(output_path / f'train_history_{idx}.npy', history.history)

In [None]:
def intersection(lst1, lst2):
    return list(set(lst1) & set(lst2))


def run_fold(fold_idx, learning_rate, weight_decay_rate, output_path):
    tf.keras.backend.clear_session()
    dataset = load_data(fold_idx)

    # Check patient ids uniqueness
    train_dataset = dataset["train"].map(remove_extra_dim)
    val_dataset = dataset["val"].map(remove_extra_dim)

    # Create datasets and train model
    data_collator = DefaultDataCollator(return_tensors="tf")

    train_dataset_tf = train_dataset.to_tf_dataset(
        columns=['pixel_values'],
        label_cols=['label'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=data_collator
    )

    val_dataset_tf = val_dataset.to_tf_dataset(
        columns=['pixel_values'],
        label_cols=['label'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    print(train_dataset_tf)
    print(val_dataset_tf)

    model, history = train_model(fold_idx, train_dataset_tf, val_dataset_tf, learning_rate, weight_decay_rate, output_path)
    save_model(fold_idx, model, output_path)
    save_history(fold_idx, history, output_path)

    print(f'Fold {fold_idx} finished')


In [None]:
def save_model_info(output_path, fold_idx, learning_rate, weight_decay_rate):
    model_info = {"idx": fold_idx,
                    "model_id": model_id,
                    "zoom": zoom,
                    "n_splits": 5,
                    "num_train_epochs": num_train_epochs,
                    "batch_size": batch_size,
                    "learning_rate": learning_rate,
                    "weight_decay_rate": weight_decay_rate,
                    "num_warmup_steps": num_warmup_steps,
                    "num_train_steps": num_train_steps_list[0]}

    with open(output_path / f'model_info_{fold_idx}.json', 'w') as f:
        json.dump(model_info, f, indent=4)

    print(json.dumps(model_info, indent=4))

In [None]:
experiment_id = "testtest"
fold_idx = 0
learning_rate = 3e-5
# learning_rate = 1e-4
# weight_decay_rate = 0.01
weight_decay_rate = 0.005

output_path = cwd / 'results' / f'{zoom}x_{experiment_id}'

# shutil.rmtree(output_path, ignore_errors=True)
os.makedirs(output_path)

run_fold(fold_idx, learning_rate, weight_decay_rate, output_path)
save_model_info(output_path, fold_idx, learning_rate, weight_decay_rate)

In [None]:
# import argparse

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument('exid', type=str, help='Experiment ID')
#     parser.add_argument('idx', type=int, help='Fold index')
#     parser.add_argument('lr', type=float, help='Learning rate')
#     parser.add_argument('wdr', type=float, help='Weight decay rate')

#     args = parser.parse_args()

#     output_path = cwd / 'results' / f'{zoom}x_{args.exid}'

#     if not os.path.exists(output_path):
#         os.makedirs(output_path)
#     else:
#         print(f"Directory {output_path} already exists. Skipping creation.")

#     print(f"Starting experiment {args.exid} with fold {args.idx}. Hyperparams:")
#     print(f"Learning rate: {args.lr}")
#     print(f"Weight decay rate: {args.wdr}")
#     run_fold(args.idx, args.lr, args.wdr, output_path)
#     save_model_info(output_path, args.idx, args.lr, args.wdr)


#     run_fold(args.idx)
#     save_model_info(args.idx)
