In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras
import keras_nlp
import keras_cv
from keras import ops

keras.utils.set_random_seed(seed)

import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

from keras import Input, Model
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import *
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
print("SOTA model implementation, unofficial model zoo\n+=+=+=+=+=+=+=+=+=+=+=+=+=\n")
print(f"Requirements loaded, keras : v{keras.__version__}, Tensorflow : v{tf.__version__}")

2024-04-22 12:11:47.062650: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-22 12:11:47.062833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-22 12:11:47.216398: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


SOTA model implementation, unofficial model zoo
+=+=+=+=+=+=+=+=+=+=+=+=+=

Requirements loaded, keras : v3.2.1, Tensorflow : v2.15.0


# Base helper functions(layers)

In [2]:
class AttentionPooling(keras.layers.Layer):
    def __init__(self, attention_heads, attention_dims = None, bias = False, scale = None, 
                 dropout_rate = 0.05, **kwargs):
        super().__init__(**kwargs)
        self.n_heads = attention_heads
        self.n_dims = attention_dims
        self.bias = bias
        self.scale = scale
        self.dropout_rate = dropout_rate
        
    def build(self, input_shape):
        # query, key
        query_dims = input_shape[0][-1]
        key_length = input_shape[1][1]
        
        if self.n_dims == None:
            embed_dims = query_dims
        else:
            embed_dims = self.n_dims
        self.embed_dims = embed_dims
        self.per_head_dims = embed_dims//self.n_heads
        
        self.scale = self.scale if self.scale != None else embed_dims**-0.5
        
        self.query_embed_fn = keras.layers.Dense(units = embed_dims, use_bias = self.bias,
                                               name = "Q_Embedding_Dense_layer")
        self.key_embed_fn = keras.layers.Dense(units = embed_dims, use_bias = self.bias,
                                               name = "K_Embedding_Dense_layer")
        self.value_embed_fn = keras.layers.Dense(units = embed_dims, use_bias = self.bias,
                                               name = "V_Embedding_Dense_layer")
        
        self.softmax = keras.layers.Activation("softmax", name = "AttentionWeightSoftmax")
        self.proj = keras.layers.Dense(units = query_dims, use_bias = self.bias, 
                                      name = "ProjectToOriginalDimension")
        self.att_dropout = keras.layers.Dropout(self.dropout_rate)
        self.proj_dropout = keras.layers.Dropout(self.dropout_rate)
        super().build(input_shape)
        
    def call(self, inputs, **kwargs):
        if len(inputs) == 2:
            q, k = inputs
            value_ = False
        elif len(inputs) == 3:
            q, k, v = inputs
            value_ = True
            
        if len(ops.shape(q)) == 2:
            q = q[:, tf.newaxis, :]
        if len(ops.shape(k)) == 4:
            b_, w_, h_, dims_ = ops.shape(k)
            k = ops.reshape(k, [b_, w_*h_, dims_])
        if (value_) and (len(ops.shape(v))) == 4:
            b_, w_, h_, dims_ = ops.shape(v)
            v = ops.reshape(v, [b_, w_*h_, dims_])
            
        batch_size, query_length, q_dims = ops.shape(q)
        _, key_length, k_dms = ops.shape(k)
        
        query = self.query_embed_fn(q) * self.scale #batch, 1(or, query length), q_dims
        query = ops.reshape(query, [batch_size, query_length, self.n_heads, self.per_head_dims])
        
        key = self.key_embed_fn(k)
        key = ops.reshape(key, [batch_size, key_length, self.n_heads, self.per_head_dims])
        
        if value_:
            value = self.value_embed_fn(v)#각각 batch, token_length, heads, per_head_dims
        else:
            value = self.value_embed_fn(k)
        value = ops.reshape(value, [batch_size, key_length, self.n_heads, self.per_head_dims])
        attention_score = keras.ops.einsum("abhd, achd -> ahbc",
                                          query, key) #b = query length, c = key length
        attention_weight = self.softmax(attention_score)
        attention_weight = self.att_dropout(attention_weight)
        self.attention_weight = attention_weight 
        attended_output = keras.ops.einsum("ahbc, achd -> abhd",
                                          attention_weight, value)
        attended_output = keras.ops.reshape(attended_output, 
                                           [batch_size, query_length, self.n_heads*self.per_head_dims]
                                           )
        attended_output = self.proj(attended_output)
        attended_output = self.proj_dropout(attended_output)
        if keras.ops.shape(attended_output)[1] == 1:
            attended_output = keras.ops.squeeze(attended_output, axis = 1)
        return attended_output, attention_weight

