# Model with character recognition - hybrid solution

Builds on `RNN-Morse-chars-key`. It appears that only the keying recognition model works more or less. With the post-processing done in `RNN-Morse-chars-key` that essentially consist in windowing by the element separator sense signal and set the dahs sense signal by subtracting of the original dits sense signal (some scaling of the dits sense signal also) we obtain fairly good stimuli that can feed a classical logic based on the positive edge trigger of dits, dahs and character separator signals.  

Here we re-introduce the alphanumeric alphabet (36 characters) as we are not bound by the character model.

It sort of works (the only option giving some results so far) but is impaired by hard decisions. Still missing a second good model for character recognition with DNN.

## Create string

Each character in the alphabet should happen a large enough number of times. As a rule of thumb we will take some multiple of the number of characters in the alphabet. If the multiplier is large enough the probability of each character appearance will be even over the alphabet. 

In [None]:
import MorseGen

morse_gen = MorseGen.Morse()
alphabet = morse_gen.alphabet36
print(132/len(alphabet))

morsestr = MorseGen.get_morse_str(nchars=132*2, nwords=27*2, chars=alphabet)
print(alphabet)
print(len(morsestr), morsestr)

## Generate dataframe and extract envelope

In [None]:
Fs = 8000
samples_per_dit = morse_gen.nb_samples_per_dit(Fs, 13)
n_prev = int((samples_per_dit/128)*19) + 1
print(f'Samples per dit at {Fs} Hz is {samples_per_dit}. Decimation is {samples_per_dit/128:.2f}. Look back is {n_prev}.')
label_df = morse_gen.encode_df_decim_str(morsestr, samples_per_dit, 128, alphabet)
env = label_df['env'].to_numpy()
print(type(env), len(env))

In [None]:
import numpy as np

def get_new_data(morse_gen, SNR_dB=-23, nchars=132, nwords=27, phrase=None, alphabet="ABC"):
    if not phrase:
        phrase = MorseGen.get_morse_str(nchars=nchars, nwords=nwords, chars=alphabet)
    print(len(phrase), phrase)
    Fs = 8000
    samples_per_dit = morse_gen.nb_samples_per_dit(Fs, 13)
    n_prev = int((samples_per_dit/128)*19) + 1 # number of samples to look back is slightly more than a dit-dah and a word space (2+3+7=12)
    print(f'Samples per dit at {Fs} Hz is {samples_per_dit}. Decimation is {samples_per_dit/128:.2f}. Look back is {n_prev}.')
    label_df = morse_gen.encode_df_decim_str(phrase, samples_per_dit, 128, alphabet)
    # extract the envelope
    envelope = label_df['env'].to_numpy()
    # remove the envelope
    label_df.drop(columns=['env'], inplace=True)
    SNR_linear = 10.0**(SNR_dB/10.0)
    SNR_linear *= 256 # Apply original FFT
    print(f'Resulting SNR for original {SNR_dB} dB is {(10.0 * np.log10(SNR_linear)):.2f} dB')
    t = np.linspace(0, len(envelope)-1, len(envelope))
    power = np.sum(envelope**2)/len(envelope)
    noise_power = power/SNR_linear
    noise = np.sqrt(noise_power)*np.random.normal(0, 1, len(envelope))
    # noise = butter_lowpass_filter(raw_noise, 0.9, 3) # Noise is also filtered in the original setup from audio. This empirically simulates it
    signal = (envelope + noise)**2
    signal[signal > 1.0] = 1.0 # a bit crap ...
    return envelope, signal, label_df, n_prev

Try it...

In [None]:
import matplotlib.pyplot as plt 

envelope, signal, label_df, n_prev = get_new_data(morse_gen, SNR_dB=-17, phrase=morsestr, alphabet=alphabet)

# Show
print(n_prev)
print(type(signal), signal.shape)
print(type(label_df), label_df.shape)
print(max(signal))
    
x0 = 0
x1 = 1500

