In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pandas as pd
import torch as torch
import os
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn

In [2]:
DATA_DIRECTORY = "../datos/"

In [3]:
diagnosticos  = pd.read_excel(DATA_DIRECTORY+"RESUMEN TAC CEREBRALES.xlsx")

In [4]:
# diccionario cuya llave es el id de paciente y el valor una lista 
# donde cada elemento de la lista es la matriz de una i
diccionario_imagenes_pacientes = dict()

for paciente in diagnosticos.paciente:
    directorio_paciente = DATA_DIRECTORY+"paciente_"+str(paciente)
    archivos_paciente = os.listdir(directorio_paciente)
    
    lista_imagenes_paciente = []
    for archivo in archivos_paciente:
        if archivo.endswith(".jpg"):
            imagen = mpimg.imread(directorio_paciente+"/"+archivo)
            lista_imagenes_paciente.append(imagen)
            
    diccionario_imagenes_pacientes[paciente] = lista_imagenes_paciente
    

## Modelos y arquitecturas
### Arquitecturas experimental  DNC
* Alimentamos al modelo imagen por imagen y se presenta un solo diagnostico por paciente
* El controller de la DNC esta compuesto por una convnet

In [5]:
CONTROLLER_OUTPUT_SIZE = 1024

In [6]:
#TODO: cambiar valores quemados por valores parametrizados y calculos dependientes
class ConvController(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1,4,kernel_size=3,stride=1)
        self.fc1  =  torch.nn.Linear(262144,CONTROLLER_OUTPUT_SIZE)
        
    def forward(self,x):
        h = self.conv1(x)
        
        #flatten
        h =  x.view(-1,x.shape[1]*x.shape[2]*x.shape[3])
        h =  self.fc1(h)
        
        return h #h_t in my txt

