In [270]:
import matplotlib.pyplot as plt

from skimage.color import rgb2gray, rgba2rgb
import skimage.io as skio
from skimage.util.shape import view_as_windows
from skimage.transform import resize

import numpy as np
import pandas as pd 
from scipy import stats  
import scipy.stats
from einops import rearrange, reduce, repeat

import torch

import itertools
import os
import sys
import glob
import warnings
import tqdm
import copy

from PIL import Image

import argparse

from utils import default_paths, nsd_utils, texture_utils
from model_fitting import initialize_fitting
from feature_extraction import fwrf_features

try:
    device = initialize_fitting.init_cuda()
except:
    device = 'cpu:0'
    
class bent_gabor_feature_bank():
    
    def __init__(self, freq_values=None, bend_values=None, orient_values=None, \
                 image_size=128, device='cpu:0'):
        
        self.image_size = image_size;
        self.device = device
        
        self.__set_kernel_params__(freq_values, bend_values, orient_values)
        self.__generate_kernels__()
        
    def __set_kernel_params__(self, freq_values, bend_values, orient_values):
        
        """
        Set some default params for the banana kernels.
        sigmaXbend:  sigma for the bent gaussian in x-direction
        sigmaYbend:  sigma for the bent gaussian in y-direction
        xA_shift:    center shift in x direction
        yA_shift:    center shift in y direction   
        
        freq_values: freq of filters, cyc/image
        orient_values: orientation of filters, 0-2pi
        bend_values: control bending of filters
        """
        
        self.sigmaXbend = 2;
