In [1]:
import torch
import  numpy as np 
import cv2
from timm.models.layers import DropPath

In [176]:
#morph mlp
class MorphFCT(torch.nn.Module):
    def __init__(self, dim, chunk_dim, dropout=0.1, attention_dropout=0.1) -> None:
        super().__init__()
        self.dim = dim
        self.chunk_dim = chunk_dim
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.l1 = torch.nn.Linear(self.dim, self.dim)
        self.l2 = torch.nn.Linear(self.dim, self.dim)
        self.dropout = torch.nn.Dropout(dropout)
        
    def forward(self, x):
        b, t, h, w, c = x.shape
        k = c // self.chunk_dim
        x_o = x
        x = x.reshape(b, t, self.chunk_dim, (h * w) // self.chunk_dim, self.chunk_dim, k)
        x = x.permute(0, 4, 3, 2, 1, 5)
        x = x.reshape(b, self.chunk_dim, h,w, t*k)
        x = self.l1(x)
        x = x.reshape(b, self.chunk_dim, h, w, t, k)
        x = x.permute(0, 4, 2, 3, 1, 5)
        x = x.reshape(b,t,h,w,c)
        x = self.dropout(x)
        x = self.l2(x)
        x = self.dropout(x)
        return x
            
        
#14,28,28,49
class MorphFCS(torch.nn.Module):
    def __init__(self, chunk_len, dim, dropout=0.1, bias=False, activation=torch.nn.GELU()) -> None:
        super().__init__()
        self.chunk_len = chunk_len
        self.dim = dim
        self.bias = bias 
        self.proj = torch.nn.Linear(self.dim, self.dim)
        self.h = torch.nn.Linear(self.dim, self.dim, self.bias, activation=activation)
        self.w = torch.nn.Linear(self.dim, self.dim, self.bias, activation=activation)
        self.c = torch.nn.Linear(self.dim, self.dim, self.bias, activation=activation)
        self.dropout = torch.nn.Dropout(dropout)
        
    def forward(self, x):
        b, t, h, w, c = x.shape
        hs = (h*w) // self.chunk_len
        k = c // self.chunk_len
        h_o = x.permute(0,1,3,2,4)
        h_o = h_o.reshape(b, t, hs, self.chunk_len, self.chunk_len, k)
        h_o = h_o.permute(0,1,2,4,3,5)
        h_o = h_o.reshape(b, t, hs, self.chunk_len, self.chunk_len * k)

        h_o = self.h(h_o)
        h_o = h_o.reshape(b, t, hs, self.chunk_len, self.chunk_len, k)
        h_o = h_o.permute(0,1,2,4,3,5).reshape(b, t, h, w, c).permute(0, 1, 3, 2, 4)
        
        w_o = x.permute(0,1,2,3,4)   
        w_o = w_o.reshape(b, t, hs, self.chunk_len, self.chunk_len, k)
        w_o = w_o.permute(0,1,2,4,3,5)
        w_o = w_o.reshape(b, t, hs, self.chunk_len, self.chunk_len * k)
    
        w_o = self.w(w_o)
        w_o = w_o.reshape(b, t, hs, self.chunk_len, self.chunk_len, k)
        w_o = w_o.permute(0,1,2,4,3,5).reshape(b, t, h, w, c)
        
        c_o = self.c(x)
        
        o = h_o + w_o + c_o
        
        o = self.proj(o)
        x = self.dropout(o)
        return x
########
        
class Mlp(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, activation=torch.nn.GELU(), dropout=0.1) -> None:
        super(Mlp).__init__()
        self.in_feautres = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.activation = activation
        self.dropout = dropout
        self.l1 = torch.nn.Linear(self.in_features, self.hidden_features)
        self.l2 = torch.nn.Linear(self.hidden_features, self.out_features)
        self.dp = torch.nn.Dropout(self.dropout)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.activation(x)
        x = self.l2(x)
        x = self.dp(x)
        return x

#######
#mlp blocks 3,4,9,3
#c = 112, 224, 392, 784 
class MorphMLPBlock(torch.nn.Module):
    def __init__(self, in_dim, mid_dim, out_dim, skip_c = 1., stoch_drop=0.1, dropout=0.1, activation=None, normalizer=torch.nn.LayerNorm) -> None:
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.mid_dim = mid_dim
        self.dropout = dropout
        self.activation = None
        if activation == None:
            self.activation = torch.nn.GELU()
        else: 
            self.activation = activation()
        self.normalizer1 = normalizer(self.in_dim)
        self.normalizer2 = normalizer(self.in_dim)
        self.normalizer3 = normalizer(self.in_dim)
        
        self.droppath = None
        if stoch_drop > 0.0:
            self.droppath = DropPath(stoch_drop)
        else:
            self.droppath = torch.nn.Identity()
        
        self.mfct = MorphFCT()
        self.mfcs = MorphFCS()
        self.mlp = Mlp()
        self.skip_c = skip_c
        
        
    def forward(self, x): 
        xt = self.normalizer1(x)
        xt = self.mfct(xt)
        x1 = x + xt
        x1 = self.normalizer2(x1)
        x1 = self.mfcs(x1)
        x = x + self.droppath(x1) / self.skip_c
        x2 = self.normalizer3(x)
        x2 = self.mlp(x2)
        x = x + self.droppath(x2) / self.skip_c
        return x
        

#######


class PatchEmbedder(torch.nn.Module):
    def __init__(self, c1, in_chans,  k=(3,3,3), s=(2,4,4), p=(1,1,1), d=(1,2,2), normalizer=None) -> None:
        super().__init__()
        self.c1 = c1
        self.in_chans = in_chans
        self.conv = torch.nn.Conv3d(self.in_chans, c1, k, s, p, d)
        self.normalizer = normalizer
        if self.normalizer==None:
            self.normalizer = torch.nn.BatchNorm3d(self.c1)
    def forward(self, x):
        x = self.conv(x)
        x = self.normalizer(x)
        
        return x


class DownSample(torch.nn.Module):
    def __init__(self, in_chans, out_chans, k=(1,3,3), s=(1,2,2), p=(0,1,1), s_norm = True) -> None:
        super().__init__()
        self.conv3 = torch.nn.Conv3d(in_chans, out_chans, k, s, p)
        self.layer_norm = torch.nn.LayerNorm(out_chans)
        self.batch_norm = torch.nn.BatchNorm3d(out_chans,True)
        self.s_norm = True
    def forward(self, x):
        # might need to permute dim
        x = x.permute(0, 4, 1, 2, 3)
        x = self.conv3(x)
        x = x.permute(0, 2, 3, 4, 1)
        if self.s_norm == False:
            x = self.batch_norm(x)
        else: 
            x = self.layer_norm(x)
        return x

#######
class MorphMLP(torch.nn.Module):
    def __init__(self, res=(224,224,3), num_classes=10, embed_dims=1000, stochcastic_drops = [], stages=[], chunks=[]) -> None:
        super().__init__()
        assert(len(stochcastic_drops) == len(stages) and len(stages) == len(chunks))
        self.num_classes = num_classes
        self.res = res
        self.embed_dims = embed_dims
        self.stochastic_drops = stochcastic_drops
        self.stages = stages
        self.chunks = chunks
        
    

In [171]:
m = MorphFCT(112, 14)
a = torch.rand(1, 14, 56, 56, 112)
m(a).shape

torch.Size([1, 14, 56, 56, 112])

In [180]:
a = torch.rand(1, 3, 128, 512, 512)
b = PatchEmbedder(32, 3)
b(a).shape

torch.Size([1, 32, 64, 128, 128])

In [169]:
ds = DownSample(56//2, 56)
a = torch.rand(1, 10, 224, 224, 56//2)
ds(a).shape 

torch.Size([1, 10, 112, 112, 56])