In [1]:
#Input arrays are (3500, 2800)
import torch
from torch import nn
import numpy as np
import pandas as pd
import os
import pydicom as dicom
import math

In [70]:
#Positional encoding class
#May want to scrap this for a learnable positional encoding model as opposed to sinusoidal 
class Positional_Encoding(nn.Module):
    def __init__(self,data,dropout=0.1,n = 10000):
        super(Positional_Encoding, self).__init__()
        self.embedded_dim, self.position = data.shape
        self.dropout = nn.Dropout(p=dropout)
        
        self.positional_matrix = torch.zeros(self.embedded_dim,self.position)

        for pos in range(self.position):
            for i in range(int(self.embedded_dim/2)):
                denom = pow(n, 2*i/self.embedded_dim) 
                self.positional_matrix[2*i,pos] = np.sin(pos/denom)
                self.positional_matrix[2*i+1,pos] = np.cos(pos/denom)     
    
    def forward(self,data):
#         print(f'Data shape: {data.shape}')
#         print(f'positional_matrix shape: {self.positional_matrix.shape}')
        self.summer_matrix = data + self.positional_matrix
        self.summer_matrix = self.dropout(self.summer_matrix)

        return self.summer_matrix, self.positional_matrix

In [71]:
#Apply conv layer to a (500, 400) subset of each scan 
#TODO max pool is necessary 
class convlayer(nn.Module):
    def __init__(self, num_patch: int = 49):
        super(convlayer, self).__init__()
        n = num_patch
#         self.conv2d_1 = nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 13, stride = 1)
        self.conv2d_1 = nn.Conv2d(in_channels = n*1, out_channels = n *8, kernel_size = 13, stride = 1, groups = n)
        
        self.pooling2d_1 = nn.MaxPool2d(2)
        
        self.conv2d_2 = nn.Conv2d(in_channels = n*8, out_channels = n*16, kernel_size = 11, stride = 1, groups = n)

        self.pooling2d_2 = nn.MaxPool2d(2)
        
#         self.conv2d_3 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 9, stride = 1, groups = n)
        self.conv2d_3 = nn.Conv2d(in_channels = n*16, out_channels = n*32, kernel_size = 9, stride = 1, groups = n)

        
#         self.conv2d_4 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 7, stride = 1, groups = n)
        self.conv2d_4 = nn.Conv2d(in_channels = n*32, out_channels = n*32, kernel_size = 7, stride = 1, groups = n)

        self.pooling2d_3 = nn.MaxPool2d(2)
        
#         self.conv2d_5 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 5, stride = 1, groups = n)
        self.conv2d_5 = nn.Conv2d(in_channels = n*32, out_channels = n*64, kernel_size = 5, stride = 1, groups = n)


        
        self.dnn = nn.Linear(105280,256)
        
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        
        
    def forward(self, tensor):
        x = self.conv2d_1(tensor)
        x = self.relu(x)
        x = self.pooling2d_1(x)
        
        x = self.conv2d_2(x)
        x = self.relu(x)
        x = self.pooling2d_2(x)
        
        x = self.conv2d_3(x)
        x = self.relu(x)
        x = self.conv2d_4(x)
        x = self.relu(x)
        
        x = self.pooling2d_3(x)
        
        x = self.conv2d_5(x)
        x = self.relu(x)
        #print(x.shape)
        x = self.flatten(x)
        x = torch.reshape(x,(105280,1))
        x = self.dnn(x.T)
        
        return x

In [None]:
# init for 
#  

def forward():
    

In [4]:
sample = torch.zeros(1,500,400)
print(sample.shape)
model = convlayer()
x = model.forward(sample)

torch.Size([1, 500, 400])


In [5]:
#Test with one image 
ds = dicom.dcmread('2ddfad7286c2b016931ceccd1e2c7bbc copy.dicom')
info = ds.pixel_array[:3500,:]
list_of_patches = []
patches_matrix = torch.zeros(49,500,400)
for i in range(7):
    x_axis = [i*500,(i+1)*500]
    for j in range(7):
        y_axis = [j*400,(j+1)*400]
        correct_patch = np.ascontiguousarray(info[x_axis[0]:x_axis[1],y_axis[0]:y_axis[1]],dtype = np.float32)
        tensor = torch.from_numpy(correct_patch)
#         list_of_patches.append(tensor)
        loc = 7*i + j
        patches_matrix[loc,:,:] = tensor
# print(len(list_of_patches))
# print(patches_matrix.shape[0])
# print(list_of_patches[1])
# print(patches_matrix[1,:,:])
embedded_patches = []
embedded_matrix = torch.zeros(256,49)
for i in range(patches_matrix.shape[0]):
    patch = patches_matrix[i,:,:]
