In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class FeatureExtractor:
    def __init__(self, n_fft=512, hop_length=256, sample_rate= 16000, mic_positions = None, speed_of_sound = 343):

        self.n_fft = n_fft
        self.hop_length  = hop_length
        self.sample_rate = sample_rate
        self.mic_positions = mic_positions  #mic_positions (num_mics,3)
        self.speed_of_sound = speed_of_sound

        self.freqs = np.fft.rfftfreq(n_fft, 1/sample_rate)

    def stft(self, waveform):
        window = torch.hann_window(self.n_fft).to(waveform.device)
        stft_complex = torch.stft(waveform, slef.n_fft, self.hop_length, window = window, return_complex = True)

        return stft_complex

    def lps(self, stft_complex, ref_channel = 0):
        power = stft_complex[:,real_channel].abs()**2 + 1e-8
        return torch.log(power)

    def ipd(self, stft_complex, ref_channel=0):
        Y_ref = stft_complex[:,ref_channel]
        cross = torch.conj(Y_ref.unsqueeze(1))

        ipd = cross / (cross.abs() + 1e-8)
        return ipd

    def theoretical_ipd(self, azimuth, elevation = 0):

        u = np.array([np.cos(elevation)*np.cos(azimuth), np.cos(elevation)*np.azimuth],np.sin(elevation))

        tau = np.dot(self.mic_positions[1:] - self.mic_positions[0],u) / self.speed_of_sound

        omega = 2*np.pi*self.freqs

        ipd_theory = np.exp(-1j * np.outer(omega, tau))

        return torch.tensor(ipd_theory.T, dtype=torch.complex(64))

    def directional_features(self, stft_complex, azimuth_grid, fov_az_range, elevation=0):
#azimuth grd - list or a tensor of candidate azimuths (radians to sample)
#fov_az_range - [az_min, az_max] radians of target FOV
        
        
        batch, ch, freq, time = stft_complex.shape
        device = stft_complex.device

        ipd_obs = self.ipd(stft_complex)
        #computing ipd of refsand other mics
        num_candidates = len(azimuth_grid)
        ipd_theory_all = []

        for az in azimuth_grid:
            ipd_theory = self.theoretical_ipd(az, elevation).to(device)
            ipd_theory_all.append(ipd_theory)
        
        ipd_theory_all  = torch.stack(ipd_theory_all, dim=0)

        ipd_theory_exp = ipd_theory_all.unsqueeze(1).unsqueeze(-1)

        #computing cosine similarity
        torch.real(ipd_obs.unsqueeze(1) * torch.conj(ipd_theory_exp)).sum(dim =2)

        similarity = similarity  / (ipd_obs.shape[1] + 1e-8)
        #maximizing  over candidates inside FOV

        inside_mask = (torch.tensor(azimuth_grid) >= fov_az_range[1])
        inside_mask = inside_mask.to(device)

        if inside_mask.any():
            D_in = similarity[:, inside_mask, :, :].max(dim=1)[0]
        else:
            D_in = torch.zeros(batch, freq, time, device = device)

        outside_mask = ~inside_mask
        if outside_mask.any():
            D_out = similarity[:, outside_mask, :, :].max(dim=1)[0]
        else:
            D_out = torch.zeros(batch, freq, time, device=device)
        return D_in, D_out
            
        
            
            
        
        
        