# Train model with noisy signal - PyTorch version

This notebook is based on Mauri AG1LE's work back in 2015 that can be found on Github [here](https://github.com/ag1le/RNN-Morse). From an audio signal generated at 8 kHz sample rate (thus in 4 kHz bandwidth) it attempts to recognize Morse code distinctive features in the signal that is:

  - The envelope
  - The "dit"
  - The "dah"
  - The element separator at the end of a "dit" or a "dah"
  - The character separator
  - The word separator
  
It trains a LSTM based recurrent neural network (RNN) on a slightly noisy signal (a few dB SNR in 4 kHz bandwidth) in an encoder-decoder fashion. The envelope of the signal is taken as input as a time series of floating point values and the labels are also time series of the 6 signals described above.

It then attempts prediction on a much noisier signal of the same test data to see how it can perform in retrieving the 6 predicted signals and reformat the original envelope.

This is the PyTorch version.

In [None]:
!pip install sounddevice torchinfo

In [None]:
!sudo apt-get install libportaudio2

## Generate annotated raw signal

In [None]:
import MorseGen
import matplotlib.pyplot as plt 

#phrase = '01234 6789 QUICK BROWN FOX 01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX01234 6789 QUICK BROWN FOX 01234 6789 QUICK BROWN FOX'
#phrase = '7U7K 0DC55B H ZN0J Q9 H2X0 LZ16A ECA2DE 6A2 NUPU 67IL6EIH YVZA 5OTGC3U C3R PGW RS0 84QTV4PB EZ1 JBGJ TT1W4M5PBJ GZVLWXQG 7POU6 FMTXA N3CZ Y1Q9VZ6 9TVL CWP8KSB'
phrase = '6 WREB W7UU QNWXS2 3KRO72Q AN1TI QZIWH G L0U7 Y17X45 OVIC2 C052W00PI60 O5Y 10R2N 4 FHC JXRGS4 DWBOL7ZUXJU EMNC3 WWBNT7 0UP GMKQ YG83H8 IT2Q Y0YBZ SQ80I5 W7SW 0K BMJ8JPM 51CK1 R08T 7SU1LYS7W6T 4JKVQF V3G UU2O1OM4 P4B 4A9DLC VI1H 4 HMP57 Q6G3 4QADIG FRJ 0MVL EPSM CS N9IZEMA GSRWUPBYB FD29 YI3PY N31W X88NS 773EW4Q4 LSW'
Fs = 8000
morse_gen = MorseGen.Morse()
samples_per_dit = morse_gen.nb_samples_per_dit(Fs, 13)
n_prev = int((samples_per_dit/128)*12) + 1 # number of samples to look back is slightly more than a dit-dah and a word space (2+3+7=12)
print(f'Samples per dit at {Fs} Hz is {samples_per_dit}. Decimation is {samples_per_dit/128:.2f}. Look back is {n_prev}.')
label_df = morse_gen.encode_df(phrase, samples_per_dit)
# keep just the envelope
label_df.drop(columns=['dit','dah', 'ele', 'chr', 'wrd'], inplace=True)
print(label_df.shape)
plt.figure(figsize=(50,5))
x = 210000
y = 300000
plt.plot(label_df[x:y].env*0.9 + 0.0, label='env')
plt.title("labels")
plt.legend()
plt.grid()

### Audio

In [None]:
import numpy as np

Fc = 600
SNR_dB = -3
SNR_linear = 10.0**(SNR_dB/10.0)
t = np.linspace(0, len(label_df)-1, len(label_df))
cw = np.sin((Fc/Fs)*2*np.pi*t)
morsecode = cw * label_df.env
power = morsecode.var()
noise_power = power/SNR_linear
noise = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
signal = morsecode + noise
print(len(signal))

plt.figure(figsize=[25,5])
plt.plot(signal[210000:300000])

### Envelope

In [None]:
plt.figure(figsize=(25,2))
plt.plot(label_df[x:y].env)
plt.title("envelope")

### Plot spectrogram

In [None]:
plt.figure(figsize=(25,10))
plt.ylabel("Frequency (Hz)")
plt.xlabel("Time (seconds)")
_ = plt.specgram(signal[0:30*Fs], NFFT=256, Fs=Fs, Fc=0, mode='magnitude', noverlap=128)

