# Train the CNN with the binary images, 
## Note : do not forget to include a neutral class


In [None]:
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import ModelCheckpoint

batch_size = 32
target_size_custom = (256, 256)
data_dir = r'malware_Image_richer_class' 
'''IMPORTANT
the github example does not provide enough
 Images Per Label so the effectiveness of 
 the model should be tested using a complete
   dataset such as https://www.kaggle.com/datasets/ikrambenabd/malimg-original
IMPORTANT'''

save_checkpoints_path = 'malware_model_checkpoints.h5'

datagen = ImageDataGenerator(rescale=1 / 255.0, validation_split=0.2)  

train_gen = datagen.flow_from_directory(
    directory=data_dir,
    target_size=target_size_custom,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=True,
    seed=42,
    subset="training"  
)

val_gen = datagen.flow_from_directory(
    directory=data_dir,
    target_size=target_size_custom,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=False, 
    seed=42,
    subset="validation"  
)

classes = train_gen.class_indices
num_classes = len(classes)

def malware_model():
    model = Sequential()
    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', input_shape=(target_size_custom[0], target_size_custom[1], 3)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(16, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.25))
    model.add(Dense(50, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=["accuracy"])
    return model

Malware_model = malware_model()
Malware_model.summary()

cp_callback = ModelCheckpoint(save_checkpoints_path + '.keras', verbose=1, monitor="val_accuracy", save_best_only=True)

history = Malware_model.fit(
    train_gen,
    epochs=10,
    validation_data=val_gen,  
    callbacks=[cp_callback]
)
