In [1]:
import numpy as np
import pandas as pd
import torch.nn as nn
import torch

from pytorch_wavelets import DTCWTForward, DTCWTInverse
from sklearn.model_selection import train_test_split
import random

from efficientnet_pytorch import EfficientNet
import pickle

import matplotlib.pyplot as plt
from scipy import fftpack

import os
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision.transforms import Lambda

from PIL import Image
import torchvision.transforms as transforms

from torch.utils.data import DataLoader


In [2]:

class EfficientNet_b0(nn.Module):

    def __init__(self,no_of_outputs_classes_for_your_dataset=20):
        """initialisation of the neural network using as a segmentation network
        the efficientnet_pytorch
        https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
        and add a few layers.

        Args:
            no_of_outputs_classes_for_your_dataset (int): depends on what we want to predict. Defaults to 2.
        """

        super(EfficientNet_b0, self).__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b0')
        self.classifier_layer = nn.Sequential(
                nn.Linear(1280 , 512),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512 , 256),
                nn.Linear(256 , no_of_outputs_classes_for_your_dataset),

            )
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)#do we want this as optimizer?
        self.mse = nn.MSELoss()
        self.lin1 = nn.Linear(1280 , no_of_outputs_classes_for_your_dataset)

    def forward(self, inputs):
        """pass forward through the network to compute the neural result we want

        Args:
            inputs (_type_): the image presented to the mices

        Returns:
            _type_: the V1 neural response
        """
        input = inputs.float()
        input = self.model.extract_features(input)
        input = self.model._avg_pooling(input)
        input = input.flatten(start_dim=1)
        inputs = self.model._dropout(input)
        x = self.classifier_layer(inputs)
        #x = self.lin1(inputs)
        return x


    def forward_train(self, inputs):
        """pass forward through the network to compute the neural result we want

        Args:
            inputs (_type_): the image presented to the mices

        Returns:
            _type_: the V1 neural response
        """
        input = inputs.float()
        with torch.no_grad():
          input = self.model.extract_features(input)
          input = self.model._avg_pooling(input)
          input = input.flatten(start_dim=1)
          inputs = self.model._dropout(input)
        x = self.classifier_layer(inputs)
        #x = self.lin1(inputs)
        return x

    def backward(self, y_pred, y_exp):
        """the backward cahnge the weights of the network using the backpropagation
        algorithm and the chosen algorithm

        Args:
            y_pred (_type_): the result of the forward pass
            y_exp (_type_): the actual result we wish to have

        Returns:
            double: the loss of this backward pass (for this batch and epoch)
        """
        y_pred = y_pred.float()
        y_exp = y_exp.float()
        loss = self.mse(y_pred, y_exp)
        self.optimizer.zero_grad()
        loss.backward()
        with torch.no_grad():
            self.optimizer.step()
        return loss

def main():
    """runs the neural network
    still needs to figure out the data shape/format, set up the mini batch and the added layers
    """
    nb_epoch = 20
    mini_batch_size = 200
    mod = EfficientNet_b0()
    for epoch in range(nb_epoch):
        running_loss = 0.0
        for i in range(0, train_input.size(0), mini_batch_size):
            #still need to see the input shape
            mod.backward(mod.forward(train_input[i,i+mini_batch_size]),train_output[i,i+mini_batch_size])



In [3]:
def split(dat):
    index = random.sample(range(0,239),int(239*0.8))
    train= dat[dat['count'].isin(index)]
    test = dat[-dat['count'].isin(index)]
    output_train = torch.tensor(train.drop(['count'],axis = 1).values)
    output_test = torch.tensor(test.drop(['count'],axis = 1).values)
    return train, test, output_train, output_test

# all images

In [4]:
def Dataload(train, test):
    transform = Lambda(lambda y: torch.cat((y,y,y),0))
    target_transform_train = Lambda(lambda y : torch.tensor(train[train.index==y.tolist()].drop(['count'],axis = 1).values).squeeze(1))
    target_transform_test = Lambda(lambda y : torch.tensor(test[test.index==y.tolist()].drop(['count'],axis = 1).values).squeeze(1))
    img_dir = './../data/images/image_'
    img_label = train['count']
    class CustomImageDataset(Dataset):
        def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
            self.img_labels = annotations_file
            self.img_dir = img_dir
            self.transform = transform
            self.target_transform = target_transform

        def __len__(self):
            return len(self.img_labels)

        def __getitem__(self, idx):
            image = read_image(self.img_dir+ str(self.img_labels.iloc[idx, 0])+'.png')
            label = self.img_labels.iloc[idx, 1]
            if self.transform:
                image = self.transform(image)
            if self.target_transform:
                label = self.target_transform(label)
            return image, label
        
    train_dataloader = DataLoader(CustomImageDataset(train.reset_index()[['count','stimulus_presentation_id']],img_dir,transform,target_transform_train), batch_size=54, shuffle=True)
    test_dataloader = DataLoader(CustomImageDataset(test.dropna().reset_index()[['count','stimulus_presentation_id']],img_dir,transform,target_transform_test), batch_size=54, shuffle=True)
    return train_dataloader, test_dataloader

