In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
from einops import rearrange, repeat

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
class Attention(nn.Module):
    def __init__(self, in_channels, out_channels, width, height):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.width = width
        self.height = height
        
        self.in_dim = in_channels//2
        stride = 1
        
        self.key_conv = nn.Conv3d(self.in_channels, self.in_dim, kernel_size=(1, 3, 3), stride=stride, padding=(0,1,1))
        self.query_conv = nn.Conv3d(self.in_channels, self.in_dim, kernel_size=(1, 3, 3), stride=stride, padding=(0,1,1))
        self.value_conv = nn.Conv3d(self.in_channels, self.in_dim, kernel_size=(1, 3, 3), stride=stride, padding=(0,1,1))
        
        self.scale = (width*height)**(-0.5)
        self.sigmoid = nn.Sigmoid()
        
        self.out_conv = nn.Conv3d(self.in_dim, self.in_channels, kernel_size=(1, 3, 3), stride=stride, padding=(0,1,1))
        
    def forward(self, x):
        
        key = self.key_conv(x)
        query = self.query_conv(x)
        value = self.value_conv(x)
        
        key = rearrange(key, 'b c s h w -> b c s (h w)')
        query = rearrange(query, 'b c s h w -> b c s (h w)')
        value = rearrange(value, 'b c s h w -> b c s (h w)')
        
        attention_mat = self.sigmoid(torch.matmul(query, key.transpose(-1, -2))*self.scale)
        attention = torch.matmul(attention_mat, value)
        out = rearrange(attention, 'b c s (h w) -> b c s h w', h = self.height, w = self.width)
        
        return self.out_conv(out)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.BatchNorm3d(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)    
    
class Forward(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.in_dim = in_channels**2
        stride = 1

        self.conv1 = nn.Sequential(nn.Conv3d(self.in_channels, self.in_channels, kernel_size=(1, 3, 3), stride=stride, padding=(0,1,1)),
                                   nn.ReLU())

        
    def forward(self, x):
        
        x = self.conv1(x)
        
        return x
    

class Conv_Transfomer(nn.Module):
    def __init__(self, in_channels width, height, depth=1):
        super().__init__()
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(in_channels, Attention(in_channels, in_channels, width, height)),
                PreNorm(in_channels, Forward(in_channels))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        
        return x

In [10]:
class Vid_Transformer(nn.Module):
    def __init__(self, in_channels, out_channels, width, height, seq_len, depth=1):
        super().__init__()
        
        self.in_channels =  in_channels
        self.out_channels =  out_channels
        self.width = width
        self.height = height
        self.seq_len = seq_len
        self.depth = depth
        
        self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.seq_len, self.width, self.height))
        
        self.cvt = Conv_Transfomer(self.in_channels, self.out_channels, self.width, self.height, self.depth).to(device)
        
        self.final_conv = nn.Sequential(nn.BatchNorm3d(self.in_channels), 
                                       nn.Conv3d(self.in_channels, self.out_channels, kernel_size=(1, 3, 3), stride=1, padding=(0,1,1)))
        
    def forward(self, x):
        
        b, c, s, h, w = x.shape
        pos_embeddings = repeat(self.pos_embedding, '1 1 s w h -> b c s w h', b = b, c = c) 
        
        x += pos_embeddings
        
        x = self.final_conv(x)
        
        x = x.mean(dim=2)
        
        return x

In [11]:
batch_size = 1
seq_len = 20
width = 32 
height = 32
in_channels = 512
out_channels= 512
depth = 5
feat = torch.rand((batch_size, in_channels, seq_len, width, height)).to(device)

# cvt = Conv_Transfomer(in_channels, out_channels, width, height,depth).to(device)
# out = cvt(feat)
vt = Vid_Transformer(in_channels, out_channels, width, height, seq_len, depth=1).to(device)

out = vt(feat)

In [12]:
out.shape

torch.Size([1, 512, 32, 32])

In [5]:
!nvidia-smi

Tue Jun 20 15:50:01 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  Off  | 00000000:D8:00.0 Off |                    0 |
| N/A   24C    P0    38W / 250W |   3835MiB / 40960MiB |     19%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
out.shape

torch.Size([1, 512, 20, 32, 32])

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f

from decoder import Decoder

from conv_transformer import Vid_Transformer
import torchvision.models as models

class seq_seg(nn.Module):
    def __init__(self, n_classes, seq_length):
        super(seq_seg, self).__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.seq_length = seq_length
        self.n_classes = n_classes
        
        self.encoder = nn.Sequential(*models.vgg16(pretrained=True).features[:27])
        

        self.seq_len = 20
        self.width = 32 
        self.height = 32
        self.in_channels = 512
        self.out_channels= 512
        self.depth = 4

        self.feat_aggregator = Vid_Transformer(self.in_channels, self.out_channels, self.width, self.height, self.seq_len, self.depth)

        self.decoder = Decoder(self.n_classes)

    
    def forward(self, x):
        
        batch_size, c, s, h, w = x.shape
        feat = torch.zeros((batch_size, 512, self.seq_len, self.width, self.height)).to(self.device)
        
        for i in range(self.seq_length):
            feat[:, :, i, :, :] = self.encoder(x[:,:,i,:,:])
        
        x = self.feat_aggregator(feat)
        x = self.decoder(x)

        return x
    
if __name__ == '__main__':
    class DiceLoss(nn.Module):
        def __init__(self, weight=None, size_average=True):
            super(DiceLoss, self).__init__()

        def forward(self, inputs, targets, smooth=1):
            
    #         print(inputs.shape, targets.shape)
            #comment out if your model contains a sigmoid or equivalent activation layer
            inputs = f.sigmoid(inputs)       
            
            #flatten label and prediction tensors
            inputs = inputs[:, 1, :, :]
            targets = targets[:, 1, :, :]
            #print(inputs.shape)
            #print(targets.shape)
            inputs = inputs.reshape(-1) #inputs.view(-1)
            targets = targets.reshape(-1) #targets.view(-1)
            
            intersection = (inputs * targets).sum()                            
            dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
            
            return 1 - dice

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 3
    n_classes = 2
    seq_len = 20
    model = seq_seg(n_classes, seq_len)
    model.to(device)

    Input = torch.randint(0, 255, (batch_size, 3, seq_len, 512, 512)).type(torch.FloatTensor).to(device)
    out = model(Input)
    target = torch.ones((batch_size,n_classes, 1, 512, 512)).to(device)
    # criterion = nn.CrossEntropyLoss()
    criterion = DiceLoss()
    optimizer = torch.optim.Adadelta(model.parameters(), lr=0.001)

    print('Out shape - ', out.shape)
    print('Target shape - ', target.shape)
    loss = criterion(out, target)
    print(f'loss - {loss.item()}')
    
    with torch.autograd.set_detect_anomaly(True):
        loss.backward()
        optimizer.step()

  from .autonotebook import tqdm as notebook_tqdm


Out shape -  torch.Size([3, 2, 1, 512, 512])
Target shape -  torch.Size([3, 2, 1, 512, 512])
loss - 0.33669984340667725




In [1]:
!nvidia-smi

Tue Jun 20 17:20:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  Off  | 00000000:D8:00.0 Off |                    0 |
| N/A   24C    P0    52W / 250W |     45MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces