In [None]:
import os
import io
import imageio
#import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import glob
import pandas as pd
from matplotlib import pyplot as plt
import random 
import math 
# Setting seed for reproducibility
SEED = 42

os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)
random.seed(SEED)
RESULT_PATH = "results_m"

## Load Data

In [None]:
IMG_SIZE = 24
BATCH_SIZE = 32
FRAME = 10
CHANNEL = 5

In [None]:
IMG_SIZE = 24
BATCH_SIZE = 32
FRAME = 10
CHANNEL = 5


train_flist = [] 
test_flist = [] 
val_flist = [] 


for cpath, dirs, files in os.walk('HUST-LEBW/training'):
    if '.jpg' in ''.join(files):
        train_flist += [ cpath + '/' + f for f in files if 'checkpoint' not in cpath]
            
for cpath, dirs, files in os.walk('HUST-LEBW/test'):
    if  '.jpg' in ''.join(files):
        val_flist += [ cpath + '/' + f for f in files if 'checkpoint' not in cpath]
        
for cpath, dirs, files in os.walk('MAEB'):
    if '.jpg' in ''.join(files):
        test_flist += [ cpath + '/' + f for f in files if 'checkpoint' not in cpath and np.sum(cv.imread(cpath + '/' + f)) !=0 ]
            
random.shuffle(train_flist)

test_flist = test_flist + val_flist


train_lb = [ 1 if f.split('/')[2] == 'blink' or f.split('/')[2] == 'close' or f.split('/')[3] == 'close' else 0 for f in train_flist ]
val_lb = [ 1 if f.split('/')[2] == 'blink' or f.split('/')[2] == 'close' or f.split('/')[3] == 'close' else 0 for f in val_flist ]
test_lb = [ 1 if f.split('/')[2] == 'blink' or f.split('/')[2] == 'close' or f.split('/')[3] == 'close' else 0 for f in test_flist ]

len(train_flist), len(val_flist),len(test_flist), len(train_lb), len(val_lb), len(test_lb)

In [None]:
# DATA
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, FRAME, CHANNEL)
NUM_CLASSES = 1

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 100

# TUBELET EMBEDDING
PATCH_SIZE = (4, 4, 3)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 2
NUM_LAYERS = 1

In [None]:
@tf.function
def get_img(path, mode='train', trans = [0, 2, 1, 3]):
    
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.rgb_to_grayscale(img)

    if mode =='train':
        img = tf.image.random_flip_left_right(img)
        
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.per_image_standardization(img)

    img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE*FRAME))
    img = tf.reshape(img, (IMG_SIZE, FRAME,IMG_SIZE, 1))
    img = tf.transpose(img, trans)
    
    # Make an input of residual embedding
    img2 = (img - img[:,:,0:1, :])
    img3 = (img - img[:,:,9:, :])
    img4 = (img - img[:,:,4:5, :])
    img5 = (img - img[:,:,5:6, :])

    return tf.concat([img, img2, img3, img4, img5], axis=-1)

