In [13]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import os
from tqdm import tqdm
from keras_preprocessing.image import img_to_array, load_img

In [14]:
def load_images(path, SIZE=(224, 224, 3)):
    malware, cl = [], []
    for TYPE in tqdm(os.listdir(path)):
        if TYPE == '.DS_Store':
            continue
        CLASS_PATH = os.path.join(path, TYPE)
        for IMG in os.listdir(CLASS_PATH):
            if IMG == '.DS_Store':
                continue
            IMG_PATH = os.path.join(CLASS_PATH, IMG)
            malware.append(img_to_array(load_img(IMG_PATH, target_size=SIZE)))
            cl.append(TYPE)
    return [np.asarray(malware), np.asarray(cl)]

In [15]:
byteplots_train, classes_train = load_images('/Applications/ML projects/Success/Blended Malware/Dataset/train')
byteplots_val, classes_val = load_images('/Applications/ML projects/Success/Blended Malware/Dataset/val')

100%|██████████| 32/32 [00:48<00:00,  1.50s/it]
100%|██████████| 32/32 [00:18<00:00,  1.75it/s]


In [16]:
train_file = '/Applications/ML projects/Distillation learning/Dataset/train_file.npz'
test_file = '/Applications/ML projects/Distillation learning/Dataset/test_file.npz'

np.savez_compressed(train_file, byteplots_train, classes_train)
print('Saved Dataset: ', train_file)
np.savez_compressed(test_file, byteplots_val, classes_val)
print('Saved Dataset: ', test_file)

Saved Dataset:  /Applications/ML projects/Distillation learning/Dataset/train_file.npz
Saved Dataset:  /Applications/ML projects/Distillation learning/Dataset/test_file.npz


In [58]:
from sklearn.preprocessing import LabelEncoder

In [59]:
def load_samples(file):
    dataset = np.load(file)
    byteplots, classes = dataset['arr_0'], dataset['arr_1']
    byteplots = byteplots / 255.0

    shuffle = np.arange(byteplots.shape[0])
    np.random.shuffle(shuffle)
    byteplots = byteplots[shuffle]
    classes = classes[shuffle]

    le = LabelEncoder()
    classes = le.fit_transform(classes)

    return byteplots, classes

In [60]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.student = student
        self.teacher = teacher

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature
    
    def train_step(self, data):
        byteplots, classes = data

        teacher_predictions = self.teacher(byteplots, training=False)

        with tf.GradientTape() as tape:
            student_predictions = self.student(byteplots, training=True)
            student_loss = self.student_loss_fn(classes, student_predictions)
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1)
                ) * self.temperature**2
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        trainable_variables = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_variables)

        self.optimizer.apply_gradients(zip(gradients, trainable_variables))
        self.compiled_metrics.update_state(classes, student_predictions)

        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        byteplots, classes = data
        student_prediction = self.student(byteplots, training=False)
        student_loss = self.student_loss_fn(classes, student_prediction)
        self.compiled_metrics.update_state(classes, student_prediction)
        
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [61]:
vgg16 = keras.applications.vgg16.VGG16(weights='imagenet', include_top=False, input_tensor=keras.Input(shape=(224, 224, 3)))
vgg16.trainable = False

teacher = keras.Sequential(
    [
        keras.Input(shape=(224, 224, 3)),
        vgg16,

        keras.layers.Flatten(),

        keras.layers.Dense(units=1024, activation='relu'),
        keras.layers.Dense(units=256, activation='relu'),
        keras.layers.Dense(units=31)

    ]
)

In [62]:
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [63]:
student = keras.Sequential(
    [
        keras.Input(shape=(224, 224, 3)),
        
        keras.layers.Conv2D(filters=32, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Conv2D(filters=64, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Conv2D(filters=128, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Conv2D(filters=256, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Conv2D(filters=512, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Conv2D(filters=512, kernel_size=(4, 4), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(alpha=0.2),

        keras.layers.Flatten(),

        keras.layers.Dense(units=1024, activation='relu'),
        keras.layers.Dense(units=256, activation='relu'),
        keras.layers.Dense(units=31)
    ]
)

In [64]:
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

In [65]:
student_scratch = keras.models.clone_model(student)

In [66]:
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [67]:
train_byteplots, train_classes = load_samples(train_file)

In [68]:
teacher_history = teacher.fit(train_byteplots, train_classes, epochs=5, batch_size=32)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [69]:
teacher_path = '/Applications/ML projects/Distillation learning/Models/teacher.h5'

In [70]:
teacher.save(teacher_path)

In [71]:
distiller_history = distiller.fit(train_byteplots, train_classes, epochs=5, batch_size=32)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [72]:
student_path = '/Applications/ML projects/Distillation learning/Models/student.h5'

In [73]:
student.save(student_path)



In [74]:
student_scratch_history = student_scratch.fit(train_byteplots, train_classes, epochs=5, batch_size=32)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [75]:
student_scratch_path = '/Applications/ML projects/Distillation learning/Models/student_scratch.h5'

In [76]:
student_scratch.save(student_scratch_path)

In [77]:
test_byteplots, test_classes = load_samples(test_file)

In [78]:
teacher.evaluate(test_byteplots, test_classes)



[0.30547896027565, 0.9234338998794556]

In [79]:
distiller.evaluate(test_byteplots, test_classes)



[0.9164733290672302, 0.22709810733795166]

In [80]:
student_scratch.evaluate(test_byteplots, test_classes)



[0.7355527281761169, 0.8471255302429199]