In [None]:
import os
import datetime
import numpy as np

In [None]:
import tensorflow as tf
import tf_keras
from tf_keras.layers import Input,Dense
from tf_keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D,Conv2DTranspose,concatenate
from tf_keras.callbacks import ReduceLROnPlateau,EarlyStopping
from tf_keras.callbacks import LearningRateScheduler,ModelCheckpoint
import tf_keras.backend as K
from tf_keras.losses import categorical_crossentropy
from tf_keras.preprocessing.image import ImageDataGenerator
from tf_keras import layers

In [None]:
import tensorflow as tf
from tf_keras.layers import Input,Dense
from tf_keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D,Conv2DTranspose,Conv3D,DepthwiseConv2D,DepthwiseConv1D
from tf_keras.callbacks import ReduceLROnPlateau,EarlyStopping
from tf_keras.callbacks import LearningRateScheduler,ModelCheckpoint
import tf_keras.backend as K
from tf_keras.preprocessing.image import ImageDataGenerator
from tf_keras import layers

In [None]:
import tempfile
import tensorflow_model_optimization as tfmot
from tf_keras.applications.mobilenet_v2 import MobileNetV2


In [None]:
import glob
filename = glob.glob(r'DeepLearningBasedTBDiagnosis/dataset/TBX11K/imgs/tb/*.*')
print(len(filename))
filename = glob.glob(r'DeepLearningBasedTBDiagnosis/dataset/TBX11K/imgs/sick/*.*')
print(len(filename))
filename = glob.glob(r'DeepLearningBasedTBDiagnosis/dataset/TBX11K/imgs/health/*.*')
print(len(filename))

In [None]:
BATCH_SIZE = 16
IMG_SIZE = (256,256)

train_dir = 'dataset/TBX11K/imgs/train/'
test_dir = 'dataset/TBX11K/imgs/test/'

In [None]:
from dataset import define_data
from sklearn.preprocessing import LabelBinarizer
train = LabelBinarizer()

In [None]:
from dataset import define_data
train_data, train_y = define_data(train_dir,IMG_SIZE)
target_val = train.fit_transform(train_y)

In [None]:
test_size=0.15
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(np.array(train_data, ),target_val,test_size=test_size,random_state=42)

In [None]:
back_net = MobileNetV2
net = "MobileNetV2"

In [None]:
def model(input_shape):
    inputs = Input(shape=input_shape,name="input_image")
    base_model = back_net(input_tensor = inputs, weights="imagenet", include_top=False, alpha=0.35)
    #base_model.trainable=False
    x=base_model.output
    x=GlobalAveragePooling2D(name="gap")(x)
    output=Dense(3,activation="softmax")(x)
    return tf_keras.Model(inputs,output)

In [None]:
#the steps of this are vaguely unclear to me - i.e., do I need to compile again? is load_weights sufficient? time to research
base_model = model(input_shape=(256,256,3))
base_model.load_weights('mobilnet-output/class_weights.06-0.96.weights.h5')
base_model.compile(optimizer=tf_keras.optimizers.Adam(),
              loss=tf_keras.losses.CategoricalCrossentropy(),
              metrics=["accuracy"])

In [None]:
batch_size = 16
epochs = 10
#test_size exists
image_num = x_train.shape[0] * (1-test_size)
end_step = np.ceil(image_num / batch_size).astype(np.int32) * epochs
print(image_num, end_step)

In [None]:
pruning_params = {
    "pruning_schedule" : tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=0.2,
        begin_step=1,
        end_step=end_step
    )
}

In [None]:
def apply_pruning(layer):
    if isinstance(layer, tf_keras.layers.Conv2D):
        return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
    return layer

In [None]:
prune_model = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

In [None]:
prune_model.compile(
    optimizer=tf_keras.optimizers.Adam(),
    loss=tf_keras.losses.CategoricalCrossentropy(),
    metrics=['categorical_accuracy']
)

In [None]:
filepath=f'{back_net.__name__}/pruned/class_weights.{{epoch:02d}}-{{val_loss:.2f}}.weights.h5'

In [None]:
callbacks = [
    tf_keras.callbacks.ModelCheckpoint(
        filepath=filepath,
        save_best_only=True,save_weights_only=True,verbose=1),
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir="./logs")
]

In [None]:
prune_model.fit(
    x=x_train,
    y=y_train,
    validation_split=test_size,
    epochs=epochs,
    batch_size=batch_size,
    callbacks=callbacks,
)

In [None]:
_,model_for_prune_accuracy = prune_model.evaluate(x_test,y_test)
print(model_for_prune_accuracy)

In [None]:
predictions = prune_model.predict(x_test)

In [None]:
preds = np.zeros((y_test.shape[0]))
for i,p in enumerate(predictions):
    l = np.argmax(p)
    preds[i] = l

y = [np.argmax(yy) for yy in y_test]

In [None]:
model_for_export = tfmot.sparsity.keras.strip_pruning(prune_model)

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y,preds))

In [None]:
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
import matplotlib.pyplot as plt
cm = confusion_matrix(y,preds,labels=[0,1,2])
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=["NORM","NON-TB","TB"])
disp.plot()
plt.savefig(f"{back_net.__name__}-prune-confusionmatrix.png")
plt.show()