plt.figure(figsize=(50,6))
plt.plot(signal[x0:x1]*0.9, label="sig")
plt.plot(envelope[x0:x1]*0.9, label='env')
plt.plot(label_df[x0:x1].dit*0.9 + 1.0, label='dit')
plt.plot(label_df[x0:x1].dah*0.9 + 1.0, label='dah')
plt.plot(label_df[x0:x1].ele*0.9 + 2.0, label='ele')
plt.plot(label_df[x0:x1].chr*0.9 + 2.0, label='chr')
plt.plot(label_df[x0:x1].wrd*0.9 + 2.0, label='wrd')
plt.title("signal and keying labels")
plt.legend()
plt.grid()

In [None]:
plt.figure(figsize=(50,1.0+0.5*len(morse_gen.alphabet)))
plt.plot(signal[x0:x1]*0.9, label="sig")
plt.plot(envelope[x0:x1]*0.9, label='env')
for i, a in enumerate(alphabet):
    plt.plot(label_df[x0:x1][a]*0.9 + 1.0 + i, label=a)
plt.title("alpha labels")
plt.legend()
plt.grid()

## Create data loader for keying model
### Define keying dataset

In [None]:
import torch

class MorsekeyingDataset(torch.utils.data.Dataset):
    def __init__(self, morse_gen, device, SNR_dB=-23, nchars=132, nwords=27, phrase=None, alphabet="ABC"):
        self.envelope, self.signal, self.label_df0, self.seq_len = get_new_data(morse_gen, SNR_dB=SNR_dB, phrase=phrase, alphabet=alphabet)
        self.label_df = self.label_df0[['dit','dah','ele','chr','wrd']]
        self.X = torch.FloatTensor(self.signal).to(device)
        self.y = torch.FloatTensor(self.label_df.values).to(device)
        
    def __len__(self):
        return self.X.__len__() - self.seq_len

    def __getitem__(self, index):
        return (self.X[index:index+self.seq_len], self.y[index+self.seq_len])
    
    def get_envelope(self):
        return self.envelope
    
    def get_signal(self):
        return self.signal
    
    def get_labels(self):
        return self.label_df
    
    def get_labels0(self):
        return self.label_df0
    
    def get_seq_len(self):
        return self.seq_len()

### Define keying data loader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_keying_dataset = MorsekeyingDataset(morse_gen, device, -20, 132*5, 27*5, morsestr, alphabet)
train_keying_loader = torch.utils.data.DataLoader(train_keying_dataset, batch_size=1, shuffle=False) # Batch size must be 1

In [None]:
signal = train_keying_dataset.get_signal()
signal = (signal - min(signal))
label_df = train_keying_dataset.get_labels()
label_df0 = train_keying_dataset.get_labels0()

print(type(signal), signal.shape)
print(type(label_df), label_df.shape)

x0 = 0
x1 = 1500

plt.figure(figsize=(50,6))
plt.plot(signal[x0:x1]*0.9, label="sig")
plt.plot(envelope[x0:x1]*0.9, label='env')
plt.plot(label_df[x0:x1].dit*0.9 + 1.0, label='dit')
plt.plot(label_df[x0:x1].dah*0.9 + 1.0, label='dah')
plt.plot(label_df[x0:x1].ele*0.9 + 2.0, label='ele')
plt.plot(label_df[x0:x1].chr*0.9 + 2.0, label='chr')
plt.plot(label_df[x0:x1].wrd*0.9 + 2.0, label='wrd')
plt.title("keying - signal and labels")
plt.legend()
plt.grid()

## Create model classes

The model classes are the same they will be instantiated differently for keying and character models 

### Create model for keying recognition

In [None]:
import torch
import torch.nn as nn

class MorseLSTM(nn.Module):
    """
    Initial implementation
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(1, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(1, 1, self.hidden_layer_size).to(device),
            torch.zeros(1, 1, self.hidden_layer_size).to(device)
        )        
    
class MorseBatchedLSTM(nn.Module):
    """
    Initial implementation
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(1, 1, self.hidden_layer_size).to(self.device))
        self.m = nn.Softmax(dim=-1)

    def forward(self, input_seq):
        #print(len(input_seq), input_seq.shape, input_seq.view(-1, 1, 1).shape)
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(-1, 1, self.input_size), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return self.m(predictions[-1])
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(1, 1, self.hidden_layer_size).to(device),
            torch.zeros(1, 1, self.hidden_layer_size).to(device)
        )     
    