In [7]:
#TODO: cambiar valores quemados por valores parametrizados y calculos dependientes
#TODO: cordar por que en algun momento le puse bias = False a los pesos del vector de salida de la DNC
class DNC(torch.nn.Module):
    
    def __init__(self,controller,memory_size = (10,10),read_heads = 1):
        super().__init__()
        self.controller = controller
        self.N = memory_size[0] # number of memory locations
        self.W = memory_size[1] # word size of the memory 
        self.R = read_heads # number of read heads
        self.WS = 1 #not in the paper(they use 1), but used as a parametrizable number of write heads for further experiments
        self.interface_vector_size = (self.W*self.R) + (self.W*self.WS) + (2*self.W) + (5*self.R) + 3
        
        # inicialization st to random just for testing, remember to put on zeros
        self.memory_matrix = self.memory_matrix =  nn.Parameter(torch.randn(size=memory_size),requires_grad= False) 
        
        #1024 es el tamaño del vector de salida del controlador, 1 es el tamaño de salida de la dnc
        self.output_vector_linear = torch.nn.Linear(CONTROLLER_OUTPUT_SIZE,1,bias=False) #W_y 
        self.interface_vector_linear = torch.nn.Linear(CONTROLLER_OUTPUT_SIZE,self.interface_vector_size,bias=False) #W_ξ
        self.read_vectors_to_output_linear = torch.nn.Linear(self.R*self.W,1,bias = False) #W_r in my txt
        
        self.read_keys = torch.Tensor(size=(self.R,self.W)).requires_grad_(False) # k_r in my txt
        self.read_strenghts = torch.Tensor(size=(self.R,1)).requires_grad_(False) #β_r
        
        self.read_weighting = torch.Tensor(torch.zeros(size=(self.R,self.N))).requires_grad_(False) #r_w
        
        self.write_key = torch.Tensor(size=(1,self.W)).requires_grad_(False) # k_w in my txt
        self.write_strenght = torch.Tensor(size=(1,1)).requires_grad_(False) # β_w
        
        self.write_weighting = torch.Tensor(torch.zeros(size=(1,self.N))).requires_grad_(False) # w_w
        
        self.usage_vector = torch.Tensor(torch.zeros(size=(1,self.N))).requires_grad_(False) #u_t
        
        self.memory_matrix_ones = torch.Tensor(torch.ones(size=memory_size)).requires_grad_(False) #E on paper
        
    def forward(self,x):
        h_t = self.controller(x) #controller output called ht in the paper
        
        output_vector = self.output_vector_linear(h_t) # called Vt in the paper(υ=Wy[h1;...;hL]) v_o_t in my txt
        interface_vector = self.interface_vector_linear(h_t).data #called ξt(ksi) in the paper ,ξ_t in my txt
        
        self.read_keys.data = interface_vector[0,0:self.R*self.W].view((self.R,self.W)) #k_r in my txt
        
        #clamp temporary added because the exp was returning inf  values
        read_strenghts =  torch.clamp( interface_vector[0,self.R*self.W:self.R*self.W+self.R].view((self.R,1)),max=85)
        self.read_strenghts.data = self.oneplus(read_strenghts) #β_r
        
        self.write_key.data = interface_vector[0,self.R*self.W+self.R:self.R*self.W+self.R+self.W].view((1,self.W)) # k_w
        
        write_strenght = torch.clamp(interface_vector[:,self.R*self.W+self.R+self.W:self.R*self.W+self.R+self.W + 1].view((1,1)),max=85)
        self.write_strenght.data = self.oneplus(write_strenght) #β_w
        
        erase_vector = interface_vector[0,self.R*self.W+self.R+self.W + 1: self.R*self.W+self.R+self.W + 1 + self.W].view((1,self.W))
        erase_vector = torch.sigmoid(erase_vector) #e_t
        
        write_vector = interface_vector[0,self.R*self.W+self.R+self.W + 1 + self.W:self.R*self.W+self.R+self.W + 1 + 2*self.W].view((1,self.W)) #v_t
        
        free_gates  =  interface_vector[0,self.R*self.W+self.R+self.W + 1 + 2*self.W:self.R*self.W+2*self.R+self.W + 1 + 2*self.W].view((self.R,1)) #f_t
        free_gates =   torch.sigmoid(free_gates)
        
        # Escritura
        # TODO: verificar y/o experimentar si el ordern es :primero escribir y luego leer de la memoria(asi parece en el pazper)
        print("free gates",free_gates,free_gates.size())
        print("read wei",self.read_weighting,self.read_weighting.size())
        retention_vector = (1.0 - free_gates * self.read_weighting).prod(dim=0)
        print("usage",self.usage_vector)
        print("write weit",self.write_weighting)
        print("retention",retention_vector)
        self.usage_vector.data = (self.usage_vector +self.write_weighting - (self.usage_vector *self.write_weighting))*retention_vector #u_t
        _,free_list = torch.topk(-self.usage_vector,self.N,dim=1) #φt indices of memory locations ordered by usage
        print("usage post",self.usage_vector)
        print("free list",free_list)
        allocation_weighting = tf.zeros_like(self.usage_vector)
        se
        allocation_weighting = (torch.gather(self.usage_vector,1,free_list))
        print("allo",allocation_weighting)
        print("-------------------------------------------------")
        write_content_weighting = self.content_lookup(self.memory_matrix,self.write_key,self.write_strenght)

        self.write_weighting.data = write_content_weighting #TODO: this should be a combination of temporal and dymamic allocaiton
        
        new_memory_matrix = self.memory_matrix*(self.memory_matrix_ones - torch.matmul(self.write_weighting.t(),erase_vector)) + torch.matmul(self.write_weighting.t(),write_vector)
        
        self.memory_matrix.data = new_memory_matrix
        
        # read by content weithing(attention by similarity)
        read_content_weighting = self.content_lookup(self.memory_matrix,self.read_keys,self.read_strenghts)
        
        #read weithing is a combination of reading modes,TODO:add temporal attention not just by similarity
        self.read_weighting.data = read_content_weighting
        
        read_vectors = torch.matmul(self.read_weighting,self.memory_matrix).view((1,self.R*self.W)) #r in my txt
        read_heads_to_output = self.read_vectors_to_output_linear(read_vectors) #v_r_t in my t xt
        
        #TODO: experiment and decide if maintain sigmoid
        y_t = torch.sigmoid(output_vector + read_heads_to_output)
        return y_t
    
    def oneplus(self,x):
        # apply oneplus operation to a tensor to constrain it's elements to [1,inf)
        #TODO: check numerical statiliby as exp is returning inf for numbers like 710,emporary added clamp to 85
        return torch.log(1+torch.exp(x)) + 1
    
    def content_lookup(self,matrix,keys,strengths):
        # returns a probability distribution over the memory locations 
        # with higher probability to memory locations with bigger similarity to the keys
        # bigger strenght make more aggresive distributions ,for example a distribution (0.2,0.3,0.5) with
        # bigger strenght becomes (0.1,0.12,0.78)
        # returns tensor of shape (read keys,memory size) = (R,N)
        keys_norm =  torch.sqrt(torch.sum(keys**2,dim=1).unsqueeze(dim=1))
        matrix_norm = torch.sqrt(torch.sum(matrix**2,dim=1))
        norms_multiplication = keys_norm*matrix_norm
        # calc cosine similarity between keys and memory locations(1e-6 is used avoiding div by 0)
        divide_zero_prevent_factor = torch.zeros_like(norms_multiplication,requires_grad=False).add_(1e-6)
        cosine_similarity = torch.matmul(keys,matrix.t())/(torch.max(norms_multiplication,divide_zero_prevent_factor))
        
        # do a "strenght" softmax to calculate the probability distribution
        numerator = torch.exp(cosine_similarity*strengths)
        denominator = numerator.sum(dim=1).unsqueeze(dim=1)

        distribution = numerator/denominator
        
        return distribution

