# Setup

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
project_id = 'stairnet-unlabeled'
!gcloud config set project {project_id}

In [None]:
! gsutil ls -al gs://

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

In [None]:
!mkdir data
!gcsfuse --implicit-dirs stairnet_unlabeled_bucket data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from PIL import Image
import gc
import math, re, os
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tqdm import tqdm
from collections import Counter
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, accuracy_score
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
def create_distribution_strategy():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
        strategy = tf.distribute.TPUStrategy(tpu)
    except ValueError: 
        strategy = tf.distribute.MirroredStrategy()
    print("Number of accelerators: ", strategy.num_replicas_in_sync)

    return strategy

strategy = create_distribution_strategy()

# Generating TFRecords

In [None]:
# Videos in the test split
TEST_VIDEOS = [
    'IMG_02_1', 
    'IMG_02_4', 
    'IMG_05_1', 
    'IMG_11_1', 
    'IMG_14_2', 
    'IMG_20_1',
]

# Mapping class to numeric representation
CLASS_MAP = {'IS': 0 , 'ISLG': 1, 'LG': 2, 'LGIS': 3}  

IMAGE_SIZE = 256
SEED_NUMBER = 42
AUTO = tf.data.experimental.AUTOTUNE

# Base folder path
FOLDER_PATH = 'data/StairNet/'

In [None]:
def get_video_number(file_name):
    '''
        parse video number from string
        input: [IMG_#_#] frame # #CLASS#.jpg
        output: IMG_#_#
    '''
    return file_name.split(' ')[0].replace("['", '').replace("']", '').replace("'", '')

def get_video_paths(folder_path, selected_videos):
    ''' for each sample generate full frame path '''
    img_paths = {vid_name: list() for vid_name in selected_videos}
    for class_path in os.listdir(folder_path):
        for img_sample in tqdm(os.listdir(os.path.join(folder_path, class_path))):
            curr_sample = os.path.join(folder_path, class_path, img_sample)
            #print(get_video_number(curr_sample).split('/')[-1])
            if get_video_number(curr_sample).split('/')[-1] in selected_videos:
                img_paths[get_video_number(curr_sample).split('/')[-1]].append(curr_sample)
    print()
    print(len(img_paths))
    return img_paths

In [None]:
img_paths = get_video_paths(FOLDER_PATH, TEST_VIDEOS)

In [None]:
for vid_name in img_paths.keys():
    print(vid_name, len(img_paths[vid_name]))

In [None]:
def load_image(filename, img_load='pil', img_size=(256, 256)):
  ''' reading frame from filename'''
  if img_load == 'cv2':
    img = cv2.imread(filename)
    img.resize(img_size)
  elif img_load == 'pil':
    img = Image.open(filename)
    img = img.resize(img_size, Image.ANTIALIAS)  
  return np.array(img)

def read_sample(img_path, img_size = 256):
  ''' reading sample from img_path'''
  img = load_image(img_path)
  labels = np.array(CLASS_MAP[img_path.split('/')[-1].split(' ')[-1].split('.')[0]])
  return img, labels

def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def generate_seq_dataset(img_arr):
    # iterate over video in dataset
    counter = 0
    for vid_name in list(img_arr.keys())[-1:]:
        print(vid_name)
        writer = tf.io.TFRecordWriter(f'drive/MyDrive/Baseline_samples/test_{vid_name}.tfrecord')
        for img in tqdm(img_arr[vid_name]): # for each sample in video generate sample
            np_sample, labels = read_sample(img, img_size=256)
            assert np_sample.shape == (256, 256, 3)
            feature = {
                'label': _bytes_feature(labels.tobytes()),
                'image': _bytes_feature(np_sample.tobytes())
            }
            # writing to tfrecord file
            tf_example = tf.train.Example(features = tf.train.Features(feature=feature))
            writer.write(tf_example.SerializeToString())
            del np_sample, labels, tf_example
            gc.collect()
            counter += 1
        writer.close()
        del writer
        gc.collect()
    return counter

In [None]:
counter = generate_seq_dataset(img_paths)
# counter = 12907 + 5573 + 11838 + 9364 + 11401 + 5652
print('Number of frames: ', counter)

# Dataset

In [None]:
# generated tfrecords of the test split videos
TEST_FILENAMES = [
    'drive/MyDrive/Baseline_samples/test_IMG_02_1.tfrecord',
    'drive/MyDrive/Baseline_samples/test_IMG_02_4.tfrecord',
    'drive/MyDrive/Baseline_samples/test_IMG_05_1.tfrecord',
    'drive/MyDrive/Baseline_samples/test_IMG_11_1.tfrecord',
    'drive/MyDrive/Baseline_samples/test_IMG_14_2.tfrecord',
    'drive/MyDrive/Baseline_samples/test_IMG_20_1.tfrecord'   
]

In [None]:
def decode_image(image_data):
    image = tf.io.decode_raw(image_data, tf.uint8)
    # image = tf.image.decode_jpeg(image_data, channels=3)
    # image = image_data
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) # explicit size needed for TPU
    image = tf.image.random_crop(value=image, size=(224, 224, 3))
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "label": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    #label = tf.cast(example['label'], tf.int32)
    label = tf.io.decode_raw(example['label'], 'int32')[0]
    return image, label # returns a dataset of (image, label) pairs

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    # for hyperparameter testing using random seed to shuffle the data the same to elimite this variable from results
    dataset = dataset.shuffle(counter//10, seed=SEED_NUMBER, reshuffle_each_iteration=None)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset 


BATCH_SIZE = 157
TEST_STEPS = -(-counter // BATCH_SIZE)             # The "-(-//)" trick rounds up instead of down :-)
print(' {} test images'.format(counter))
print('Number of test steps: ', TEST_STEPS)

# Loading model

In [None]:
!ls drive/MyDrive/supervised_pretraining/

In [None]:
# with tpu_strategy.scope():
#     model = tf.keras.models.load_model('./drive/MyDrive/supervised_pretraining/StairNet_v2.h5', compile=True)

MODEL_PATH = 'drive/MyDrive/supervised_pretraining/StairNet_v3.h5'
model = keras.models.load_model(MODEL_PATH)

# Inference

In [None]:
cmdataset = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(counter))).numpy() # get everything as one batch
cm_probabilities = model.predict(images_ds, steps=TEST_STEPS)
cm_predictions = np.argmax(cm_probabilities, axis=-1)
print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
print("Predicted labels: ", cm_predictions.shape, cm_predictions)

In [None]:
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASS_MAP)))
# score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASS_MAP)), average='weighted')
# precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASS_MAP)), average='weighted')
# recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASS_MAP)), average='weighted')
cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
accuracy = accuracy_score(cm_correct_labels, cm_predictions)
# display_confusion_matrix(cmat, score, precision, recall)
print('accuracy: {:.5f}, f1 score: {:.5f}, precision: {:.5f}, recall: {:.5f}'.format(accuracy, score, precision, recall))

In [None]:
ax = plt.subplot()
sns.heatmap(cmat, annot=True, fmt='.2%', cmap='Blues')
ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels');
ax.set_title(f'Normalized Confusion Matrix  \nAccuracy: {accuracy} F1: {score} \nPrecision: {precision} Recall: {recall}'); 
ax.xaxis.set_ticklabels(['IS', 'IS-LG', 'LG', 'LG-IS']); ax.yaxis.set_ticklabels(['IS', 'IS-LG', 'LG', 'LG-IS']);

plot_location = f"StairNetv3(Dima)_image_validation_test.jpg"
plt.savefig(plot_location)