#         self.sigmaXbend = 3;
        self.sigmaYbend = 6;
        self.kernel_size = self.image_size
        self.xA_shift   = 0
        self.yA_shift   = 0
        
        if freq_values is None:
            self.freq_values = [64, 32, 16, 8]
        else:
            self.freq_values = freq_values
            
        nyquist = 0.5*self.kernel_size
        if any(np.array(self.freq_values)>nyquist):
            raise ValueError('for image of size %d x %d, must have freqs < %.2f'%\
                            (self.kernel_size, self.kernel_size, nyquist))
        self.kA = np.array(self.freq_values)*2*np.pi / self.kernel_size
        self.scale_values = np.log(2*np.pi/self.kA)/np.log(np.sqrt(2))
        
        if orient_values is None:
            self.orient_values = np.linspace(0,2*np.pi, 9)[0:8]
        else:
            self.orient_values = orient_values
        if bend_values is None:
            self.bend_values = [0, 0.02,0.07,0.10,0.18,0.45]
        else:
            self.bend_values = bend_values
            
        print('freq values')
        print(self.freq_values)
        print('scale values')
        print(self.scale_values)
        print('bend values')
        print(self.bend_values)
        print('orient values')
        print(self.orient_values)

    def __make_bananakernel__(self, kA, bA, alphaA, is_curved):
        
        """
        Generate banana wavelet kernels.  The kernels
        can be used to filter a image to quantify curvatures.

        kA:          scale param, length of the wave vector K
                     kA =  2*np.pi/((np.sqrt(2))**scale)
                     filter frequency: (cycle/object) = kA*kernel_size / (2*pi)
        bA:          bending value b (arbitrary, roughly between 0-0.5)
        alphaA:      direction of the wave vector (i.e. orientation in rad)
        is_curved:   Are we making a curved gabor? If false, making a sharp angle detector.
                     Note if bA==0, then these are the same. 

        return SpaceKernel, FreqKernel

        """
        
        assert not (isinstance(bA, complex))
       
        kernel_size = self.kernel_size
        if kernel_size%2 !=0:
            kernel_size = kernel_size + 1
        [xA, yA] = np.meshgrid(np.arange(-kernel_size/2, kernel_size/2,1),np.arange(-kernel_size/2, kernel_size/2,1)) 
        xA = xA - self.xA_shift
        yA = yA - self.yA_shift

        xRotL = np.cos(alphaA)*xA + np.sin(alphaA)*yA 
        yRotL = np.cos(alphaA)*yA - np.sin(alphaA)*xA

        if is_curved:
            # make a curved "banana" gabor.
            scale = np.log(2*np.pi/kA)/np.log(np.sqrt(2))
            xRotBendL = xRotL + bA/scale * (yRotL)**2
        else:
            # otherwise making a sharp angle detector, use abs instead of squaring.
            # adjusting the constant here to make the bA values ~similar across curved/angle filters.
            xRotBendL = xRotL + bA*4 * np.abs(yRotL)
            
        yRotBendL = yRotL

        """make the DC free""" 
        tmpgaussPartA = np.exp(-0.5*(kA)**2*((xRotBendL/self.sigmaXbend)**2 + (yRotBendL/(self.sigmaYbend))**2))
        tmprealteilL  = 1*tmpgaussPartA*(np.cos(kA*xRotBendL) - 0)
        tmpimagteilL  = 1*tmpgaussPartA*(np.sin(kA*xRotBendL) - 0)

        numeratorRealL = np.sum(tmprealteilL)
        numeratorImagL = np.sum(tmpimagteilL)
        denominatorL   = np.sum(tmpgaussPartA)

        DCValueAnalysis = np.exp(-0.5 * self.sigmaXbend * self.sigmaXbend)
        if denominatorL==0:
            DCPartRealA = DCValueAnalysis
            DCPartImagA = 0
        else:    
            DCPartRealA = numeratorRealL/denominatorL
            DCPartImagA = numeratorImagL/denominatorL
            if DCPartRealA < DCValueAnalysis:
                DCPartRealA = DCValueAnalysis
                DCPartImagA = 0

        """generate a space kernel""" 
        preFactorA = kA**2
        gaussPartA = np.exp(-0.5*(kA)**2*((xRotBendL/self.sigmaXbend)**2 + (yRotBendL/(self.sigmaYbend))**2))
        realteilL  = preFactorA*gaussPartA*(np.cos(kA*xRotBendL) - DCPartRealA)
        imagteilL  = preFactorA*gaussPartA*(np.sin(kA*xRotBendL) - DCPartImagA)

        """normalize the kernel"""  
        normRealL   = np.sqrt(np.sum(realteilL**2))
        normImagL   = np.sqrt(np.sum(imagteilL**2))
        normFactorL = kA**2

        total_std = normRealL + normImagL
        if total_std == 0:
            total_std = 10**20
        norm_realteilL = realteilL*normFactorL/(0.5*total_std)
        norm_imagteilL = imagteilL*normFactorL/(0.5*total_std)
        
        space_kernel = norm_realteilL + norm_imagteilL*1j
        freq_kernel = np.fft.ifft2(space_kernel)
        
        return space_kernel, freq_kernel
 
    def __generate_kernels__(self):
        
        """
        Make the bank of filters.
        """
        
        n_scales = len(self.scale_values)
        n_orients = len(self.orient_values)
        n_bends = len(self.bend_values)
        
        curv_freq_kernels,rect_freq_kernels,lin_freq_kernels, \
            curv_space,rect_space,lin_space = [],[],[],[],[],[]
        
        curv_kernel_pars = np.zeros((n_scales*(n_bends-1)*n_orients, 4))
        rect_kernel_pars = np.zeros((n_scales*(n_bends-1)*n_orients, 4))
        lin_kernel_pars = np.zeros((n_scales*n_orients, 4))
        
        ci=-1; ri=-1; li=-1
            
        for is_curved in [True, False]:

            for kA, bA, alphaA in itertools.product(self.kA, self.bend_values, self.orient_values):

                space_kernel, freq_kernel = self.__make_bananakernel__(kA, bA, alphaA, is_curved)

                if bA == 0:
                    if not is_curved:
                        # the linear kernels each get defined twice (once with is_curv=True and False)
                        # only counting one occurence of each.
                        lin_freq_kernels.append(freq_kernel)
                        lin_space.append(space_kernel.real) 
                        li+=1
                        lin_kernel_pars[li,:] = [kA, bA, alphaA, is_curved]
                    else:
                        continue
                elif is_curved:
                    # this is a curved banana filter
                    curv_freq_kernels.append(freq_kernel)
                    curv_space.append(space_kernel.real)
                    ci+=1
                    curv_kernel_pars[ci,:] = [kA, bA, alphaA, is_curved]
                else:
                    # this is a second-order rectilinear filter
                    rect_freq_kernels.append(freq_kernel)
                    rect_space.append(space_kernel.real)
                    ri+=1
                    rect_kernel_pars[ri,:] = [kA, bA, alphaA, is_curved]
                    

        self.kernels = {'curv_freq':curv_freq_kernels, 'curv_space':curv_space,
                        'rect_freq':rect_freq_kernels, 'rect_space':rect_space, 
                        'lin_freq':lin_freq_kernels, 'lin_space':lin_space}
        self.rect_kernel_pars = rect_kernel_pars
        self.curv_kernel_pars = curv_kernel_pars
        self.lin_kernel_pars = lin_kernel_pars
        self.n_rect_kernels = self.rect_kernel_pars.shape[0]
        self.n_curv_kernels = self.curv_kernel_pars.shape[0]
      
    def plot_kernel_bends(self, ori_ind=0, scale_ind=0):

        rect_kernel_pars = self.rect_kernel_pars
        curv_kernel_pars = self.curv_kernel_pars
        lin_kernel_pars = self.lin_kernel_pars
        rect_spat_kernel_list = self.kernels['rect_space']
        curv_spat_kernel_list = self.kernels['curv_space']
        lin_spat_kernel_list = self.kernels['lin_space']

        plt.figure(figsize=(20,12))
        npx = 3;
        npy = len(self.bend_values)

        ori = self.orient_values[ori_ind]
        sc = self.kA[scale_ind]

        kk2plot = np.where((rect_kernel_pars[:,2]==ori) & (rect_kernel_pars[:,0]==sc))[0]
        for ki, kk in enumerate(kk2plot):
            plt.subplot(npx, npy, ki+1)
            plt.pcolormesh(rect_spat_kernel_list[kk])
            plt.axis('square')
            plt.gca().invert_yaxis()
            plt.axis('off')
            plt.title('bend=%.2f'%(rect_kernel_pars[kk,1]))

        kk2plot = np.where((curv_kernel_pars[:,2]==ori) & (curv_kernel_pars[:,0]==sc))[0]
        for ki, kk in enumerate(kk2plot):
            plt.subplot(npx, npy,ki+npy+1)
            plt.pcolormesh(curv_spat_kernel_list[kk])
            plt.axis('square')
            plt.gca().invert_yaxis()
            plt.axis('off')
            plt.title('bend=%.2f'%(curv_kernel_pars[kk,1]))

        kk2plot = np.where((lin_kernel_pars[:,2]==ori) & (lin_kernel_pars[:,0]==sc))[0]
        for ki, kk in enumerate(kk2plot):
            plt.subplot(npx, npy ,ki+npy*2+1)
            plt.pcolormesh(lin_spat_kernel_list[kk])
            plt.axis('square')
            plt.gca().invert_yaxis()
            plt.axis('off')
            plt.title('bend=%.2f'%(lin_kernel_pars[kk,1]))

        plt.suptitle('ori=%.2f rad, freq=%.2f cyc/im'%(ori, self.freq_values[scale_ind]))
        
    def filter_image_batch(self, image_batch, which_kernels='curv'):
        
        
        if which_kernels=='curv':
            kernel_list = self.kernels['curv_freq']
        elif which_kernels=='rect':
            kernel_list = self.kernels['rect_freq']
        elif which_kernels=='linear':
            kernel_list = self.kernels['lin_freq']
        else:
            raise ValueError('which_kernels must be one of [curv, rect, linear]')

        """image x, image y, kernel dimension, all images (4D array)"""
        all_kernels = np.dstack(kernel_list)
        
        """calculate kernel norm for normalization"""
        all_kernels_power =  np.einsum('ijk,ijk->k',np.abs(all_kernels),np.abs(all_kernels))
        all_kernels_power =  np.sqrt(all_kernels_power)

        """stack fft image list to a 3d array"""
        image_batch_fft = np.fft.fft2(image_batch, axes=(0,1))
        
        all_conved_images = np.abs(np.fft.ifft2(image_batch_fft[:,:,np.newaxis,:]*all_kernels[:,:,:,np.newaxis],axes=(0,1)))
        all_conved_images = np.power(all_conved_images,1/2) ## power correction
        all_conved_images = all_conved_images/all_kernels_power[np.newaxis, np.newaxis,:,np.newaxis]
    
        return np.fft.fftshift(all_conved_images, axes=(0,1))
    
    def filter_image_batch_pytorch(self, image_batch, which_kernels='all', to_numpy=True):

        if which_kernels=='curv':
            kernel_list = self.kernels['curv_freq']
        elif which_kernels=='rect':
            kernel_list = self.kernels['rect_freq']
        elif which_kernels=='linear':
            kernel_list = self.kernels['lin_freq']
        elif which_kernels=='all':
            kernel_list = self.kernels['curv_freq']+self.kernels['rect_freq']+self.kernels['lin_freq']
        else:
            raise ValueError('which_kernels must be one of [curv, rect, linear, all]')

        # stack all the filters together, [self.kernel_size, self.kernel_size, n_filters]
        # and send to specified device.
        all_kernels = np.dstack(kernel_list)
        all_kernels_tensor = torch.complex(torch.Tensor(np.real(all_kernels)), \
                                           torch.Tensor(np.imag(all_kernels)))
        
        # Compute power of each kernel, will use to normalize the convolution result.
        all_kernels_power =  torch.sum(torch.sum(torch.pow(torch.abs(all_kernels_tensor), 2), axis=0), axis=0)
        all_kernels_power =  torch.sqrt(all_kernels_power)

        # send image batch to device [self.image_size, self.image_size, n_images]
        image_batch_tensor = torch.Tensor(image_batch).to(self.device)
        # get frequency domain representation of images
        image_batch_fft = torch.fft.fftn(image_batch_tensor, dim=(0,1))

        # apply the filters by multiplying all at once
        mult = image_batch_fft.view([self.image_size, self.image_size,1,-1]) * \
                all_kernels_tensor.view([self.image_size, self.image_size,-1,1])
        # back to spatial domain 
        all_conved_images = torch.abs(torch.fft.ifftn(mult,dim=(0,1)))
        # power correction
        all_conved_images = torch.pow(all_conved_images,1/2) 
        all_conved_images = all_conved_images/ \
                    all_kernels_power.view([1,1,all_kernels_power.shape[0],1])
        
        # shift back to original spatial configuration
        all_conved_images = torch.fft.fftshift(all_conved_images, dim=(0,1))

        if to_numpy:
            all_conved_images = all_conved_images.detach().cpu().numpy()
            
        return all_conved_images

