In [2]:
#Input arrays are (3500, 2800)
import torch
from torch import nn
import numpy as np
import pandas as pd
import os
import pydicom as dicom
import math
import time

In [3]:
#Positional encoding class
#May want to scrap this for a learnable positional encoding model as opposed to sinusoidal 
class Positional_Encoding(nn.Module):
    def __init__(self,data,dropout=0.1,n = 10000):
        super(Positional_Encoding, self).__init__()
        self.embedded_dim, self.position = data.shape
        self.dropout = nn.Dropout(p=dropout)
        
        self.embedded_dim += 1 #adding one to embedded dim to take into account token prepend
        
        self.learned_embedding_vec = nn.Parameter(torch.zeros(1,self.position))
        
        self.positional_matrix = torch.zeros(self.embedded_dim,self.position)

        for pos in range(self.position):
            for i in range(int(self.embedded_dim/2)):
                denom = pow(n, 2*i/self.embedded_dim) 
                self.positional_matrix[2*i,pos] = np.sin(pos/denom)
                self.positional_matrix[2*i+1,pos] = np.cos(pos/denom)     
    
    def forward(self,data):
#         print(f'Data shape: {data.shape}')
#         print(f'positional_matrix shape: {self.positional_matrix.shape}')
        self.data = torch.vstack((self.learned_embedding_vec,data))
        self.summer_matrix = self.data + self.positional_matrix
        self.summer_matrix = self.dropout(self.summer_matrix)

        return self.summer_matrix

In [4]:
#Apply conv layer to a (500, 400) subset of each scan 
#TODO max pool is necessary 
class convlayer(nn.Module):
    def __init__(self, num_patch: int = 49):
        super(convlayer, self).__init__()
        self.num_patch = num_patch
        n = num_patch
#         self.conv2d_1 = nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 13, stride = 1)
        self.conv2d_1 = nn.Conv2d(in_channels = n*1, out_channels = n *8, kernel_size = 13, stride = 1, groups = n)
        
        self.pooling2d_1 = nn.MaxPool2d(2)
        
        self.conv2d_2 = nn.Conv2d(in_channels = n*8, out_channels = n*16, kernel_size = 11, stride = 1, groups = n)

        self.pooling2d_2 = nn.MaxPool2d(2)
        
#         self.conv2d_3 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 9, stride = 1, groups = n)
        self.conv2d_3 = nn.Conv2d(in_channels = n*16, out_channels = n*32, kernel_size = 9, stride = 1, groups = n)

        
#         self.conv2d_4 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 7, stride = 1, groups = n)
        self.conv2d_4 = nn.Conv2d(in_channels = n*32, out_channels = n*32, kernel_size = 7, stride = 1, groups = n)

        self.pooling2d_3 = nn.MaxPool2d(2)
        
#         self.conv2d_5 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride = 1, groups = n)
        self.conv2d_5 = nn.Conv2d(in_channels = n*32, out_channels = n*64, kernel_size = 5, stride = 1, groups = n)


        
        self.dnn = nn.Linear(105280,256)
        
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        
        
    def forward(self, tensor):
        x = self.conv2d_1(tensor)
        x = self.relu(x)
        x = self.pooling2d_1(x)
        
        x = self.conv2d_2(x)
        x = self.relu(x)
        x = self.pooling2d_2(x)
        
        x = self.conv2d_3(x)
        x = self.relu(x)
        x = self.conv2d_4(x)
        x = self.relu(x)
        
        x = self.pooling2d_3(x)
        
        x = self.conv2d_5(x)
        x = self.relu(x)
        
        x = self.flatten(x)
        x = torch.reshape(x,(self.num_patch,105280))
        #print(f'x.shape in forward in conv layer {x.shape}')
        x = self.dnn(x)
        
        return x

In [None]:
sample = torch.zeros(1,500,400)
print(sample.shape)
model = convlayer()
x = model.forward(sample)

