In [45]:
from functools import partial
import jax
from jax import random
import jax.numpy as jnp
import flax
import flax.linen as nn

from typing import Union, Optional, Iterable, Callable, Tuple, List, Any
from layers import AdaptiveAveragePool2D, DropPath, to_2tuple
import warnings
from jax import grad, jit, vmap

# For testing
import torch


### Definitions

In [46]:
class Sequential(nn.Module):
    layers: List[nn.Module]

    @nn.compact
    def __call__(self, x):
        for lyr in self.layers:
            x = lyr(x)
        return x


In [47]:
class MLP_with_DepthWiseConv(nn.Module):
    hidden_features: int = None
    out_features: int = None
    act: Callable = nn.gelu
    drop: float = 0.0
    extra_relu: bool = False
    trainable: bool = False

    @nn.compact
    def __call__(self, x, feat_size=List[int]):
        in_features = x.shape[-1]
        out_features = self.out_features or in_features
        hidden_features = self.hidden_features or in_features
        drop = nn.Dropout(rate=self.drop, deterministic=not self.trainable)

        x = nn.Dense(hidden_features)(x)
        B, N, C = x.shape
        H, W = feat_size
        x = x.reshape(B, H, W, C)
        x = nn.relu(x)
        x = nn.Conv(
            hidden_features,
            kernel_size=(3, 3),
            use_bias=True,
            feature_group_count=hidden_features,
        )(x)
        x = x.reshape(B, -1, x.shape[3])
        x = self.act(x)
        x = drop(x)
        x = nn.Dense(out_features)(x)
        x = drop(x)

        return x


In [48]:
class Attention(nn.Module):
    dim: int
    num_heads: int = 8
    sr_ratio: int = 1
    linear: bool = False
    qkv_bias: bool = True
    attn_drop: float = 0.0
    proj_drop: float = 0.0
    trainable: bool = False

    @nn.compact
    def __call__(self, x, feat_size=List[int]):
        assert (
            self.dim % self.num_heads == 0
        ), f"Input dim {self.dim} should be dividable by num_heads {self.num_heads}."
        head_dim = self.dim // self.num_heads
        scale = head_dim**-0.5

        attn_drop = nn.Dropout(self.attn_drop, deterministic=not self.trainable)
        proj = nn.Dense(self.dim)
        proj_drop = nn.Dropout(self.proj_drop, deterministic=not self.trainable)

        B, N, C = x.shape
        H, W = feat_size

        q = nn.Dense(self.dim, use_bias=self.qkv_bias)(x)
        q = jnp.transpose(q.reshape(B, N, self.num_heads, -1), (0, 2, 1, 3))

        if not self.linear:
            pool = None
            if self.sr_ratio > 1:
                sr = nn.Conv(
                    self.dim,
                    kernel_size=(self.sr_ratio, self.sr_ratio),
                    strides=self.sr_ratio,
                )
                norm = nn.LayerNorm()
            else:
                sr = None
                norm = None
            act = None
        else:
            pool = AdaptiveAveragePool2D(7)
            sr = nn.Conv(self.dim, kernel_size=(1, 1), strides=1)
            norm = nn.LayerNorm()
            act = nn.gelu

        if pool is not None:
            x_ = x.reshape(B, H, W, C)
            x_ = sr((pool(x_))).reshape(B, -1, C)
            x_ = norm(x_)
            x_ = act(x_)
            kv = nn.Dense(self.dim * 2, use_bias=self.qkv_bias)(x_)
            kv = jnp.transpose(
                kv.reshape(B, -1, 2, self.num_heads, head_dim), (2, 0, 3, 1, 4)
            )
        else:
            if sr is not None:
                x_ = x.reshape(B, H, W, C)
                x_ = sr(x_).reshape(B, -1, C)
                x_ = norm(x_)
                kv = nn.Dense(self.dim * 2, use_bias=self.qkv_bias)(x_)
                kv = jnp.transpose(
                    kv.reshape(B, -1, 2, self.num_heads, head_dim), (2, 0, 3, 1, 4)
                )
            else:
                kv = nn.Dense(self.dim * 2, use_bias=self.qkv_bias)(x)
                kv = jnp.transpose(
                    kv.reshape(B, -1, 2, self.num_heads, head_dim), (2, 0, 3, 1, 4)
                )

        k, v = kv[0], kv[1]

        attn = (q @ jnp.swapaxes(k, -2, -1)) * scale
        attn = nn.softmax(attn, axis=-1)
        attn = attn_drop(attn)

        x = jnp.swapaxes(attn @ v, 1, 2).reshape(B, N, C)
        x = proj(x)
        x = proj_drop(x)

        return x


In [49]:
class Block(nn.Module):
    dim: int
    num_heads: int
    mlp_ratio: float = 4.0
    sr_ratio: int = 1
    linear: bool = False
    qkv_bias: bool = False
    drop: float = 0.0
    attn_drop: float = 0.0
    drop_path: float = 0.0
    act: Callable = nn.gelu
    norm_layer: Callable = nn.LayerNorm()
    trainable: bool = False

    @nn.compact
    def __call__(self, x, feat_size=List[int]):
        attn = Attention(
            dim=self.dim,
            num_heads=self.num_heads,
            sr_ratio=self.sr_ratio,
            linear=self.linear,
            qkv_bias=self.qkv_bias,
            attn_drop=self.attn_drop,
            proj_drop=self.drop,
            trainable=not self.trainable,
        )

        mlp = MLP_with_DepthWiseConv(
            hidden_features=int(self.dim * self.mlp_ratio),
            act=self.act,
            drop=self.drop,
            extra_relu=self.linear,
            trainable=not self.trainable,
        )

        if self.drop_path > 0.0:
            drop_path = DropPath(self.drop_path, trainable=not self.trainable)
            x = x + drop_path(attn(self.norm_layer(x), feat_size))
            x = x + drop_path(mlp(self.norm_layer(x), feat_size))

        else:
            x = x + attn(self.norm_layer(x), feat_size)
            x = x + mlp(self.norm_layer(x), feat_size)

        return x


In [50]:
class OverlapPatchEmbed(nn.Module):
    patch_size: int = 7
    strides: int = 4
    embed_dim: int = 768

    @nn.compact
    def __call__(self, x):
        patch_size = to_2tuple(self.patch_size)
        assert (
            max(patch_size) > self.strides
        ), "Patch size should be larger than stride."
        norm = nn.LayerNorm()

        x = nn.Conv(
            self.embed_dim,
            kernel_size=patch_size,
            strides=self.strides,
            padding=(patch_size[0] // 2, patch_size[1] // 2),
        )(x)
        feat_size = x.shape[1:3]
        x = x.reshape(x.shape[0], -1, x.shape[3])
        x = norm(x)

        return x, feat_size


In [51]:
class PyramidVisionTransformerStage(nn.Module):
    dim: int
    

    @nn.compact
    def __call__(self, x):

### Test

In [62]:
x = jnp.zeros((2, 64 * 64, 64))
feat_size = [64, 64]


In [63]:
drop, key = random.split(random.PRNGKey(0), 2)
layer = Block(64, trainable=False)
params = layer.init({"params": key, "dropout": drop}, x, feat_size)["params"]
out = layer.apply({"params": params}, x, feat_size, rngs={"dropout": drop})
