In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from numpy import expand_dims
from keras.models import Model
import sklearn.metrics as metrics
import matplotlib.pyplot as pyplot
import sys
sys.path.append('..')
from src.data_augmentation import *
from src.folder_preparation import *
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator,array_to_img, img_to_array, load_img

### split original dataset into train and validation sets
#### <font color='red'> only run once!</font>

In [None]:
# empty_train_valid_split_directory('../simplifed-data-only-oranges')

In [None]:
train_dir = Path('./split_data-only-oranges/train')
val_dir = Path('./split_data-only-oranges/validation')

### augment all images by stitching them together
#### <font color='red'> only run once!</font>

In [None]:
# stitch_all_classes_in_root_directory(train_dir)

### augment classes with insufficient stitched images automatically
#### <font color='red'> only run once!</font>

In [None]:
# auto_augment_classes_in_root_directory(train_dir)

### visualize data population for for each class

#### <font color='red'> only run once!</font>

In [None]:
# #put only stitched and original images into a dataframe
# df_class_stitched = dataframe_root_directory(root_dir=train_dir)
# auto_augment_classes_in_root_directory(root_dir=train_dir)
# #put only stitched and original images into a dataframe
# df_class_all = dataframe_root_directory(root_dir=train_dir)

In [None]:
df_class_stitched = pd.read_csv('image_list.csv')
df_class_all = pd.read_csv('image_list_auto_stitched_original.csv')

In [None]:
plt.figure(figsize=(6,6))
ax = sns.countplot(x="class",hue = 'type',data=df_class_stitched)
ax.set_yscale('log')

In [None]:
plt.figure(figsize=(6,6))
ax = sns.countplot(x="class",hue = 'type',data=df_class_all)
ax.set_yscale('log')

#### create dataset

In [None]:
batch_size = 4
img_height = 50
img_width = 40

In [None]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    seed = 123,
    image_size = (img_height,img_width),
    batch_size=batch_size
)

In [None]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir,
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size
)

In [None]:
class_names = train_ds.class_names
print(class_names)

#### visualize the data