In [None]:
#Test with one image 
ds = dicom.dcmread('2ddfad7286c2b016931ceccd1e2c7bbc copy.dicom')
info = ds.pixel_array[:3500,:]
list_of_patches = []
patches_matrix = torch.zeros(49,500,400)
for i in range(7):
    x_axis = [i*500,(i+1)*500]
    for j in range(7):
        y_axis = [j*400,(j+1)*400]
        correct_patch = np.ascontiguousarray(info[x_axis[0]:x_axis[1],y_axis[0]:y_axis[1]],dtype = np.float32)
        tensor = torch.from_numpy(correct_patch)
#         list_of_patches.append(tensor)
        loc = 7*i + j
        patches_matrix[loc,:,:] = tensor
# print(len(list_of_patches))
# print(patches_matrix.shape[0])
# print(list_of_patches[1])
# print(patches_matrix[1,:,:])
embedded_patches = []
embedded_matrix = torch.zeros(256,49)
for i in range(patches_matrix.shape[0]):
    patch = patches_matrix[i,:,:]
#     print(patch.shape)
    model = convlayer()
    patch = patch[None,:,:]
    #print(patch.shape)
    x = model.forward(patch)
#     embedded_patches.append(x.T)
#     print(x)
    embedded_matrix[:,i] = x
# Learnable embedding
learned_embedding_vec = nn.Parameter(torch.zeros(256,1))
#nn.Parameter adds it to model paramter so that we can backprop through it 
embedded_matrix = torch.hstack((learned_embedding_vec,embedded_matrix))

#     print(embedded_matrix)

In [None]:
positional = Positional_Encoding(embedded_matrix)
summer, pos = positional.forward(embedded_matrix)

In [5]:
#Full Embedding class that takes in data name for individual image and outputs positional embedding where each column
#vector represents a positionally-embedded patch except for the very first column vector, which is a learnt 
#classification token 
#TODO finish reimplement for  batched input

class embedding_block(nn.Module):
    #Data in this sense is the image that has not been translated into an array
    #Want to set x_con to 3500
    def __init__(self, x_amount = 7, y_amount = 7, x_con = 3500, y_con = 2800):
        super(embedding_block, self).__init__()
        
        assert(x_con % x_amount == 0)
        assert(y_con % y_amount == 0)
        self.x_amount = x_amount
        self.y_amount = y_amount
        self.x_con = x_con
        self.y_con = y_con
        
        
        self.amount_of_patches = int(x_amount * y_amount)
        self.x_ran = int(x_con / x_amount)
        self.y_ran = int(y_con / y_amount)
        self.patches_matrix = torch.zeros(self.amount_of_patches,self.x_ran,self.y_ran)
        #print(self.patches_matrix.shape)
        
        #print('here')
        
        # TODO : add conv1 and conv2
        
    
    def forward(self, data):
        # recheck for proper class variables(change self. to strictly local variable)
        self.info = data[:,:self.x_con,:self.y_con]
        
        
        
        self.batch_size = self.info.shape[0]
        
        self.batched_patches = self.info.unfold(1,self.x_ran,self.x_ran).unfold(2,self.y_ran,self.y_ran)
        self.batched_patches = torch.reshape(self.batched_patches,
                                             (self.batch_size,self.amount_of_patches,self.x_ran,self.y_ran))

        
        

        
        self.LCC = self.batched_patches[0]
        #print(f'self.LCC shape before conv: {self.LCC.shape}')
        self.LMLO = self.batched_patches[1]
        self.RCC = self.batched_patches[2]
        self.RMLO = self.batched_patches[3]
        
#         Dealing with (RCC, LCC)  and (RMLO, LMLO)
#         init 2 conv layers that: 
#              - torch.Size([B, 49, 500, 400]) apply some function that applies 49 independent conv 
#              layers to dim 1 i.e. clever use of reshape/channels/batch_size/groups
#              - torch.Size([B, 49, 256])
        
#         init -> conv_1
        self.cc_conv = convlayer()
        self.mlo_conv = convlayer()
        
        self.LCC = self.cc_conv.forward(self.LCC)
        #print(f'self.LCC shape after conv: {self.LCC.shape}')
        self.RCC = self.cc_conv.forward(self.RCC)
        
        self.LMLO = self.mlo_conv.forward(self.LMLO)
        self.RMLO = self.mlo_conv.forward(self.RMLO)