def get_dataset(flist , lbs, mode = 'train', trans = [0 , 2, 1, 3]): # [1, 0, 2, 3]

    dataset = tf.data.Dataset.from_tensor_slices( flist )
    dataset_lb = tf.data.Dataset.from_tensor_slices( lbs )
    
    dataset = dataset.map(lambda x: get_img(x, mode, trans), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = tf.data.Dataset.zip((dataset, dataset_lb))

    if mode == 'train':
        dataset = dataset.repeat()

    dataset = dataset.batch(BATCH_SIZE)      
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

In [None]:
train_steps_per_epoch = math.ceil(len(train_flist) / BATCH_SIZE)
test_steps_per_epoch = math.ceil(len(test_flist) / BATCH_SIZE)
val_steps_per_epoch = math.ceil(len(val_flist) / BATCH_SIZE)

## Video Vision Transformer

In [None]:
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.projection = tf.keras.layers.Conv3D(32, (3,3,3), padding='same', activation='relu')
        self.pool = tf.keras.layers.MaxPool3D((2,2,1))
        self.projection2 = tf.keras.layers.Conv3D(64, (3,3,3), padding='same', activation='relu')
        self.pool2 = tf.keras.layers.MaxPool3D((2,2,1))
        self.projection3 = tf.keras.layers.Conv3D(128, (3,3,3), padding='same')
        self.pool4 = tf.keras.layers.AveragePooling3D((3,3,1), strides=(3,3,1))

        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        
        x = self.projection(videos)    
        x = self.pool(x)
        x = self.projection2(x)    
        x = self.pool2(x)
        x = self.projection3(x)    
        x = self.pool4(x)

        flattened_patches = self.flatten(x)
        return flattened_patches

    
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _,  num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens
    
def create_vivit_classifier(
    tubelet_embedder,
    tubelet_embedder2,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    num_classes=NUM_CLASSES,
):
    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches.
    
    patches1 = tubelet_embedder(inputs[:,:,:,:, 0:1])# Tubelet embedding's output
    patches2 = tubelet_embedder2(inputs[:,:,:,:, 1:])# Residual embedding's output
    
    patches = tf.keras.layers.Concatenate()([patches1, patches2])
    # Encode patches.
    encoded_patches = positional_encoder(patches)

    atten = []
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization and MHSA
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output, attention_output_score = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=(2*embed_dim) // num_heads, dropout=0.1, 
        )(x1, x1, return_attention_scores=True)

        # Skip connection
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer Normalization and MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = keras.Sequential(
            [
                layers.Dense(units=embed_dim*2, activation=tf.nn.gelu),
            ]
        )(x3)

        # Skip connection
        encoded_patches = layers.Add()([x3, x2])

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(units=num_classes)(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model



## Train

In [None]:
def vit(fid):
    # Initialize model
    model = create_vivit_classifier(
        tubelet_embedder=TubeletEmbedding(
            embed_dim=PROJECTION_DIM
        ),
        tubelet_embedder2=TubeletEmbedding(
            embed_dim=PROJECTION_DIM
        ),
        positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM*2),
    )

    train_dataset = get_dataset(train_flist, train_lb, mode = 'train', trans = [0 , 2, 1, 3])
    test_dataset = get_dataset(test_flist, test_lb, mode = 'test', trans =[0 , 2, 1, 3])
    val_dataset = get_dataset(val_flist, val_lb, mode = 'test',   trans =[0 , 2, 1, 3])

    SEED = 42
    os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
    keras.utils.set_random_seed(SEED)
    random.seed(SEED)
    
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath='model/VIT_{}'.format(fid),save_best_only=True,
            monitor='val_accuracy',
            mode='max',)

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    # Train the model.
    _ = model.fit(train_dataset, steps_per_epoch=train_steps_per_epoch, epochs=EPOCHS, validation_data=val_dataset, validation_steps=val_steps_per_epoch, callbacks=[model_checkpoint_callback],)

    best_model = keras.models.load_model('model/VIT_{}'.format(fid))

    preds = tf.nn.sigmoid(best_model.predict(test_dataset)).numpy().ravel()
    pd.DataFrame( np.array([test_flist, preds]).T, columns=['path', 'pred']).to_csv('{}/VIT_{}.csv'.format(RESULT_PATH, fid),index=False)

In [None]:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

# Cross validation
for fid in range(1,6):
    vit(fid)


In [None]:
sub_model = tf.keras.Model(model.input, model.layers[-8].output)

In [None]:
tes = []
for t in test_dataset:
    tes = t
    break

In [None]:
res = sub_model.predict([ tes[0] ] )
k = 16
print(test_flist[k])
t = res[1][k][5, 0, :].reshape(2, 2, 10)
plt.imshow(cv.cvtColor(cv.imread(test_flist[k]), cv.COLOR_RGB2BGR))

In [None]:
plt.plot(t.reshape((-1, 10)).sum(axis=0))