In [1]:
import tensorflow as tf
import numpy as np

In [2]:
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self,num_heads=2,key_dim=8,attention_axes=(2,3,4)):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim,attention_axes=attention_axes)
    def call(self,x):
        return self.mha(query=x, key=x, value=x)

In [3]:
class MLP(tf.keras.layers.Layer):
    def __init__(self,input_dims, output_dims, mlp_ratio=1):
        super().__init__()
        self.input_dims=input_dims
        self.output_dims=output_dims
        self.mlp_ratio=mlp_ratio
        
    def build(self, input_shape):
        neuron_count=int(self.mlp_ratio * input_shape[-1])
        self.dense1=tf.keras.layers.Dense(neuron_count,activation='relu')
        self.dense2=tf.keras.layers.Dense(self.output_dims,activation='relu')
        
    def call(self,x):
        x=self.dense1(x)
        x=self.dense2(x)
        return x
    

In [4]:
def normal_window_partition(x, window_size):
    B, T, H, W, C = x.shape
    windows = []

    for t in range(T // window_size[0]):
        for h in range(H // window_size[1]):
            for w in range(W // window_size[2]):
                temp = x[:, t * window_size[0]:(t + 1) * window_size[0],
                         h * window_size[1]:(h + 1) * window_size[1],
                         w * window_size[2]:(w + 1) * window_size[2]]
                windows.append(temp)

    
    windows = tf.stack(windows)
    return windows

def reverse_window_partition(windows, original_shape, window_size):
    B, T, H, W, C = original_shape

    output_shape = (B, T, H, W, C)

    video = np.zeros(output_shape, dtype=np.float32)
    window_idx = 0

    for t in range(T // window_size[0]):
        for h in range(H // window_size[1]):
            for w in range(W // window_size[2]):
                temp = windows[window_idx]
                video[:, t * window_size[0]:(t + 1) * window_size[0],
                      h * window_size[1]:(h + 1) * window_size[1],
                      w * window_size[2]:(w + 1) * window_size[2], :] = temp
                window_idx += 1

    return video

In [5]:
import matplotlib.pyplot as plt

In [6]:
x=tf.zeros(( 32,1, 4, 4, 3))

layer = tf.keras.layers.MultiHeadAttention(
    num_heads=2, key_dim=2,attention_axes=(2,3,4))
input_tensor = tf.keras.Input(shape=x.shape)
output_tensor,scores = layer(input_tensor, input_tensor,return_attention_scores=True)
print(output_tensor.shape, scores.shape )

(None, 32, 1, 4, 4, 3) (None, 32, 2, 1, 4, 4, 1, 4, 4)


In [7]:
class WindowMSA(tf.keras.layers.Layer):
    def __init__(self,num_heads=2,key_dim=8):
        super().__init__()
        self.num_heads=num_heads
        self.key_dim=key_dim

        
    def build(self,input_shape):
        self.layer_norm1=tf.keras.layers.LayerNormalization()
        self.layer_norm2=tf.keras.layers.LayerNormalization()
        self.spatial_msa=BaseAttention(self.num_heads, self.key_dim)
        self.temporal_msa=BaseAttention(self.num_heads, self.key_dim)
        self.mlp=MLP(input_shape[-1],input_shape[-1])
        
    def call(self,x):
        B,T,H,W,C=x.shape
        MH,MW=H//16,W//24

        self.temporal_window=(T,1,1)
        self.spatial_window=(1,MH,MW)
        
        
        x=self.layer_norm1(x)
        windows=normal_window_partition(x,self.spatial_window)
        
        windows=tf.transpose(windows, (1,0,2,3,4,5)) 
        x=self.spatial_msa(windows)
        x=tf.transpose(x, (1,0,2,3,4,5))
        

        x=reverse_window_partition(x,(B,T,H,W,C), self.spatial_window )
        windows=normal_window_partition(x, self.temporal_window)
        windows=tf.transpose(windows, (1,0,2,3,4,5))
        x=self.temporal_msa(windows)
        x=tf.transpose(x, (1,0,2,3,4,5))
        
        x=reverse_window_partition(x,(B,T,H,W,C),self.temporal_window)
        x=self.layer_norm2(x)
        x=self.mlp(x)
        return x    

In [8]:
def shifted_window_partition(x,window_size):
    B,T,H,W,C=x.shape
    window_size =[4*i for i in window_size]
    print(window_size)
    corner_windows=[]
    corner_windows.append(x[:,:,tf.newaxis,:window_size[0],:window_size[1]])
    corner_windows.append(x[:,:,tf.newaxis,:window_size[0],-window_size[1]:])
    corner_windows.append(x[:,:,tf.newaxis,-window_size[0]:,:window_size[1]])
    corner_windows.append(x[:,:,tf.newaxis,-window_size[0]:,-window_size[1]:])
    
    
    side_windows=[]
    side_windows.append(x[:,:,tf.newaxis,:window_size[0],window_size[1]:-window_size[1]])
    side_windows.append(x[:,:,tf.newaxis,window_size[0]:-window_size[0],:window_size[1]])
    side_windows.append(x[:,:,tf.newaxis,window_size[0]:-window_size[0],-window_size[1]:])
    side_windows.append(x[:,:,tf.newaxis,-window_size[0]:,window_size[1]:-window_size[1]])
    
    middle_window=x[:,:,tf.newaxis, window_size[0]:-window_size[0],window_size[1]:-window_size[1]]

    s=tf.concat(corner_windows,axis=1)

    
    return tf.concat(corner_windows,axis=1),tf.concat([side_windows[0],side_windows[-1]],axis=1) , tf.concat([side_windows[1],side_windows[-2]],axis=1) ,middle_window

In [9]:
def reverse_shifted_window_partition(x,orginal_shape,window_size):
    B,T,H,W,C=orginal_shape
    window_size =[4*i for i in window_size]
    corner_windows,top_bottom,left_right,middle=x
    y=np.zeros((B,T,H,W,C),dtype=np.float32)
    
    corner_windows=tf.squeeze(corner_windows)
    s=corner_windows.shape[1]//4
    y[:,:,:window_size[0],:window_size[1]]=corner_windows[:,:s]
    y[:,:,:window_size[0],-window_size[1]:]=corner_windows[:,s:2*s]
    y[:,:,-window_size[0]:,:window_size[1]]=corner_windows[:,2*s:-s]
    y[:,:,-window_size[0]:,-window_size[1]:]=corner_windows[:,-s:]
    
    
    left_right=tf.squeeze(left_right)
    s=left_right.shape[1]//2
    y[:,:,window_size[0]:-window_size[0],:window_size[1]]=left_right[:,:s]
    y[:,:,window_size[0]:-window_size[0],-window_size[1]:]=left_right[:,s:]
    
    
    top_bottom=tf.squeeze(top_bottom)
    s=top_bottom.shape[1]//2 
    y[:,:,:window_size[0],window_size[1]:-window_size[1]]= top_bottom[:,:s]
    y[:,:,-window_size[0]:,window_size[1]:-window_size[1]]= top_bottom[:,s:]
    
    middle=tf.squeeze(middle)
    y[:,:, window_size[0]:-window_size[0],window_size[1]:-window_size[1]]=middle
    
    
    return y

In [10]:
x=tf.random.uniform((10,4,156,384,3))

c=shifted_window_partition(x,(4,4))
print(c[0].shape)

cc=[]
for i in c:
    i=tf.transpose(i,(1,0,2,3,4,5))
    print(i.shape)
    cc.append(i)

[16, 16]
(10, 16, 1, 16, 16, 3)
(16, 10, 1, 16, 16, 3)
(8, 10, 1, 16, 352, 3)
(8, 10, 1, 124, 16, 3)
(4, 10, 1, 124, 352, 3)


In [11]:
class ShiftedWindowMSA(tf.keras.layers.Layer):
    def __init__(self,num_heads=2,key_dim=8):
        super().__init__()
        self.num_heads=num_heads
        self.key_dim=key_dim

    def build(self,input_shape):
        self.layer_norm1=tf.keras.layers.LayerNormalization()
        self.layer_norm2=tf.keras.layers.LayerNormalization()
        self.spatial_msa=BaseAttention(self.num_heads, self.key_dim)
        self.temporal_msa=BaseAttention(self.num_heads, self.key_dim)
        self.mlp=MLP(input_shape[-1],input_shape[-1])
        
        
    def call(self,x):
        B,T,H,W,C=x.shape
        MH,MW=H//16,W//24

        x=self.layer_norm1(x)

        
        x=shifted_window_partition(x,(MH,MW))
        all_windows=[]
        for window_blocks in x:
            window_blocks=self.spatial_msa(window_blocks)
            all_windows.append(window_blocks)
        
        x=reverse_shifted_window_partition(all_windows,(B,T,H,W,C),(MH,MW)) 
        windows=normal_window_partition(x,(T,1,1))
        
        windows=tf.transpose(windows, (1,0,2,3,4,5))
        x=self.spatial_msa(windows)
        x=tf.transpose(x, (1,0,2,3,4,5))
        
        x=reverse_window_partition(x,(B,T,H,W,C),(T,1,1))
        x=self.layer_norm2(x)
        x=self.mlp(x)
        return x        

In [12]:
import time

In [13]:
class SepSTSBock(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
        
    def build(self,input_shape):
        self.window_msa=WindowMSA(num_heads=1, key_dim=1)
        self.shifted_window_msa=ShiftedWindowMSA(num_heads=1, key_dim=1)
        self.conv=tf.keras.layers.Conv3D(1,kernel_size=(3,3,3),padding='same')
        

    def call(self,x):
        x=self.conv(x)
        s=time.time()
        x=self.window_msa(x)
        e=time.time()
        print('for W MSA: ',e-s)
        x=self.shifted_window_msa(x)
        print('for SW MSA: ',time.time()-e)
        
        return x

In [14]:
x=tf.random.uniform((10,4,64,96,3))
layer=SepSTSBock()


In [15]:
class ShallowEmbedding(tf.keras.layers.Layer):
    def __init__(self,d_model=8):
        super().__init__()
        self.d_model=d_model
        
    def build(self,input_shape):
        self.embeddings = tf.keras.layers.Dense(self.d_model , activation='relu')
        
    def call(self, x):
        return self.embeddings(x)

In [16]:

class DeepEmbedding(tf.keras.layers.Layer):
    def __init__(self,nf):
        super().__init__()
        self.nf=nf
        
    def build(self, input_shape):
        self.down0 = tf.keras.layers.Conv3D(filters=self.nf[-1], kernel_size=(1, 2, 2), strides=(1, 2, 2), padding='same')
        self.down1 = tf.keras.layers.Conv3D(filters=self.nf[-2], kernel_size=(1, 2, 2), strides=(1, 2, 2), padding='same')
        self.down2 = tf.keras.layers.Conv3D(filters=self.nf[-3], kernel_size=(1, 2, 2), strides=(1, 2, 2), padding='same')
        self.down3 = tf.keras.layers.Conv3D(filters=self.nf[-4], kernel_size=(1, 2, 2), strides=(1, 2, 2), padding='same')
        
        self.sts0=SepSTSBock()
        self.sts1=SepSTSBock()
        self.sts2=SepSTSBock()
        self.sts3=SepSTSBock()
        
        
        
    def call(self, x):
        
        x0=self.down0(x)

        
        x0=self.sts0(x0)
        
        x0=tf.squeeze(x0)
        x1=self.down1(x0)
        x1=self.sts1(x1)
        
        x1=tf.squeeze(x1)
        x2=self.down2(x1)
        x2=self.sts2(x2)
        
        x2=tf.squeeze(x2)
        x3=self.down3(x2)
        x3=self.sts3(x3)
        
        return x0,x1,x2,x3

In [17]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.nf=[1, 1, 1, 1]
        
    def build(self,input_shape):
        self.shallow=ShallowEmbedding()
        self.deep=DeepEmbedding(self.nf)
        
    def call(self,x):
        x=self.shallow(x)
        x=self.deep(x)
        return x

In [18]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def build(self,input_shape):
        self.upsample0=tf.keras.layers.UpSampling3D(size=2)
        self.upsample1=tf.keras.layers.UpSampling3D(size=2)
        self.upsample2=tf.keras.layers.UpSampling3D(size=2)
        self.upsample3=tf.keras.layers.UpSampling3D(size=2)
        self.add=tf.keras.layers.Add()
    
    def call(self,x):
        x0,x1,x2,x3=x
        
        y0=self.upsample0(x3)
        y0=self.add(y0,x2)
        
        y1=self.upsample0(y0)
        y1=self.add(y1,x1)
        
        y2=self.upsample0(y1)
        y2=self.add(y2,x0)

        y3=self.upsample0(y2)
        
        return y3

In [19]:
class Transformer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def build(self,input_shape):
        self.encoder=Encoder()
        self.decoder=Decoder()
        super(Transformer, self).build(input_shape)
        
    def call(self,x):
        encoder_outputs=self.encoder(x)
        x=self.decoder(encoder_outputs)
        
        return x