In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import pylab as plt

class Conv1D_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding="same", act=True):
        super().__init__()

        padding = int((kernel_size - 1) / 2)
        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.batchnorm = nn.BatchNorm1d(out_channels)        
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = self.conv(x)
        out = self.batchnorm(out)
        out = self.gelu(out)
        out = self.dropout(out)
        return out

class Residual_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=[1, 3, 3], strides=[1, 1, 1]):
        super().__init__()
        self.conv1d_1 = Conv1D_block(in_channels, out_channels, kernel_size=kernel_size[0], stride=strides[0])
        self.conv1d_2 = Conv1D_block(out_channels, out_channels, kernel_size=kernel_size[1], stride=strides[1], act=False)
        self.conv1d_res = Conv1D_block(in_channels, out_channels, kernel_size=kernel_size[2], stride=strides[2], act=False)
        self.gelu = nn.GELU()

    def forward(self, x):
        out = self.conv1d_1(x)
        out = self.conv1d_2(out)

        shortcut = self.conv1d_res(x)
        out += shortcut
        out = self.gelu(out)
        return out

class UpsampleConcat_block(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)

    def forward(self, x1, x2):
        out = self.upsample(x1)
        out = torch.cat([out, x2], axis=1)
        return out

class TransUNet1D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, channels=[32, 64, 128, 256], mha_latent=False):
        super().__init__()

        self.in_channels = in_channels
        self.mha_latent = mha_latent

        self.enc1 = Residual_block(in_channels=in_channels, out_channels=channels[0], strides=[2, 1, 2]) # in_seq_len=512,
        self.enc2 = Residual_block(in_channels=channels[0], out_channels=channels[1], strides=[2, 1, 2]) # in_seq_len=256, 
        self.enc3 = Residual_block(in_channels=channels[1], out_channels=channels[2], strides=[2, 1, 2]) # in_seq_len=128, 
        self.enc4 = Residual_block(in_channels=channels[2], out_channels=channels[3], strides=[2, 1, 2]) # in_seq_len=64,  
                        
        self.dec4 = Residual_block(in_channels=channels[3],               out_channels=channels[3], strides=[1, 1, 1]) # in_seq_len=64,  
        self.dec3 = Residual_block(in_channels=channels[2] + channels[3], out_channels=channels[2], strides=[1, 1, 1]) # in_seq_len=128, 
        self.dec2 = Residual_block(in_channels=channels[1] + channels[2], out_channels=channels[1], strides=[1, 1, 1]) # in_seq_len=256, 
        self.dec1 = Residual_block(in_channels=channels[0] + channels[1], out_channels=channels[0], strides=[1, 1, 1]) # in_seq_len=512, 

        self.up4 = UpsampleConcat_block()
        self.up3 = UpsampleConcat_block()
        self.up2 = UpsampleConcat_block()
        self.up1 = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
        
        self.last_conv = nn.Conv1d(in_channels=channels[0], out_channels=out_channels, kernel_size=1, stride=1) 

        if self.mha_latent:
            self.d_model = channels[-1]# dimension of model
            self.N = 6 # number of layers
            self.h = 8 # number of heads
            self.d_ff = 1024 # dimension of feed-forward network
            self.P_drop = 0.1 # dropout probability
            
            self.mha_latent_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.h, dim_feedforward=self.d_ff, dropout=self.P_drop, activation="gelu")
            self.mha = nn.TransformerEncoder(self.mha_latent_layer, num_layers=self.N)

    def get_positional_encoding(self, pos, i, dim):
        angles = 1 / math.pow(10000, (2 * (i // 2))/ dim)

        if i % 2 == 0:
            return math.sin(pos * angles)
        return math.cos(pos * angles)

    def position_encoding(self, x, seq_length):
        pe = torch.zeros(seq_length, 1, self.d_model, device=device)

        for i in range(seq_length):
            for j in range(self.d_model):
                pe[i, :, j] = self.get_positional_encoding(i, j, self.d_model)        

        # # visualize pe
        # plt.pcolormesh(pe[:,0],cmap='Blues')
        # plt.xlabel('Dimension')
        # plt.ylabel('Position')
        
        return pe
        
    def check_output_shape(self, shape):
        x = torch.rand(shape)
        yhat = self.forward(x)

        print(f'x: {x.shape}')
        print(f'yhat: {yhat.shape}')

    def forward(self, x):

        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3) # e4: batch, channel(d_model), seq_len

        if self.mha_latent:
            e4 = e4.permute((2,0,1)) # e4: seq_len, batch, channel(d_model)
            pos = self.position_encoding(e4, e4.shape[0])
            e4 += pos
            e4 = self.mha(e4)            
            e4 = e4.permute((1,2,0)) # e4: batch, channel(d_model), seq_len

        d4 = self.dec4(e4)
        u4 = self.up4(d4, e3)
        
        d3 = self.dec3(u4)
        u3 = self.up3(d3, e2)
        
        d2 = self.dec2(u3)        
        u2 = self.up2(d2, e1)
        
        d1 = self.dec1(u2)
            
        u1 = self.up1(d1)
        out = self.last_conv(u1)
        out = torch.softmax(out,1)

        return out        

In [2]:
net = TransUNet1D(in_channels=3, out_channels=2, channels=[32, 64, 128, 256], mha_latent=False)
net

TransUNet1D(
  (enc1): Residual_block(
    (conv1d_1): Conv1D_block(
      (conv): Conv1d(3, 32, kernel_size=(1,), stride=(2,))
      (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (gelu): GELU(approximate='none')
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv1d_2): Conv1D_block(
      (conv): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (gelu): GELU(approximate='none')
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv1d_res): Conv1D_block(
      (conv): Conv1d(3, 32, kernel_size=(3,), stride=(2,), padding=(1,))
      (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (gelu): GELU(approximate='none')
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (gelu): GELU(approximate='none')
  )
  (enc2): Residual_block(
    (conv1d_1): Conv1D_

In [3]:
net.eval()
yhat = net.check_output_shape((2,3,2048))

x: torch.Size([2, 3, 2048])
yhat: torch.Size([2, 2, 2048])