#     print(patch.shape)
    model = convlayer()
    patch = patch[None,:,:]
    #print(patch.shape)
    x = model.forward(patch)
#     embedded_patches.append(x.T)
#     print(x)
    embedded_matrix[:,i] = x
# Learnable embedding
learned_embedding_vec = nn.Parameter(torch.zeros(256,1))
#nn.Parameter adds it to model paramter so that we can backprop through it 
embedded_matrix = torch.hstack((learned_embedding_vec,embedded_matrix))

#     print(embedded_matrix)

In [6]:
positional = Positional_Encoding(embedded_matrix)
summer, pos = positional.forward(embedded_matrix)

In [81]:
#Full Embedding class that takes in data name for individual image and outputs positional embedding where each column
#vector represents a positionally-embedded patch except for the very first column vector, which is a learnt 
#classification token 
#TODO reimplement vectorized patching via pytorch link 
class embedding_block(nn.Module):
    #Data in this sense is the image that has not been translated into an array
    #Want to set x_con to 3500
    def __init__(self, x_amount = 7, y_amount = 7, x_con = 3518, y_con = 2800):
        super(embedding_block, self).__init__()
        
        assert(x_con % x_amount == 0)
        assert(y_con % y_amount == 0)
        self.x_amount = x_amount
        self.y_amount = y_amount
        self.x_con = x_con
        self.y_con = y_con
        
        
        self.amount_of_patches = int(x_amount * y_amount)
        self.x_ran = int(x_con / x_amount)
        self.y_ran = int(y_con / y_amount)
        self.patches_matrix = torch.zeros(self.amount_of_patches,self.x_ran,self.y_ran)
        #print(self.patches_matrix.shape)
        
        #print('here')
    
    def forward(self, data):
        ds = dicom.dcmread(data)
        self.info = ds.pixel_array[:self.x_con,:self.y_con]
        for i in range(self.x_amount):
            x_axis = [i*self.x_ran,(i+1)*self.x_ran]
            for j in range(self.y_amount):
                y_axis = [j*self.y_ran,(j+1)*self.y_ran]
                correct_patch = np.ascontiguousarray(self.info[x_axis[0]:x_axis[1],y_axis[0]:y_axis[1]],
                                                     dtype = np.float32)
                tensor = torch.from_numpy(correct_patch)
                loc = self.x_amount*i + j
                self.patches_matrix[loc,:,:] = tensor
        
        
#         self.embedded_matrix = torch.zeros(256,self.amount_of_patches)
        self.embedded_matrix = torch.zeros(self.amount_of_patches,256)
        for i in range(patches_matrix.shape[0]):
            self.patch = self.patches_matrix[i,:,:]
            self.model = convlayer()
            self.patch = self.patch[None,:,:]
            self.x = self.model.forward(patch)
            self.embedded_matrix[i,:] = self.x
            # Dealing with (RCC, LCC)  and (RMLO, LMLO)
            # init 2 conv layers that: 
            #      - torch.Size([B, 49, 500, 400]) apply some function that applies 49 independent conv 
            #      layers to dim 1 i.e. clever use of reshape/channels/batch_size/groups
            #      - torch.Size([B, 49, 256])
            # 
            # init -> conv_1
            # init -> conv_2 
            # 
            # rcc_after_conv = conv_1(RCC)
            # lcc_after_conv = conv_1(mirror(LCC))
            #      # or do RCC cat mirror(LCC) ==> torch.Size([B*2, 49, 256])
            # 
            # 
            # conv_2(RMLO)
            # conv_2(mirror(LMLO))

            
#         self.learned_embedding_vec = nn.Parameter(torch.zeros(256,1))
        self.learned_embedding_vec = nn.Parameter(torch.zeros(1,256))
        #nn.Parameter adds it to model paramter so that we can backprop through it 
#         self.embedded_matrix = torch.v=hstack((self.learned_embedding_vec,self.embedded_matrix))
        self.embedded_matrix = torch.vstack((self.learned_embedding_vec,self.embedded_matrix))
#         print(f'self.embedded_matrix shape: {self.embedded_matrix.shape}')
        self.positional = Positional_Encoding(self.embedded_matrix)
#         print(f'positional_matrix shpae: {self.positional.positional_matrix.shape}')
        self.summer, self.pos = self.positional.forward(self.embedded_matrix)
        
#         print(f'self.summer shape: {self.summer.shape}')
        
        return self.summer

In [80]:
x = embedding_block(x_con = 3500)
x.forward('2ddfad7286c2b016931ceccd1e2c7bbc copy.dicom')

