In [41]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import layers, models, Input
from capsule_layers import CapsuleLayer, Mask, margin_loss, squash
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from PIL import Image
import os


train_dir = './dataset/training'
test_dir = './dataset/testing'

train_datagen = ImageDataGenerator(
    rescale=1./255,  
    rotation_range=15,  
    width_shift_range=0.1,  
    height_shift_range=0.1,  
    shear_range=0.1,  
    zoom_range=0.1,  
    horizontal_flip=True,  
    fill_mode='nearest'  
)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(128, 128),  
    batch_size=32,
    class_mode='categorical',  
    color_mode='grayscale' 
)

validation_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(128, 128),
    batch_size=32,
    class_mode='categorical',
    color_mode='grayscale'
)

Found 5711 images belonging to 4 classes.
Found 1311 images belonging to 4 classes.


In [42]:
# class weights to handle imbalanced data
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

# Dataset Preparation


#### Adjusting CapsNet Architecture for Detection
##### Capsule Network Design

In [43]:
def CapsNet(input_shape, n_classes, routings):
    x = layers.Input(shape=input_shape)

    # Conv layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(x)
    conv1 = layers.BatchNormalization()(conv1)
    
    # PrimaryCaps layer
    primarycaps = CapsuleLayer(num_capsule=32, dim_capsule=8, routings=1)(conv1)

    # DigitCaps layer
    digitcaps = CapsuleLayer(num_capsule=n_classes, dim_capsule=16, routings=routings, name='digitcaps')(primarycaps)

    # Decoder network
    y = layers.Input(shape=(n_classes,))
    masked = Mask()([digitcaps, y])
    
    # Flatten the masked output before feeding into the decoder
    flattened = layers.Flatten()(masked)

    decoder = models.Sequential([
        layers.Dense(512, activation='relu', input_dim=n_classes * 16),  # Adjusted input dimension
        layers.Dense(1024, activation='relu'),
        layers.Dense(np.prod(input_shape), activation='sigmoid'),
        layers.Reshape(target_shape=input_shape)
    ], name='decoder')

    decoded = decoder(flattened)

    model = models.Model([x, y], [digitcaps, decoded])
    model.compile(optimizer='adam',
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., 0.392],
                  metrics={'capsnet': 'accuracy'})

    return model


In [31]:
input_shape = (128, 128, 1)  
n_classes = 4  
routings = 3  

model = CapsNet(input_shape, n_classes, routings)
print(model.summary())

epochs = 50 

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


None


In [44]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, verbose=1),
    ModelCheckpoint('best_model.keras', monitor='val_accuracy', save_best_only=True, verbose=1)
]

In [48]:
import numpy as np
from PIL import Image
import os

def load_images_and_labels(base_dir, target_size=(128, 128)):
    images = []
    labels = []
    label_map = {}  # Maps folder names to numerical labels

    # Assuming subdirectories in the base directory are class labels
    for label_id, subdir in enumerate(sorted(os.listdir(base_dir))):
        current_dir = os.path.join(base_dir, subdir)
        label_map[subdir] = label_id
        for filename in os.listdir(current_dir):
            img_path = os.path.join(current_dir, filename)
            try:
                with Image.open(img_path) as img:
                    img = img.convert('L')  # Convert to grayscale
                    img = img.resize(target_size)  # Resize image
                    img_array = np.array(img) / 255.0  # Normalize the image
                    images.append(img_array)
                    labels.append(label_id)
            except IOError:
                print(f"Error opening image file {img_path}. Skipping...")

    images = np.array(images).reshape(-1, 128, 128, 1)  # Reshape for the model
    labels = np.array(labels)
    return images, labels, label_map

# Example usage
train_images, train_labels, train_label_map = load_images_and_labels('./dataset/training')


In [40]:
from PIL import Image
img = Image.open('./dataset/training/glioma/Tr-glTr_0002.jpg')
img.show()