In [31]:
import torch
import torch.nn as nn
import numpy as np

ATTS = []
DX = []

class self_attention(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,q_dim,k_dim,v_dim,v_mapping,activation='relu'):
        super(self_attention,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.v_mapping = v_mapping

        self.q_dim = q_dim
        self.k_dim = k_dim
        self.v_dim = v_dim

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = q_dim , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = k_dim , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = v_dim , kernel_size= 1)
        self.o_proj = nn.Conv2d(in_channels = v_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) 
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)        
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        if self.v_mapping == True:
            proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
        else:
            proj_value = x.view(m_batchsize,-1,width*height) # B X C X N


        out = torch.bmm(proj_value,attention.permute(0,2,1))
        
        if self.v_mapping == True:
            out = out.view(m_batchsize,self.v_dim,width,height)
            out = self.o_proj(out)
        else:
            out = out.view(m_batchsize,C,width,height)

        DX.append(self.gamma*out)
        out = self.gamma*out + x
        return out,attention


class net_one_neuron_sa_central(nn.Module):
    def __init__(self,q_dim,k_dim,v_dim,v_mapping):
        super().__init__()
        self.layers_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=30, kernel_size=(5, 5), stride=(1, 1)),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(30),
            nn.Sigmoid(),
            nn.Dropout2d(0.3),
            nn.Conv2d(in_channels=30, out_channels=30, kernel_size=(5, 5), stride=(1, 1)),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(30),
            nn.Sigmoid(),
            nn.Dropout2d(0.3)
         ) #[N,30,9,9] 
        
        self.attention = self_attention(in_dim=30,q_dim=q_dim,k_dim=k_dim,v_dim=v_dim,v_mapping=v_mapping)

        self.layers_sa_reg = nn.Sequential(
            nn.BatchNorm2d(30),
            nn.Sigmoid(),
            nn.Dropout2d(0.3)
        )

        self.layers_2 = nn.Sequential(
            nn.Conv2d(in_channels=30, out_channels=30, kernel_size=(3, 3), stride=(1, 1)),
            nn.BatchNorm2d(30),
            nn.Sigmoid(),
            nn.Dropout2d(0.3), #or here
            nn.Conv2d(in_channels=30, out_channels=30, kernel_size=(3, 3), stride=(1, 1)),
            nn.BatchNorm2d(30),
            nn.Sigmoid(),
        )
        self.Linear = nn.Linear(30, 1)

    def forward(self, x):
        x = self.layers_1(x)
        x,att = self.attention(x) 
        ATTS.append(att)
        x = self.layers_sa_reg(x)
        x = self.layers_2(x)
        x = x.reshape(-1,30,25)
        x = x[:,:,12]
        x = self.Linear(x)
        return x


class seperate_core_model_sa_central(nn.Module):
    def __init__(self,num_neurons,q_dim,k_dim,v_dim,v_mapping):
        super().__init__()
        self.models = nn.ModuleList([net_one_neuron_sa_central(q_dim=q_dim,k_dim=k_dim,v_dim=v_dim,v_mapping=v_mapping) for i in range(num_neurons)])
        self.num_neurons = num_neurons

    def forward(self, x):
        outputs = [self.models[i].forward(x) for i in range(self.num_neurons)]
        outputs = torch.stack(outputs, dim=1)
        return outputs.reshape((outputs.shape[0], outputs.shape[1]))


In [32]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from torch.utils.data import Dataset,DataLoader
from utils import *

num_neurons = 10
net = seperate_core_model_sa_central(num_neurons=10,q_dim=30,k_dim=30,v_dim=30,v_mapping=True)
net.load_state_dict(torch.load("./sa_cnn_central_q_k_v_vm_30_30_30_True/model.pth"))

vimg = np.load('../all_sites_data_prepared/pics_data/val_img_m1s1.npy')
vresp = np.load('../all_sites_data_prepared/New_response_data/valRsp_m1s1.npy')
vimg = np.reshape(vimg,(1000,1,50,50))

In [33]:
pc_firstx_neuron(net,10,vimg,vresp)

(0.5331493489960433,
 [0.41079651791595057,
  0.3265673600632211,
  0.673245414508159,
  0.40098287141458533,
  0.4550501783307829,
  0.856060944056603,
  0.6136948878025793,
  0.5238667172533987,
  0.5031307727728765,
  0.568097825842277])