class MorseLSTM2(nn.Module):
    """
    LSTM stack
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6, dropout=0.2):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size, num_layers=2, dropout=dropout)
        self.linear = nn.Linear(hidden_layer_size, output_size)
        self.hidden_cell = (torch.zeros(2, 1, self.hidden_layer_size).to(self.device),
                            torch.zeros(2, 1, self.hidden_layer_size).to(self.device))

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
    def zero_hidden_cell(self):
        self.hidden_cell = (
            torch.zeros(2, 1, self.hidden_layer_size).to(device),
            torch.zeros(2, 1, self.hidden_layer_size).to(device)
        )        
        
class MorseNoHLSTM(nn.Module):
    """
    Do not keep hidden cell
    """
    def __init__(self, device, input_size=1, hidden_layer_size=8, output_size=6):
        super().__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        h0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        c0 = torch.zeros(1, 1, self.hidden_layer_size).to(self.device)
        lstm_out, _ = self.lstm(input_seq.view(len(input_seq), 1, -1), (h0, c0))
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
    
class MorseBiLSTM(nn.Module):
    """
    Attempt Bidirectional LSTM: does not work
    """
    def __init__(self, device, input_size=1, hidden_size=12, num_layers=1, num_classes=6):
        super(MorseEnvBiLSTM, self).__init__()
        self.device = device # This is the only way to get things work properly with device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x.view(len(x), 1, -1), (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out[-1]    

Create the keying model instance and print the details

In [None]:
morse_key_model = MorseBatchedLSTM(device, hidden_layer_size=12, output_size=5).to(device) # This is the only way to get things work properly with device
morse_key_loss_function = nn.MSELoss()
morse_key_optimizer = torch.optim.Adam(morse_key_model.parameters(), lr=0.001)

print(morse_key_model)
print(morse_key_model.device)

In [None]:
# Input and hidden tensors are not at the same device, found input tensor at cuda:0 and hidden tensor at cpu
for m in morse_key_model.parameters():
    print(m.shape, m.device)
X_t = torch.rand(n_prev)
#X_t = torch.tensor([-0.9648, -0.9385, -0.8769, -0.8901, -0.9253, -0.8637, -0.8066, -0.8066, -0.8593, -0.9341, -1.0000, -0.9385])
X_t = X_t.cuda()
print("Input shape", X_t.shape, X_t.view(-1, 1, 1).shape)
print(X_t)
morse_key_model(X_t)

In [None]:
import torchinfo
channels=10
H=n_prev
W=1
torchinfo.summary(morse_key_model, input_size=(channels, H, W))

## Train keying model

In [None]:
it = iter(train_keying_loader)
X, y = next(it)
print(X.reshape(n_prev,1).shape, X[0].shape, y[0].shape)
print(X[0], y[0])
X, y = next(it)
print(X[0], y[0])

In [None]:
%%time
from tqdm.notebook import tqdm

epochs = 20
morse_key_model.train()

for i in range(epochs):
    train_losses = []
    loop = tqdm(enumerate(train_keying_loader), total=len(train_keying_loader), leave=True)
    for j, train in loop:
        X_train = train[0][0]
        y_train = train[1][0]
        morse_key_optimizer.zero_grad()
        if morse_key_model.__class__.__name__ in ["MorseLSTM", "MorseLSTM2", "MorseBatchedLSTM", "MorseBatchedLSTM2"]:
            morse_key_model.zero_hidden_cell() # this model needs to reset the hidden cell
        y_pred = morse_key_model(X_train)
        single_loss = morse_key_loss_function(y_pred, y_train)
        single_loss.backward()
        morse_key_optimizer.step()
        train_losses.append(single_loss.item())
        # update progress bar
        if j % 1000 == 0:
            loop.set_description(f"Epoch [{i+1}/{epochs}]")
            loop.set_postfix(loss=np.mean(train_losses))

print(f'final: {i+1:3} epochs loss: {np.mean(train_losses):6.4f}')

In [None]:
save_model = True
if save_model: 
    torch.save(morse_key_model.state_dict(), 'models/morse_key_model')
else:
    morse_key_model.load_state_dict(torch.load('models/morse_key_model', map_location=device))

### Extract results for next step

In [None]:
p_key_train = torch.empty(1,5).to(device)
morse_key_model.eval()

loop = tqdm(enumerate(train_keying_loader), total=len(train_keying_loader))
for j, train in loop:
    with torch.no_grad():
        X_train = train[0]
        pred_val = morse_key_model(X_train[0])
        p_key_train = torch.cat([p_key_train, pred_val.reshape(1,5)])

In [None]:
# drop first garbage sample
p_key_train = p_key_train[1:]
print(p_key_train.shape)

In [None]:
print(p_key_train[0:2])
p_dits = p_key_train[:,0].to('cpu').numpy()
p_dahs = p_key_train[:,1].to('cpu').numpy()
p_eles = p_key_train[:,2].to('cpu').numpy()
p_chrs = p_key_train[:,3].to('cpu').numpy()
p_wrds = p_key_train[:,4].to('cpu').numpy()

plt.figure(figsize=(50,6))
plt.plot(signal[x0+n_prev:x1+n_prev]*0.9, label="sig")
plt.plot(envelope[x0+n_prev:x1+n_prev]*0.9, label='env')
plt.plot(p_dits[x0:x1]*0.9 + 1.0, label='dit')
plt.plot(p_dahs[x0:x1]*0.9 + 1.0, label='dah')
plt.plot(p_eles[x0:x1]*0.9 + 2.0, label='ele')
plt.plot(p_chrs[x0:x1]*0.9 + 2.0, label='chr')
plt.plot(p_wrds[x0:x1]*0.9 + 2.0, label='wrd')
plt.title("keying - predictions")
plt.legend()
plt.grid()

## Post processing

In [None]:
dit_shift = round(samples_per_dit / 128)
dit2_shift = round(samples_per_dit / 64)
dit3_shift = round(samples_per_dit / 32)
print(dit_shift, dit2_shift, dit3_shift)
dah2_shift = dit2_shift - dit_shift
dah3_shift = dit3_shift - dit2_shift
print(dah2_shift, dah3_shift)

elem_window = p_eles[dit_shift:]
w_dits = p_dits[:-dit_shift] * elem_window
w_dahs = p_dahs[:-dit_shift] * elem_window
w_dahs -= w_dits
w_dahs2 = w_dahs[dah2_shift:]
w_dahs3 = w_dahs[dah3_shift:]
w_chrs = p_chrs[dit_shift:]
w_wrds = p_wrds[dit_shift:]

w_dits *= 2.0
w_dits[w_dits > 1.0] = 1.0

label_char_df = train_keying_dataset.get_labels0().drop(columns=['dit','dah','ele','chr','wrd'])
label_char_df = label_char_df[n_prev:].reset_index(drop=True)

plt.figure(figsize=(50,6))
plt.plot(signal[x0+n_prev:x1+n_prev]*0.9, label="sig")
plt.plot(envelope[x0+n_prev:x1+n_prev]*0.9, label='env')
plt.plot(w_dits[x0:x1]*0.9 + 1.0, label='dit')
plt.plot(w_dahs[x0:x1]*0.9 + 1.0, label='dah')
plt.plot(w_dahs2[x0:x1]*0.9 + 1.0, label='da2', alpha=0.5)
plt.plot(w_dahs3[x0:x1]*0.9 + 1.0, label='da3', alpha=0.5)
#plt.plot(p_eles[x0+dit_shift:x1+dit_shift]*0.9 + 2.0, label='ele')
plt.plot(w_chrs[x0:x1]*0.9 + 2.0, label='chr')
plt.plot(w_wrds[x0:x1]*0.9 + 2.0, label='wrd')
plt.title("keying - predictions")
plt.legend()
plt.grid()

### Bi frequency reconstruction

The idea is to use the resulting dits and dahs sense signals in their original shape. The lengths of dits and dahs are therefore similar. To help distinguish between them by ear one would assign a higher pitch to the dits (thus mimicking the "i" of the dit) and a lower pitch to the dahs (thus mimicking the "ah" of the dah). Moreover to reconstruct the rhythm the dahs sense is delayed by two dits.

In [None]:
import scipy as sp
import scipy.special
from scipy.io import wavfile

Fdah = 440 # A4
Fdit = 523 # C5
Fs = 8000
noverlap = 128
decim = 128

dit_mod = w_dits[:-dah3_shift]
dit_wav = np.array([[x]*noverlap for x in dit_mod]).flatten()
dah_mod = w_dahs3
dah_wav = np.array([[x]*noverlap for x in dah_mod]).flatten()

dit_wt = (Fdit / Fs)*2*np.pi
dah_wt = (Fdah / Fs)*2*np.pi
dit_tone = np.sin(np.arange(len(dit_wav))*dit_wt)
dah_tone = np.sin(np.arange(len(dah_wav))*dah_wt)

wavfile.write('audio/bif.wav', Fs, dit_tone*dit_wav + dah_tone*dah_wav)

### Mono frequency reconstruction

Here we stick to more convenient Morse code. In order to do so 2 delayed copies of the dahs sense signal are summed to reconstruct the length of an original dah

In [None]:
Fcode = 523

mod_len = min(len(w_dits), len(w_dahs), len(w_dahs2), len(w_dahs3))
dit_mod = w_dits[:mod_len]
dah_mod = w_dahs[:mod_len] + w_dahs2[:mod_len] + w_dahs3[:mod_len]
all_mod = dit_mod + dah_mod
mod_wav = np.array([[x]*noverlap for x in all_mod]).flatten()

wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(dit_wav))*wt)

wavfile.write('audio/mof.wav', Fs, tone*mod_wav)

plt.figure(figsize=(50,6))
plt.plot(signal[x0+n_prev:x1+n_prev]*0.9, label="sig")
plt.plot(envelope[x0+n_prev:x1+n_prev]*0.9, label='env')
plt.plot(all_mod[x0:x1]*0.9 + 1.0, label='mod')
plt.title("envelope reconstruction")
plt.legend()
plt.grid()

### Original signal

In [None]:
sig_wav = np.array([[x]*noverlap for x in signal[n_prev:]]).flatten()
sig_tone = np.sin(np.arange(len(sig_wav))*wt)
wavfile.write('audio/ori.wav', Fs, sig_tone*sig_wav)

## Decoder 

The decoder is built around a state machine that navigates in the "Morse tree":

<img src="files/Morse1Min.gif">


  - The state is reset to `start` when a character trigger is received and the current character is appended to the result string
  - A `dit` trigger moves the state to the right node in the tree and sets the current character
  - A `dah` trigger moves the state to the left node in the tree and sets the current character
  - If a `dit` or `dah` is received while on a leaf node the state is not moved
  - The character corresponding to the `start` state is the unknown character `?`
  - Some nodes do not correspond to any character. For these nodes the character attached is the unknown character `?`
  - A `wrd` trigger resets the state to `start` and appends the space character to the result string
  
Triggers are defined on `dit`, `dah`, `chr` and `wrd` signals as positive edge triggers that default to `0.5` level. 

Some tricks are implemented:

  - de-bouncing of signals by their expected lengths
  - original dah sense signal is expected to be delayed by 2 dit lengths to recover their original starting moment. Assuming this condition treating dah in priority over dit helps in masking small dit re-bounce at start of original dah period.
  

In [None]:
class MorseDecoder:
    def __init__(self, dit_length, dit_lvl=0.5, dah_lvl=0.5, chr_lvl=0.5, wrd_lvl=0.5):
        self.dit_length = dit_length
        self.dit_lvl = dit_lvl
        self.dah_lvl = dah_lvl
        self.chr_lvl = chr_lvl
        self.wrd_lvl = wrd_lvl
        self.prev_samples = None
        self.result = ""
        self.state = "start"
        self.dit_count = 0
        self.dah_count = 0
        self.chr_count = 0
        self.wrd_count = 0
        self.morse_tree = {
            "start": ("T", "E"),
            "T": ("M", "N"),
            "E": ("A", "I"),
            "M": ("O", "G"),
            "N": ("K", "D"),
            "A": ("W", "R"),
            "I": ("U", "S"),
            "O": ("Odash", "Odit"),
            "G": ("Q", "Z"),
            "K": ("Y", "C"),
            "D": ("X", "B"),
            "W": ("J", "P"),
            "R": (None, "L"),
            "U": ("Udash", "F"),
            "S": ("V", "H"),
            "Odash": ("0", "9"),
            "Odit": (None, "8"),
            "Q": (None, None),
            "Z": (None, "7"),
            "Y": (None, None),
            "C": (None, None),
            "X": (None, None),
            "B": (None, "6"),
            "J": ("1", None),
            "P": (None, None),
            "L": (None, None),
            "Udash": ("2", None),
            "F": (None, None),
            "V": ("3", None),
            "H": ("4", "5"),
            "0": (None, None),
            "1": (None, None),
            "2": (None, None),
            "3": (None, None),
            "4": (None, None),
            "5": (None, None),
            "6": (None, None),
            "7": (None, None),
            "8": (None, None),
            "9": (None, None),
        }
        
    def _dit_trig(self):
        next_state = self.morse_tree[self.state][1] # right
        if next_state is not None:
            self.state = next_state
        
    def _dah_trig(self):
        next_state = self.morse_tree[self.state][0] # left
        if next_state is not None:
            self.state = next_state
    
    def _chr_trig(self):
        if len(self.state) == 1:
            self.result += self.state
        self.state = "start"
        
    def _wrd_trig(self):
        self.result += " "
        self.state = "start"
    
    def reset(self):
        self.prev_samples = None
        self.result = ""
        self.state = "start"
        self.dit_count = 0
        self.dah_count = 0
        self.chr_count = 0
        self.wrd_count = 0
            
    def new_samples(self, samples):
        if not self.prev_samples:
            self.prev_samples = samples
            return
        # de-bouncing with 3 dits length (dah length) - priority to dah (mask false dit)
        if self.dah_count > 3*self.dit_length and self.prev_samples[1] < self.dah_lvl and samples[1] >= self.dah_lvl:
            self.dah_count = 0
            self._dah_trig()
        # de-bouncing with 1 dit length
        elif self.dit_count > self.dit_length and self.prev_samples[0] < self.dit_lvl and samples[0] >= self.dit_lvl:
            self.dit_count = 0
            self._dit_trig()
        # de-bouncing with 2 dits length (chr length = char separator - element separator)
        if self.chr_count > 2*self.dit_length and self.prev_samples[2] < self.chr_lvl and samples[2] >= self.chr_lvl:
            self.chr_count = 0
            self._chr_trig()
        # de-bouncing with 4 dits length (wrd length = word separator - chr separator - element separator)
        if self.wrd_count > 4*self.dit_length and self.prev_samples[3] < self.wrd_lvl and samples[3] >= self.wrd_lvl:
            self.wrd_count = 0
            self._wrd_trig()
        if self.dit_count <= self.dit_length:
            self.dit_count += 1
        if self.dah_count <= 3*self.dit_length:
            self.dah_count += 1
        if self.chr_count <= 2*self.dit_length:
            self.chr_count += 1
        if self.wrd_count <= 4*self.dit_length:
            self.wrd_count += 1
        self.prev_samples = samples

In [None]:
print(type(w_dits), type(w_dahs3), type(w_chrs), type(w_wrds))
print(len(w_dits), len(w_dahs3), len(w_chrs), len(w_wrds))

In [None]:
import re

# will take delayed dah signal (thus shorter)
w_len = np.min([len(w_dits), len(w_dahs3), len(w_chrs), len(w_wrds)])
print(w_len)

morse_decoder = MorseDecoder(dit_shift, chr_lvl=0.6)
for i in range(w_len):
    morse_decoder.new_samples([w_dits[i], w_dahs3[i], w_chrs[i], w_wrds[i]])
result = morse_decoder.result
result = re.sub(' +', ' ', result)
print(len(result), result)

## Test signal reconstruction

In [None]:
test_phrase = "VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB VVV DE F4EXB"
test_keying_dataset = MorsekeyingDataset(morse_gen, device, -20, 132*5, 27*5, test_phrase, alphabet)
test_keying_loader = torch.utils.data.DataLoader(test_keying_dataset, batch_size=1, shuffle=False) # Batch size must be 1

In [None]:
p_dit_l = []
p_dah_l = []
p_ele_l = []
p_chr_l = []
p_wrd_l = []
morse_key_model.eval()

loop = tqdm(enumerate(test_keying_loader), total=len(test_keying_loader))
for j, test in loop:
    with torch.no_grad():
        X_test = test[0]
        pred_val = morse_key_model(X_test[0]).cpu()
        p_dit_l.append(pred_val[0].item())
        p_dah_l.append(pred_val[1].item())
        p_ele_l.append(pred_val[2].item())
        p_chr_l.append(pred_val[3].item())
        p_wrd_l.append(pred_val[4].item())
        
p_dit = np.array(p_dit_l)
p_dah = np.array(p_dah_l)
p_ele = np.array(p_ele_l)
p_chr = np.array(p_chr_l)
p_wrd = np.array(p_wrd_l)

# trim negative values
p_dit[p_dit < 0] = 0
p_dah[p_dah < 0] = 0
p_ele[p_ele < 0] = 0
p_chr[p_chr < 0] = 0
p_wrd[p_wrd < 0] = 0

In [None]:
test_signal = test_keying_dataset.get_signal()
test_envelope = test_keying_dataset.get_envelope()

elem_window = p_ele[dit_shift:]
w_dits = p_dit[:-dit_shift] * elem_window
w_dahs = p_dah[:-dit_shift] * elem_window
w_dahs -= w_dits
w_dahs2 = w_dahs[dah2_shift:]
w_dahs3 = w_dahs[dah3_shift:]
w_chrs = p_chr[dit_shift:]
w_wrds = p_wrd[dit_shift:]

w_dits *= 2.0
w_dits[w_dits > 1.0] = 1.0

plt.figure(figsize=(50,6))
plt.plot(test_signal[x0+n_prev:x1+n_prev]*0.7, label="sig")
plt.plot(test_envelope[x0+n_prev:x1+n_prev]*0.9, label='env')
plt.plot(w_dits[x0:x1]*0.9 + 1.0, label='dit')
plt.plot(w_dahs[x0:x1]*0.9 + 1.0, label='dah')
plt.plot(w_dahs2[x0:x1]*0.9 + 1.0, label='da2', alpha=0.5)
plt.plot(w_dahs3[x0:x1]*0.9 + 1.0, label='da3', alpha=0.5)
plt.plot(w_chrs[x0:x1]*0.9 + 2.0, label='chr')
plt.plot(w_wrds[x0:x1]*0.9 + 2.0, label='wrd')
plt.title("Test keying - predictions")
plt.legend()
plt.grid()

### Reconstructed signal

In [None]:
mod_len = min(len(w_dits), len(w_dahs), len(w_dahs2), len(w_dahs3))
print(mod_len, noverlap, mod_len*noverlap)
dit_mod = w_dits[:mod_len]
dah_mod = w_dahs[:mod_len] + w_dahs2[:mod_len] + w_dahs3[:mod_len]
all_mod = dit_mod + dah_mod
mod_wav = np.array([[x]*noverlap for x in all_mod]).flatten()

wt = (Fcode / Fs)*2*np.pi
tone = np.sin(np.arange(len(mod_wav))*wt)

wavfile.write('audio/moft.wav', Fs, tone*mod_wav)

plt.figure(figsize=(50,6))
plt.plot(test_signal[x0+n_prev:x1+n_prev]*0.7, label="sig")
plt.plot(test_envelope[x0+n_prev:x1+n_prev]*0.9, label='env')
plt.plot(all_mod[x0:x1]*0.9 + 1.0, label='mod')
plt.title("envelope reconstruction")
plt.legend()
plt.grid()

### Original signal

In [None]:
sig_wav = np.array([[x]*noverlap for x in test_signal[n_prev:]]).flatten()
sig_tone = np.sin(np.arange(len(sig_wav))*wt)
wavfile.write('audio/orit.wav', Fs, sig_tone*sig_wav)

In [None]:
w_len = min(len(w_dits), len(w_dahs3), len(w_dahs2), len(w_dahs3))
morse_decoder = MorseDecoder(dit_shift, chr_lvl=0.7)
for i in range(w_len):
    morse_decoder.new_samples([w_dits[i], w_dahs3[i], w_chrs[i], w_wrds[i]]) # take delayed dah signal
result = morse_decoder.result
result = re.sub(' +', ' ', result)
print(len(result), result)