In [None]:
plt.figure(figsize=(10,10))
for images,labels in train_ds:
    for i in range(4):
        ax = plt.subplot(2,2,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
    break

### standardize the data

In [None]:
from tensorflow.keras import layers
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
normalized_train_ds = train_ds.map(lambda x, y: (normalization_layer(x),y))
normalized_val_ds = val_ds.map(lambda x, y: (normalization_layer(x),y))

### configure dataset for performance

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
train_ds = normalized_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = normalized_val_ds.cache().prefetch(buffer_size=AUTOTUNE)

### train a model

In [None]:
from tensorflow.keras import layers, models


In [None]:
num_classes = 10
model = tf.keras.Sequential([
#     layers.experimental.preprocessing.Rescaling(1./255),
    layers.Conv2D(32,3,activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(32,3,activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32,3,activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128,activation='relu'),
    layers.Dense(num_classes)
])

In [None]:
model.compile(
    optimizer='adam',
    loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)

In [None]:
train_ds = train_ds.apply(tf.data.experimental.ignore_errors())
val_ds = val_ds.apply(tf.data.experimental.ignore_errors())

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)

In [None]:
# model.save('save_your_model_here.h5')

In [None]:
# from keras.models import load_model

# model_new = load_model('auto_stitched_orig_scondModel_round1.h5')
# model_new.summary()

### visualize model

In [None]:
from keras.utils.vis_utils import plot_model

In [None]:
plot_model(model=model_new,to_file='../disp-images/architecture.png',show_shapes=True,
          show_layer_names=True)

In [None]:
from ann_visualizer.visualize import ann_viz

ann_viz(model_new,title="CNN")

### confusion matrix for validation

In [None]:
predictions = model.predict_generator(val_ds)

In [None]:
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
labels = np.array([x[1].numpy() for x in list(val_ds)])

In [None]:
ground_truth = np.array(list(np.concatenate(labels, axis=0 )))

In [None]:
confusion_matrix = metrics.confusion_matrix(y_true=ground_truth, y_pred=predicted_classes)  # shape=(12, 12)

In [None]:
cm_labels = list(np.sort(list(map(int,class_names))))

In [None]:
df_cm = pd.DataFrame(confusion_matrix,columns=cm_labels,index = cm_labels)

In [None]:
sns.set_theme(font_scale=1.4)
plt.figure(figsize=(10,10))
ax=sns.heatmap(df_cm.divide(df_cm.sum(axis=1),axis='rows'),
           cmap='Blues',
           annot=True,
              fmt='0.2f')
ax.set(xlabel='Predictions',ylabel='Truths')
plt.title('Confusion matrix')
plt.tight_layout()
# plt.savefig('../disp-images/validation_confusionMatrix.png')

### confusion matrix for training

In [None]:
predictions_train = model.predict_generator(train_ds)

In [None]:
predicted_classes_train = np.argmax(predictions_train, axis=1)

In [None]:
labels_train = np.array([x[1].numpy() for x in list(train_ds)])

In [None]:
ground_truth_train = np.array(list(np.concatenate(labels_train, axis=0 )))

In [None]:
confusion_matrix_train = metrics.confusion_matrix(y_true=ground_truth_train, y_pred=predicted_classes_train)  # shape=(12, 12)

In [None]:
cm_labels = list(np.sort(list(map(int,class_names))))

In [None]:
df_cm_train = pd.DataFrame(confusion_matrix_train,columns=cm_labels,index = cm_labels)

In [None]:
sns.set_theme(font_scale=1.4)
plt.figure(figsize=(10,10))
ax=sns.heatmap(df_cm_train.divide(df_cm_train.sum(axis=1),axis='rows'),
           cmap='Blues',
           annot=True,
              fmt='0.2f')
ax.set(xlabel='Predictions',ylabel='Truths')
plt.title('Confusion matrix')
plt.tight_layout()
# plt.savefig('../disp-images/training_confusionMatrix.png')

### visualizing training history

In [None]:
epochs = 10
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(1,epochs+1)


In [None]:
plt.figure(figsize=(8, 8))
# plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training')
plt.plot(epochs_range, val_acc, label='Validation')
plt.legend(loc='upper right')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.tight_layout()
# plt.savefig('../disp-images/training_validation_accuracy.png')

### visualize filters

In [None]:
model_new.summary()

In [None]:
for layer in model_new.layers:
    print(layer.name)

In [None]:
layer = model_new.layers[0]
print(layer.name)

In [None]:
filters,biases = layer.get_weights()
print(layer.name,filters.shape)

In [None]:
# summarize filter shapes
for layer in model_new.layers:
    # check for convolutional layer
    if 'conv' not in layer.name:
        continue
    # get filter weights
    filters, biases = layer.get_weights()
    print(layer.name, filters.shape)

In [None]:
f_min, f_max = filters.min(),filters.max()
filters = (filters-f_min)/(f_max-f_min)

In [None]:
plt.figure(figsize=(10,10))
# plot first few filters
n_filters, ix = 6, 1
for i in range(n_filters):
    f = filters[:, :, :, i]
    # plot each channel separately
    for j in range(3):
        # specify subplot and turn of axis
        ax = pyplot.subplot(n_filters, 3, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # plot filter channel in grayscale
        pyplot.imshow(f[:, :, j])
        ix += 1
plt.tight_layout()
plt.savefig('../disp-images/activation_filters.png')

### <font color='purple'>observation</font>
### <font color='purple'>*we see that light directions have been random which is consistent with our data gathering procedure*</font>

### visualize feature maps

In [None]:
plt.rcParams["axes.grid"] = False
plt.figure(figsize=(10,10))
image_batch, labels_batch = next(iter(normalized_train_ds))
first_image = image_batch[0]
plt.imshow(first_image)
plt.tight_layout()
# plt.savefig('../disp-images/sample_image.png')

In [None]:
#list all convolutional layers
for i in range(len(model.layers)):
    layer=model.layers[i]
    if 'conv' not in layer.name:
        continue
    # summarize output shape
    print(i, layer.name, layer.output.shape)

In [None]:
model_fm=Model(inputs=model_new.inputs,outputs=model_new.layers[0].output)

In [None]:
img = expand_dims(first_image,axis=0)

In [None]:
feature_maps=model_fm.predict(img)

### layer 1

In [None]:
plt.figure(figsize=(10,10))
dim1=8
dim2=4
ix = 1
for _ in range(dim1):
    for _ in range(dim2):
        # specify subplot and turn of axis
        ax = pyplot.subplot(dim1, dim2, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # plot filter channel in grayscale
        pyplot.imshow(feature_maps[0, :, :, ix-1])
        ix += 1
# show the figure
pyplot.show()

## visualize all blocks

In [None]:
# redefine model to output right after the first hidden layer
ixs = [0,2,4]
outputs = [model_new.layers[i].output for i in ixs]
model_fm = Model(inputs=model_new.inputs, outputs=outputs)
# get feature map for first hidden layer
feature_maps = model_fm.predict(img)

In [None]:
dim1=4
dim2=2
ix = 1
counter=0
for fmap in feature_maps:
    counter+=1
    plt.figure(figsize=(20,20))
    # plot all 64 maps in an 8x8 squares
    ix = 1
    for _ in range(dim1*dim2):
        # specify subplot and turn of axis
        ax = pyplot.subplot(dim1, dim2, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # plot filter channel in grayscale
        pyplot.imshow(fmap[0, :, :, ix-1])
        ix += 1
    print('---------------------------------------------')
    # show the figure
    plt.tight_layout()
    plt.savefig('../disp-images/feature_visualization'+str(counter)+'.png')