# Train model with noisy envelope - filter with dual input version

Starting from `RNN-Morse-filter` a pure noise channel is added along the noisy signal channel in an attempt for the network to better identify silence periods.

The formatting of X data is changed to take into account the second dimensionality of the time series (samples). This shows an example of dealing with 2D time series.

In [None]:
!pip install sounddevice torchinfo

## Generate annotated raw signal

Generates the envelope after audio preprocessing. The resulting decimation factor is 128 thus we will take 1 every 128 samples from the original signal modulated at 8 kHz sample rate. This uses a modified version of `encode_df` (`encode_df_decim`) of `MorseGen` thus the original ratio in samples per dit is respected. This effectively takes a floating point ratio (shown in display) for the samples per dit decimation (about 5.77 for the nominal values of 8 kHz sampling rate and 13 WPM Morse code speed) 

In [None]:
import MorseGen
import matplotlib.pyplot as plt 

#phrase = '01234 6789 QUICK BROWN FOX 01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX 01234 6789 QUICK BROWN FOX'
#phrase = '7U7K 0DC55B H ZN0J Q9 H2X0 LZ16A ECA2DE 6A2 NUPU 67IL6EIH YVZA 5OTGC3U C3R PGW RS0 84QTV4PB EZ1 JBGJ TT1W4M5PBJ GZVLWXQG 7POU6 FMTXA N3CZ Y1Q9VZ6 9TVL CWP8KSB'
phrase = '6 WREB W7UU QNWXS2 3KRO72Q AN1TI QZIWH G L0U7 Y17X45 OVIC2 C052W00PI60 O5Y 10R2N 4 FHC JXRGS4 DWBOL7ZUXJU EMNC3 WWBNT7 0UP GMKQ YG83H8 IT2Q Y0YBZ SQ80I5 W7SW 0K BMJ8JPM 51CK1 R08T 7SU1LYS7W6T 4JKVQF V3G UU2O1OM4 P4B 4A9DLC VI1H 4 HMP57 Q6G3 4QADIG FRJ 0MVL EPSM CS N9IZEMA GSRWUPBYB FD29 YI3PY N31W X88NS 773EW4Q4 LSW'
Fs = 8000
morse_gen = MorseGen.Morse()
samples_per_dit = morse_gen.nb_samples_per_dit(Fs, 13)
n_prev = int((samples_per_dit/128)*12) + 1 # number of samples to look back is slightly more than a dit-dah and a word space (2+3+7=12)
print(f'Samples per dit at {Fs} Hz is {samples_per_dit}. Decimation is {samples_per_dit/128:.2f}. Look back is {n_prev}.')
label_df = morse_gen.encode_df_decim(phrase, samples_per_dit, 128)
# keep just the envelope
label_df.drop(columns=['dit','dah', 'ele', 'chr', 'wrd'], inplace=True)
print(label_df.shape)
plt.figure(figsize=(50,5))
x = 0
y = 1500
plt.plot(label_df[x:y].env*0.9 + 0.0, label='env')
plt.title("labels")
plt.legend()
plt.grid()

### Envelope

The SNR must be calculated in the FFT bin bandwidth. In the original `RNN-Morse-pytorch` notebook the bandwidth is 4 kHz / 256 = 15,625 Hz and SNR is 3 dB. Theoretically you would apply the FFT ratio to the original SNR but this does not work in practice. You have to take a much lower SNR to obtain a similar envelope.

In [None]:
import numpy as np

SNR_dB = -25
SNR_linear = 10.0**(SNR_dB/10.0)
SNR_linear *= 256 # Apply original FFT
print(f'Resulting SNR for original {SNR_dB} dB is {(10.0 * np.log10(SNR_linear)):.2f} dB')
t = np.linspace(0, len(label_df)-1, len(label_df))
morsecode = label_df.env
power = morsecode.var()
noise_power = power/SNR_linear
noise0 = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
noise1 = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
signal = morsecode + noise0
print(len(signal))

plt.figure(figsize=[25,5])
plt.plot(signal[x:y] + 2.0, label='s+n0')
plt.plot(noise0[x:y], label='n0')
plt.plot(noise1[x:y], label='n1')
plt.grid()
plt.legend()

