In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os

In [7]:
# Paths to the dataset
train_path = "/content/drive/MyDrive/Data/train"
valid_path = "/content/drive/MyDrive/Data/valid"
test_path = "/content/drive/MyDrive/Data/test"


In [8]:
# Model Parameters
IMAGE_SIZE = 224
BATCH_SIZE = 32
N_CLASSES = 4
EPOCHS = 50
LEARNING_RATE = 0.0001

In [9]:
# Data Augmentation
train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet50.preprocess_input,
    rotation_range=15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
)

In [10]:
valid_datagen = ImageDataGenerator(preprocessing_function=tf.keras.applications.resnet50.preprocess_input)
test_datagen = ImageDataGenerator(preprocessing_function=tf.keras.applications.resnet50.preprocess_input)

train_generator = train_datagen.flow_from_directory(
    train_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)

valid_generator = valid_datagen.flow_from_directory(
    valid_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    test_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False
)

Found 763 images belonging to 4 classes.
Found 144 images belonging to 4 classes.
Found 515 images belonging to 4 classes.


In [None]:
# Load Pretrained ResNet50 Model
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [None]:
# Freeze base model layers
for layer in base_model.layers[:]:
    layer.trainable = False

In [None]:
# Custom Layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
output = Dense(N_CLASSES, activation='softmax')(x)

In [None]:
# Final Model
model = Model(inputs=base_model.input, outputs=output)

In [None]:
# Compile Model
model.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='categorical_crossentropy', metrics=['accuracy'])


In [None]:
# Callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint('best_model.keras', save_best_only=True, monitor='val_accuracy', mode='max')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)


In [None]:
# Train Model
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=valid_generator,
    callbacks=[checkpoint, early_stopping]
)

  self._warn_if_super_not_called()


Epoch 1/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m466s[0m 16s/step - accuracy: 0.3372 - loss: 1.8823 - val_accuracy: 0.3763 - val_loss: 2.2234
Epoch 2/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m233s[0m 8s/step - accuracy: 0.4990 - loss: 1.3570 - val_accuracy: 0.3978 - val_loss: 1.8699
Epoch 3/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m264s[0m 8s/step - accuracy: 0.5136 - loss: 1.2886 - val_accuracy: 0.3978 - val_loss: 1.4702
Epoch 4/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m269s[0m 8s/step - accuracy: 0.5840 - loss: 1.1401 - val_accuracy: 0.4731 - val_loss: 1.2668
Epoch 5/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m257s[0m 8s/step - accuracy: 0.6024 - loss: 1.0726 - val_accuracy: 0.4946 - val_loss: 1.2441
Epoch 6/50
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 8s/step - accuracy: 0.6147 - loss: 1.0228 - val_accuracy: 0.5806 - val_loss: 1.0856
Epoch 7/50
[1m29/29[0m [32m━━━

KeyboardInterrupt: 

In [None]:
# Evaluate Model
test_loss, test_acc = model.evaluate(test_generator)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

In [None]:
# Plot Accuracy and Loss Curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()

plt.show()

1. Use a Pretrained Model Instead of a Custom CNN
2. Increase Dataset Size or Apply Stronger Data Augmentation
3. Use a Higher Learning Rate & Fine-Tune the Model
4. Use Class Weights if Data is Imbalanced
5. Ensure Correct Input Shape for Pretrained Models
6. Use an Advanced Optimizer

In [None]:
# -*- coding: utf-8 -*-
"""Lung Cancer Image Classification using Swin Transformer"""

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import SwinTransformerV2T
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

# Paths to dataset
train_path = "/content/drive/MyDrive/Data-20250319T093650Z-001/Data/train"
valid_path = "/content/drive/MyDrive/Data-20250319T093650Z-001/Data/valid"
test_path = "/content/drive/MyDrive/Data-20250319T093650Z-001/Data/test"

# Model Parameters
IMAGE_SIZE = 224
BATCH_SIZE = 32
N_CLASSES = 4
EPOCHS = 50
LEARNING_RATE = 0.0001

# Data Augmentation
train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.swin_transformer.preprocess_input,
    rotation_range=15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
)

valid_datagen = ImageDataGenerator(preprocessing_function=tf.keras.applications.swin_transformer.preprocess_input)
test_datagen = ImageDataGenerator(preprocessing_function=tf.keras.applications.swin_transformer.preprocess_input)

train_generator = train_datagen.flow_from_directory(
    train_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)

valid_generator = valid_datagen.flow_from_directory(
    valid_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    test_path, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical', shuffle=False
)

# Load Pretrained Swin Transformer
base_model = SwinTransformerV2T(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))

# Freeze base model layers
for layer in base_model.layers:
    layer.trainable = False

# Custom Layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
output = Dense(N_CLASSES, activation='softmax')(x)

# Final Model
model = Model(inputs=base_model.input, outputs=output)

# Compile Model
model.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='categorical_crossentropy', metrics=['accuracy'])

# Callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint('best_swin_model.keras', save_best_only=True, monitor='val_accuracy', mode='max')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train Model
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=valid_generator,
    callbacks=[checkpoint, early_stopping]
)

# Evaluate Model
test_loss, test_acc = model.evaluate(test_generator)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

# Plot Accuracy and Loss Curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()

plt.show()
