# Homework №1

    In this homework you will need to implement the following stuff:
        1) Discrete Fourier Transform
        2) Fast Fourier Transform
        3) Compare by performance
        4) Short-time Fourier Transform based on (2) and hann window function
        5) MelScale
        6) Digit classification based on you melspectrograms
        
    Note:
        You should test your implementation with torchaudio functions
        (e.g. torch.allclose(torchaudio.transforms.Spectrogram.__call__, your_function))

### Main rules
    1) All operations must be implemented with pytorch (don't use numpy)
    2) Everything should support batch input
    3) No cycles, only matrix multiplications
    4) Clean and clear code 

In [None]:
!pip install torchaudio

In [44]:
import torch
import torchaudio

In [73]:
import numpy as np
import pandas as pd
import time
from typing import List
import plotly.graph_objects as go

# Discrete Fourier Transform (1 pts)

In [46]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [149]:
wav, sr = torchaudio.load('/content/drive/MyDrive/DLA/audio.wav')

In [150]:
wav = wav.squeeze()[:64]

In [151]:
wav.shape

torch.Size([64])

In [49]:
def descrete_fourier_transform(wav: torch.tensor, return_time: bool = False) -> torch.tensor:
    time1 = time.time()
    N = wav.shape[0]
    
    j = torch.complex(torch.tensor([0], dtype=torch.float32), 
                      torch.tensor([1], dtype=torch.float32))
    pi = torch.acos(torch.zeros(1)).item() * 2

    res_matrix = torch.complex(torch.ones([N, N], dtype=torch.float32),
                               torch.zeros([N, N], dtype=torch.float32))                      

    omega_dict = {}

    for row in range(1, N):
      for col in range(1, N):
          n = row * col
          try: 
            omega = omega_dict[n]
          except KeyError:
            omega = torch.exp((-j * 2 * pi) / N) ** n
            omega_dict[n] = omega
          res_matrix[row, col] = omega

    wav_complex = torch.complex(wav, torch.zeros_like(wav))
    time2 = time.time()

    if return_time:
      return torch.inner(res_matrix, wav_complex), time2 - time1
    else: 
      return torch.inner(res_matrix, wav_complex)

In [None]:
dft = descrete_fourier_transform(wav)

### Sanity Check

In [None]:
torch.allclose(dft, torch.fft.fft(wav))

True

# Fast Fourier Transform (3pts)

    A common task for machine learning engineer is to take an paper and implement it.
    So, just do it!
