In [1]:
_ = !pip3 install -r requirements.txt

In [2]:
from utils import plot_learning_curves, build_data
from transformer import PatchExtract, PatchEmbedding, PatchMerging, SwinTransformer

In [3]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

In [4]:
import sklearn as skl
from sklearn.metrics import f1_score, accuracy_score

In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import *

In [6]:
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [7]:
train = build_data('train')
val = build_data('val')

In [8]:
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor = 'val_expr_output_loss', 
        patience = 3,
        restore_best_weights = True
    ),
    keras.callbacks.ModelCheckpoint(
        'models/swin.hdf5',
        monitor="val_expr_output_loss",
        save_weights_only=False,
        save_best_only=True,
        verbose=1,
        mode="auto",
        save_freq="epoch",
        period = 1
    )
]



In [9]:
EPOCHS = 1000
BATCH_SIZE = 16

In [10]:
input_shape = (224, 224, 3)

In [11]:
train_gen = ImageDataGenerator(
    rotation_range=10, # rotation
    width_shift_range=0.2, # horizontal shift
    height_shift_range=0.2, # vertical shift
    zoom_range=0.2, # zoom
    horizontal_flip=True, # horizontal flip
).flow_from_dataframe(
    dataframe = train,
    x_col = 'image_path',
    y_col = ['valence', 'arousal', 'expression'],
    target_size = (224, 224),
    color_mode = 'rgb',
    class_mode = 'multi_output',
    batch_size = BATCH_SIZE,
    shuffle = True
)

val_gen = ImageDataGenerator().flow_from_dataframe(
    dataframe = val,
    x_col = 'image_path',
    y_col = ['valence', 'arousal', 'expression'],
    target_size = (224, 224),
    color_mode = 'rgb',
    class_mode = 'multi_output',
    batch_size = BATCH_SIZE,
    shuffle = True
)

Found 287651 validated image filenames.
Found 3999 validated image filenames.


In [12]:
def build_swin(patch_size = (2, 2), dropout_rate = 0.2, num_heads = 8,
               embed_dim = 64, num_mlp = 256, qkv_bias = True, window_size = 2,
               shift_size = 1, image_dimension = 224):
    
    num_patch_x = input_shape[0] // patch_size[0]
    num_patch_y = input_shape[1] // patch_size[1]
    
    inp = Input(input_shape)
    extract = PatchExtract(patch_size)(inp)
    embed = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(extract)
    
    x = SwinTransformer(
        dim=embed_dim,
        num_patch=(num_patch_x, num_patch_y),
        num_heads=num_heads,
        window_size=window_size,
        shift_size=0,
        num_mlp=num_mlp,
        qkv_bias=qkv_bias,
        dropout_rate=dropout_rate,
        name = 'SwinBlock1'
    )(embed)
    
    x = SwinTransformer(
        dim=embed_dim,
        num_patch=(num_patch_x, num_patch_y),
        num_heads=num_heads,
        window_size=window_size,
        shift_size=0,
        num_mlp=num_mlp,
        qkv_bias=qkv_bias,
        dropout_rate=dropout_rate,
        name = 'SwinBlock2'
    )(x)
    
    x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
    x = BatchNormalization()(x)
    x = GlobalAveragePooling1D()(x)
    
    val_out = Dense(1, activation="relu", name = 'valn_output')(x)
    aro_out = Dense(1, activation="relu", name = 'aro_output')(x)
    exp_out = Dense(8, activation="softmax", name = 'expr_output')(x)
    
    model = keras.Model(inputs = inp, outputs = [val_out, aro_out, exp_out])
    model.compile(loss=["mse", "mse", "sparse_categorical_crossentropy"], optimizer='adam')
    
    return model

In [13]:
model = build_swin()

In [14]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
patch_extract (PatchExtract)    (None, 12544, 12)    0           input_1[0][0]                    
__________________________________________________________________________________________________
patch_embedding (PatchEmbedding (None, 12544, 64)    803648      patch_extract[0][0]              
__________________________________________________________________________________________________
SwinBlock1 (SwinTransformer)    (None, 12544, 64)    50072       patch_embedding[0][0]            
______________________________________________________________________________________________

In [15]:
c = {}
for w in model.weights:
    if w.name not in c:
        c[w.name] = 0
    c[w.name] += 1

for v in c:
    if c[v] > 1:
        print(v)

In [16]:
history = model.fit(
        x = train_gen,
        epochs = EPOCHS,
        validation_data = val_gen,
        callbacks = callbacks
    )

Epoch 1/1000


KeyboardInterrupt: 

In [23]:
model.save_weights('testw.h5')

In [24]:
model.load_weights('testw.h5')

In [25]:
model

<tensorflow.python.keras.engine.functional.Functional at 0x7f46c264a0b8>