In [5]:
def training(train_dataloader):
    train_features, train_labels = next(iter(train_dataloader))
    Eff=EfficientNet_b0(train_labels.size(-1)).to(device)
    epochs = 3
    l = []
    Eff.train()
    for epoch in range(epochs):
        print(epoch)
        loss_epoch = []
        for b in range(200):
            true_input,true_output = next(iter(train_dataloader))
            true_input = true_input.to(device).float()
            true_output = true_output.squeeze().to(device)
            predicted_output = Eff.forward_train(true_input)
            loss= Eff.backward(predicted_output,true_output)
            loss_epoch.append(float(loss))
            if b%20 == 0:
                print(loss)
            del true_input, true_output, predicted_output
        l.append(loss_epoch)
    return Eff


# MEI

### functions

In [6]:
from scipy import ndimage

def process(x, mu=0.4, sigma=0.224):
    """ Normalize and move channel dim in front of height and width"""
    x = (x - mu) / sigma
    if isinstance(x, torch.Tensor):
        return x.transpose(-1, -2).transpose(-2, -3)
    else:
        return np.moveaxis(x, -1, -3)

def unprocess(x, mu=0.4, sigma=0.224):
    """Inverse of process()"""
    x = x * sigma + mu
    if isinstance(x, torch.Tensor):
        return x.transpose(-3, -2).transpose(-2, -1)
    else:
        return np.moveaxis(x, -3, -1)

def roll(tensor, shift, axis):
    if shift == 0:
        return tensor

    if axis < 0:
        axis += tensor.dim()

    dim_size = tensor.size(axis)
    after_start = dim_size - shift
    if shift < 0:
        after_start = -shift
        shift = dim_size - abs(shift)

    before = tensor.narrow(axis, 0, dim_size - shift)
    after = tensor.narrow(axis, after_start, shift)
    return torch.cat([after, before], axis)

def batch_std(batch, keepdim=False, unbiased=True):
    """ Compute std for a batch of images. """
    std = batch.view(len(batch), -1).std(-1, unbiased=unbiased)
    if keepdim:
        std = std.view(len(batch), 1, 1, 1)
    return std



def fft_smooth(grad, factor=1/4):
    """
    Tones down the gradient with 1/sqrt(f) filter in the Fourier domain.
    Equivalent to low-pass filtering in the spatial domain.
    """
    if factor == 0:
        return grad
    #h, w = grad.size()[-2:]
    # grad = tf.transpose(grad, [0, 3, 1, 2])
    # grad_fft = tf.fft2d(tf.cast(grad, tf.complex64))

    # grad = tf.transpose(grad, [0, 3, 1, 2])
    # grad_fft = tf.fft2d(tf.cast(grad, tf.complex64))

    pp = torch.fft.rfft(grad.data, n = int(224*factor), norm='ortho')
    return torch.fft.irfft(pp,n = 224)


def blur(img, sigma):
    if sigma > 0:
        for d in range(len(img)):
            img[d] = ndimage.filters.gaussian_filter(img[d], sigma, order=0)
    return img


def blur_in_place(tensor, sigma):
    blurred = np.stack([blur(im, sigma) for im in tensor.cpu().numpy()])
    tensor.copy_(torch.Tensor(blurred))

In [7]:

def make_step(net, src, neuron_id, best,step_size=1.5, sigma=None, precond=0, step_gain=1,
              blur=True, jitter=0, eps=1e-12, clip=True, bias=0.4, scale=0.224,
              train_norm=None, norm=None, add_loss=0, _eps=1e-12, max = 0):
    """ Update src in place making a gradient ascent step in the output of net.

    Arguments:
        net (nn.Module or function): A backpropagatable function/module that receives
            images in (B x C x H x W) form and outputs a scalar value per image.
        src (torch.Tensor): Batch of images to update (B x C x H x W).
        step_size (float): Step size to use for the update: (im_old += step_size * grad)
        sigma (float): Standard deviation for gaussian smoothing (if used, see blur).
        precond (float): Strength of gradient smoothing.
        step_gain (float): Scaling factor for the step size.
        blur (boolean): Whether to blur the image after the update.
        jitter (int): Randomly shift the image this number of pixels before forwarding
            it through the network.
        eps (float): Small value to avoid division by zero.
        clip (boolean): Whether to clip the range of the image to be in [0, 255]
        train_norm (float): Decrease standard deviation of the image feed to the
            network to match this norm. Expressed in original pixel values. Unused if
            None
        norm (float): Decrease standard deviation of the image to match this norm after
            update. Expressed in z-scores. Unused if None
        add_loss (function): An additional term to add to the network activation before
            calling backward on it. Usually, some regularization.
    """
    if src.grad is not None:
        src.grad.zero_()

    # apply jitter shift
    if jitter > 0:
        ox, oy = np.random.randint(-jitter, jitter + 1, 2)  # use uniform distribution
        ox, oy = int(ox), int(oy)
        src.data = roll(roll(src.data, ox, -1), oy, -2)
    
    img = src
    if train_norm is not None and train_norm > 0.0:
        # normalize the image in backpropagatable manner
        img_idx = batch_std(src.data) + _eps > train_norm / scale  # images to update
        if img_idx.any():
            img = src.clone() # avoids overwriting original image but lets gradient through
            img[img_idx] = ((src[img_idx] / (batch_std(src[img_idx], keepdim=True) +
                                             _eps)) * (train_norm / scale))

    #y = net(img)
    y = net.forward(img)[:, neuron_id]
    
    if y.mean()>max:
        max = y.mean()
        best = src
        
    (y.mean()+add_loss).backward()
    grad = src.grad
    if precond > 0:
        grad = fft_smooth(grad, precond)
    
    # src.data += (step_size / (batch_mean(torch.abs(grad.data), keepdim=True) + eps)) * (step_gain / 255) * grad.data
    src.data += (step_size / (torch.abs(grad.data).mean() + eps)) * (step_gain / 255) * grad.data
    # * both versions are equivalent for a single-image batch, for batches with more than
    # one image the first one is better but it drawns out the gradients that are spatially
    # wide; for instance a gradient of size 5 x 5 pixels all at amplitude 1 will produce a
    # higher change in an image of the batch than a gradient of size 20 x 20 all at
    # amplitude 1 in another. This is alright in most cases, but when generating diverse
    # images with min linkage (i.e, all images receive gradient from the signal and two
    # get the gradient from the diversity term) it drawns out the gradient generated from
    # the diversity term (because it is usually bigger spatially than the signal gradient)
    # and becomes hard to find very diverse images (i.e., increasing the diversity term
    # has no effect because the diversity gradient gets rescaled down to smaller values
    # than the signal gradient)
    # In any way, gradient mean is only used as normalization here and using the mean is
    # alright (also image generation works normally).

    #print(src.data.std() * scale)
    if norm is not None and norm > 0.0:
        data_idx = batch_std(src.data) + _eps > norm / scale
        src.data[data_idx] =  (src.data / (batch_std(src.data, keepdim=True) + _eps) * norm / scale)[data_idx]

    if jitter > 0:
        # undo the shift
        src.data = roll(roll(src.data, -ox, -1), -oy, -2)

    if clip:
        src.data = torch.clamp(src.data, -bias / scale, (255 - bias) / scale)

    if blur:
        blur_in_place(src.data, sigma)


In [8]:

