<a href="https://colab.research.google.com/github/ktcliff/AIPlantScanClone/blob/main/SequentialModelPoisoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
from matplotlib import pyplot as plt

print(tf.__version__)

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images [0:1000]
train_labels = train_labels [0:1000]
test_images = test_images [0:1000]
test_labels = test_labels [0:1000]

from matplotlib import pyplot as plt
plt.imshow(train_images[0])

train_images = train_images.reshape(-1, 28*28)/255.0
test_images = test_images.reshape(-1, 28*28)/255.0

In [None]:
def create_model(optimizer, activation_param):
    model = tf.keras.models.Sequential([
        keras.layers.Dense(512, activation=activation_param, input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10)
    ])

    model.compile(optimizer=optimizer,
                  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

In [None]:
# Training with clean data
model = create_model('adam', 'relu')

history_clean = model.fit(train_images,
                          train_labels,
                          epochs=10,
                          validation_data=(test_images, test_labels))

test_score_clean, accuracy_clean = model.evaluate(test_images, test_labels)
print("Clean Data Model Accuracy: ", accuracy_clean)

In [None]:
# Visualize training & validation accuracy and loss for clean data
acc_clean = history_clean.history['sparse_categorical_accuracy']
val_acc_clean = history_clean.history['val_sparse_categorical_accuracy']
loss_clean = history_clean.history['loss']
val_loss_clean = history_clean.history['val_loss']
epochs_range = range(10)

plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(epochs_range, acc_clean, label='Training Accuracy')
plt.plot(epochs_range, val_acc_clean, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy (Clean Data)')

plt.subplot(2, 2, 2)
plt.plot(epochs_range, loss_clean, label='Training Loss')
plt.plot(epochs_range, val_loss_clean, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss (Clean Data)')

In [None]:
def poison_data(training_labels, target_label, replacement_label, number_to_replace):
    counter = 0
    for i in range(len(training_labels)):
        if (training_labels[i] == target_label) & (counter <= number_to_replace):
            training_labels[i] = replacement_label
            counter += 1
    return training_labels

poisoned_labels = poison_data(train_labels.copy(), train_labels[0], train_labels[1], 200)

In [None]:
# Training with poisoned data
model_poisoned = create_model('adam', 'relu')
history_poisoned = model_poisoned.fit(train_images,
                                      poisoned_labels,
                                      epochs=10,
                                      validation_data=(test_images, test_labels))

test_score_poisoned, accuracy_poisoned = model_poisoned.evaluate(test_images, test_labels)
print("Poisoned Data Model Accuracy: ", accuracy_poisoned)

In [None]:
# Visualize training & validation accuracy and loss for poisoned data
acc_poisoned = history_poisoned.history['sparse_categorical_accuracy']
val_acc_poisoned = history_poisoned.history['val_sparse_categorical_accuracy']
loss_poisoned = history_poisoned.history['loss']
val_loss_poisoned = history_poisoned.history['val_loss']
epochs_range = range(10)

plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(epochs_range, acc_poisoned, label='Training Accuracy')
plt.plot(epochs_range, val_acc_poisoned, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy (Poisoned Data)')

plt.subplot(2, 2, 2)
plt.plot(epochs_range, loss_poisoned, label='Training Loss')
plt.plot(epochs_range, val_loss_poisoned, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss (Poisoned Data)')