#         init -> conv_2 
        
#         rcc_after_conv = conv_1(RCC)
#         lcc_after_conv = conv_1(mirror(LCC))
#              # or do RCC cat mirror(LCC) ==> torch.Size([B*2, 49, 256])
        
        
#         conv_2(RMLO)
#         conv_2(mirror(LMLO))

            
#         self.learned_embedding_vec = nn.Parameter(torch.zeros(256,1))
        
        #nn.Parameter adds it to model paramter so that we can backprop through it 
#         self.embedded_matrix = torch.v=hstack((self.learned_embedding_vec,self.embedded_matrix))
         
#         print(f'self.embedded_matrix shape: {self.embedded_matrix.shape}')
        #TODO MOVE TO INIT
        self.pos_encoding_LCC = Positional_Encoding(self.LCC)
        self.pos_encoding_RCC = Positional_Encoding(self.RCC)
        self.pos_encoding_LMLO = Positional_Encoding(self.LMLO)
        self.pos_encoding_RMLO = Positional_Encoding(self.RMLO)
        
        self.summer_LCC = self.pos_encoding_LCC.forward(self.LCC)
        self.summer_RCC = self.pos_encoding_RCC.forward(self.RCC)
        self.summer_LMLO = self.pos_encoding_LMLO.forward(self.LMLO)
        self.summer_RMLO = self.pos_encoding_RMLO.forward(self.RMLO)
        
        self.batched_positional_encoding = torch.zeros(self.batch_size,50,256)
        #print(self.batched_positional_encoding.shape)
        #print(self.summer_LCC.shape)
        self.batched_positional_encoding[0] = self.summer_LCC
        self.batched_positional_encoding[1] = self.summer_LMLO
        self.batched_positional_encoding[2] = self.summer_RCC
        self.batched_positional_encoding[3] = self.summer_RMLO
#         print(f'positional_matrix shpae: {self.positional.positional_matrix.shape}')
#         self.summer, self.pos = self.positional.forward(self.embedded_matrix)
        
#         print(f'self.summer shape: {self.summer.shape}')
        
        return self.batched_positional_encoding




        

In [10]:
tensor = torch.rand(4,3500,2800)
print(tensor.shape)
# tensor2 = tensor.unfold(0,1,1).unfold(1,500,500).unfold(2,400,400)
tensor2 = tensor.unfold(1,500,500).unfold(2,400,400)
print(tensor2.shape)
tensor3 = torch.reshape(tensor2, (4,49,500,400))
tensor3.shape

torch.Size([4, 3500, 2800])
torch.Size([4, 7, 7, 500, 400])


torch.Size([4, 49, 500, 400])

In [None]:
x = embedding_block(x_con = 3500)
y = x.forward(tensor)
y.shape

In [6]:
#Global and Local mlp are equivalent 
class local_mlp(nn.Module):
    def __init__(self, hidden_output = 1024, dropout = .5):
        super(local_mlp, self).__init__()
        self.fnn1 = nn.Linear(256,hidden_output)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fnn2 = nn.Linear(hidden_output,256)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, data):
        self.x = self.fnn1(data)
        self.x = self.gelu(self.x)
        self.x = self.dropout1(self.x)
        self.x = self.fnn2(self.x)
        self.x = self.gelu(self.x)
        self.x = self.dropout2(self.x)
        
        return self.x