#device: 1
device#: 0
device name: GeForce RTX 2080 Ti

torch: 1.8.1+cu111
cuda:  11.1
cudnn: 8005
dtype: torch.float32


In [None]:

best_curv_kernel = np.argmax(mean_curv_over_space, axis=1)
best_rect_kernel = np.argmax(mean_rect_over_space, axis=1)
best_lin_kernel = np.argmax(mean_lin_over_space, axis=1)

curv_z = scipy.stats.zscore(mean_curv_over_space, axis=0)
rect_z = scipy.stats.zscore(mean_rect_over_space, axis=0)
lin_z = scipy.stats.zscore(mean_lin_over_space, axis=0)

best_curv_kernel_z = np.argmax(curv_z, axis=1)
best_rect_kernel_z = np.argmax(rect_z, axis=1)
best_lin_kernel_z = np.argmax(lin_z, axis=1)

mean_curv_z = np.mean(curv_z, axis=1)
mean_rect_z = np.mean(rect_z, axis=1)
mean_lin_z = np.mean(lin_z, axis=1)

curv_rect_index = (mean_curv_z - mean_rect_z - mean_lin_z) / \
                  (mean_curv_z + mean_rect_z + mean_lin_z)


In [263]:
len(bank.kernels['curv_freq']+bank.kernels['rect_freq']+bank.kernels['lin_freq'])

