In [None]:
#!pip install onnx onnxruntime tf2onnx

In [1]:
import tensorflow as tf
from tensorflow.keras import layers as L

print(tf.__version__)

2.19.0


In [2]:
class CustomLayerNorm(L.Layer):
    def __init__(self, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.gamma = None
        self.beta = None

    def build(self, input_shape):
        dim = int(input_shape[-1])
        self.gamma = self.add_weight(
            name="gamma", shape=(dim,), initializer="ones", trainable=True
        )
        self.beta = self.add_weight(
            name="beta", shape=(dim,), initializer="zeros", trainable=True
        )
        super().build(input_shape)

    def call(self, x):
        mean = tf.reduce_mean(x, axis=-1, keepdims=True)
        var  = tf.reduce_mean(tf.square(x - mean), axis=-1, keepdims=True)
        xhat = (x - mean) / tf.sqrt(var + self.eps)
        return xhat * self.gamma + self.beta

In [3]:
class RMSNorm(L.Layer):
    def __init__(self, eps=1e-6, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.use_bias = use_bias
        self.gamma = None
        self.beta = None

    def build(self, input_shape):
        dim = int(input_shape[-1])
        self.gamma = self.add_weight(
            name="gamma", shape=(dim,), initializer="ones", trainable=True
        )
        if self.use_bias:
            self.beta = self.add_weight(
                name="beta", shape=(dim,), initializer="zeros", trainable=True
            )
        super().build(input_shape)

    def call(self, x):
        rms = tf.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps)
        y = (x / rms) * self.gamma
        if self.use_bias:
            y = y + self.beta
        return y

In [4]:


# ----------------------
# 기본 블록: Transformer Encoder
# ----------------------
class TransformerEncoder(L.Layer):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1, use_layernorm=False, **kwargs):
        super().__init__(**kwargs)
        #self.norm1 = L.LayerNormalization(epsilon=1e-6)
        if use_layernorm:
          self.norm1 = CustomLayerNorm(eps=1e-6)
        else:
          self.norm1 = RMSNorm(eps=1e-6)

        self.attn = L.MultiHeadAttention(num_heads=num_heads, key_dim=dim//num_heads, dropout=dropout)
        self.drop1 = L.Dropout(dropout)

        if use_layernorm:
          self.norm2 = CustomLayerNorm(eps=1e-6)
        else:
          self.norm2 = RMSNorm(eps=1e-6)
        self.mlp   = tf.keras.Sequential([
            L.Dense(mlp_dim, activation=tf.keras.activations.relu),
            L.Dropout(dropout),
            L.Dense(dim),
            L.Dropout(dropout)
        ])

    def call(self, x, training=False):
        # Self-Attention + Residual
        h = self.norm1(x)
        h = self.attn(h, h, training=training)
        x = x + self.drop1(h, training=training)
        # MLP + Residual
        h = self.norm2(x)
        h = self.mlp(h, training=training)
        return x + h

# ----------------------
# ViT 모델 생성 함수 (간단 버전)
# ----------------------
def build_vit(
    image_size=224,          # 입력 이미지 한 변 크기
    patch_size=16,           # 패치 한 변 크기
    num_classes=10,          # 클래스 수
    dim=192,                 # 토큰 임베딩 차원
    depth=6,                 # Transformer layer 개수
    heads=3,                 # Multi-Head 수
    mlp_dim=384,             # MLP 내부 차원
    dropout=0.1,
    use_layernorm=False
):
    assert image_size % patch_size == 0, "image_size는 patch_size로 나누어 떨어져야 합니다."
    num_patches = (image_size // patch_size) ** 2

    inputs = L.Input(shape=(image_size, image_size, 3),batch_size=1)

    # 1) 패치 임베딩: Conv로 패치 분할 + 선형 투영
    x = L.Conv2D(
        filters=dim, kernel_size=patch_size, strides=patch_size,
        padding="valid", name="patch_embedding"
    )(inputs)                           # [B, H/ps, W/ps, dim]
    x = L.Reshape((num_patches, dim))(x)  # [B, N, dim]

    # 2) 위치 임베딩(learnable)
    pos_embed = self_positional_embedding(num_patches, dim, name="positional_embedding")
    x = x + pos_embed

    # 3) Transformer Encoder stack
    for i in range(depth):
        x = TransformerEncoder(dim=dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout,use_layernorm=use_layernorm, name=f"encoder_{i}")(x)

    # 4) 분류 헤드: GAP over tokens -> Dense
    #x = L.LayerNormalization(epsilon=1e-6)(x)
    if use_layernorm:
      x = CustomLayerNorm(eps=1e-6)(x)
    else:
      x = RMSNorm(eps=1e-6)(x)
    x = L.GlobalAveragePooling1D()(x)
    x = L.Dropout(dropout)(x)
    outputs = L.Dense(num_classes, activation="softmax")(x)

    return tf.keras.Model(inputs, outputs, name="TinyViT")

def self_positional_embedding(num_patches, dim, name="positional_embedding"):
    # [1, N, dim] 학습 가능한 위치 임베딩
    pe = tf.Variable(
        initial_value=tf.random.normal([1, num_patches, dim]) * 0.02,
        trainable=True, name=name, dtype=tf.float32
    )
    # Keras Functional 호환을 위해 Lambda로 감싼 텐서를 반환
    return L.Lambda(lambda _: pe)(tf.zeros((1, num_patches, dim)))



In [5]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("imsparsh/flowers-dataset")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'flowers-dataset' dataset.
Path to dataset files: /kaggle/input/flowers-dataset


In [6]:
# ----------------------
# 샘플 사용법
# ----------------------
def get_model(use_layernorm):
    # 하이퍼파라미터
    IMG = 224
    PATCH = 16
    NCLASS = 5

    model = build_vit(
        image_size=IMG, patch_size=PATCH,
        num_classes=NCLASS, dim=128, depth=5, heads=5, mlp_dim=256, dropout=0.1,use_layernorm=use_layernorm
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    model.summary()
    return model




# with RMSNorm

In [7]:
def preprocess(x, y):
      x = tf.image.resize(x, (IMG, IMG))
      x = tf.cast(x, tf.float32) / 255.0
      return x, y

model = get_model(use_layernorm=False)
IMG=224
train_ds = tf.keras.utils.image_dataset_from_directory(
"/kaggle/input/flowers-dataset/train",
    image_size=(IMG,IMG),
    batch_size=8
)

normalization_layer = L.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))

model.fit(train_ds, epochs=2)

Found 2746 files belonging to 5 classes.
Epoch 1/2
[1m344/344[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 118ms/step - accuracy: 0.2358 - loss: 1.7981
Epoch 2/2
[1m344/344[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 8ms/step - accuracy: 0.4775 - loss: 1.2332


<keras.src.callbacks.history.History at 0x7c642c05c0b0>

In [8]:
model.save('tinytvit.h5')



In [9]:
model.export('tinyvit')

Saved artifact at 'tinyvit'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 5), dtype=tf.float32, name=None)
Captures:
  136769733092624: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733093776: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733094160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733099152: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098960: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098384: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098576: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733099344: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098000: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098768: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769684194128: Tensor

In [10]:
import glob
import random
import numpy as np
import os
from PIL import Image
def load_image_as_float(path, img_size):
    # RGB 보장, 리사이즈, [0,1] 스케일
    with Image.open(path) as im:
        im = im.convert("RGB")
        im = im.resize((img_size, img_size), Image.BILINEAR)
        arr = np.asarray(im, dtype=np.float32) / 255.0
    return arr  # (H, W, 3) float32
def representative_data_gen():
    # test 폴더에서 확장자별로 수집
    exts = ("*.jpg")
    img_paths = []
    for ext in exts:
        img_paths.extend(glob.glob(os.path.join('/kaggle/input/flowers-dataset/test', ext)))

    if not img_paths:
        raise FileNotFoundError(f"No images found under: {'/kaggle/input/flowers-dataset/test'}")


    for p in img_paths[:200]:
        x = load_image_as_float(p, IMG)      # (H,W,3) float32 in [0,1]
        x = np.expand_dims(x, 0)             # (1,H,W,3)
        # TFLite는 'list of input tensors'로 받습니다.
        yield [x]

In [11]:
# ====== 3) INT8(완전 정수) 양자화 ======
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

# 완전 정수 경로: 모든 연산, 입출력까지 int8
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type  = tf.uint8
converter.inference_output_type = tf.float32


In [12]:
tflite_model = converter.convert()
with open("custom_vit_int8.tflite", "wb") as f:
        f.write(tflite_model)

Saved artifact at '/tmp/tmplls56emj'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(1, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(1, 5), dtype=tf.float32, name=None)
Captures:
  136769733092624: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733093776: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733094160: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733099152: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098960: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098384: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098576: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733099344: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098000: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769733098768: TensorSpec(shape=(), dtype=tf.resource, name=None)
  136769684194128: Ten