## Create model

Let's create the model now so we have an idea of its inputs and outputs

In [None]:
import torch
import torch.nn as nn

class MorseEnvLSTM(nn.Module):
    """
    Initial implementation
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(1, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(1, 1, self.hidden_layer_size).to(device),
            torch.zeros(1, 1, self.hidden_layer_size).to(device)
        )        
    
class MorseEnvLSTM2(nn.Module):
    """
    LSTM stack
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6, dropout=0.2):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size, num_layers=2, dropout=dropout)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(2, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(2, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(2, 1, self.hidden_layer_size).to(device),
            torch.zeros(2, 1, self.hidden_layer_size).to(device)
        )        
        
class MorseEnvNoHLSTM(nn.Module):
    """
    Do not keep hidden cell
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        h0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        c0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        lstm_out, _ = self.lstm(input_seq.view(len(input_seq), 1, -1), (h0, c0))
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
class MorseEnvBiLSTM(nn.Module):
    """
    Attempt Bidirectional LSTM: does not work
    """
    def __init__(self, device, input_size=1, hidden_size=12, num_layers=1, num_classes=6):
        super(MorseEnvBiLSTM, self).__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x.view(len(x), 1, -1), (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out[-1]    

Create the model instance and print the details

In [None]:
# Hidden layers:
# 4: good at reconstructing signal, some post-processing necessary for dit/dah, word silence is weak and undistinguishable from character silence 
# 5: fairly good at reconstructing signal, all signals distinguishable with some post-processing for dit/dah
# 6: more contrast on all signals but a spike appears in the character space in predicted envelope
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
morse_env_model = MorseEnvLSTM(device, input_size=2, hidden_layer_size=6, output_size=1).to(device) # This is the only way to get things work properly with device
morse_env_loss_function = nn.MSELoss()
morse_env_optimizer = torch.optim.Adam(morse_env_model.parameters(), lr=0.001)

print(morse_env_model)
print(morse_env_model.device)

In [None]:
# Input and hidden tensors are not at the same device, found input tensor at cuda:0 and hidden tensor at cpu
for m in morse_env_model.parameters():
    print(m.shape, m.device)
X_t = torch.rand((12, 2))
X_t = X_t.to(device)
print(X_t)
morse_env_model(X_t)

In [None]:
import torchinfo
channels=10
H=n_prev
W=1
torchinfo.summary(morse_env_model, input_size=(channels, H, W))

## Generate training data
### Raw data

In [None]:
sig = signal.to_numpy()
sig /= max(sig)
labels = label_df
labels = labels.truncate(after=len(sig)-1, copy=False)
print(type(labels), labels.shape, type(sig), sig.shape, type(noise1), noise1.shape)
plt.figure(figsize=[25,5])
plt.plot(sig[x:y])
plt.title("Signal (X0)")
plt.grid()

In [None]:
plt.figure(figsize=(50,6))
plt.plot(sig[x:y]*0.9 + 0.0, label="X0")
plt.plot(labels[x:y].env*0.9 + 1.0, label="env_y")
plt.title("image line and labels")
plt.grid()
plt.legend()

### Format data for PyTorch 
With training and test data split

In [None]:
# train / test values for splitting
test_ratio = 0.5
n_trn = round(len(labels) * (1 - test_ratio))
print(n_trn)

Result are distinct tensors of input tensors and output tensors directly moved to device (GPU if this is the case)

In [None]:
def pytorch_rolling_window(x, window_size, step_size=1):
    # unfold dimension to make our rolling window
    return x.unfold(0, window_size, step_size).transpose(2,1)

X = np.vstack((sig, noise1)).T
X_train = pytorch_rolling_window(torch.FloatTensor(X[:n_trn]), n_prev, 1).to(device)
y_train = torch.FloatTensor(labels.iloc[n_prev:n_trn+1].values).to(device)
print("Train shapes", X_train.shape, y_train.shape)
print("X train\n", X_train)
print("y_train\n", y_train)
print("train[0] shapes", X_train[0].shape, y_train[0].shape)
X_test = pytorch_rolling_window(torch.FloatTensor(X[n_trn:-1]), n_prev, 1).to(device)
y_test = torch.FloatTensor(labels.iloc[n_trn+n_prev:].values).to(device)
print("Test shape", X_test.shape, y_test.shape)
# make sure it works
y_pred = morse_env_model(X_train[0])
print("y_pred\n", y_pred)

In [None]:
# Move data to CPU for visualization
X_train_v = X_train.cpu()
y_train_v = y_train.cpu()
X_test_v = X_test.cpu()
y_test_v = y_test.cpu()

# Input (noisy) data for visualization
l_train = sig[:n_trn+n_prev]
l_test = sig[n_trn+n_prev:]

In [None]:
a = []
b = []
for t in range(5):
    a.append(X_test_v[t*n_prev])
    b.append(X_train_v[t*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate((tuple(a)))*0.5, label='test')
plt.plot(np.concatenate((tuple(b)))*0.5+0.5, label='train')
plt.title("Train and test")
plt.legend()
plt.grid()

In [None]:
a = []
for i in range(5):
    a.append(X_test_v[i*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate(tuple(a)), label='X_test')
plt.plot(l_test[:5*n_prev]+1.0, label='line')
plt.plot(y_test_v[:5*n_prev,0]+2.0, label='y_test')
plt.title("Test")
plt.legend()
plt.grid()

## Train model

In [None]:
%%time
epochs = 2
morse_env_model.train()

for i in range(epochs):
    for j in range (len(X_train)):
        morse_env_optimizer.zero_grad()
        if morse_env_model.__class__.__name__ in ["MorseEnvLSTM", "MorseEnvLSTM2"]:
            morse_env_model.zero_hidden_cell() # this model needs to reset the hidden cell
        y_pred = morse_env_model(X_train[j])
        single_loss = morse_env_loss_function(y_pred, y_train[j])
        single_loss.backward()
        morse_env_optimizer.step()
        if j % 1000 == 0:
            print(f'   train {j}/{len(X_train)} loss: {single_loss.item():10.8f}')
    print(f'epoch: {i+1:3} loss: {single_loss.item():10.8f}')

print(f'final: {i+1:3} epochs loss: {single_loss.item():10.10f}')

In [None]:
torch.save(morse_env_model.state_dict(), 'models/morse_env_model')

### Predict (test)

In [None]:
%%time
p_sig_l = []
morse_env_model.eval()

for i in range(len(X_test)):
    with torch.no_grad():
        pred_val = morse_env_model(X_test[i]).cpu()
        p_sig_l.append(pred_val[0].item())
        
p_sig = np.array(p_sig_l)

# trim negative values
p_sig[p_sig < 0] = 0

In [None]:
plt.figure(figsize=(30,2))
plt.plot(y_test_v[:y,0]*0.9, label="y0")
plt.plot(p_sig[:y]*0.9 + 1.0, label="sig")
plt.title("Predictions")
plt.legend()
plt.grid()
plt.savefig('img/pred.png')

In [None]:
sig = p_sig[:y]
sig = (sig - min(sig)) / (max(sig) - min(sig))
mor = y_test_v[:y,0]
mor = (mor - min(mor)) / (max(mor) - min(mor))
plt.figure(figsize=(30,5))
plt.plot(sig, label="sig")
plt.plot(l_test[:y] + 1.0, label="inp")
plt.title("predicted signal modulation")
plt.legend()
plt.grid()

In [None]:
import scipy as sp
import scipy.special
from scipy.io import wavfile

Fcode = 600
Fs = 8000
noverlap = 128
decim = 128
emod = sig
emod /= max(emod)
remod = np.array([[x]*noverlap for x in emod]).flatten()
mor = y_test_v[:y,0]
mor = (mor - min(mor)) / (max(mor) - min(mor))
ref_mod = np.array([[x]*decim for x in mor]).flatten()
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(remod))*wt)
wavfile.write('audio/re.wav', Fs, tone*remod)
plt.figure(figsize=(25,5))
plt.plot(tone*remod, label='mod')
plt.plot(ref_mod*1.2, label='mor')
plt.title("reconstructed signal")
plt.grid()
plt.legend()

In [None]:
omod = l_test[:y]
omod / max(omod)
orig_mod = np.array([[x]*decim for x in omod]).flatten()
wavfile.write('audio/or.wav', Fs, tone*orig_mod)
plt.figure(figsize=(25,5))
plt.plot(tone*orig_mod, label='ori')
plt.plot(ref_mod*1.2, label='mor')
plt.title("original filtered signal")
plt.grid()
plt.legend()

## Make new predictions

In [None]:
#phrase = '6 WREB W7UU QNWXS2 3KRO72Q AN1TI QZIWH G L0U7 Y17X45 OVIC2 C052W00PI60 O5Y 10R2N 4 FHC JXRGS4 DWBOL7ZUXJU EMNC3 WWBNT7 0UP GMKQ YG83H8 IT2Q Y0YBZ SQ80I5 W7SW 0K BMJ8JPM 51CK1 R08T 7SU1LYS7W6T 4JKVQF V3G UU2O1OM4 P4B 4A9DLC VI1H 4 HMP57 Q6G3 4QADIG FRJ 0MVL EPSM CS N9IZEMA GSRWUPBYB FD29 YI3PY N31W X88NS 773EW4Q4 LSW'
phrase = 'VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV D'
Fs = 8000
morse_gen = MorseGen.Morse()
samples_per_dit = morse_gen.nb_samples_per_dit(Fs, 13)
n_prev = int((samples_per_dit/128)*12) + 1 # number of samples to look back is slightly more than a dit-dah and a word space (2+3+7=12)
print(f'Samples per dit at {Fs} Hz is {samples_per_dit}. Decimation is {samples_per_dit/128:.2f}. Look back is {n_prev}.')
label_df = morse_gen.encode_df_decim(phrase, samples_per_dit, 128)
# keep just the envelope
label_df.drop(columns=['dit','dah', 'ele', 'chr', 'wrd'], inplace=True)
print(label_df.shape)
plt.figure(figsize=(50,5))
x = 0
y = 1500
plt.plot(label_df[x:y].env*0.9 + 0.0, label='env')
plt.title("labels")
plt.legend()
plt.grid()

In [None]:
SNR_dB = -20
SNR_linear = 10.0**(SNR_dB/10.0)
SNR_linear *= 256 # Apply original FFT
print(f'Resulting SNR for original {SNR_dB} dB is {(10.0 * np.log10(SNR_linear)):.2f} dB')
t = np.linspace(0, len(label_df)-1, len(label_df))
morsecode = label_df.env
power = morsecode.var()
noise_power = power/SNR_linear
noise0 = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
noise1 = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
signal = morsecode + noise0
print(len(signal))

plt.figure(figsize=[25,5])
plt.plot(signal[x:y] + 2.0, label='s+n0')
plt.plot(noise0[x:y], label='n0')
plt.plot(noise1[x:y], label='n1')
plt.grid()
plt.legend()

## Generate training data (new prediction)
### Raw data

In [None]:
sig = signal.to_numpy()
sig /= max(sig)
labels = label_df
labels = labels.truncate(after=len(sig)-1, copy=False)
print(type(labels), type(sig), labels.shape, sig.shape, len(labels), len(sig))
plt.figure(figsize=[25,5])
plt.plot(sig[x:y])
plt.title("Signal (X)")
plt.grid()

In [None]:
plt.figure(figsize=(50,2))
plt.plot(sig[x:y]*0.9 + 0.0, label="sig_X")
plt.plot(labels[x:y].env*0.9 + 1.0, label="env_y")
plt.title("image line and labels")
plt.grid()
plt.legend()

### Format new data for PyTorch 
New X and y

In [None]:
X = np.vstack((sig, noise1)).T
X_train = pytorch_rolling_window(torch.FloatTensor(X[:n_trn]), n_prev, 1).to(device)
y_train = torch.FloatTensor(labels.iloc[n_prev:n_trn+1].values).to(device)
X_test = pytorch_rolling_window(torch.FloatTensor(X[n_trn:-1]), n_prev, 1).to(device)
y_test = torch.FloatTensor(labels.iloc[n_trn+n_prev:].values).to(device)
# make sure it works
y_pred = morse_env_model(X_train[0])
y_pred

In [None]:
# Move data to CPU for visualization
X_train_v = X_train.cpu()
y_train_v = y_train.cpu()
X_test_v = X_test.cpu()
y_test_v = y_test.cpu()

# Input (noisy) data for visualization
l_train = sig[:n_trn+n_prev]
l_test = sig[n_trn+n_prev:]

In [None]:
a = []
b = []
for t in range(5):
    a.append(X_test_v[t*n_prev])
    b.append(X_train_v[t*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate((tuple(a)))*0.5, label='test')
plt.plot(np.concatenate((tuple(b)))*0.5+0.5, label='train')
plt.title("Train and test")
plt.legend()
plt.grid()

In [None]:
a = []
for i in range(5):
    a.append(X_test_v[i*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate(tuple(a)), label='X_test')
plt.plot(l_test[:5*n_prev]+1.0, label='line')
plt.plot(y_test_v[:5*n_prev,0]+2.0, label='y_test')
plt.title("Test")
plt.legend()
plt.grid()

## Predict (new data)

In [None]:
%%time
p_sig_l = []
morse_env_model.eval()

for i in range(len(X_test)):
    with torch.no_grad():
        pred_val = morse_env_model(X_test[i]).cpu()
        p_sig_l.append(pred_val[0].item())
        
p_sig = np.array(p_sig_l)

# trim negative values
p_sig[p_sig < 0] = 0

In [None]:
plt.figure(figsize=(30,2))
plt.plot(y_test_v[:y,0]*0.9, label="mor")
plt.plot(p_sig[:y]*0.9 + 1.0, label="sig")
plt.title("Predictions")
plt.legend()
plt.grid()

In [None]:
sig = p_sig[:y]
sig = (sig - min(sig)) / (max(sig) - min(sig))
mor = y_test_v[:y,0]
mor = (mor - min(mor)) / (max(mor) - min(mor))
plt.figure(figsize=(30,6))
plt.plot(sig, label="sig")
plt.plot(mor*1.2, label="mor")
plt.title("predicted signal modulation")
plt.grid()
plt.legend()

In [None]:
import scipy as sp

#omod = np.array([sp.special.expit(12*(x-0.3)) for x in l_test[:y]])
#omod = np.array([sp.special.expit(20*(x-0.18)) for x in l_test[:y]])
omod = l_test[:y]
orig_mod = np.array([[x]*decim for x in omod]).flatten()
orig_mod /= max(orig_mod)
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(orig_mod))*wt)
wavfile.write('audio/or1.wav', Fs, tone*orig_mod)
ref_mod = np.array([[x]*decim for x in mor]).flatten()
plt.figure(figsize=(50,5))
plt.plot(tone*orig_mod, label='mod')
plt.plot(ref_mod*1.2, label='mor')
plt.title("original filtered signal")
plt.legend()
plt.grid()

In [None]:
import scipy as sp

# def modscale(x):
#     return sp.special.expit(20*(x-0.28))
    
#emod = np.array([sp.special.expit(40*(x-0.38)) for x in sig])
emod = sig
emod /= max(emod)
#emod = modn
remod = np.array([[x]*decim for x in emod]).flatten()
remor = np.array([[x]*decim for x in mor]).flatten()
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(remod))*wt)
wavfile.write('audio/re1.wav', Fs, tone*remod)
plt.figure(figsize=(50,5))
plt.plot(tone*remod, label='filt')
plt.plot(remor*1.2, label='omod')
plt.title("reconstructed signal")
plt.legend()
plt.grid()

In [None]:
xs = np.linspace(0, 1, 100)
ys = np.array([sp.special.expit(40*(x-0.38)) for x in xs])
plt.figure(figsize=(30,6))
plt.plot(ys, label="sig")
plt.title("modified sigmoid")
plt.grid()
plt.legend()