352

In [271]:
bank = bent_gabor_feature_bank(device='cpu:0')

freq values
[64, 32, 16, 8]
scale values
[2. 4. 6. 8.]
bend values
[0, 0.02, 0.07, 0.1, 0.18, 0.45]
orient values
[0.         0.78539816 1.57079633 2.35619449 3.14159265 3.92699082
 4.71238898 5.49778714]


In [309]:
thing = np.random.normal(0,1,[100,100,200])
time = time.time
ans = np.mean(np.mean(thing, axis=0), axis=0)

In [294]:
image_batch = np.random.normal(0,1,[128,128,50])

In [295]:
import time

st = time.time()
output1 = np.concatenate([bank.filter_image_batch_pytorch(image_batch, k) for k in ['curv','rect', 'linear']], axis=2)
elapsed = time.time() - st
print('elapsed time = %.5f s'%elapsed)
st = time.time()
output2 = bank.filter_image_batch_pytorch(image_batch, 'all')
elapsed = time.time() - st
print('elapsed time = %.5f s'%elapsed)

In [302]:
st = time.time()
output3 = np.concatenate([bank.filter_image_batch(image_batch, k) for k in ['curv','rect', 'linear']], axis=2)
elapsed = time.time() - st
print('elapsed time = %.5f s'%elapsed)

