# Training the model

This notebook has as goal to train a model to classify `Chest X-Ray` images into `normal` or `pneumonia` classes. 



## Setup

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

# TensorFlow
import tensorflow as tf
from tensorflow import keras
from keras import utils, layers, optimizers, callbacks, metrics, Sequential

# Hugging Face
import datasets
from transformers import TFViTForImageClassification

In [2]:
DATASET_NAME = "mmenendezg/pneumonia_x_ray"
MODEL_NAME = "google/vit-large-patch32-384"

BATCH_SIZE = 32
SEED = 92
AUTOTUNE = tf.data.AUTOTUNE

## Load the dataset

In [3]:
def format_dataset(example: dict):
    return example["image"], example["label"]


def convert_tf_dataset(dataset: datasets.Dataset, shuffle: bool = False):
    tf_dataset = (
        dataset.to_tf_dataset(batch_size=1)
        .unbatch()
        .map(format_dataset, num_parallel_calls=AUTOTUNE)
    )
    if shuffle:
        tf_dataset.shuffle(dataset.num_rows)

    tf_dataset = tf_dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)

    return tf_dataset


def load_dataset():
    pneumonia_dataset = datasets.load_dataset(DATASET_NAME)
    train_ds = convert_tf_dataset(pneumonia_dataset["train"], shuffle=True)
    valid_ds = convert_tf_dataset(pneumonia_dataset["validation"])
    test_ds = convert_tf_dataset(pneumonia_dataset["test"])

    return train_ds, valid_ds, test_ds


In [4]:
train_ds, valid_ds, test_ds = load_dataset()


Found cached dataset parquet (/Users/mmenendezg/.cache/huggingface/datasets/mmenendezg___parquet/mmenendezg--pneumonia_x_ray-052017b06aabdb98/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

### Data Augmentation

In [None]:
data_augmentation = Sequential([
    layers.RandomFlip("horizontal")
])

## Train the model

In [None]:
# Clear the TensorFlow session
tf.keras.backend.clear_session()
tf.random.set_seed(SEED)