In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ReduceLROnPlateau
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = tf.cast(x_train, tf.float32)
x_test  = tf.cast(x_test,  tf.float32)

x_train = tf.image.resize(x_train, [72, 72])  # (50000, 72, 72, 3)
x_test  = tf.image.resize(x_test,  [72, 72])  # (10000, 72, 72, 3)


x_train /= 255.0
x_test  /= 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

dataget = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    validation_split=0.2
)

train_generator = dataget.flow(x_train, y_train, batch_size=32, subset='training')
validation_generator = dataget.flow(x_train, y_train, batch_size=32, subset='validation')

for x_batch, y_batch in train_generator:
    print(f"Train Generator: {x_batch.shape}, {y_batch.shape}")
    break
for x_batch, y_batch in validation_generator:
    print(f"Validation Generator: {x_batch.shape}, {y_batch.shape}")
    break

In [None]:
def window_partition(x, window_size):
  B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
  pad_h = (window_size - H % window_size) % window_size
  pad_w = (window_size - W % window_size) % window_size
  x = tf.pad(x, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
  H_padded, W_padded = H + pad_h, W + pad_w

  x = tf.reshape(x, [B, H_padded // window_size, window_size, W_padded // window_size, window_size, C])
  x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
  x = tf.reshape(x, [-1, window_size * window_size, C])

  return x, H_padded, W_padded

def window_reverse(windows, window_size, H_padded, W_padded, H, W, C):
  B = tf.shape(windows)[0] // (H_padded // window_size * W_padded // window_size)
  x = tf.reshape(windows, [B, H_padded // window_size, W_padded // window_size, window_size, window_size, C])
  x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
  x = tf.reshape(x, [B, H_padded, W_padded, C])
  x = x[:, :H, :W, :]
  return x

In [None]:
class WindowAttention(layers.Layer):
  def __init__(self, dim, window_size, num_heads):
    super(WindowAttention, self).__init__()
    self.dim = dim
    self.window_size = window_size
    self.num_heads = num_heads
    assert dim % num_heads == 0, "dim must be divisible by num_heads"
    self.head_dim = dim // num_heads
    self.scale = tf.math.pow(tf.cast(self.head_dim, tf.float32), -0.5)

    self.query = layers.Dense(dim, use_bias=False)
    self.key = layers.Dense(dim, use_bias=False)
    self.value = layers.Dense(dim, use_bias=False)
    self.proj = layers.Dense(dim)

  def call(self, x, mask=None):
    B_, N, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]

    tf.debugging.assert_equal(C, self.dim, message=f"Input dimension ({C}) doesn't match layer dimension ({self.dim}).")
    tf.debugging.assert_equal(N, self.window_size ** 2, message=f"Input sequence length ({N}) doesn't match window size ({self.window_size}).")

    q = tf.reshape(self.query(x), [B_, N, self.num_heads, self.head_dim])
    k = tf.reshape(self.key(x), [B_, N, self.num_heads, self.head_dim])
    v = tf.reshape(self.value(x), [B_, N, self.num_heads, self.head_dim])

    q = tf.transpose(q, [0, 2, 1, 3]) * self.scale
    k = tf.transpose(k, [0, 2, 1, 3])
    v = tf.transpose(v, [0, 2, 1, 3])

    attn = tf.matmul(q, k, transpose_b=True)
    if mask is not None:
      attn += mask
    aatn = tf.nn.softmax(attn, axis=-1)

    out = tf.matmul(aatn, v)
    out = tf.transpose(out, [0, 2, 1, 3])
    out = tf.reshape(out, [B_, N, C])
    out = self.proj(out)

    return out

In [None]:
class SwinTransformerBlock(layers.Layer):
  def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0):
    super(SwinTransformerBlock, self).__init__()
    self.dim = dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size
    self.mlp_ratio = mlp_ratio

    self.norm1 = layers.LayerNormalization(epsilon=1e-5)
    self.norm2 = layers.LayerNormalization(epsilon=1e-5)
    self.attn = WindowAttention(dim, window_size, num_heads)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = models.Sequential([
            layers.Dense(mlp_hidden_dim, activation='gelu'),
            layers.Dropout(0.4),# dropout
            layers.Dense(dim),
            layers.Dropout(0.4) # dropout
        ])

  def call(self, x):
    shortcut = x
    H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
    x = self.norm1(x)
    x_windows, H_padded, W_padded = window_partition(x, self.window_size)
    attn_windows = self.attn(x_windows)
    x = window_reverse(attn_windows, self.window_size, H_padded, W_padded, H, W, C)
    x = x + shortcut
    shortcut = x
    x = self.norm2(x)
    x = self.mlp(x)
    x = x + shortcut

    return x

In [None]:
class SwinTransformer(models.Model):
  def __init__(self, img_size=32, patch_size=4, num_classes=10, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]):
    super(SwinTransformer, self).__init__()
    self.patch_embed = layers.Conv2D(96, kernel_size=patch_size, strides=patch_size, padding='valid')
    self.pos_drop = layers.Dropout(0.2)
    self.blocks = []
    dim = 96
    for i in range(len(depths)):
      for j in range(depths[i]):
        shift_size = 0 if j % 2 == 0 else patch_size // 2
        self.blocks.append(
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads[i],
                window_size=7,
                shift_size=shift_size
            )
        )
      if i < len(depths) - 1:
        dim *= 2
        self.blocks.append(layers.Conv2D(dim, kernel_size=2, strides=2, padding='valid'))
    self.norm = layers.LayerNormalization(epsilon=1e-6)
    self.avgpool = layers.GlobalAveragePooling2D()
    self.head = layers.Dense(num_classes, activation='softmax')

  def call(self, x):
    x = self.patch_embed(x)
    x = self.pos_drop(x)
    for block in self.blocks:
      x = block(x)
    x = self.norm(x)
    x = self.avgpool(x)
    x = self.head(x)

    return x

In [None]:
sample_input = tf.random.normal([64, 32, 32, 3])

patch_embed_output = layers.Conv2D(96, kernel_size=4, strides=4, padding='valid')(sample_input)
print(f"Patch Embedding Output Shape: {patch_embed_output.shape}")

windows, H_padded, W_padded = window_partition(patch_embed_output, 7)
print(f"Windowed Patches Shape: {windows.shape}")

restored_output = window_reverse(windows, window_size=7, H_padded=H_padded, W_padded=W_padded,
                                 H=tf.shape(patch_embed_output)[1], W=tf.shape(patch_embed_output)[2], C=tf.shape(patch_embed_output)[3])
print(f"Restored Output Shape: {restored_output.shape}")

swin_block  = SwinTransformerBlock(dim=96, num_heads=3, window_size=7) # layers.Sequential -> models.Sequential
block_output = swin_block(patch_embed_output)
print(f"Swin Transformer Block Output Shape: {block_output.shape}")

In [None]:
from tensorflow.keras.losses import CategoricalCrossentropy

model = SwinTransformer(img_size=72, patch_size=4, num_classes=10)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4, weight_decay=1e-5),
    loss=CategoricalCrossentropy(label_smoothing=0.1),
    metrics=['accuracy']
)

lr_reducer = ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-6)

history = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=100,
    callbacks=[lr_reducer]
)

test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Progress')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Progress')
plt.legend()
plt.tight_layout()
plt.show()