elapsed time = 54.49722 s


In [296]:
output1.shape, output2.shape

((128, 128, 352, 50), (128, 128, 352, 50))

In [297]:
output1.dtype, output2.dtype

(dtype('float32'), dtype('float32'))

In [300]:
output1[100,120,300,:]  

array([1.4696233, 1.4214585, 1.7884996, 2.476495 , 0.7059956, 1.3420668,
       1.8813096, 2.1318772, 2.8678038, 2.6955152, 1.8448137, 1.6807752,
       2.3423147, 1.7749122, 2.268266 , 1.711766 , 2.5058978, 2.2039583,
       0.6841866, 1.3269516, 2.5777104, 1.5785817, 1.0995805, 2.1393645,
       2.17551  , 1.4904385, 1.6286892, 2.0808933, 1.0770144, 1.4902999,
       2.4968717, 3.195699 , 2.496212 , 1.8734933, 2.2282205, 2.0609915,
       1.793022 , 1.2400877, 2.4840724, 2.538421 , 3.2831697, 1.8803866,
       2.2383482, 2.0741968, 1.4740773, 1.7631953, 1.6768072, 3.549436 ,
       1.4978232, 0.9511576], dtype=float32)

In [301]:
output2[100,120,300,:]

array([1.4696233, 1.4214585, 1.7884996, 2.476495 , 0.7059956, 1.3420668,
       1.8813096, 2.1318772, 2.8678038, 2.6955152, 1.8448137, 1.6807752,
       2.3423147, 1.7749122, 2.268266 , 1.711766 , 2.5058978, 2.2039583,
       0.6841866, 1.3269516, 2.5777104, 1.5785817, 1.0995805, 2.1393645,
       2.17551  , 1.4904385, 1.6286892, 2.0808933, 1.0770144, 1.4902999,
       2.4968717, 3.195699 , 2.496212 , 1.8734933, 2.2282205, 2.0609915,
       1.793022 , 1.2400877, 2.4840724, 2.538421 , 3.2831697, 1.8803866,
       2.2383482, 2.0741968, 1.4740773, 1.7631953, 1.6768072, 3.549436 ,
       1.4978232, 0.9511576], dtype=float32)

In [268]:
import time

st = time.time()
output1 = bank.filter_image_batch_pytorch(image_batch, 'rect')
elapsed = time.time() - st
print('elapsed time = %.5f s'%elapsed)
st = time.time()
output2 = bank.filter_image_batch(image_batch, 'rect')
elapsed = time.time() - st
print('elapsed time = %.5f s'%elapsed)

elapsed time = 0.77986 s
elapsed time = 2.77879 s


In [269]:
output2.shape

(128, 128, 160, 5)

In [254]:
output1.dtype

torch.float32

In [255]:
output2.dtype

dtype('float64')

In [256]:
output1[100,120,10,:]
    

tensor([0.2784, 0.2287, 0.2413, 0.1713, 0.1855])

In [257]:
output2[100,120,10,:]
    

array([0.27844459, 0.22866219, 0.24127039, 0.17133451, 0.18545024])

In [225]:

kernel_list = bank.kernels['curv_freq']
 
"""image x, image y, kernel dimension, all images (4D array)"""
all_kernels = np.dstack(kernel_list)