tensor([[ 0.0000,  0.9350,  1.0103,  ...,  0.0000,  0.5022, -0.5627],
        [ 1.1080,  0.5853, -0.4702,  ..., -0.1034, -0.9863, -0.9570],
        [-0.0000,  0.7062,  0.0000,  ..., -0.8839, -0.0000,  0.5322],
        ...,
        [ 1.1170,  1.1192,  1.1094,  ...,  0.0000,  1.0976,  0.0000],
        [ 0.0216, -0.0076, -0.0090,  ...,  0.0319,  0.0447,  0.0464],
        [ 1.1118,  1.0984,  1.0978,  ...,  1.1012,  1.1039,  1.1173]],
       grad_fn=<MulBackward0>)

In [82]:
class local_mlp(nn.Module):
    def __init__(self, hidden_output = 1024, dropout = .5):
        super(local_mlp, self).__init__()
        self.fnn1 = nn.Linear(256,hidden_output)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.fnn2 = nn.Linear(hidden_output,256)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, data):
        self.x = self.fnn1(data)
        self.x = self.gelu(self.x)
        self.x = self.dropout1(self.x)
        self.x = self.fnn2(self.x)
        self.x = self.gelu(self.x)
        self.x = self.dropout2(self.x)
        
        return self.x

In [83]:
class local_encoder_block(nn.Module):
    def __init__(self, data_shape = (50,256), hidden_output_fnn1 = 1024, dropout = .5):
        super(local_encoder_block, self).__init__()
        self.data_shape = data_shape
        self.ln1 = nn.LayerNorm(data_shape)
        self.ln2 = nn.LayerNorm(data_shape)
        
        self.attention = nn.MultiheadAttention(256,16)
        self.mlp = local_mlp(hidden_output = hidden_output_fnn1, dropout = dropout)
        
    def forward(self, data):
        #i = 0
        #print(i)
        if data.shape == (256,50):
            print('in here')
            data = data.T
        self.x = self.ln1(data)
        self.att_out, self.att_out_weights = self.attention(query = self.x,key = self.x,value = self.x)
        #print(self.att_out.shape)
        self.x_tilda = self.att_out + data
        #print(self.x_tilda.shape)
        self.x_second = self.ln2(self.x_tilda)
        self.x_second = self.mlp.forward(self.x_second)
        #print(self.x_second.shape)
        self.x_second = self.x_second + self.x_tilda
        
        #i += 1
        return self.x_second  

In [85]:
class Visual_Transformer(nn.Module):
    #embedding parameters, local encoder parameters
    def __init__(self, x_amount = 7, y_amount = 7, x_con = 3518, y_con = 2800,
              data_shape = (50,256), hidden_output_fnn = 1024, dropout = .5,
              number_of_layers = 10):
        super(Visual_Transformer, self).__init__()
        self.embedding_block = embedding_block(x_amount = x_amount, y_amount = y_amount, x_con = x_con, y_con = y_con)
        self.blks = nn.Sequential()
        for i in range(number_of_layers):
            self.blks.add_module(f'{i}', local_encoder_block(data_shape = data_shape))
            
    def forward(self,data):
        self.x = self.embedding_block(data)
        print(self.x.shape)
        i = 0 
        for blk in self.blks:
            print(f'This is {i} run')
            i += 1
            self.x = blk(self.x) 
        return self.x

In [None]:
x = Visual_Transformer(x_con = 3500)
output = x.forward('2ddfad7286c2b016931ceccd1e2c7bbc copy.dicom')

In [41]:
print(output.T[0])

tensor([-3.4374e-01,  7.1017e+00,  1.9626e+01, -2.8022e-01, -1.1748e+00,
        -1.4898e+00, -5.6181e-01, -5.4902e+00, -8.9077e+00, -1.0441e+01,
        -8.9423e-01, -1.5933e+00, -9.3679e-01, -5.5882e-04,  8.7317e-01,
        -4.4071e+00,  2.6351e+00, -8.5413e-01, -1.1503e+00, -1.7661e-01,
         7.9309e-01,  6.6905e-01,  2.7895e+01,  9.5001e+00, -1.2156e+00,
        -4.0845e-01,  5.4589e-01,  8.1266e-01,  7.0526e-02, -3.0588e+00,
         1.1913e+01, -1.2253e+00, -2.4221e-01,  9.2028e-01,  3.7052e-01,
        -8.5558e-01, -1.3633e+00, -1.1262e+00,  1.3749e-01,  7.8645e-01,
        -2.8929e-01, -4.6127e-01, -1.8099e-01, -1.0848e+00, -1.4540e-01,
         7.1564e-01,  6.0247e-01, -2.2884e-01, -1.3012e+00, -1.3794e+00],
       grad_fn=<SelectBackward0>)