# GC-ViT

In [3]:
class PatchEncoder(keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = keras.layers.Dense(units=projection_dim)
        self.position_embedding = keras.layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

In [4]:
class SE(keras.layers.Layer):
    def __init__(self, output_dim = None, squeeze_rate = 0.25, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
        self.rate = squeeze_rate
    def build(self, input_shape) : #batch_size, h, w, dims
        if self.output_dim == None:
            self.output_dim = input_shape[-1]
        else:
            pass
        self.avg_pool = keras.layers.GlobalAveragePooling2D(keepdims = True, name = "AvgPooling")
        self.mlps = keras.Sequential([keras.layers.Dense(units = int(self.rate * self.output_dim),
                                                        use_bias = False, name = "Dense1"),
                                      keras.layers.Activation("gelu", name = "GeluAct"),
                                      keras.layers.Dense(units = self.output_dim, use_bias = False, name = "Dense2"),
                                      keras.layers.Activation("sigmoid", name = "Excitation_Sigmoid")
                                     ])
        #super().build(input_shape)
    def call(self, inputs, **kwargs):
        pooled = self.avg_pool(inputs)
        weights = self.mlps(pooled)
        return inputs * weights
    
class DownSampler(keras.layers.Layer):
    def __init__(self, keepdims = False, **kwargs):
        super().__init__(**kwargs)
        self.keepdims = keepdims
    def build(self, input_shape):
        embed_dims = input_shape[-1]
        out_dim = embed_dims if self.keepdims else 2*embed_dims
        self.fused_mbconv = keras.Sequential([keras.layers.DepthwiseConv2D(kernel_size = 3, padding = 'same', use_bias = False, name = "DWConv"),
                                             keras.layers.Activation("gelu", name = 'GeluAct'),
                                             SE(name = "SqueezeAndExcitation2D"),
                                             keras.layers.Conv2D(filters = embed_dims, kernel_size = 1, padding = 'same', use_bias = False, name = "PointWiseConv")],
                                            name = "Fused_MBConvLayer")
        self.down_conv = keras.layers.Conv2D(filters = out_dim, kernel_size = 3, strides = 2, padding = 'same', use_bias = False, name = "DownConvolution")
        self.layernorm1 = keras.layers.LayerNormalization(epsilon = 1e-5, name = 'LayerNorm1')
        self.layernorm2 = keras.layers.LayerNormalization(epsilon = 1e-5, name = 'LayerNorm2')
    def call(self, inputs, **kwargs):
        x = self.layernorm1(inputs)
        x += self.fused_mbconv(inputs)
        x = self.down_conv(x)
        return self.layernorm2(x)
    
class MLP(keras.layers.Layer):
    def __init__(self, middle_dim = None, output_dim = None,
                activation = 'gelu', dropout = 0.2, **kwargs):
        super().__init__(**kwargs)
        self.middle_dim = middle_dim
        self.output_dim = output_dim
        self.activation = activation
        self.dropout_rate = dropout
    def build(self, input_shape):
        self.input_dims = input_shape[-1]
        self.middle_dim = int(1.5*self.input_dims) if self.middle_dim == None else self.middle_dim
        self.output_dim = self.input_dims if self.output_dim == None else self.output_dim
        self.mlp1 = keras.layers.Dense(units = self.middle_dim, name = "FirstMLP")
        self.act = keras.layers.Activation(self.activation, name = "MiddleActivation")
        self.mlp2 = keras.layers.Dense(units = self.output_dim, name = "SecondMLP")
        self.drop1 = keras.layers.Dropout(self.dropout_rate, name = "Dropout1")
        self.drop2 = keras.layers.Dropout(self.dropout_rate, name = "Dropout2")
    def call(self, inputs, **kwargs):
        x = self.mlp1(inputs)
        x = self.act(x)
        x = self.drop1(x)
        x = self.mlp2(x)
        x = self.drop2(x)
        return x
    
class PatchEmbedding(keras.layers.Layer):
    def __init__(self, embed_dim, patching_type = "conv", #conv or tokenlearner
                 **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.patching_type = patching_type
    def build(self, input_shape):
        if (self.patching_type == "tokenlearner") or (self.patching_type == "token_learner"):
            self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size = 3, strides = 2, padding = 'same', name = "projection_conv") #Overlapping patches
            #token learner implementation
            batch_size, w, h, filters = input_shape
            n_tokens = (w//4) * (h//4) ; self.resized_w, self.resized_h = int(w//4), int(h/4)
            self.input_seq_flatten = keras.layers.Reshape([1, -1, self.embed_dim], name = "image_to_sequence_reshape")
            self.layer_norm = keras.layers.LayerNormalization(epsilon = 1e-5)
            self.attention_ops = keras.Sequential([keras.layers.Conv2D(n_tokens, kernel_size = 3, activation = "gelu", use_bias = False, padding = 'same'),
                                                  keras.layers.Conv2D(n_tokens, kernel_size = 3, activation = "gelu", use_bias = False, padding = 'same'),
                                                  keras.layers.Conv2D(n_tokens, kernel_size = 3, activation = "sigmoid", use_bias = False, padding = 'same'),
                                                  keras.layers.Reshape([-1, n_tokens]), #batch_size, HW, n_tokens
                                                  keras.layers.Permute([2,1])], #batch_size, n_tokens, HW
                                                 name = "Conv_for_attention_weight")
        else:
            self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size = 3, strides = 2, padding = 'same') #Overlapping patches
            self.down_sample = DownSampler(keepdims = True, name = "DownSampler_after_projection")
    def call(self, inputs, **kwargs):
        if (self.patching_type == "tokenlearner") or (self.patching_type == "token_learner"):
            #token learner implementation
            norm_input = self.layer_norm(inputs)
            proj_inputs = self.proj(norm_input)
            seq_inputs = self.input_seq_flatten(proj_inputs) #batch, 1, HW, embed_dims
            att_weights = self.attention_ops(proj_inputs) #batch, n_tokens, HW
            att_weights = ops.expand_dims(att_weights, axis = -1) #batch, n_tokens, HW, 1
            attended = att_weights * seq_inputs #batch, n_tokens, HW, embed_dims
            attended = ops.mean(attended, axis = 2) #batch, n_tokens, embed_dims
            #reshape to 2D array
            attended = ops.reshape(attended, [-1, self.resized_w, self.resized_h, self.embed_dim]
                                  )
            return attended
        else:
            x = self.proj(inputs)
            x = self.down_sample(x)
            return x
        
class FeatExtract(keras.layers.Layer):
    def __init__(self, keepdims = False, **kwargs):
        super().__init__(**kwargs)
        self.keepdims = keepdims
    def build(self, input_shape):
        batch_size, H, W, embed_dims = input_shape
        self.fused_mbconv = keras.Sequential([keras.layers.DepthwiseConv2D(kernel_size = 3, padding = 'same', use_bias = False, name = "DWConv"),
                                             keras.layers.Activation("gelu", name = 'GeluAct'),
                                             SE(name = "SqueezeAndExcitation2D"),
                                             keras.layers.Conv2D(filters = embed_dims, kernel_size = 1, padding = 'same', use_bias = False, name = "PointWiseConv")],
                                            name = "Fused_MBConvLayer")
        if self.keepdims == False:
            self.pool = keras.layers.MaxPooling2D(name = "FeatExtractMaxPool2D")
    def call(self, inputs):
        x = inputs + self.fused_mbconv(inputs)
        if self.keepdims == False:
            return self.pool(x)
        return x
    
class GlobalQueryGenerator(keras.layers.Layer):
    def __init__(self, keepdims = False, **kwargs):
        super().__init__(**kwargs)
        self.keepdims = keepdims #Keepdims는 여기서 0과 1로 이루어진 list도 될 수 있다 -> FeatExtract layer를 keepdims의 원소 갯수만큼 repeat!
    def build(self, input_shape):
        self.q_generator = keras.Sequential([FeatExtract(keepdims = keepdim, name = f"FeatureExtraction_{idx+1}") for idx, keepdim in enumerate(self.keepdims)])
    def call(self, inputs):
        return self.q_generator(inputs)
    
class WindowAttention(keras.layers.Layer):
    def __init__(self, window_size, n_heads, global_query, #제공된다면 global, 아니라면 local mHSA -> 0 or 1
                qkv_bias = True, qk_scale = None,
                dropout_rate = 0.05, return_attention_weights = False, **kwargs):
        super().__init__(**kwargs)
        self.window_size = (window_size, window_size)
        self.n_heads = n_heads
        self.global_query = global_query
        self.bias = qkv_bias
        self.scale = qk_scale
        self.dropout_rate = dropout_rate
        self.return_attention_weights = return_attention_weights
    def build(self, input_shape) :
        #input = [query, key, value]
        embed_dims = input_shape[0][-1]
        head_dims = embed_dims//self.n_heads
        self.scale = self.scale if self.scale != None else embed_dims**-0.5
        self.qkv_size = 3-int(self.global_query)
        self.qkv_embed_fn = keras.layers.Dense(units = embed_dims * self.qkv_size, use_bias = self.bias,
                                 name = "QKV_Embedding_Dense_layer")
        self.softmax = keras.layers.Activation("softmax", name = "AttentionWeightSoftmax") #for attention weight computation
        self.proj = keras.layers.Dense(units = embed_dims, use_bias = self.bias, name = "Projection")
        self.attention_dropout = keras.layers.Dropout(self.dropout_rate)
        self.projection_dropout = keras.layers.Dropout(self.dropout_rate)
        self.relative_position_bias_table = self.add_weight(
            name="relative_position_bias_table",
            shape=[
                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
                self.n_heads,
            ],
            initializer=keras.initializers.TruncatedNormal(stddev=0.02),
            trainable=True,
            dtype=self.dtype,
        ) #<- learnable weight of relational position. 위 window size = 4의 예시에서, 임의의 두 지점 간 거리의 경우의 수는 총 49개 -> 이에 해당하는 weight tensor를 만듬.
        super().build(input_shape)
    def get_relative_position_index(self): #<- window 내 2 지점 간 거리의 index matrix.
        coords_h = ops.arange(self.window_size[0])
        coords_w = ops.arange(self.window_size[1])
        coords = ops.stack(ops.meshgrid(coords_h, coords_w, indexing="ij"), axis=0)
        coords_flatten = ops.reshape(coords, [2, -1])
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = ops.transpose(relative_coords, axes=[1, 2, 0])
        relative_coords_xx = relative_coords[:, :, 0] + self.window_size[0] - 1
        relative_coords_yy = relative_coords[:, :, 1] + self.window_size[1] - 1
        relative_coords_xx = relative_coords_xx * (2 * self.window_size[1] - 1)
        relative_position_index = relative_coords_xx + relative_coords_yy
        return relative_position_index
    def call(self, inputs, **kwargs):
        #input : key(=value), global query OR key only
        # input component shape : batch*n_windows, h_window*w_window, embed_dims -> level/block 설계 시 repeat 처리 후 attention에 feed
        if self.global_query :
            inputs, q_global = inputs
            batch_size = ops.shape(q_global)[0]
        else:
            inputs = inputs[0]
        batch_, token_length, embed_dims = ops.shape(inputs) #global query는 query generator에 의해 token_length 개 만큼의 token으로 전체 이미지/feature map을 압축한 상태
        
        qkv = self.qkv_embed_fn(inputs) #batch*n_windows, h_w * w_w, qkv_size * embed_dims
        
        qkv = ops.reshape(qkv, [-1, token_length, self.qkv_size, self.n_heads, embed_dims//self.n_heads])
        qkv = ops.transpose(qkv, [2, 0, 3, 1, 4]) #qkv_size, batch_, n_heads, token_length, C
        
        #QKV 분리
        if self.global_query:
            k, v = ops.split(qkv, 2, axis = 0) #각각 batch_, n_heads, token_length, C
            #repeat the global query tensor
            # batch_size, n_query_tokens, dims -> batch_(=batch * n_windows), n_query_tokens, dims
            q_global = ops.repeat(q_global, batch_//batch_size, axis = 0) #->batch_, n_query_tokens, dims
            q = ops.reshape(q_global, [batch_, token_length, self.n_heads, embed_dims//self.n_heads])
            q = ops.transpose(q, [0, 2, 1, 3]
                             )
        else:
            q, k, v = ops.split(qkv, 3, axis = 0)
            q = ops.squeeze(q, axis = 0)
        k = ops.squeeze(k, axis = 0)
        v = ops.squeeze(v, axis = 0)
        
        q *= self.scale #batch_, n_heads, token_length, dimension_per_heads(=C)
        attention_score = q@ops.transpose(k, [0, 1, 3, 2]) #batch_, n_heads, token_length, token_length
        
        #positional encoding(bias) 계산 -> attention score에 더해 주기
        # Code from original keras homepage
        relative_position_bias = ops.take(
            self.relative_position_bias_table,
            ops.reshape(self.get_relative_position_index(), [-1]),
        )
        relative_position_bias = ops.reshape(
            relative_position_bias,
            [
                self.window_size[0] * self.window_size[1],
                self.window_size[0] * self.window_size[1],
                -1,
            ],
        )
        relative_position_bias = ops.transpose(relative_position_bias, axes=[2, 0, 1])
        attention_score += relative_position_bias[None,]
        attention_weight = self.softmax(attention_score)
        attention_weight = self.attention_dropout(attention_weight) #batch_, n_heads, token_length, token_length
        #value tensor shape : batch_, n_heads, token_length, dimension_per_heads(=C)
        attended_output = attention_weight@v
        attended_output = ops.transpose(attended_output, [0, 2, 1, 3])
        attended_output = ops.reshape(attended_output, [batch_, token_length, embed_dims])
        attended_output = self.projection_dropout(self.proj(attended_output))
        self.attention_weight = attention_weight
        if self.return_attention_weights:
            return attended_output, attention_weight
        else:
            return attended_output
        
        
class Block(keras.layers.Layer):
    def __init__(self, #이하는 Window Attention configurations
                 window_size, num_heads, global_query, 
                 qkv_bias = True, qk_scale = None, dropout_rate = 0.05, 
                 # 이하는 MLP module의 configuration
                 mlp_ratio = 4.0, layer_scale = None, return_attention_weights = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.window_size = window_size
        self.n_heads = num_heads
        self.global_query = global_query
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.dropout_rate = dropout_rate
        
        self.mlp_ratio = mlp_ratio
        self.layer_scale = layer_scale
        self.return_attention_weights = return_attention_weights
    def build(self, input_shape):
        #input tensor : list of key/query or key only
        # each tensor is batch_size, w, h, channel dims shape tensor
        batch_size, H, W, dims = input_shape[0]
        self.norm1 = keras.layers.LayerNormalization(epsilon = 1e-5)
        self.norm2 = keras.layers.LayerNormalization(epsilon = 1e-5)
        self.window_attention = WindowAttention(window_size = self.window_size,
                                               n_heads = self.n_heads,
                                               global_query = self.global_query,
                                               qkv_bias = self.qkv_bias,
                                               qk_scale = self.qk_scale,
                                               dropout_rate = self.dropout_rate,
                                                return_attention_weights = self.return_attention_weights)
        self.mlps = MLP(middle_dim = int(self.mlp_ratio * dims), dropout = self.dropout_rate)
        if self.layer_scale != None:
            self.gamma1 = self.add_weight(shape = [dims], name = "Gamma1", trainable = True,
                                         initializer = keras.initializer.Constant(self.layer_scale), dtype = self.dtype)
            self.gamma2 = self.add_weight(shape = [dims], name = "Gamma2", trainable = True,
                                         initializer = keras.initializer.Constant(self.layer_scale), dtype = self.dtype)
        else:
            self.gamma1, self.gamma2 = 1.0, 1.0
        self.n_windows = int(H//self.window_size) * int(W//self.window_size)
        
    #input feature map을 일정 크기의 window로 partition을 만들어주는 함수 및
    # 그 partition을 받아 원래의 feature map으로 돌려주는 함수를 만들자
    def window_partition(self, inputs): #feature map -> multiple windows
        batch_size, H, W, dims = ops.shape(inputs)
        h, w = H//self.window_size, W//self.window_size
        inputs = ops.reshape(inputs, [batch_size, 
                                      h, self.window_size,
                                     w, self.window_size, 
                                     dims])
        inputs = ops.transpose(inputs, [0,#batch_size
                                        1,3, #h, w
                                        2,4, #winsize, winsize
                                        5])
        return ops.reshape(inputs, [-1, self.window_size, self.window_size, dims]) #batch_size*n_windows, window_size, window_size, dims
        
    def window_reverse(self, inputs, H, W, dims): #window partition -> original feature map
        x = ops.reshape(inputs, [-1, H//self.window_size, W//self.window_size, self.window_size, self.window_size, dims])
        x = ops.transpose(x, [0, 1, 3, 2, 4, 5])
        return ops.reshape(x, [-1, H, W, dims])
    
    def call(self, inputs, **kwargs):
        if self.global_query:
            inputs, global_query = inputs
        else:
            inputs = inputs[0]
        batch_size, H, W, dims = ops.shape(inputs)
        x = self.norm1(inputs)
        x = self.window_partition(x) 
        x = ops.reshape(x, [-1, self.window_size*self.window_size, dims])
        if self.global_query:
            outputs_ = self.window_attention([x, global_query]
                                     )
        else:
            outputs_ = self.window_attention([x])
        if self.return_attention_weights:
            x, attention_weight = outputs_
        else:
            x = outputs_
        x = self.window_reverse(x, H, W, dims)
        x = inputs + self.gamma1*x
        x += self.gamma2*(self.mlps(self.norm2(x)))
        if self.return_attention_weights:
            return x, attention_weight
        else:
            return x
    
class Level(keras.layers.Layer):
    def __init__(self, 
                depth, #<- Block repetition depth
                num_heads, window_size, keepdims, #downsampler 및 block의 hyperparameter
                downsample = True, mlp_ratio = 4.0,
                qkv_bias = True, qk_scale = None,
                dropout = 0.05, layer_scale = None, return_attention_weights = True,
                **kwargs):
        super().__init__(**kwargs)
        self.depth = depth
        self.n_heads = num_heads
        self.window_size = window_size
        self.keepdims = keepdims
        self.downsample = downsample
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.qk_scale = qk_scale
        self.dropout_rate = dropout
        self.layer_scale = layer_scale
        self.return_attention_weights = return_attention_weights
        
    def build(self, input_shape):
        #input tensor : feature map / patches
        batch_size, H, W, dims = input_shape
        self.blocks = [Block(window_size = self.window_size, num_heads = self.n_heads, global_query = bool(idx%2), 
                             qkv_bias = self.qkv_bias, qk_scale = self.qk_scale, dropout_rate = self.dropout_rate, 
                             mlp_ratio = self.mlp_ratio, layer_scale = self.layer_scale, return_attention_weights = self.return_attention_weights,
                             name = f"GCViTBlock{idx+1}") for idx in range(self.depth)]
        self.downsampler = DownSampler(name = "Downsampler")
        self.query_generator = GlobalQueryGenerator(keepdims = self.keepdims, name = "GlobalQueryGenerator")
        
    def call(self, inputs, **kwargs):
        patches = inputs
        global_query = self.query_generator(inputs)
        for idx, block in enumerate(self.blocks):
            if idx % 2 :
                outputs_ = block([patches, global_query])
            else:
                outputs_ = block([patches])
            if self.return_attention_weights:
                patches, attention_weights = outputs_
            else:
                patches = outputs_
        if self.downsample == False:
            return patches
        else:
            return self.downsampler(patches)
def get_gcvit_configs(res, initial_embedding_dims, name = None):
    return {'res' : res,
            'embed_dims' : initial_embedding_dims,
            "patch_embedding_type" : "conv", #conv or tokenlearner
            "level_depth" : [2,4,6,8],
            "level_heads" : [2,4,8,16],
            "level_keepdims" : [[0,0,0],
                                   [0,0],
                                   [1], 
                                    [1]
                                   ], #3번째 level부터는 window attention == global attention
            "level_window_size" : [res//32, res//32, res//16, res//32],
            "model_name" : f"GCViT_res{res}" if name == None else name
                }
def get_gcvit(configs):
    res = configs["res"]
    inputs = Input([res,res,3], name = "ImageInput")
    patcher = PatchEmbedding(embed_dim = configs['embed_dims'], patching_type = configs["patch_embedding_type"],
                             name = "PatchEmbedding")
    patches = patcher(inputs)
    
    for idx, (depth, heads, keepdims, window_size) in enumerate(zip(configs["level_depth"], configs["level_heads"], configs["level_keepdims"], configs["level_window_size"])):
        if idx == len(configs['level_depth'])-1:
            downsample = False
        else:
            downsample = True
        level = Level(depth = depth, num_heads = heads, window_size = window_size, keepdims = keepdims, downsample = downsample,
                      name = f"GCViT_Lv{idx+1}_downsample_{downsample}")
        patches = level(patches)
    model = keras.Model(inputs, patches,
                       name = configs["model_name"])
    return model

# Wrap to End-to-End model

In [5]:
def get_feature_extractor(conv_base, #if None, Vanilla ViT
                         embed_dims, res, pe_type = "rotary",
                          patch_size = 16,
                        att_depth = 4, att_heads = 16) : 
    inputs = Input([res,res,3], name = "Input_images")
    batch_size = ops.shape(inputs)[0]
    if conv_base is None : #Vanilla Vision Transformer
        scaled_inputs = keras.layers.LayerNormalization(name = "InitialLN")(inputs)
        patches = Conv2D(filters = embed_dims, activation = "gelu", kernel_size = patch_size, strides = patch_size, padding = 'SAME', name = "PatchingStem")(scaled_inputs)
        _, w, h, dims = ops.shape(patches)
        patches = ops.reshape(patches, [-1, w*h, dims])
        cls_token = ops.expand_dims(keras.layers.GlobalAveragePooling1D(name = 'GAPforClsToken')(patches),
                                    axis = 1)
        patches = ops.concatenate([cls_token, patches], axis = 1)
        
        if pe_type in ['rotary', 'rotation', 'rotatory', 'roformer']:
            patches = keras_nlp.layers.RotaryEmbedding(name = "RotaryPositionalEmbedding")(patches)
        elif pe_type in ["learnable", 'absolute']:
            patches = PatchEncoder(num_patches = 1+w*h, projection_dim = embed_dims)(patches)
        elif pe_type == None:
            pass
            
        for idx in range(att_depth):
            x0 = LayerNormalization(name = f"PreLN{idx+1}")(patches)
            x1, attention_score = AttentionPooling(att_heads, embed_dims, name = f"MHA{idx+1}")([x0, x0])
            x2 = keras.layers.Add(name = f"PreAdd{idx+1}")([patches, x1])
            x3 = LayerNormalization(name = f"PostLN{idx+1}")(x2)
            x4 = Dense(units = embed_dims, activation = 'gelu', name = f"TokenMixMLP{idx+1}")(x3)
            patches = keras.layers.Add(name = "Encoded_Patches" if idx == att_depth-1 else f"PostAdd{idx+1}")([x4, x2])
        learned_token = keras.layers.Identity(name = "feature_vector")(patches[:, 0, :])
        attention_score = keras.layers.Identity(name = "attention_weight")(attention_score[..., 1:])
        model_name = f"ViT_depth{att_depth}_dims{embed_dims}_heads{att_heads}_patch{patch_size}"
    else:
        feature_map = conv_base(inputs)
        _, w, h, dims = ops.shape(feature_map)
        dims = ops.shape(feature_map)[-1] ; batch_size = ops.shape(feature_map)[0]
        feature_map = ops.reshape(feature_map, [-1, w*h, dims])
        learned_token = keras.layers.GlobalAveragePooling1D(name = 'GAPforRepVec')(feature_map)
        feature_map = ops.concatenate([learned_token[:, tf.newaxis, :],
                                      feature_map], axis = 1)
        
        if pe_type in ['rotary', 'rotation', 'rotatory', 'roformer']:
            feature_map = keras_nlp.layers.RotaryEmbedding(name = "RotaryPositionalEmbedding")(feature_map)
        elif pe_type in ["learnable", 'absolute']:
            feature_map = PatchEncoder(num_patches = 1+w*h, projection_dim = embed_dims)(feature_map)
        elif pe_type == None:
            pass
        for idx in range(att_depth):
            feature_map, attention_score = AttentionPooling(att_heads, embed_dims, name = f"MHA_after_Conv_{idx+1}")([feature_map, feature_map])
        learned_token = keras.layers.Identity(name = "feature_vector")(feature_map[:, 0, :])
        attention_score = keras.layers.Identity(name = "attention_weight")(attention_score[..., 1:])
        
        model_name = f"{conv_base.name}_depth{att_depth}_dims{embed_dims}_heads{att_heads}"
    model = Model(inputs, [learned_token, attention_score],
                  name = model_name)
    return model

def get_full_model(conv_base_name, res, embed_dims = 1280, patch_size = 16, pe_type = 'rotary',
                   att_depth = 4, att_heads = 8,
                  extra_configs = None):
    if conv_base_name in ["effnet", 'EfficientNet']:
        conv_base = keras.applications.EfficientNetV2B1(input_shape = [res,res,3],
                                                       include_top = False)
    elif conv_base_name in ["effnet_small", "EfficientNetSmall"]:
        conv_base = keras.applications.EfficientNetV2S(input_shape = [res,res,3],
                                                       include_top = False)
    elif conv_base_name in ["effnet_base", "EfficientNetBase"]:
        conv_base = keras.applications.EfficientNetV2M(input_shape = [res,res,3],
                                                       include_top = False)
    elif conv_base_name in ["convnext", 'ConvNeXt']:
        conv_base = keras.applications.ConvNeXtTiny(input_shape = [res,res,3],
                                                       include_top = False)
    elif conv_base_name in ["convnext_small", 'ConvNeXtSmall']:
        conv_base = keras.applications.ConvNeXtSmall(input_shape = [res,res,3],
                                                       include_top = False)
    elif conv_base_name in ["convnext_base", 'ConvNeXtBase']:
        conv_base = keras.applications.ConvNeXtBase(input_shape = [res,res,3],
                                                       include_top = False)
    elif isinstance(conv_base_name, dict):
        conv_base = get_gcvit(conv_base_name)
    else:
        conv_base = None
    return get_feature_extractor(conv_base, pe_type = pe_type, res = res, patch_size = patch_size, embed_dims = embed_dims, att_depth = att_depth, att_heads = att_heads)