In [7]:
class local_encoder_block(nn.Module):
    def __init__(self, data_shape = (4,50,256), hidden_output_fnn1 = 1024, dropout = .5):
        super(local_encoder_block, self).__init__()
        self.data_shape = data_shape
        #print(data_shape)
        self.ln1 = nn.LayerNorm([data_shape[1],data_shape[2]]) #Layer norm over the H and W of each image
        self.ln2 = nn.LayerNorm([data_shape[1],data_shape[2]])
        
        self.attention = nn.MultiheadAttention(embed_dim = 256,num_heads = 16, batch_first = True)
        self.mlp_0 = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        self.mlp_1 = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        self.mlp_2 = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        self.mlp_3 = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        
        self.dnn_output = torch.zeros(data_shape)
        
    def forward(self, data):
        #i = 0
        #print(i)
        if data.shape == (4,256,50):
            print('in here')
            data = data.T
        self.x = self.ln1(data)
        self.att_out, self.att_out_weights = self.attention(query = self.x,key = self.x,value = self.x)
        #print(self.att_out.shape)
        self.x_tilda = self.att_out + data
        #print(self.x_tilda.shape)
        self.x_second = self.ln2(self.x_tilda)
        #self.x_second = self.mlp.forward(self.x_second)
        self.dnn_output[0] = self.mlp_0.forward(self.x_second[0])
        self.dnn_output[1] = self.mlp_1.forward(self.x_second[1])
        self.dnn_output[2] = self.mlp_2.forward(self.x_second[2])
        self.dnn_output[3] = self.mlp_3.forward(self.x_second[3])
        #print(self.x_second.shape)
        self.x_second = self.dnn_output + self.x_tilda
        
        #i += 1
        return self.x_second  

In [55]:
class global_encoder_block(nn.Module):
    def __init__(self, data_shape = (1,200,256), hidden_output_fnn1 = 1024, dropout = .5):
        super(global_encoder_block, self).__init__()
        self.data_shape = data_shape
        self.gln1 = nn.LayerNorm(data_shape)
        self.ln2 = nn.LayerNorm(data_shape)
        self.attention = nn.MultiheadAttention(embed_dim = 256, num_heads = 16, batch_first = True)
        self.mlp = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        
    def forward(self, data):
        self.x = self.gln1(data)
        self.att_out, self.att_out_weights = self.attention(query = self.x,key = self.x,value = self.x)
        #print(self.att_out.shape)
        self.x_tilda = self.att_out + data
        #print(self.x_tilda.shape)
        self.x_second = self.ln2(self.x_tilda)
        #self.x_second = self.mlp.forward(self.x_second)
        self.dnn_output = self.mlp.forward(self.x_second)
        self.x_second = self.dnn_output + self.x_tilda
        
        return self.x_second
    

In [54]:
class Visual_Transformer(nn.Module):
    #embedding parameters, local encoder parameters
    def __init__(self, x_amount = 7, y_amount = 7, x_con = 3518, y_con = 2800,
              data_shape = (4,50,256), hidden_output_fnn = 1024, dropout = .5,
              number_of_layers = 10):
        super(Visual_Transformer, self).__init__()
        self.embedding_block = embedding_block(x_amount = x_amount, y_amount = y_amount, x_con = x_con, y_con = y_con)
        self.blks = nn.Sequential()
        for i in range(number_of_layers):
            self.blks.add_module(f'{i}', local_encoder_block(data_shape = data_shape))
            
    def forward(self,data):
        self.x = self.embedding_block.forward(data)
        print(self.x.shape)
        i = 0 
        for blk in self.blks:
            print(f'This is {i} local attention run')
            i += 1
            self.x = blk(self.x) 
        return self.x

In [73]:
class global_transformer(nn.Module):
    def __init__(self, x_amount = 7, y_amount = 7, x_con = 3518, y_con = 2800,
              data_shape = (4,50,256), hidden_output_fnn = 1024, dropout = .5,
              number_of_layers = 10, num_layers_global = 10):
        super(global_transformer,self).__init__()
        self.data_shape = data_shape
        new_data_shape = (1,data_shape[0]*data_shape[1],data_shape[2])
        self.individual_transformer = Visual_Transformer(x_amount = x_amount, y_amount = y_amount, x_con = x_con,
                                                   y_con = y_con, data_shape = data_shape, 
                                                    hidden_output_fnn = hidden_output_fnn, 
                                                    dropout = dropout, number_of_layers = number_of_layers)
        self.blks = nn.Sequential()
        for i in range(num_layers_global):
            self.blks.add_module(f'{i}', global_encoder_block(data_shape = new_data_shape))
            
        self.flatten = nn.Flatten()
            
        self.class_head = classification_head(input_layer = data_shape[0]*data_shape[2], 
                                              hidden_output_class = 512, dropout = .5)
    def forward(self, data):
        self.x = self.individual_transformer.forward(data)
        shape1, shape2, shape3 = self.x.shape
        self.x = torch.reshape(self.x,(1, shape1 * shape2,shape3))
        i = 0
        for blk in self.blks:
            print(f'This is {i} global attention run')
            self.x = blk(self.x)
            i += 1
        
        self.x = torch.squeeze(self.x)
        print(self.x.shape)
        self.x = self.x[[0, 1 * shape2, 2 * shape2, 3* shape2],:]
        x = torch.reshape(self.x,(1,self.x.shape[0]*self.x.shape[1]))
        print(x.shape)
        x = self.class_head.forward(x)
        return x