"""calculate kernel norm for normalization"""
all_kernels_power =  np.einsum('ijk,ijk->k',np.abs(all_kernels),np.abs(all_kernels))
all_kernels_power =  np.sqrt(all_kernels_power)
all_kernels_power_1 = all_kernels_power


"""stack fft image list to a 3d array"""
image_batch_fft = np.fft.fft2(image_batch, axes=(0,1))

mult = image_batch_fft[:,:,np.newaxis,:]*all_kernels[:,:,:,np.newaxis]

all_conved_images = np.abs(np.fft.ifft2(mult,axes=(0,1)))
# all_conved_images_1 = all_conved_images
all_conved_images = np.power(all_conved_images,1/2) ## power correction
all_conved_images_1 = all_conved_images
# all_conved_images = all_conved_images/all_kernels_power[np.newaxis, np.newaxis,:,np.newaxis]

# all_conved_images_1 = all_conved_images

In [226]:
all_kernels_power_1[50:60]

array([0.02727183, 0.02727721, 0.02727183, 0.02727721, 0.02727183,
       0.02727721, 0.02727266, 0.02727722, 0.02727266, 0.02727722])

In [227]:
all_kernels_power_2[50:60]

tensor([0.0273, 0.0273, 0.0273, 0.0273, 0.0273, 0.0273, 0.0273, 0.0273, 0.0273,
        0.0273])

In [184]:
mult_1[10,90,50,:]

array([-2.93472890e-09-3.40413453e-10j,  1.77749642e-09-1.50342876e-09j,
        3.93845074e-09-5.78986695e-09j,  4.38945932e-09+4.00415532e-09j,
       -6.78794118e-09+3.18103287e-09j])

In [185]:
mult_2[10,90,50,:]

tensor([-2.9347e-09-3.4041e-10j,  1.7775e-09-1.5034e-09j,
         3.9385e-09-5.7899e-09j,  4.3895e-09+4.0042e-09j,
        -6.7879e-09+3.1810e-09j])

In [209]:
all_conved_images_1[50,100,80,:]

array([0.00593715, 0.00647518, 0.00165913, 0.00410518, 0.00667207])

In [211]:
all_conved_images_2[50,100,80,:]

tensor([0.0059, 0.0065, 0.0017, 0.0041, 0.0067])

In [204]:
all_conved_images_2.dtype

torch.float32

In [203]:
all_conved_images_2[50,100,80,:]

tensor([3.5250e-05, 4.1928e-05, 2.7527e-06, 1.6853e-05, 4.4517e-05])

In [224]:

# stack all the filters together, [self.kernel_size, self.kernel_size, n_filters]
# and send to specified device.
all_kernels = np.dstack(kernel_list)
all_kernels_tensor = torch.complex(torch.Tensor(np.real(all_kernels)), \
                                   torch.Tensor(np.imag(all_kernels)))

# Compute power of each kernel, will use to normalize the convolution result.
all_kernels_power =  torch.sum(torch.sum(torch.pow(torch.abs(all_kernels_tensor), 2), axis=0), axis=0)
all_kernels_power =  torch.sqrt(all_kernels_power)
all_kernels_power_2 = all_kernels_power

# send image batch to device [self.image_size, self.image_size, n_images]
image_batch_tensor = torch.Tensor(image_batch).to(bank.device)
# get frequency domain representation of images
image_batch_fft = torch.fft.fftn(image_batch_tensor, dim=(0,1))

# apply the filters by multiplying all at once
mult = image_batch_fft.view([bank.image_size, bank.image_size,1,-1]) * \
        all_kernels_tensor.view([bank.image_size, bank.image_size,-1,1])
# # back to spatial domain 
all_conved_images = torch.abs(torch.fft.ifftn(mult,dim=(0,1)))
# all_conved_images_2 = all_conved_images
# # power correction
all_conved_images = torch.pow(all_conved_images,1/2) 
# all_conved_images_2 = all_conved_images
# all_conved_images = all_conved_images/ \
#             all_kernels_power.view([1,1,all_kernels_power.shape[0],1])
all_conved_images_2 = all_conved_images
# all_conved_images_2 = all_conved_images