## Generate final annotated data
### Find peak

In [None]:
import MorseDSP

maxtab, f, s = MorseDSP.find_peak(Fs, signal)
tone = maxtab[0,0]
plt.title("Morse signal peak found at {} Hz".format(tone))
plt.xlabel("Frequency (Hz)")
plt.ylabel("Amplitude (log)")
plt.yscale('log')
_ = plt.plot(f[0:int(len(f)/2-1)], abs(s[0:int(len(s)/2-1)]),'g-')
_ = plt.scatter(maxtab[:,0], maxtab[:,1], c='r') 
plt.show()

### Generate image

In [None]:
nside_bins = 1
nfft = 256
f, t, img, noverlap = MorseDSP.specimg(Fs, signal, None, None, tone, nfft, nside_bins)
decim = nfft - noverlap
print(type(signal), signal.shape)
print(type(f), f.shape)
print(type(t), t.shape, max(t))
print(type(img), img.shape)
print(noverlap, len(signal)//noverlap, decim)
# Show first 25 seconds at most
rmax = 25 / max(t) if max(t) > 25 else 25
imax = int(rmax*len(t))
t1 = t[:imax]
img1 = img[:,:imax]
plt.figure(figsize=(30,3))
plt.pcolormesh(t1, f, img1, shading='flat', cmap=plt.get_cmap('binary'))
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.show()  

### Generate spectral line

In [None]:
plt.figure(figsize=(30,3))
plt.plot(img[nside_bins-1][:1500], label="-1")
plt.plot(img[nside_bins][:1500], label="0")
plt.plot(img[nside_bins+1][:1500], label="+1")
plt.legend()
plt.show()  

In [None]:
import numpy as np

img_line = np.sum(img, axis=0)
img_line /= max(img_line)
print(img_line.shape)
plt.figure(figsize=(25,5))
plt.plot(img_line[:1500], label="lin")
plt.legend()
plt.grid()

## Create model

Let's create the model now so we have an idea of its inputs and outputs

In [None]:
import torch
import torch.nn as nn

class MorseEnvLSTM(nn.Module):
    """
    Initial implementation
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(1, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(1, 1, self.hidden_layer_size).to(device),
            torch.zeros(1, 1, self.hidden_layer_size).to(device)
        )        
    
class MorseEnvLSTM2(nn.Module):
    """
    LSTM stack
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6, dropout=0.2):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size, num_layers=2, dropout=dropout)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(2, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(2, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(2, 1, self.hidden_layer_size).to(device),
            torch.zeros(2, 1, self.hidden_layer_size).to(device)
        )        
        
class MorseEnvNoHLSTM(nn.Module):
    """
    Do not keep hidden cell
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        h0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        c0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        lstm_out, _ = self.lstm(input_seq.view(len(input_seq), 1, -1), (h0, c0))
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
class MorseEnvBiLSTM(nn.Module):
    """
    Attempt Bidirectional LSTM: does not work
    """
    def __init__(self, device, input_size=1, hidden_size=12, num_layers=1, num_classes=6):
        super(MorseEnvBiLSTM, self).__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x.view(len(x), 1, -1), (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out[-1]    

Create the model instance and print the details

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#morse_env_model = MorseEnvLSTM(device, hidden_layer_size=5, output_size=1).to(device) # This is the only way to get things work properly with device
morse_env_model = MorseEnvLSTM2(device, hidden_layer_size=5, output_size=1).to(device)
morse_env_loss_function = nn.MSELoss()
morse_env_optimizer = torch.optim.Adam(morse_env_model.parameters(), lr=0.001)

print(morse_env_model)
print(morse_env_model.device)

In [None]:
# Input and hidden tensors are not at the same device, found input tensor at cuda:0 and hidden tensor at cpu
for m in morse_env_model.parameters():
    print(m.shape, m.device)
#X_t = torch.rand((48, 1))
X_t = torch.tensor([-0.9648, -0.9385, -0.8769, -0.8901, -0.9253, -0.8637, -0.8066, -0.8066, -0.8593, -0.9341, -1.0000, -0.9385])
X_t = X_t.cuda()
print(X_t)
morse_env_model(X_t)

In [None]:
import torchinfo
channels=10
H=n_prev
W=1
torchinfo.summary(morse_env_model, input_size=(channels, H, W))

## Generate training data
### Raw data

In [None]:
sig = img_line
sig /= max(img_line)
labels = label_df[::decim] # decimate labels by spectrum decimation
labels.reset_index(drop=True, inplace=True)
labels = labels.truncate(after=len(sig)-1, copy=False)
print(type(labels), type(sig), labels.shape, sig.shape, len(labels), len(sig))

In [None]:
x = 0
y = imax
print(x, y)
plt.figure(figsize=(30,2))
plt.plot(sig[x:y]*0.9 + 0.0, label="sig_X")
plt.plot(labels[x:y].env*0.9 + 1.0, label="env_y")
plt.title("image line and labels")
plt.grid()
plt.legend()

### Format data for PyTorch 
With training and test data split

In [None]:
# Values for train / test split
test_ratio = 0.5
n_trn = round(len(labels) * (1 - test_ratio))

In [None]:
def pytorch_rolling_window(x, window_size, step_size=1):
    # unfold dimension to make our rolling window
    return x.unfold(0,window_size,step_size)

X_train = pytorch_rolling_window(torch.FloatTensor(sig[:n_trn]), n_prev, 1).to(device)
y_train = torch.FloatTensor(labels.iloc[n_prev:n_trn+1].values).to(device)
print(X_train.shape, y_train.shape)
print(X_train)
print(y_train)
print(X_train[0].shape, y_train[0].shape)
X_test = pytorch_rolling_window(torch.FloatTensor(sig[n_trn:-1]), n_prev, 1).to(device)
y_test = torch.FloatTensor(labels.iloc[n_trn+n_prev:].values).to(device)
print(X_test.shape, y_test.shape)
# make sure it works
y_pred = morse_env_model(X_train[0])
y_pred

In [None]:
# Move data to CPU for visualization
X_train_v = X_train.cpu()
y_train_v = y_train.cpu()
X_test_v = X_test.cpu()
y_test_v = y_test.cpu()

# Input (noisy) data for visualization
l_train = sig[:n_trn+n_prev]
l_test = sig[n_trn+n_prev:]

In [None]:
a = []
b = []
for t in range(5):
    a.append(X_test_v[t*n_prev])
    b.append(X_train_v[t*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate((tuple(a)))*0.5, label='test')
plt.plot(np.concatenate((tuple(b)))*0.5+0.5, label='train')
plt.title("Train and test")
plt.legend()
plt.grid()

In [None]:
a = []
for i in range(5):
    a.append(X_test_v[i*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate(tuple(a)), label='X_test')
plt.plot(l_test[:5*n_prev]+1.0, label='line')
plt.plot(y_test_v[:5*n_prev,0]+2.0, label='y_test')
plt.title("Test")
plt.legend()
plt.grid()

## Train model

In [None]:
%%time
morse_env_model.load_state_dict(torch.load('models/morse_env_model_lstm2_02'))

### Predict (test)

In [None]:
%%time
p_sig_l = []
morse_env_model.eval()

for i in range(len(X_test)):
    with torch.no_grad():
        pred_val = morse_env_model(X_test[i]).cpu()
        p_sig_l.append(pred_val[0].item())
        
p_sig = np.array(p_sig_l)

# trim negative values
p_sig[p_sig < 0] = 0

In [None]:
plt.figure(figsize=(30,2))
plt.plot(y_test_v[:y,0]*0.9, label="y0")
plt.plot(p_sig[:y]*0.9 + 1.0, label="sig")
plt.title("Predictions")
plt.legend()
plt.grid()
plt.savefig('img/pred.png')

In [None]:
sig = p_sig[:y]
sig = (sig - min(sig)) / (max(sig) - min(sig))
mor = y_test_v[:y,0]
mor = (mor - min(mor)) / (max(mor) - min(mor))
plt.figure(figsize=(30,5))
plt.plot(sig, label="sig")
plt.plot(l_test[:y] + 1.0, label="inp")
plt.title("predicted signal modulation")
plt.legend()
plt.grid()

In [None]:
from scipy.io import wavfile

Fcode = 600
Fs = 8000
emod = sig
emod /= max(emod)
remod = np.array([[x]*noverlap for x in emod]).flatten()
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(remod))*wt)
wavfile.write('audio/re.wav', Fs, tone*remod)
plt.figure(figsize=(25,5))
plt.plot(tone*remod)
plt.title("reconstructed signal")
plt.grid()
# .4QTV4PB EZ1 JBGJ TT1W4M...
# 7U7K 0DC55B H ZN0J Q9 H2X0 LZ16A ECA2DE 6A2 NUPU 67IL6EIH YVZA 5OTGC3U C3R PGW RS0 84QTV4PB EZ1 JBGJ TT1W4M5PBJ GZVLWXQG 7POU6 FMTXA N3CZ Y1Q9VZ6 9TVL CWP8KSB'

In [None]:
omod = l_test[:y]
omod / max(omod)
orig_mod = np.array([[x]*decim for x in omod]).flatten()
wavfile.write('audio/or.wav', Fs, tone*orig_mod)
plt.figure(figsize=(25,5))
plt.plot(tone*orig_mod)
plt.title("original filtered signal")
plt.grid()

## Make new predictions

In [None]:
SNR_dB = -15
SNR_linear = 10.0**(SNR_dB/10.0)
noise_power = power/SNR_linear
noise = np.sqrt(noise_power)*np.random.normal(0, 1, len(morsecode))
signal1 = morsecode + noise

### Find peak

In [None]:
maxtab, f, s = MorseDSP.find_peak(Fs, signal1)
tone = maxtab[0,0]
plt.title("Morse signal peak found at {} Hz".format(tone))
plt.xlabel("Frequency (Hz)")
plt.ylabel("Amplitude (log)")
plt.yscale('log')
_ = plt.plot(f[0:int(len(f)/2-1)], abs(s[0:int(len(s)/2-1)]),'g-')
_ = plt.scatter(maxtab[:,0], maxtab[:,1], c='r') 
plt.show()

### Generate image

In [None]:
nside_bins = 1
nfft = 256
f, t, img, noverlap = MorseDSP.specimg(Fs, signal1, None, None, tone, nfft, nside_bins)
decim = nfft - noverlap
print(type(signal1), signal1.shape)
print(type(f), f.shape)
print(type(t), t.shape, max(t))
print(type(img), img.shape)
print(noverlap, len(signal1)//noverlap, decim)
# Show first 25 seconds at most
rmax = 25 / max(t) if max(t) > 25 else 25
imax = int(rmax*len(t))
t1 = t[:imax]
img1 = img[:,:imax]
plt.figure(figsize=(30,3))
plt.pcolormesh(t1, f, img1, shading='flat', cmap=plt.get_cmap('binary'))
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.show()  

### Generate spectral line

In [None]:
plt.figure(figsize=(30,3))
plt.plot(img[nside_bins-1][:1500], label="-1")
plt.plot(img[nside_bins][:1500], label="0")
plt.plot(img[nside_bins+1][:1500], label="+1")
plt.legend()
plt.show()

In [None]:
img_line = img[nside_bins] #np.sum(img, axis=0)
img_line /= max(img_line)
print(img_line.shape)
plt.figure(figsize=(30,3))
plt.plot(img_line[:1500], label="lin")
plt.legend()
plt.show()

## Generate training data (new prediction)
### Raw data

In [None]:
sig = img_line
sig /= max(img_line)
labels = label_df[::decim] # decimate labels by spectrum decimation
labels.reset_index(drop=True, inplace=True)
labels = labels.truncate(after=len(sig)-1, copy=False)
print(type(labels), type(sig), labels.shape, sig.shape, len(labels), len(sig))

In [None]:
x = 0
y = imax
print(x, y)
plt.figure(figsize=(30,2))
plt.plot(sig[x:y]*0.9 + 0.0, label="sig_X")
plt.plot(labels[x:y].env*0.9 + 1.0, label="env_y")
plt.title("image line and labels")
plt.grid()
plt.legend()

### Format new data for PyTorch 
Same labels and y but new X

In [None]:
# Reuse optimized formatting for X pnly
X_train = pytorch_rolling_window(torch.FloatTensor(sig[:n_trn]), n_prev, 1).to(device)
X_test = pytorch_rolling_window(torch.FloatTensor(sig[n_trn:-1]), n_prev, 1).to(device)
# make sure it works
y_pred = morse_env_model(X_train[0])
y_pred

In [None]:
# Move data to CPU for visualization
X_train_v = X_train.cpu()
X_test_v = X_test.cpu()

# Input (noisy) data for visualization
l_train = sig[:n_trn+n_prev]
l_test = sig[n_trn+n_prev:]

In [None]:
a = []
b = []
for t in range(5):
    a.append(X_test_v[t*n_prev])
    b.append(X_train_v[t*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate((tuple(a)))*0.5, label='test')
plt.plot(np.concatenate((tuple(b)))*0.5+0.5, label='train')
plt.title("Train and test")
plt.legend()
plt.grid()

In [None]:
a = []
for i in range(5):
    a.append(X_test_v[i*n_prev])
plt.figure(figsize=(25,3))
plt.plot(np.concatenate(tuple(a)), label='X_test')
plt.plot(l_test[:5*n_prev]+1.0, label='line')
plt.plot(y_test_v[:5*n_prev,0]+2.0, label='y_test')
plt.title("Test")
plt.legend()
plt.grid()

## Predict (new data)

In [None]:
%%time
p_sig_l = []
morse_env_model.eval()

for i in range(len(X_test)):
    with torch.no_grad():
        pred_val = morse_env_model(X_test[i]).cpu()
        p_sig_l.append(pred_val[0].item())
        
p_sig = np.array(p_sig_l)

# trim negative values
p_sig[p_sig < 0] = 0

In [None]:
plt.figure(figsize=(30,2))
plt.plot(y_test_v[:y,0]*0.9, label="y0")
plt.plot(p_sig[:y]*0.9 + 1.0, label="sig")
plt.title("Predictions")
plt.legend()
plt.grid()
plt.savefig('img/pred.png')

In [None]:
sig = p_sig[:y]
sig = (sig - min(sig)) / (max(sig) - min(sig))
mor = y_test_v[:y,0]
mor = (mor - min(mor)) / (max(mor) - min(mor))
plt.figure(figsize=(30,6))
plt.plot(sig, label="sig")
plt.plot(mor*1.2, label="mor")
plt.title("predicted signal modulation")
plt.grid()
plt.legend()

In [None]:
import scipy as sp

#omod = np.array([sp.special.expit(12*(x-0.3)) for x in l_test[:y]])
omod = l_test[:y]
orig_mod = np.array([[x]*decim for x in omod]).flatten()
orig_mod /= max(orig_mod)
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(orig_mod))*wt)
wavfile.write('audio/or1.wav', Fs, tone*orig_mod)
ref_mod = np.array([[x]*decim for x in mor]).flatten()
plt.figure(figsize=(25,5))
plt.plot(tone*orig_mod, label='mod')
plt.plot(ref_mod*1.2, label='mor')
plt.title("original filtered signal")
plt.legend()
plt.grid()

In [None]:
import scipy as sp

# def modscale(x):
#     return sp.special.expit(20*(x-0.28))
    
emod = sig
# emod[emod < 0.12] = 0 # eliminate bias
# emod = np.array([sp.special.expit(20*(x-0.15)) for x in sig])
emod /= max(emod)
#emod = modn
remod = np.array([[x]*decim for x in emod]).flatten()
remor = np.array([[x]*decim for x in mor]).flatten()
wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(remod))*wt)
wavfile.write('audio/re1.wav', Fs, tone*remod)
plt.figure(figsize=(50,10))
plt.plot(tone*remod, label='filt')
plt.plot(remor*1.2, label='omod')
plt.title("reconstructed signal")
plt.legend()
plt.grid()

In [None]:
import scipy as sp

sx = np.linspace(0, 1, 121)
sy = sp.special.expit(10*(sx-0.15))
plt.plot(sx, sy)
plt.grid()
plt.xlabel('x')
plt.title('expit(x)')
plt.show()