In [14]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib ipympl

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rcParams["figure.figsize"] = (12,4)

In [3]:
import torch
from loading_data import TechnoDataset

dataset = TechnoDataset("../../Victor BIGAND/TECHNO/techno_resampled.dat")

valid_ratio = 0.9995
# Load the dataset for the training/validation sets
train_valid_dataset =  dataset
# Split it into training and validation sets
nb_train = int((1.0 - valid_ratio) * len(train_valid_dataset) +1)
nb_valid =  int(valid_ratio * len(train_valid_dataset))

train_dataset, valid_dataset = torch.utils.data.dataset.random_split(train_valid_dataset, [nb_train, nb_valid])


# Prepare 
num_threads = 0     # Loading the dataset is using 4 CPU threads
batch_size  = 2   # Using minibatches of 128 samples
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_threads)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_threads)

In [62]:
import math as m
for i, x in enumerate(train_loader) :
    Nsig = x.shape[2]
    Nwin = 1024
    Nfft = Nwin
    Nhop = int(0.25*Nwin)

    L = int((Nsig - Nwin + Nhop)/Nhop)
    if Nfft & 2 == 0 : #si pair
        I = int(Nfft/2 + 1)
    else: #si impair
        I = int((Nfft-1)/2 + 1)

    print(I,L)

    STFT = torch.stft(x[i,0,:], n_fft=Nfft, win_length=Nwin, window=torch.hamming_window(Nwin), hop_length=Nhop, return_complex=True, center=False)

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    img = ax1.imshow(np.abs(STFT),
        #cmap=cmap ,
        interpolation = "bilinear",
        aspect="auto" ,
        origin="lower")

    fig.tight_layout()

    print(STFT.shape)
    break

32768


In [None]:
from torch import nn

def Spectral_Loss(x, y, Nfft=512, Nwin=512, window="Hamming", Nhop = Nwin//4, epsilon = 1, device = "cpu"):
    Nsig = x.shape[2]
    batch_size = x.shape[0]

    L = int((Nsig - Nwin + Nhop)/Nhop)

    if Nfft % 2 == 0 : #si pair
        I = int(Nfft/2 + 1)
    else :
        I = int((Nfft-1)/2+1)
        
    if window=="Hamming":
        window = torch.hamming_window(Nwin)
    if window=="Hanning":
        window = torch.hann_window(Nwin)

    lX = torch.zeros([batch_size, I, L], dtype=torch.cfloat)
    lY = torch.zeros([batch_size, I, L], dtype=torch.cfloat)

    for i, x_sample in enumerate(x) :
        X_xample = torch.stft(x_sample, Nfft, win_length=Nwin, window=window, hop_length=Nhop, return_complex=True, center=False) #, center=True, pad_mode='reflect', normalized=False, onesided=None, return_complex=None
        lX_sample = torch.log(torch.abs(X_xample)**2 + epsilon)
        lX[i] = lX_sample

    for i, y_sample in enumerate(y):
        Y_xample = torch.stft(y_sample, Nfft, win_length=Nwin, window=window, hop_length=Nhop, return_complex=True, center=False) #, center=True, pad_mode='reflect', normalized=False, onesided=None, return_complex=None
        lY_sample = torch.log(torch.abs(Y_xample)**2 + epsilon)
        lY[i] = lY_sample


    recons_criterion = nn.L1Loss(reduction="none")
    spectral_loss = np.sum(np.mean(recons_criterion(lX,lY),axis=0))

    return spectral_loss