In [162]:
import warnings

# Import libraries
import numpy as np
import pandas as pd
from keras.layers import *
from keras.models import *
from keras.preprocessing.image import ImageDataGenerator
from tensorflow import keras
from sklearn.preprocessing import MultiLabelBinarizer

warnings.filterwarnings('ignore')

%matplotlib inline

In [163]:
pd_data = pd.read_csv('../data/multilabel_modified/multilabel_classification_clean.csv')   # reading the csv file
pd_data.head()

Unnamed: 0,Image_Name,Classes,motorcycle,truck,boat,bus,cycle,person,desert,mountains,sea,sunset,trees,sitar,ektara,flutes,tabla,harmonium
0,image1.jpg,bus person,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0
1,image2.jpg,sitar,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0
2,image3.jpg,flutes,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
3,image4.jpg,bus trees,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0
4,image5.jpg,bus,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0


In [164]:
train_dir = "../data/multilabel_modified/train"
val_dir = "../data/multilabel_modified/validation"
test_dir = "../data/multilabel_modified/test"

# Creating Image Data Generator for train, validation and test set
def split_and_trim(labels):
    return [label.strip() for label in labels.split(' ')]

# Apply the function to the Labels column
pd_data['Classes'] = pd_data['Classes'].apply(split_and_trim)

mlb = MultiLabelBinarizer()
multi_hot_labels = mlb.fit_transform(pd_data['Classes'])

# Add multi-hot encoded labels to DataFrame
label_columns = mlb.classes_
for i, label in enumerate(label_columns):
    pd_data[label] = multi_hot_labels[:, i]

train_gen = ImageDataGenerator(rescale = 1.0/255.0) # Normalise the data
train_image_generator = train_gen.flow_from_dataframe(
    dataframe=pd_data,
    directory=train_dir,
    x_col='Image_Name',
    y_col=label_columns.tolist(),
    color_mode="rgb",
    class_mode="raw",
    target_size=(150, 150),
    batch_size=32  # to make this tutorial simple
)

val_gen = ImageDataGenerator(rescale = 1.0/255.0) # Normalise the data
val_image_generator = val_gen.flow_from_dataframe(
    dataframe=pd_data,
    directory=val_dir,
    x_col='Image_Name',
    y_col=label_columns.tolist(),
    color_mode="rgb",
    class_mode="raw",
    target_size=(150, 150),
    batch_size=32  # to make this tutorial simple
)


Found 4000 validated image filenames.
Found 1999 validated image filenames.


In [166]:
# Define a simple CNN model
model = Sequential([
    Conv2D(4, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(8, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.5),
    Dense(len(label_columns), activation='sigmoid')  # Multi-label classification
])

In [167]:
# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Print the model summary
model.summary()

Model: "sequential_22"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_50 (Conv2D)          (None, 148, 148, 4)       40        
                                                                 
 max_pooling2d_50 (MaxPooli  (None, 74, 74, 4)         0         
 ng2D)                                                           
                                                                 
 conv2d_51 (Conv2D)          (None, 72, 72, 8)         296       
                                                                 
 max_pooling2d_51 (MaxPooli  (None, 36, 36, 8)         0         
 ng2D)                                                           
                                                                 
 flatten_22 (Flatten)        (None, 10368)             0         
                                                                 
 dense_58 (Dense)            (None, 512)             

In [168]:
# Fit the model using the generator
early_stopping = keras.callbacks.EarlyStopping(patience=5) # Set up callbacks
hist = model.fit(train_image_generator, 
                 epochs=1, 
                 verbose=1, 
                 validation_data=val_image_generator, 
                 steps_per_epoch = 4000//32, 
                 validation_steps = 1999//32, 
                 callbacks=early_stopping)