## Experimentos
* Experimentando con DNC alimentando una imagen a la vez en orden aleatorio con pacientes también en orden aleatorio

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
EPOCHS = 1

In [10]:
conv_controller = ConvController()
dnc_model = DNC(controller=conv_controller,memory_size = (4,4),read_heads=2).to(device)

In [11]:
def loss_function(y,y_hat,last_flag):
    #print(y,y_hat,last_flag)
    base_criterion = torch.nn.BCELoss()
    return torch.full_like(y,last_flag) * base_criterion(y,y_hat)

In [12]:
criterion = loss_function
optimizer = optim.Adam(dnc_model.parameters(),lr=0.001)

In [13]:
for epoch in range(EPOCHS):
    # en cada epoch procesar los pacientes en orden aleatorio
    pacientes = np.random.choice(np.array(diagnosticos.paciente),size= len(diagnosticos.paciente),replace=False)
    
    conteo_pacientes = 0
    for paciente in pacientes:
        #TODO: remover esta validacion, solo puesta para probar una unica iteracion en compu lenta
        if conteo_pacientes >= 1:
            break
            
        imagenes_paciente = diccionario_imagenes_pacientes.get(paciente)
        diagnostico_hemorragia_paciente = np.array(float(diagnosticos[diagnosticos.paciente==paciente].hemorragia))
        tensor_diagnostico_hemorragia_paciente = torch.Tensor(diagnostico_hemorragia_paciente).to(device)
        
        indices_imagenes_pacientes = np.arange(0,len(imagenes_paciente)-1,step=1)
        indices_aleatorios_imagenes = np.random.choice(indices_imagenes_pacientes,len(indices_imagenes_pacientes),replace=False)
        
        for indice in indices_aleatorios_imagenes:
            last_image =  int(indice  == indices_aleatorios_imagenes[-1])
            
            optimizer.zero_grad()
            
            imagen_paciente = imagenes_paciente[indice]
            
            if imagen_paciente.shape != (512,512):
                #TODO: tread different image sizes with reshaping, resizing(or other ideas)
                continue
                
            tensor_imagen_paciente =  torch.unsqueeze(
                torch.unsqueeze( torch.Tensor(imagen_paciente),dim=0),dim=1).to(device)
            
            print("Alimentando paciente {} e imagen {} al modelo".format(paciente,indice),imagen_paciente.shape)
            
            diagnostico_hemorragia_aproximado = dnc_model(tensor_imagen_paciente)
            
            loss = criterion(diagnostico_hemorragia_aproximado,tensor_diagnostico_hemorragia_paciente,last_image)
            loss.backward()
            optimizer.step()
            
            if last_image:
                print(loss.cpu().data,diagnostico_hemorragia_aproximado.cpu().data)
                
            conteo_pacientes += 1

