In [1]:
import tensorflow as tf

In [89]:
# MODEL

import tensorflow.keras.layers as layers

class PreNorm(layers.Layer):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.fn = fn

    def call(self, x):
        return self.fn(self.norm(x))
    
class GELU(layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, x):
        return tf.keras.activations.gelu(x)
    
class FeedForward(layers.Layer):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()

        self.net = tf.keras.Sequential(
            [layers.Dense(hidden_dim),
            GELU(),
            layers.Dropout(dropout),
            layers.Dense(dim),
            layers.Dropout(dropout)]
        )

    def call(self, x):
        return self.net(x)
    
class Attention(layers.Layer):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()

        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.dim_head = dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = tf.keras.activations.softmax

        self.to_q = layers.Dense(inner_dim, use_bias=False)
        self.to_k = layers.Dense(inner_dim, use_bias=False)
        self.to_v = layers.Dense(inner_dim, use_bias=False)

        self.to_out = tf.keras.Sequential([
            layers.Dense(dim),
            layers.Dropout(dropout)]
        ) if project_out else layers.Identity()

    def call(self, x):
        b, n, _, h = *x.shape, self.heads

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
        q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)
        k = k.reshape(b, n, h, -1).permute(0, 2, 1, 3)
        v = v.reshape(b, n, h, -1).permute(0, 2, 1, 3)

        dots = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.attend(dots)

        out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

class Transformer(layers.Layer):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = []
        for _ in range(depth):
            self.layers.append((PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))

    def call(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
    
class ViT(layers.Layer):
    def __init__(self, window_size=(14, 300), patch_length=10, num_classes=8, dim=64, depth=1, heads=8, mlp_dim=128, pool='cls', dim_head=32, dropout=.2, emb_dropout=0., use_cls_token=True):
        super().__init__()

        channels, window_length = window_size
        num_patches = (window_length // patch_length)
        patch_dim = channels * patch_length

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        #self.patch_conv = nn.Conv1d(in_channels=channels, out_channels=dim, kernel_size=patch_length, stride=patch_length, padding=0, bias=True)
        self.patch_conv = layers.Conv1D(dim, patch_length, strides=patch_length)

        self.use_cls_token = use_cls_token
        if self.use_cls_token:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches + 1, dim))
        else:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))

        self.cls_token = nn.Parameter(torch.empty(1, 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = tf.keras.Sequential([
            layers.LayerNorm(dim),
            layers.Dense(dim, num_classes)
            ]
        )

    def forward(self, x):
        x = self.patch_conv(x).flatten(2).transpose(-2, -1)

        b, n, _ = x.shape

        if self.use_cls_token:
            cls_tokens = self.cls_token.expand(b, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            x += self.pos_embedding[:, :(n + 1)]
        else :
            x += self.pos_embedding

        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        
        x = self.mlp_head(x)

        return x

In [91]:
t = tf.random.uniform((1, 10, 64))

Transformer(64, 2, 8, 32, 128)(t)

AttributeError: 
        'EagerTensor' object has no attribute 'reshape'.
        If you are looking for numpy-related methods, please run the following:
        from tensorflow.python.ops.numpy_ops import np_config
        np_config.enable_numpy_behavior()

In [82]:
input = tf.random.uniform((64, 100, 14))
conv1d = tf.keras.layers.Conv1D(64, 10, strides=10)

conv1d(input).reshape()

AttributeError: 
        'EagerTensor' object has no attribute 'reshape'.
        If you are looking for numpy-related methods, please run the following:
        from tensorflow.python.ops.numpy_ops import np_config
        np_config.enable_numpy_behavior()

In [None]:



class ViT(nn.Module):
    def __init__(self, window_size=(14, 300), patch_length=10, num_classes=8, dim=64, depth=1, heads=8, mlp_dim=128, pool='cls', dim_head=32, dropout=.2, emb_dropout=0., use_cls_token=True):
        super().__init__()

        channels, window_length = window_size
        num_patches = (window_length // patch_length)
        patch_dim = channels * patch_length

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_conv = nn.Conv1d(in_channels=channels, out_channels=dim, kernel_size=patch_length, stride=patch_length, padding=0, bias=True)

        self.use_cls_token = use_cls_token
        if self.use_cls_token:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches + 1, dim))
        else:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))

        self.cls_token = nn.Parameter(torch.empty(1, 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

        self._init_parameters(patch_dim)

    def _init_parameters(self, patch_dim):
        bound = 1 / (patch_dim ** .5)
        nn.init.uniform_(self.patch_conv.weight, -bound, bound)
        nn.init.uniform_(self.patch_conv.bias, -bound, bound)
        nn.init.zeros_(self.pos_embedding)
        nn.init.zeros_(self.mlp_head[1].weight)
        nn.init.zeros_(self.mlp_head[1].bias)

    def forward(self, x):
        x = self.patch_conv(x).flatten(2).transpose(-2, -1)

        b, n, _ = x.shape

        if self.use_cls_token:
            cls_tokens = self.cls_token.expand(b, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            x += self.pos_embedding[:, :(n + 1)]
        else :
            x += self.pos_embedding

        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        x = self.mlp_head(x)

        return x