In [1]:
import torch
import torch as th
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torchmetrics import ScaleInvariantSignalNoiseRatio

from Modules import STFT
from model import Unet_model
from data import my_dataset

In [2]:
data_train = my_dataset('./beam_data_v1/train')

fs = 16000  # Sampling frequency
window_length = int(0.025 * fs)  # 25 ms window length
window_shift = int(0.01 * fs)  # 10 ms window shift
stft = STFT(n_fft = window_length, hop=window_shift)

train_target, train_sample, train_noise = data_train[0]

In [3]:
# поиск опорного канала
spec_sample = stft.stft(train_sample[None])
power_spec = spec_sample.abs()**2

mean_spec_power = power_spec.mean(dim=(-1, -2)) # беру среднее по F и по T
ref_channel = torch.argmax(mean_spec_power, dim=1) # индекс канала, где в среднем самый мощный сигнал
ref_sample_mic = spec_sample[:,ref_channel] # референс канал

In [4]:
## вся обработка и получение итогового сигнала
def pipeline(model, spec):
    magn = spec.abs()
    phase = th.atan2(spec.imag, spec.real)
    mask = model(torch.log(magn+1e-5))
    mag_predict = magn * mask
    imag = mag_predict * th.sin(phase)
    real = mag_predict * th.cos(phase)
    z = th.complex(real, imag)
    return z

In [5]:
## init train
epochs = 1000
model = Unet_model()

optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optim, T_0=40
        )
si = ScaleInvariantSignalNoiseRatio()
losses = []
metrics = []

In [7]:
for epoch in tqdm(range(epochs)):
    optim.zero_grad()
    
    z = pipeline(model, ref_sample_mic)
    wave_predict = stft.istft(z, train_sample.shape[-1])
    loss = F.mse_loss(wave_predict, train_target[None])
    
    loss.backward()
    optim.step()
    
    metrics.append(si(wave_predict.detach(), train_target[None]))
    losses.append(loss.item())
    #print("Epoch [{}/{}], loss: {:.4f}, metric: {:.4f}".format(epoch+1, epochs, losses[-1], metrics[-1]))

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:53<00:00,  8.81it/s]