def deepdraw(net, base_img, neuron_id, octaves, random_crop=True, original_size=None,
             bias=0, scale=1, device=device, **step_params):
    """ Generate an image by iteratively optimizing activity of net.

    Arguments:
        net (nn.Module or function): A backpropagatable function/module that receives
            images in (B x C x H x W) form and outputs a scalar value per image.
        base_img (np.array): Initial image (h x w x c)
        octaves (list of dict): Configurations for each octave:
            n_iter (int): Number of iterations in this octave
            start_sigma (float): Initial standard deviation for gaussian smoothing (if
                used, see blur)
            end_sigma (float): Final standard deviation for gaussian smoothing (if used,
                see blur)
            start_step_size (float): Initial value of the step size used each iteration to
                update the image (im_old += step_size * grad).
            end_step_size (float): Initial value of the step size used each iteration to
                update the image (im_old += step_size * grad).
            (optionally) scale (float): If set, the image will be scaled using this factor
                during optimization. (Original image size is left unchanged).
        random_crop (boolean): If image to optimize is bigger than networks input image,
            optimize random crops of the image each iteration.
        original_size (triplet): (channel, height, width) expected by the network. If
            None, it uses base_img's.
        bias (float), scale (float): Values used for image normalization (at the very
            start of processing): (base_img - bias) / scale.
        device (torch.device or str): Device where the network is located.
        step_params (dict): A handful of optional parameters that are directly sent to
            make_step() (see docstring of make_step for a description).

    Returns:
        A h x w array. The optimized image.
    """
    max = 0
    
    # prepare base image
    image = process(base_img, mu=bias, sigma=scale)  # (3,224,224)

    # get input dimensions from net
    if original_size is None:
        print('getting image size:')
        c, w, h = image.shape[-3:]
    else:
        c, w, h = original_size

    print("starting drawing")

    src = torch.zeros(1, c, w, h, requires_grad=True, device=device)
    best = src
    for e, o in enumerate(octaves):
        if 'scale' in o:
            # resize by o['scale'] if it exists
            image = ndimage.zoom(image, (1, o['scale'], o['scale']))
        _, imw, imh = image.shape
        for i in range(o['iter_n']):
            if imw > w:

                if random_crop:
                    # randomly select a crop
                    # ox = random.randint(0,imw-224)
                    # oy = random.randint(0,imh-224)
                    mid_x = (imw - w) / 2.
                    width_x = imw - w
                    ox = np.random.normal(mid_x, width_x * 0.3, 1)
                    ox = int(np.clip(ox, 0, imw - w))
                    mid_y = (imh - h) / 2.
                    width_y = imh - h
                    oy = np.random.normal(mid_y, width_y * 0.3, 1)
                    oy = int(np.clip(oy, 0, imh - h))
                    # insert the crop into src.data[0]
                    src.data[0].copy_(torch.Tensor(image[:, ox:ox + w, oy:oy + h]))
                    src.data[1].copy_(torch.Tensor(image[:, ox:ox + w, oy:oy + h]))
                else:
                    ox = int((imw - w) / 2)
                    oy = int((imh - h) / 2)
                    src.data[0].copy_(torch.Tensor(image[:, ox:ox + w, oy:oy + h]))
                    src.data[1].copy_(torch.Tensor(image[:, ox:ox + w, oy:oy + h]))
            else:
                ox = 0
                oy = 0
                src.data[0].copy_(torch.Tensor(image))
            #src = torch.cat((src,src,src),0)
            sigma = o['start_sigma'] + ((o['end_sigma'] - o['start_sigma']) * i) / o['iter_n']
            step_size = o['start_step_size'] + ((o['end_step_size'] - o['start_step_size']) * i) / o['iter_n']

            make_step(net, src, neuron_id, best,  bias=bias, scale=scale, sigma=sigma, step_size=step_size, max = max,  **step_params)

            if i % 200 == 0:
                print('finished step %d in octave %d' % (i, e))

            # insert modified image back into original image (if necessary)
            image[:, ox:ox + w, oy:oy + h] = src.data[0].cpu().numpy()

    # returning the resulting image
    return src.cpu().detach().numpy()




In [9]:
def get_adj_model(models, neuron_id, mu_eye=None, pos=None, mu_beh=None):

    def adj_model(x):
        resp = models
        return resp

    return adj_model

In [10]:
def create_MEI(Eff, nb_neurons, k):
 

  octaves0 = [
  {
          'layer':'conv5',
          'iter_n':600,
          'start_sigma':1.5,
          'end_sigma':0.05,
          'start_step_size': 25.*0.25,
          'end_step_size':0.5*0.25,
      },
  ]


  octaves = octaves0

  # prepare initial image
  channels, original_h, original_w = (1,224,224)

  # the background color of the initial image
  background_color = np.float32([128] * channels)
  # generate initial random image
  gen_image = np.random.normal(background_color, 8, (original_h, original_w, channels))
  gen_image = np.append(np.append(gen_image,gen_image, axis=2),gen_image, axis = 2)
  gen_image = np.clip(gen_image, 0, 255)

  im = np.zeros((nb_neurons,3,224,224))

  for i in range(nb_neurons):
      # generate class visualization via octavewise gradient ascent
    gen = deepdraw(Eff, gen_image,i, octaves,
                          random_crop=False,
                          bias=0, scale=1, step_gain= 9.55, train_norm = 10,precond = 1)
    im[i]=gen
  file = open('./IM{}'.format(k), 'wb')
  pickle.dump(im, file)
  file.close()

In [11]:
transform = transforms.Compose([
        transforms.PILToTensor()
    ])
