In [None]:
import tensorflow as tf
from tensorflow.keras import models, layers
import pandas as pd
import os
import io
import PIL
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras import mixed_precision

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 64 * strategy.num_replicas_in_sync
CLASSES = ['0', '1', '2', '3', '4']
EPOCHS = 25

print(BATCH_SIZE)

In [None]:
paths = []
labels = []

df = pd.read_csv("/kaggle/input/cassava-leaf-disease-classification/train.csv")

for i in df["image_id"]:
    paths.append(i)

for i in df["label"]:
    labels.append(i)

print(paths[:10])
print(labels[:10])
df

In [None]:
# Create a dataset of file paths
image_dataset = tf.data.Dataset.from_tensor_slices(paths)

# Function to load and preprocess each image
def load_and_preprocess(path):
    image = tf.io.read_file("/kaggle/input/cassava-leaf-disease-classification/train_images/" + path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [512, 512])
    return image

# Create image dataset
image_dataset = image_dataset.map(load_and_preprocess)

# Create dataset for labels
label_dataset = tf.data.Dataset.from_tensor_slices(labels)

dataset_no_tfrecord = tf.data.Dataset.zip((image_dataset, label_dataset))


train_split = 0.9
val_split = 0.1

train_num = int(len(labels)*train_split)
val_num = int(len(labels)*val_split)

# make sure to get all the number of training examples in the dataset
if not train_num + val_num == len(labels):
    train_val_num = train_num + val_num
    leftovers = len(labels) - train_val_num
    val_num += leftovers

# sanity check
if train_num + val_num == len(labels):
    print("train_num plus val_num does equal the number of training examples")

if not train_num + val_num == len(labels):
    print("train_num plus val_num does not equal the number of training examples")

train_dataset_no_tfrecord = dataset_no_tfrecord.take(train_num)
val_dataset_no_tfrecord = dataset_no_tfrecord.skip(train_num)

In [None]:
def augment(image, label):
    
    # Apply random contrast
    image = tf.image.random_contrast(image, lower=0.3, upper=0.5)
    
    # Apply random rotation
    image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
    
    return image, label

train_dataset_no_tfrecord = train_dataset_no_tfrecord.cache().batch(batch_size = BATCH_SIZE).map(augment).map(lambda x, y: (x/255, y)).shuffle(500).prefetch(buffer_size=AUTOTUNE)
val_dataset_no_tfrecord = val_dataset_no_tfrecord.cache().batch(batch_size = BATCH_SIZE).map(augment).map(lambda x, y: (x/255, y)).shuffle(500).prefetch(buffer_size=AUTOTUNE)

print("\nTraining data shapes:", train_dataset_no_tfrecord)
print("\nValidation data shapes:", val_dataset_no_tfrecord)


In [None]:
# pre-trained model
pretrained_model = tf.keras.applications.ResNet50(include_top=False,
                                                  input_shape=(512, 512, 3),
                                                  pooling='avg',
                                                  weights='imagenet')

pretrained_model.trainable = False

In [None]:
# mixed-precision training
mixed_precision.set_global_policy('mixed_float16')

with strategy.scope():
    
    model = models.Sequential([pretrained_model,
                             layers.Flatten(),
                             layers.Dense(128),
                             layers.Dropout(0.2),
                             layers.Activation('relu'),
                             layers.Dense(5, activation="softmax")])
    
    model.compile(optimizer="adam", loss = tf.keras.losses.SparseCategoricalCrossentropy(), metrics=["accuracy"])

In [None]:
model.summary()

In [None]:
history = model.fit(train_dataset_no_tfrecord, epochs=EPOCHS, validation_data=val_dataset_no_tfrecord)