In [None]:
import os 
from glob import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.application import VGG16
from tensorflow.keras import layers
from tensorflow.keras.callbacks import Callbacks,Modelcheckpoint ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam

In [None]:
H, W = 256, 256
image_size = [256,256]
c=3
class_name = ["brain_glioma","brain_menin","brain_tumor"]
lr = 1e-4
model_path = ""

In [None]:
path = "/kaggle/input/multi-cancer/Multi Cancer/Brain Cancer"

In [None]:
def load_data(path,split=0.1):
    files = glob(os.path.join(path,"*","*"))
    split_rate = int(len(files) * split)
    
    train,valid = train_test_split(files,test_size=split_rate)
    train,test = train_test_split(train,test_size=split_rate)

    return train,valid,test

In [None]:
files =load_data(path)

In [None]:
def preprocess_data(image):
    img = cv2.imread(image,cv2.IMREAD_COLOR)
    img = cv2.resize(img,(H.W))
    img = img / 255.0
    img = img.astype(np.float32)

    lable = image.split("/")[-2]
    class_idx = classes_name.index[lable]

    return img,class_idx

In [None]:
img,class_idx = preprocess_data(files[0])

In [None]:
classes = np.array(class_idx,np.float32)

In [None]:
def parse(path):
    images,labels = tf.numpy_function(preprocess_data,[path],[tf.float32,tf.int32])
    labels = tf.one_hot(labels,3)
    images.set_shape([256,256,3])
    labels.set_shape(3)

    return images,labels

In [None]:
def tf_datasets(images, batch_size=8):
    dataset = tf.data.Dataset.from_tensor_slices((images))
    dataset = dataset.map(parse)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(8)
    return dataset

In [None]:
train_ds = tf_datasets(train)
valid_ds = tf_datasets(valid)
test_ds = tf_datasets(test)

In [None]:
image_ds = tf_datasets(files)

In [None]:
for i,j in image_ds.take(1):
    print(i.numpy.shape())

In [None]:
def PlotPipeImg(img_arr):
    fig,ax = plt.subplots(1,10,figsize=(10,10))
    axes = ax.flatten()
    for img, ax in zip(img_arr,axes):
        ax.imshow(img)
        ax.axis("off")

        plt.tight_layout()
        plt.show()

In [None]:
img,idl = next(iter(image_ds))

In [None]:
model = VGG16(input_shape=image_size+[c], weights='imagenet',include_top=False)

In [None]:
model.summary()

In [None]:
for layer in model.layers:
    layer.trainable = False

In [None]:
x = layers.Flatten()(model.output)

In [None]:
last_layer = layers.Dense(3,activation='softmax')(x)
model = Model(inputs=model.input,outputs=last_layer)

In [None]:
callback=[
    Modelcheckpoint(model_path,verbose=1,save_best_only = True),
    ReduceLROnPlateau(monitor = "val_Loss",patience=5,min_lr = 1e-5,factor=0.1,verbose=1)
]

In [None]:
model.compile(loss='',optimizer=Adam(lr),metrics=['accuracy'])

In [None]:
model.fit(
    train_ds,
    valid_ds,
    epochs = 20,
    Callbacks,
)

In [None]:
model.evaluate(test_ds)