In [34]:
ATTS

[tensor([[[1.7741e-02, 2.7367e-03, 6.7806e-03,  ..., 2.4446e-04,
           2.0136e-03, 4.7802e-03],
          [1.2411e-02, 1.3878e-03, 6.4973e-03,  ..., 7.7408e-05,
           1.1413e-03, 2.9926e-03],
          [1.4917e-02, 1.0221e-03, 8.3252e-03,  ..., 5.9485e-05,
           1.5380e-03, 3.8626e-03],
          ...,
          [2.0576e-03, 3.7853e-04, 1.8651e-03,  ..., 1.7190e-05,
           7.0476e-05, 3.7876e-04],
          [9.0118e-03, 1.3719e-03, 4.3224e-03,  ..., 7.4001e-05,
           5.2315e-04, 1.7786e-03],
          [1.1845e-02, 1.2356e-03, 3.0477e-03,  ..., 5.4386e-05,
           5.5251e-04, 2.1978e-03]],
 
         [[2.4181e-03, 2.4643e-03, 1.2114e-03,  ..., 5.6786e-04,
           1.2412e-02, 5.3783e-02],
          [1.8320e-03, 1.9029e-03, 1.1459e-03,  ..., 3.2941e-04,
           8.3265e-03, 4.5266e-02],
          [1.2798e-03, 1.2638e-03, 7.9750e-04,  ..., 1.9951e-04,
           6.6282e-03, 4.5335e-02],
          ...,
          [3.9699e-04, 5.5583e-04, 3.9178e-04,  ..., 7.349

In [36]:
np.save("model2a_atts",ATTS)
np.save("model2a_dx",DX)

In [None]:
from sklearn import preprocessing as p
figure, axis = plt.subplots(2, 5, figsize=(30,10))
areas = []
for neuron in range(10):
    temp =  ATTS[neuron].flatten()
    min_max_scaler = p.MinMaxScaler()
    temp = min_max_scaler.fit_transform(temp.reshape(-1,1))
    temp = sorted(np.array(temp).flatten())
    xs = [i/(81**2) for i in range(81**2)]
    area = np.trapz(temp,dx=1/(81**2))
    areas.append(area)
    axis[neuron//5,neuron%5].plot(xs,temp)
    axis[neuron//5,neuron%5].set_title("1 - AUC =  " + str(1-area))
plt.show()
print(str(1 - (sum(areas) / len(areas))))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

figure, axis = plt.subplots(5, 6, figsize=(21,14))

for neuron in range(5,6):
    for channel in range(30):
        axis[channel//6,channel%6].scatter(x=X[neuron][0][channel].flatten(),y=DX[neuron][0][channel].flatten(),s=5)
        axis[channel//6,channel%6].set_xlabel("x")
        axis[channel//6,channel%6].set_ylabel("dx")
    plt.show()




In [None]:
import matplotlib.pyplot as plt
import numpy as np

figure, axis = plt.subplots(10, 10, figsize=(100,100))

for neuron in range(100):
    im = axis[neuron//10,neuron%10].imshow(ATTS[5][neuron].detach().cpu().numpy(), cmap='hot', interpolation='nearest')
    axis[neuron//10,neuron%10].set_title("image number: " + str(neuron))
plt.show()

# for neuron in range(10):
#     a = ATTS[neuron][0].detach().cpu().numpy()
#     plt.imshow(a, cmap='hot', interpolation='nearest')
#     plt.colorbar()
#     plt.title("test 8 neuron number: " + str(neuron))
#     plt.show()





In [None]:
import matplotlib.pyplot as plt
import numpy as np

figure, axis = plt.subplots(2, 5, figsize=(30,10))

for neuron in range(10):
    im = axis[neuron//5,neuron%5].imshow(ATTS[neuron][0].detach().cpu().numpy(), cmap='hot', interpolation='nearest')
    axis[neuron//5,neuron%5].set_title("test 8 neuron number: " + str(neuron))
plt.show()

# for neuron in range(10):
#     a = ATTS[neuron][0].detach().cpu().numpy()
#     plt.imshow(a, cmap='hot', interpolation='nearest')
#     plt.colorbar()
#     plt.title("test 8 neuron number: " + str(neuron))
#     plt.show()