for t in [1,2]:
    file = open('./dat{}'.format(t), 'rb')
    dat = pickle.load(file)
    file.close()
    nb_neurons = dat.shape[1]-1
    
    train, test, output_train, output_test = split(dat)
    train_dataloader, test_dataloader = Dataload(train, test)
    Eff = training(train_dataloader)
    Eff.eval()
    pred = np.empty((0,nb_neurons))
    for  i in range (237):
        a = Image.open('./../data/images/image_{}.0.png'.format(i))
        a = a.resize((224,224))
        a = transform(a)
        true_input_test = a.to(device).float()
        true_input_test = torch.cat((true_input_test,true_input_test,true_input_test), dim = 0).int().float().reshape(1,3,224,224)
        predicted_output_test = Eff.forward_train(true_input_test)
        pred = np.vstack((pred,predicted_output_test.cpu().detach().numpy()))
    create_MEI(Eff, nb_neurons, t)
    file = open('./IM{}'.format(t), 'rb')
    IM = pickle.load(file)
    file.close()   
    MEI_answers = np.empty((0,nb_neurons))
    for  i in range (len(IM)):
        #a = np.transpose(IM[i], (1,2,0))
        a = torch.tensor(IM[i])
        true_input_test = a.to(device).float()
        true_input_test = true_input_test.int().float().reshape(1,3,224,224)
        predicted_output_test = Eff.forward_train(true_input_test)
        MEI_answers = np.vstack((MEI_answers,predicted_output_test.cpu().detach().numpy()))

    list_neurons= []
    for i in range(pred.shape[1]):
        for j in pred[:,i]:
            if MEI_answers[i,i]<j:
                list_neurons.append(i)

    file = open('./neuron_failed{}'.format(t), 'wb')
    pickle.dump(list_neurons, file)
    file.close()


Loaded pretrained weights for efficientnet-b0
0
tensor(12.3507, grad_fn=<MseLossBackward0>)
tensor(11.5637, grad_fn=<MseLossBackward0>)
tensor(9.9057, grad_fn=<MseLossBackward0>)
tensor(6.5161, grad_fn=<MseLossBackward0>)
tensor(5.4363, grad_fn=<MseLossBackward0>)
tensor(5.1245, grad_fn=<MseLossBackward0>)
tensor(4.8357, grad_fn=<MseLossBackward0>)
tensor(4.6872, grad_fn=<MseLossBackward0>)
tensor(4.4301, grad_fn=<MseLossBackward0>)
tensor(4.9437, grad_fn=<MseLossBackward0>)
1
tensor(4.7207, grad_fn=<MseLossBackward0>)
tensor(4.3947, grad_fn=<MseLossBackward0>)
tensor(4.6662, grad_fn=<MseLossBackward0>)
tensor(4.4162, grad_fn=<MseLossBackward0>)
tensor(4.4458, grad_fn=<MseLossBackward0>)
tensor(4.8322, grad_fn=<MseLossBackward0>)
tensor(4.4632, grad_fn=<MseLossBackward0>)
tensor(4.2019, grad_fn=<MseLossBackward0>)
tensor(4.1639, grad_fn=<MseLossBackward0>)
tensor(4.5797, grad_fn=<MseLossBackward0>)
2
tensor(4.0407, grad_fn=<MseLossBackward0>)
tensor(4.5526, grad_fn=<MseLossBackward0>)


  img[d] = ndimage.filters.gaussian_filter(img[d], sigma, order=0)


finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0


  img[d] = ndimage.filters.gaussian_filter(img[d], sigma, order=0)


finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0
finished step 200 in octave 0
finished step 400 in octave 0
getting image size:
starting drawing
finished step 0 in octave 0


 [0m[01;34mbrain_observatory[0m/                     grating_98.tiff
 dat                                    grating_99.tiff
 dat0                                   grating_9.tiff
 dat1                                   IM0
 dat2                                   IM1
 dat3                                   IM2
'dat4 '                                 IM3
'dat5 '                                 image_0.02_0_0.tiff
 ecephys_data_access.ipynb.ipynb        image_0.02_0_120.tiff
 ecephys_lfp_analysis.ipynb.ipynb       image_0.02_0_150.tiff
 ecephys_optotagging.ipynb.ipynb        image_0.02_0.25_0.tiff
 ecephys_quality_metrics.ipynb.ipynb    image_0.02_0.25_120.tiff
 ecephys_receptive_fields.ipynb.ipynb   image_0.02_0.25_150.tiff
 ecephys_session.ipynb.ipynb            image_0.02_0.25_30.tiff
 Eff                                    image_0.02_0.25_60.tiff
 Eff_all                                image_0.02_0.25_90.tiff
 Eff_grat                               image_0.02_0_30.tiff
 grating_0.