In [None]:
import numpy as np
import torch
from pickle import load
from train_deep_learning import Model
from preprocess_data import load_data, labeled_audio_segmentation, convert_to_array
import os
import pandas as pd

In [None]:
def read_model_info(dir='model_data/'):

    with open(dir + 'model_info.txt', 'r') as file:
        input_layer = int(file.readline().strip())
        hidden_layers = [int(x) for x in file.readline().strip().split(',') if x.strip()]
        output_layer = int(file.readline().strip())
    
    return input_layer, hidden_layers, output_layer 

In [None]:
# import model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

input_size, hidden_sizes, output_size = read_model_info()

model = Model(input_size, hidden_sizes, output_size).to(device)
model.load_state_dict(torch.load('model_data/model.pt', weights_only=True))
model.eval()

In [None]:
# Load data transformers
def load_transformers(dir='model_data/transformer_dumps/'):
    encoder = load(open(dir + 'encoder.pkl', 'rb'))
    scaler = load(open(dir + 'scaler.pkl', 'rb'))
    pca = load(open(dir + 'pca.pkl', 'rb'))

    return {'encoder': encoder, 'scaler': scaler, 'pca': pca}

transformers = load_transformers()
print(transformers)

In [None]:
base_names = []
extensions = ['.txt', '.wav']
data_dir='data'

filenames = os.listdir(data_dir)

for file in filenames:
    base, ext = os.path.splitext(file)
    # Appends file base name to base_names if it has one of the two extensions and is not already in base_names
    if ext in extensions and base not in base_names:
        base_names.append(base)

# loads the data from every pair of txt and wav files in the data_dir    
dataframe = pd.DataFrame()
for base in base_names:
    label_file = data_dir + '/' + base + '.tsv'
    audio_file = data_dir + '/' + base + '.wav'

    audio, labels, sr = load_data(label_file, audio_file)
    segmented_audio, seg_labels = labeled_audio_segmentation(labels, audio, sr)

    key_df = pd.DataFrame(convert_to_array(segmented_audio))
    key_df['label'] = seg_labels
    dataframe = pd.concat([dataframe, key_df], ignore_index=True)

features = dataframe.drop('label', axis=1)
features.fillna(0, inplace=True)
labels = dataframe['label']

In [None]:
def pipeline(transformers, features):

    scaled = transformers['scaler'].transform(features)
    return transformers['pca'].transform(scaled)
transformed = pipeline(transformers, features)

In [None]:
trainhist = pd.read_csv('model_data/trainhist.csv')
trainhist.tail()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.subplot(2, 1, 1)
plt.title("Model Loss")
sns.lineplot(x='epoch', y='train_loss', data=trainhist)
sns.lineplot(x='epoch', y='val_loss', data=trainhist)
plt.legend(labels=['train_loss', 'val_loss'])

plt.subplot(2, 1, 2)
plt.title("Model Accuracy")
sns.lineplot(x='epoch', y='train_acc', data=trainhist)
sns.lineplot(x='epoch', y='val_acc', data=trainhist)
plt.legend(labels=['train_acc', 'val_acc'])

plt.tight_layout()
plt.show()

In [None]:
# emulates typing the characters in a list, printing out the result
# characters such as backspace, space, shift are handled
def emulate_typing(chars: list):
    shifted = False
    buffer = []

    for char in chars:
        if char == 'backspace':
            if buffer:
                buffer.pop()
        elif char == 'space':
            buffer.append(' ')
        elif char == 'shift' or char == 'shift_r':
            shifted = True
        else:
            if shifted:
                buffer.append(char.upper())
                shifted = False
            else: 
                buffer.append(char)

    print(''.join(char for char in buffer))

In [None]:
# Check model predictions
pred_idx_end = 200

features_np = features.to_numpy()

predictions = model(torch.tensor(transformed[:pred_idx_end].astype(np.float32)).to(device)).cpu().detach().numpy()
pred_y = transformers['encoder'].inverse_transform(predictions).squeeze()
true_y = labels[:pred_idx_end].to_numpy()

print("Predicted:\n\t", end="")
emulate_typing(pred_y)

print('')
print("Actual:\n\t", end='')
emulate_typing(true_y)