In [159]:
thing = 5+6j
thing = np.array([thing])
thing

array([5.+6.j])

In [160]:
torch.Tensor(npz.array([thing]))

tensor([[5.]])

In [175]:
tc = torch.complex(torch.Tensor(np.real(thing)), torch.Tensor(np.imag(thing)))
tc

tensor([5.+6.j])

In [176]:
torch.abs(tc)

tensor([7.8102])

In [177]:
np.abs(thing)

array([7.81024968])

In [163]:
torch.Tensor(thing, dtype=torch.cfloat)

TypeError: new() received an invalid combination of arguments - got (numpy.ndarray, dtype=torch.dtype), but expected one of:
 * (*, torch.device device)
      didn't match because some of the keywords were incorrect: dtype
 * (torch.Storage storage)
 * (Tensor other)
 * (tuple of ints size, *, torch.device device)
 * (object data, *, torch.device device)


In [11]:
all_kernels = np.random.normal(0,1,[40,40,5])

In [12]:
all_kernels_power =  np.einsum('ijk,ijk->k',np.abs(all_kernels),np.abs(all_kernels))
all_kernels_power =  np.sqrt(all_kernels_power)
all_kernels_power

array([40.5405043 , 40.97954212, 39.59402806, 40.80134468, 39.94954381])

In [13]:
all_kernels_power = np.sum(np.sum(all_kernels**2, axis=0), axis=0)
all_kernels_power =  np.sqrt(all_kernels_power)
all_kernels_power

array([40.5405043 , 40.97954212, 39.59402806, 40.80134468, 39.94954381])

In [14]:
all_kernels_torch = torch.Tensor(all_kernels).to('cpu:0')
all_kernels_power =  torch.sum(torch.sum(torch.pow(all_kernels_torch, 2), axis=0), axis=0)
all_kernels_power =  torch.sqrt(all_kernels_power)
all_kernels_power

tensor([40.5405, 40.9795, 39.5940, 40.8013, 39.9495])

In [71]:
image_batch = np.random.normal(0,1,[40,40,2])


In [72]:
image_batch_fft = np.fft.fft2(image_batch, axes=(0,1))
       
all_kernels_power = np.sum(np.sum(all_kernels**2, axis=0), axis=0)
all_kernels_power =  np.sqrt(all_kernels_power)

all_conved_images = np.abs(np.fft.ifft2(image_batch_fft[:,:,np.newaxis,:]*all_kernels[:,:,:,np.newaxis],axes=(0,1)))
all_conved_images = np.power(all_conved_images,1/2) ## power correction
all_conved_images = all_conved_images/all_kernels_power[np.newaxis, np.newaxis,:,np.newaxis]

In [73]:
all_conved_images[20,20,:,:]

array([[0.01810492, 0.01245891],
       [0.00791344, 0.0263366 ],
       [0.03090774, 0.02566644],
       [0.02327567, 0.0121277 ],
       [0.01799749, 0.02118772]])

In [74]:
image_batch_torch = torch.Tensor(image_batch).to('cpu:0')
image_batch_fft = torch.fft.fftn(image_batch_torch, dim=(0,1))
        
all_kernels_torch = torch.Tensor(all_kernels).to('cpu:0')
all_kernels_power =  torch.sum(torch.sum(torch.pow(all_kernels_torch, 2), axis=0), axis=0)
all_kernels_power =  torch.sqrt(all_kernels_power)

mult = image_batch_fft.view([image_batch_fft.shape[0],image_batch_fft.shape[1],1,-1]) * \
        all_kernels_torch.view([all_kernels_torch.shape[0], all_kernels_torch.shape[1],-1,1])
all_conved_images = torch.abs(torch.fft.ifftn(mult,dim=(0,1)))
all_conved_images = torch.pow(all_conved_images,1/2) ## power correction
all_conved_images = all_conved_images/ \
            all_kernels_power.view([1,1,all_kernels_power.shape[0],1])

In [75]:
all_conved_images[20,20,:,:]

tensor([[0.0181, 0.0125],
        [0.0079, 0.0263],
        [0.0309, 0.0257],
        [0.0233, 0.0121],
        [0.0180, 0.0212]])