In [59]:
class classification_head(nn.Module):
    def __init__(self, input_layer = 1024, hidden_output_class = 512, dropout = 0.5):
        super(classification_head, self).__init__()
        self.ln1 = nn.LayerNorm(input_layer)
        self.fnn1 = nn.Linear(input_layer,hidden_output_class)
        self.dropout = nn.Dropout(dropout)
        self.ln2 = nn.LayerNorm(hidden_output_class)
        self.fnn2 = nn.Linear(hidden_output_class,5)
    def forward(self, data):
        x = self.ln1(data)
        x = self.fnn1(x)
        x = self.dropout(x)
        x = self.ln2(x)
        x = self.fnn2(x)
        
        return x

In [74]:
st = time.time()
x = global_transformer(x_con = 3500)
output = x.forward(tensor)
et = time.time()
diff = et - st
print(f'The run time for one patient (4 images) is: {diff} seconds')
print(f'The shape of output is: {output.shape}')

torch.Size([4, 50, 256])
This is 0 local attention run
This is 1 local attention run
This is 2 local attention run
This is 3 local attention run
This is 4 local attention run
This is 5 local attention run
This is 6 local attention run
This is 7 local attention run
This is 8 local attention run
This is 9 local attention run
This is 0 global attention run
This is 1 global attention run
This is 2 global attention run
This is 3 global attention run
This is 4 global attention run
This is 5 global attention run
This is 6 global attention run
This is 7 global attention run
This is 8 global attention run
This is 9 global attention run
torch.Size([200, 256])
torch.Size([1, 1024])
The run time for one patient (4 images) is: 8.766899108886719 seconds
The shape of output is: torch.Size([1, 5])


In [17]:
output[1,:,:]

tensor([[-0.2354,  2.3708,  0.8859,  ...,  5.0032,  1.1088, -0.4876],
        [ 0.9578,  3.5941,  0.2396,  ...,  0.3077, -1.2816,  0.9832],
        [-0.3766,  1.4778,  0.1023,  ...,  0.1568, -0.4792,  0.6722],
        ...,
        [ 0.7433,  1.7760, -0.2126,  ...,  1.7912,  0.4902,  1.4133],
        [ 0.0743,  0.9737, -0.0658,  ..., -0.2920, -0.5391,  0.6376],
        [-0.3415,  3.2997,  1.0004,  ...,  0.9011,  0.3884,  2.5579]],
       grad_fn=<SliceBackward0>)

In [16]:
torch.reshape(output,(200,256))[50:100,:]

tensor([[-0.2354,  2.3708,  0.8859,  ...,  5.0032,  1.1088, -0.4876],
        [ 0.9578,  3.5941,  0.2396,  ...,  0.3077, -1.2816,  0.9832],
        [-0.3766,  1.4778,  0.1023,  ...,  0.1568, -0.4792,  0.6722],
        ...,
        [ 0.7433,  1.7760, -0.2126,  ...,  1.7912,  0.4902,  1.4133],
        [ 0.0743,  0.9737, -0.0658,  ..., -0.2920, -0.5391,  0.6376],
        [-0.3415,  3.2997,  1.0004,  ...,  0.9011,  0.3884,  2.5579]],
       grad_fn=<SliceBackward0>)