Alimentando paciente 5 e imagen 4 al modelo (512, 512)
free gates tensor([[5.5404e-16],
        [1.4727e-27]]) torch.Size([2, 1])
read wei tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]]) torch.Size([2, 4])
usage tensor([[0., 0., 0., 0.]])
write weit tensor([[0., 0., 0., 0.]])
retention tensor([1., 1., 1., 1.])
usage post tensor([[0., 0., 0., 0.]])
free list tensor([[0, 1, 2, 3]])
allo tensor([[0., 0., 0., 0.]])
-------------------------------------------------


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Alimentando paciente 5 e imagen 8 al modelo (512, 512)
free gates tensor([[5.1615e-18],
        [2.2108e-26]]) torch.Size([2, 1])
read wei tensor([[6.5591e-01, 2.6422e-01, 3.4248e-02, 4.5619e-02],
        [9.9996e-01, 1.7717e-07, 1.5571e-05, 1.9924e-05]]) torch.Size([2, 4])
usage tensor([[0., 0., 0., 0.]])
write weit tensor([[0.1137, 0.1354, 0.3678, 0.3831]])
retention tensor([1., 1., 1., 1.])
usage post tensor([[0.1137, 0.1354, 0.3678, 0.3831]])
free list tensor([[0, 1, 2, 3]])
allo tensor([[0.1137, 0.1354, 0.3678, 0.3831]])
-------------------------------------------------
Alimentando paciente 5 e imagen 6 al modelo (512, 512)
free gates tensor([[7.7277e-01],
        [2.5757e-18]]) torch.Size([2, 1])
read wei tensor([[3.5656e-01, 2.6856e-01, 1.7396e-01, 2.0092e-01],
        [1.2680e-12, 1.6513e-15, 9.9998e-01, 1.9845e-05]]) torch.Size([2, 4])
usage tensor([[0.1137, 0.1354, 0.3678, 0.3831]])
write weit tensor([[1.6026e-06, 1.5968e-06, 7.8865e-01, 2.1134e-01]])
retention tensor([0.7245

In [14]:
a = torch.Tensor([[1],
        [2]]) 
b = torch.Tensor([[6,7,8,9],
        [2,3,4,5]]) 

print(b.size())

(a*b).prod(dim=0)

torch.Size([2, 4])


tensor([24., 42., 64., 90.])

In [15]:
#TODO: averiguar por que salen 6 tensores de parametros si solo se han declarado 3(al momento de correr lap rueba)
train_parmams = list(dnc_model.named_parameters())

for train_param in train_parmams:
    print(train_param[0])

memory_matrix
controller.conv1.weight
controller.conv1.bias
controller.fc1.weight
controller.fc1.bias
output_vector_linear.weight
interface_vector_linear.weight
read_vectors_to_output_linear.weight


In [16]:
dnc_model.write_strenght.data


tensor([[67.8687]])

Meta (por detallar)
* Calcular memory retention vector(con los free gates)
* L temporal link matrix
* u_t usage vector