[Tap on me](http://www.robots.ox.ac.uk/~sjrob/Teaching/SP/l7.pdf)
    

In [164]:
def recurcive_dft(N: int, wav, omega_dict: dict, k: int):

  j = torch.complex(torch.tensor([0], dtype=torch.float32), 
                    torch.tensor([1], dtype=torch.float32))
  pi = torch.acos(torch.zeros(1)).item() * 2
  omega = torch.exp(-j * (2 * pi / N))
  omega_full = torch.exp(-j * (pi / N))

  even = torch.complex(torch.tensor([0], dtype=torch.float32), 
                       torch.tensor([0], dtype=torch.float32))
  
  odd = torch.complex(torch.tensor([0], dtype=torch.float32), 
                      torch.tensor([0], dtype=torch.float32))

  for n in range(2 * N):

    if n % 2 == 0:

      if omega_dict[N][k * n] != 0:
        w = omega_dict[N][k * n]

      else:
        w = omega ** (k * n)
        omega_dict[N][k * n] = w
      even = even + (wav[n] * w)

    else:
      odd = odd + (wav[n] * w)
  
  return even + (omega_full**k * odd), omega_dict


def fast_fourier_transform(wav, return_time: bool = False):
  # assume that N is a 2 in some power
  time1 = time.time()
  N = wav.shape[0]
  omega_dict = {}
  omega_dict[N // 2] = np.zeros((2 * N) ** 2, dtype=np.clongdouble)

  assert np.log2(N) / 1 != 0, "N should be a power of 2"

  fft = torch.zeros(N, dtype=torch.complex64)

  for k in range(N):
      res, omega_dict = recurcive_dft(N // 2, wav, omega_dict, k)
      fft[k] = res

  time2 = time.time()

  if return_time:
    return fft, time2 - time1
  else:
    return fft

In [165]:
fft = fast_fourier_transform(wav)


Casting complex values to real discards the imaginary part


Casting complex values to real discards the imaginary part



In [167]:
torch.allclose(fft, torch.fft.fft(wav))

False

In [184]:
def cooley_tukey_fft(x):
    time1 = time.time()
    N = x.shape[0]
    t2 = []
    
    if N == 1:
        return x

    else:
        X_even = cooley_tukey_fft(x[::2])[0]
        X_odd = cooley_tukey_fft(x[1::2])[0]
        factor = np.exp(-2j * np.pi * torch.tensor(np.arange(N)) / N)
        
        X = torch.cat([X_even + factor[:int(N / 2)] * X_odd, X_even + factor[int(N / 2):] * X_odd])
        t2.append(time.time() - time1)
        return X, t2

In [185]:
fft, t = cooley_tukey_fft(wav)

In [175]:
torch.allclose(fft, torch.fft.fft(wav))

True

# A comparison of the performance (1e-7 pts)
    Do pretty images :)

In [186]:
def plot_multi_lines_chart(df: pd.DataFrame, 
                           x_axis_name: str, 
                           y_axis_list_names: List[str]):
    """
    Create dataframe with x column and multiple y-column values. Pass df to the
    function with the column names corresponded to the x and y-axis
    :param df: pandas dataframe with data for x and y axis
    :param x_axis_name: the name of df column common for all data
    :param y_axis_list_names: list of df column names with multiple y-values
    :return: None
    """
    fig = go.Figure()
    for y in y_axis_list_names:
        fig.add_trace(go.Scatter(x=df[x_axis_name], y=df[y], name=y))
    fig.show()


def create_plot():
  dft_time = []
  fft_time = []

  wav, sr = torchaudio.load('/content/drive/MyDrive/DLA/audio.wav')
  for i in [2, 128, 256, 512, 1024]:
    dft_time.append(descrete_fourier_transform(wav.squeeze()[:i], True)[1])
    fft_time.append(max(cooley_tukey_fft(wav.squeeze()[:i])[1]))

  df = pd.DataFrame({'wav_length': [2, 128, 256, 512, 1024], 
                     'dft_time': dft_time,
                     'fft_time': fft_time})
  
  plot_multi_lines_chart(df, 'wav_length', ['dft_time', 'fft_time'])


In [187]:
create_plot()

# Short-time Fourier Transform (2 pts)

    Use torch.hann_window

In [None]:
# TODO

# MelScale (2 pts)

[Tap on me](http://practicalcryptography.com/miscellaneous/machine-learning/guide-mel-frequency-cepstral-coefficients-mfccs/)

In [None]:
# TODO

# Digit classification (5 pts)

    1) Download data from google drive: https://drive.google.com/file/d/1ouSOru91p-ZJCyI6E8cGh7N0r3vffi06/view?usp=sharing
    
    2) Split data in 80/20 proportion. Please note that both the train and the test
    must contain all types of digits and all speakers, so carefully split the data.
    
    3) The AudioMNIST dataset1 consists of 30000 audio recordings (9.5 hours) 
    of spoken digits (0-9) in English with 50 repetitions per digit for each of the 60 different speakers.
    
    4) Build a classificator of spoken digits. You can use any neural network architecture you like.
        The minimum required quality of classificator will be announced.
    
    5) Each wavfile has the following format: digit_speackerid_wavid.wav
        For example, 6_01_47.wav:
            6 -- the number 6 is spoken
            01 -- the number is spoken by 1 speaker
            47 -- id of wavfile        

    Bonus:
        If you implement a good model or use some augmentation (or something else),
        you can expect to obtain bonuses of up to 3 points.