In [None]:
import argparse
import torch
import torchaudio
import os
import numpy as np
from data.utils import get_magnitude, get_audio_from_magnitude
from torch.utils.mobile_optimizer import optimize_for_mobile
import matplotlib.pyplot as plt

from getmodel import get_model

import sounddevice as sd
import IPython.display as ipd
import copy
import time 
from pystoi import stoi
from pesq import pesq

def predict_waveform(audio, sr, length_seconds, model):
    total_samples = audio.shape[1]
    segment_length = sr * length_seconds
    n_segments = int(np.ceil(audio.shape[1] / segment_length))

    totalInferenceTime = 0
    output_segments = {"clean": [], "noise": []}
    for i in range(n_segments):
        # print(f"Processing segment {i+1}/{n_segments}")
        if audio.shape[1] >= (i + 1) * segment_length:
            seg_audio = audio[:, i * segment_length : (i + 1) * segment_length]
        else:
            seg_audio = torch.zeros([1, segment_length])
            seg_audio[:, 0 : audio.shape[1] - i * segment_length] = audio[:, i * segment_length :]

        seg_audio = seg_audio.unsqueeze(0)

        start_time = time.time()
        out_sources = model(seg_audio)  # Use the model
        end_time = time.time()
        print("Time taken for inference: ", end_time - start_time)
        totalInferenceTime += end_time - start_time
        
        out_sources = out_sources.squeeze()
        out_sources = out_sources.cpu().detach()

        clean_audio = out_sources[0:1, :]
        noise_audio = out_sources[1:2, :]

        # Append the obtained segments for each source into a list
        output_segments["clean"].append(clean_audio)
        output_segments["noise"].append(noise_audio)

    # Concatenate along time dimension to obtain the full audio
    clean_output = torch.cat(output_segments["clean"], dim=1)
    noise_output = torch.cat(output_segments["noise"], dim=1)

    print("Total inference time: ", totalInferenceTime)
    return clean_output[:, 0:total_samples], noise_output[:, 0:total_samples]


def predict_waveform_withTime(audio, sr, length_seconds, model):
    total_samples = audio.shape[1]
    segment_length = sr * length_seconds
    n_segments = int(np.ceil(audio.shape[1] / segment_length))

    output_segments = {"clean": [], "noise": []}
    totalInferenceTime = 0
    for i in range(n_segments):
        # print(f"Processing segment {i+1}/{n_segments}")
        if audio.shape[1] >= (i + 1) * segment_length:
            seg_audio = audio[:, i * segment_length : (i + 1) * segment_length]
        else:
            seg_audio = torch.zeros([1, segment_length])
            seg_audio[:, 0 : audio.shape[1] - i * segment_length] = audio[:, i * segment_length :]

        seg_audio = seg_audio.unsqueeze(0)

        start_time = time.time()
        out_sources = model(seg_audio)  # Use the model
        end_time = time.time()
        totalInferenceTime += end_time - start_time
        # print("Time taken for inference: ", end_time - start_time)
        
        out_sources = out_sources.squeeze()
        out_sources = out_sources.cpu().detach()

        clean_audio = out_sources[0:1, :]
        noise_audio = out_sources[1:2, :]

        # Append the obtained segments for each source into a list
        output_segments["clean"].append(clean_audio)
        output_segments["noise"].append(noise_audio)

    # Concatenate along time dimension to obtain the full audio
    clean_output = torch.cat(output_segments["clean"], dim=1)
    noise_output = torch.cat(output_segments["noise"], dim=1)

    return clean_output[:, 0:total_samples], noise_output[:, 0:total_samples], totalInferenceTime


### 1. Load in Trained Model and any noisy speech testing file here, then run inference

In [None]:
# Change these to test on any file with any model
filepath = "datasets/Test/26-495-0047"
tarFile = "LearningRate001"
# ---------------------------------------------------


# File to be denoised
inputFilePath = f'{filepath}.wav'
outputFilePath = f'{filepath}({tarFile}).wav'

model = "ConvTasNet"
checkpointName = f'{tarFile}.tar'
length = 4

device = "cpu"

# Getting the trained model
training_utils_dict = get_model(model)
trainedModel = training_utils_dict["model"]
data_mode = training_utils_dict["data_mode"]
# loss_fn = training_utils_dict["loss_fn"]
# loss_mode = training_utils_dict["loss_mode"]

assert os.path.isfile(checkpointName) and checkpointName.endswith(
    ".tar"
), "The specified checkpoint_name is not a valid checkpoint"
checkpoint = torch.load(checkpointName,map_location=torch.device('cpu'))
trainedModel.load_state_dict(checkpoint["model_state_dict"])
trainedModel = trainedModel.to(device)
trainedModel.eval()
print(f"Model loaded from checkpoint: {checkpointName}")


def bestModelTest(model = model, checkpointName = checkpointName, inputFilePath = inputFilePath, outputFilePath = outputFilePath, length = length, device = device):
 
    extensions = (".mp3", ".wav", ".flac")
    assert os.path.isfile(inputFilePath) and inputFilePath.endswith(
        extensions
    ), f"Input file cannot be loaded. Either it does not exist or has a wrong extension. Allowed extensions {extensions}"

    audio, sr = torchaudio.load(inputFilePath)
    if sr != 16000:
        audio = torchaudio.transforms.Resample(sr, 16000)(audio)
        sr = 16000

    audio /= audio.abs().max()

    if data_mode in ["time"]:
        clean_output, noise_output = predict_waveform(audio.to(device), sr, length, model)

    # Normalization wrt mixture
    clean_output /= audio.abs().max()#clean_output.abs().max()
    noise_output /= audio.abs().max()#noise_output.abs().max()

    plt.subplot(3, 1, 1)
    plt.plot(
        audio[
            0,
        ]
    )
    plt.subplot(3, 1, 2)
    plt.plot(
        clean_output[
            0,
        ]
    )
    plt.subplot(3, 1, 3)
    plt.plot(
        noise_output[
            0,
        ]
    )
    plt.show()

    output_name, ext = os.path.splitext(outputFilePath)

    torchaudio.save(f"{output_name}_clean{ext}", clean_output, sr)
    torchaudio.save(f"{output_name}_noise{ext}", noise_output, sr)

    return audio, clean_output, noise_output,sr


In [None]:
noisy,clean,_,sr= bestModelTest(model=trainedModel)

In [None]:
#Noisy input audio
ipd.Audio(noisy, rate=sr)

In [None]:
#Denoised audio
ipd.Audio(clean, rate=sr)

### Convert the above trained model to a pytorch mobile model

In [None]:
# from torch.utils.mobile_optimizer import optimize_for_mobile
training_utils_dict = get_model(model)
# training_utils_dict = copy.deepcopy(get_model(model))

trainedModel = training_utils_dict["model"]
data_mode = training_utils_dict["data_mode"]

checkpoint = torch.load(checkpointName,map_location=torch.device('cpu'))
trainedModel.load_state_dict(checkpoint["model_state_dict"])

scripted_module = torch.jit.script(trainedModel)
optimized_scripted_module = optimize_for_mobile(scripted_module)
optimized_scripted_module._save_for_lite_interpreter(f'./{tarFile}.ptl')
print(f'Pytorch mobile model saved as {tarFile}.ptl')

### Run inference with the pytorch mobile model and same testing noisy file as above

In [None]:
mobileModel = torch.jit.load(f'{tarFile}.ptl')
mobileModel = mobileModel.to(device)
mobileModel.eval()
print(f'Model loaded from tar file: {tarFile}.ptl')

extensions = (".mp3", ".wav", ".flac")
assert os.path.isfile(inputFilePath) and inputFilePath.endswith(
    extensions
), f"Input file cannot be loaded. Either it does not exist or has a wrong extension. Allowed extensions {extensions}"

mobile_audio, sr = torchaudio.load(inputFilePath)
if sr != 16000:
    mobile_audio = torchaudio.transforms.Resample(sr, 16000)(mobile_audio)
    sr = 16000

mobile_audio /= mobile_audio.abs().max()

# Just consider it time here
mobile_clean_output, mobile_noise_output = predict_waveform(mobile_audio.to(device), sr, length, mobileModel)

# Normalization wrt mixture
mobile_clean_output /= mobile_audio.abs().max()#clean_output.abs().max()
mobile_noise_output /= mobile_audio.abs().max()#noise_output.abs().max()

plt.subplot(3, 1, 1)
plt.plot(
    mobile_audio[
        0,
    ]
)
plt.subplot(3, 1, 2)
plt.plot(
    mobile_clean_output[
        0,
    ]
)
plt.subplot(3, 1, 3)
plt.plot(
    mobile_noise_output[
        0,
    ]
)
plt.show()

output_name, ext = os.path.splitext(outputFilePath)

torchaudio.save(f"{output_name}_clean_mobile{ext}", mobile_clean_output, sr)
torchaudio.save(f"{output_name}_noise_mobile{ext}", mobile_noise_output, sr)


In [None]:
#Denoised audio
ipd.Audio(mobile_clean_output, rate=sr)

# Testing Different SNR ranges with longer audio files than training duration
1. Making inferences on multiple noisy speech files
2. Getting the average pesq and estoi scores for the noisy speech files within their respective SNR ranges

In [None]:
def inferWithModel(mobileModel, inputTensor, sr, length,sources):
    # Just consider it time here
    mobile_clean_output, mobile_noise_output, totalInferenceTime = predict_waveform_withTime(inputTensor.to(device), sr, length, mobileModel)

    # Normalization wrt mixture
    mobile_clean_output /= mobile_audio.abs().max()#clean_output.abs().max()
    mobile_noise_output /= mobile_audio.abs().max()#noise_output.abs().max()

    mobile_clean_output = np.squeeze(mobile_clean_output)
    # print(f'sources[0].shape: {sources[0].shape}')
    # print(f'mobile_clean_output.shape: {mobile_clean_output.shape}')
    pesq_score = pesq(fs = sr,ref = sources[0].numpy(),deg = mobile_clean_output.numpy())
    estoi_score = stoi(x=sources[0].numpy(),y = mobile_clean_output.numpy(),fs_sig= sr,extended=True)

    return pesq_score, estoi_score, totalInferenceTime      


In [None]:
keep_rate=0.05
# min_snr=-20
# max_snr=-10
sr = 16000
length = 12
num_of_samples_per_snr = 50

snr_ranges = [(-30,-20),(-20,-10),(-10,0),(0,10),(10,20),(20,30)]

from data import AudioDirectoryDataset, NoiseMixerDataset
from tqdm import tqdm

clean_dataset_path = f'datasets/Validation/cleanSliced{length}sec'
noise_dataset_path = f'datasets/Validation/noisySliced{length}sec'

all_pesqScores_before = []
all_estoiScores_before = []
all_pesqScores = []
all_estoiScores = []
all_inferenceTimes= []


all_pesqScores_mob = []
all_estoiScores_mob = []
all_inferenceTimes_mob= []

for min_snr,max_snr in snr_ranges:
    print(f'Mixing {length} seconds clean and noisy with SNR range ({min_snr}, {max_snr}) with keepRate of {keep_rate}')
    train_clean_dataset = AudioDirectoryDataset(root=clean_dataset_path,keep_rate=keep_rate)
    train_noise_dataset = AudioDirectoryDataset(root=noise_dataset_path,keep_rate=keep_rate)
    train_data = NoiseMixerDataset(
        clean_dataset=train_clean_dataset, noise_dataset=train_noise_dataset, min_snr=min_snr, max_snr=max_snr
    )

    print(f"Dataset size: {len(train_data)}")
    print(f'Samples used : {num_of_samples_per_snr}')
    mixtureT, sourcesT = train_data[0]

    # Just to check if the audio sample is correct
    # torchaudio.save(f'datasets/testOnAndroidAsWell/MixtureSNR({min_snr},{max_snr}){ext}', sourcesT, sr)

    pesqScores_before = []
    estoiScores_before = []

    pesqScores = []
    estoiScores = []
    inferenceTimes = []
    
    pesqScores_mob = []
    estoiScores_mob = []
    inferenceTimes_mob = []

    for i, (mixture, sources) in  enumerate(train_data):
        # print(f'Train Element {i+1}:')
        if i==0 : print(mixture.shape, sources.shape)
        # mixture, [clean, noise]
        # 20 samples taken
        if(i == num_of_samples_per_snr): break
       
        # if i==10: print(f'Element {i} ::> PESQ Score: {pesqScore} and ESTOI Score: {estoiScore}')
        torchaudio.save(f'datasets/testOnAndroidAsWell2/{min_snr}_{max_snr}/noisySpeechSample_{i}_{ext}', sourcesT, sr)

        pesq_score_before = pesq(fs = sr,ref = sources[0].numpy(),deg = np.squeeze(mixture.numpy()))
        estoi_score_before = stoi(x=sources[0].numpy(),y = np.squeeze(mixture.numpy()),fs_sig= sr,extended=True)
        pesqScores_before.append(pesq_score_before)
        estoiScores_before.append(estoi_score_before)

        pesqScore,estoiScore,totalInferenceTime = inferWithModel(trainedModel, mixture, sr, length, sources)
        pesqScore_mob,estoiScore_mob,totalInferenceTime_mob = inferWithModel(mobileModel, mixture, sr, length, sources)
       
        # Inside specific SNR range
        pesqScores.append(pesqScore)
        estoiScores.append(estoiScore)
        inferenceTimes.append(totalInferenceTime)

        pesqScores_mob.append(pesqScore_mob)
        estoiScores_mob.append(estoiScore_mob)
        inferenceTimes_mob.append(totalInferenceTime_mob)

        # For all SNR ranges
    all_pesqScores.append(np.array(pesqScores))
    all_estoiScores.append(np.array(estoiScores))
    all_inferenceTimes.append(np.array(inferenceTimes))

    all_pesqScores_mob.append(np.array(pesqScores_mob))
    all_estoiScores_mob.append(np.array(estoiScores_mob))
    all_inferenceTimes_mob.append(np.array(inferenceTimes_mob))

    all_pesqScores_before.append(np.array(pesqScores_before))
    all_estoiScores_before.append(np.array(estoiScores_before))

    print(f'Average PESQ Score before: {np.mean(pesqScores_before)} and Average ESTOI Score before: {np.mean(estoiScores_before)}')
    print(f'Average PESQ Score: {np.mean(pesqScores)} and Average ESTOI Score: {np.mean(estoiScores)} and Average Inference Time: {np.mean(inferenceTimes)}')

print(f'Length of all_pesqScores: {len(all_pesqScores)}')
print(f'Length of all_estoiScores: {len(all_estoiScores)}')
print(f'Length of all_inferenceTimes: {len(all_inferenceTimes)}')

In [None]:
def plotResults(allPesqScores,allEstoiScores,allPesqScoresBefore,allEstoiScoresBefore,postFix = ""):
    snr_ranges = [(-30,-20),(-20,-10),(-10,0),(0,10),(10,20),(20,30)]
    mean_values = []
    variance_values = []

    ############################################################## PESQ
    # Calculate mean and variance for each SNR range
    for snr_range, data_array in zip(snr_ranges, allPesqScores):
        mean_values.append(np.mean(data_array))
        variance_values.append(np.var(data_array))

    labels = [f'({snr_range[0]},{snr_range[1]})' for snr_range in snr_ranges]
    # Plotting the histogram
    plt.bar(range(len(snr_ranges)), mean_values, yerr=variance_values, capsize=5, tick_label=labels)
    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('PESQ: Mean and Variance (After Enhancement)')

    for i, (mean_val, var_val) in enumerate(zip(mean_values, variance_values)):
        plt.text(i, mean_val + 0.05, f'{mean_val:.2f}', ha='right', va='baseline', color = 'green')
        plt.text(i, mean_val - 0.05, f'{var_val:.2f}', ha='left', va='top', color= 'yellow')

    plt.savefig(f'after_enhancement_plot_pesq{postFix}.png')
    plt.show()


    mean_values_before = []
    variance_values_before  = []
    for snr_range, data_array in zip(snr_ranges, allPesqScoresBefore):
        mean_values_before.append(np.mean(data_array))
        variance_values_before.append(np.var(data_array))

    # BEFORE ENHANCEMENT
    # Plotting the histogram
    plt.bar(range(len(snr_ranges)), mean_values_before, yerr=variance_values_before, capsize=5, tick_label=labels)
    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('PESQ: Mean and Variance (Before Enhancement)')

    for i, (mean_val, var_val) in enumerate(zip(mean_values_before, variance_values_before)):
        plt.text(i, mean_val + 0.05, f'{mean_val:.2f}', ha='right', va='baseline', color = 'green')
        plt.text(i, mean_val - 0.05, f'{var_val:.2f}', ha='left', va='top', color= 'yellow')

    plt.savefig(f'before_enhancement_plot_pesq{postFix}.png')
    plt.show()

    
    # Plotting the combined histogram for mean values
    bar_width = 0.35  # Adjust the width of the bars
    bar_positions_after = np.arange(len(snr_ranges))
    bar_positions_before = bar_positions_after + bar_width  # Adjust positions for the second set of bars

    # Bar plot for mean values after enhancement (blue)
    plt.bar(bar_positions_after, mean_values, width=bar_width, label='After Enhancement', color='green', alpha=0.7)

    # Bar plot for mean values before enhancement (green)
    plt.bar(bar_positions_before, mean_values_before, width=bar_width, label='Before Enhancement', color='blue', alpha=0.7)

    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('PESQ: Mean Comparison')
    plt.xticks(bar_positions_after + bar_width / 2, labels)  # Adjust the x-axis tick positions
    plt.legend()

    # Adding text annotations for mean values
    for i, (mean_val_after, mean_val_before) in enumerate(zip(mean_values, mean_values_before)):
        plt.text(bar_positions_after[i], mean_val_after + 0.05, f'{mean_val_after:.2f}', ha='center', va='bottom', color='green')
        plt.text(bar_positions_before[i], mean_val_before + 0.05, f'{mean_val_before:.2f}', ha='center', va='bottom', color='blue')

    # Save the combined plot
    plt.ylim(0,3.5)
    plt.savefig(f'combined_plot_pesq{postFix}.png')

    # Show the combined plot
    plt.show()



    ############################################################## ESTOI

    mean_values = []
    variance_values = []
    # Calculate mean and variance for each SNR range
    for snr_range, data_array in zip(snr_ranges, allEstoiScores):
        mean_values.append(np.mean(data_array))
        variance_values.append(np.var(data_array))

    labels = [f'({snr_range[0]},{snr_range[1]})' for snr_range in snr_ranges]
    # Plotting the histogram
    plt.bar(range(len(snr_ranges)), mean_values, yerr=variance_values, capsize=5, tick_label=labels)
    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('ESTOI: Mean and Variance (After Enhancement)')

    for i, (mean_val, var_val) in enumerate(zip(mean_values, variance_values)):
        plt.text(i, mean_val + 0.05, f'{mean_val:.2f}', ha='right', va='baseline', color = 'green')
        plt.text(i, mean_val - 0.05, f'{var_val:.2f}', ha='left', va='top', color= 'yellow')

    plt.ylim(0, 1.5)
    plt.savefig(f'after_enhancement_plot_estoi{postFix}.png')
    plt.show()


    mean_values_before = []
    variance_values_before  = []
    for snr_range, data_array in zip(snr_ranges, allEstoiScoresBefore):
        mean_values_before.append(np.mean(data_array))
        variance_values_before.append(np.var(data_array))

    # BEFORE ENHANCEMENT
    # Plotting the histogram
    plt.bar(range(len(snr_ranges)), mean_values_before, yerr=variance_values_before, capsize=5, tick_label=labels)
    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('ESTOI: Mean and Variance (Before Enhancement)')

    for i, (mean_val, var_val) in enumerate(zip(mean_values_before, variance_values_before)):
        plt.text(i, mean_val + 0.05, f'{mean_val:.2f}', ha='right', va='baseline', color = 'green')
        plt.text(i, mean_val - 0.05, f'{var_val:.2f}', ha='left', va='top', color= 'yellow')

    plt.ylim(0, 1.5)
    plt.savefig(f'before_enhancement_plot_estoi{postFix}.png')
    plt.show()


    # Plotting the combined histogram for mean values
    bar_width = 0.35  # Adjust the width of the bars
    bar_positions_after = np.arange(len(snr_ranges))
    bar_positions_before = bar_positions_after + bar_width  # Adjust positions for the second set of bars

    # Bar plot for mean values after enhancement (blue)
    plt.bar(bar_positions_after, mean_values, width=bar_width, label='After Enhancement', color='green', alpha=0.7)

    # Bar plot for mean values before enhancement (green)
    plt.bar(bar_positions_before, mean_values_before, width=bar_width, label='Before Enhancement', color='blue', alpha=0.7)

    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('ESTOI: Mean Comparison')
    plt.xticks(bar_positions_after + bar_width / 2, labels)  # Adjust the x-axis tick positions
    plt.legend()

    # Adding text annotations for mean values
    for i, (mean_val_after, mean_val_before) in enumerate(zip(mean_values, mean_values_before)):
        plt.text(bar_positions_after[i], mean_val_after + 0.05, f'{mean_val_after:.2f}', ha='center', va='bottom', color='green')
        plt.text(bar_positions_before[i], mean_val_before + 0.05, f'{mean_val_before:.2f}', ha='center', va='bottom', color='blue')

    plt.ylim(0, 1.5)
    # Save the combined plot
    plt.savefig(f'combined_plot_estoi{postFix}.png')

    # Show the combined plot
    plt.show()

In [None]:
def comparisonInferenceTimes(allPesqScores,allEstoiScores,allPesqScores_mob,allEstoiScores_mob,allInferenceTimes,allInferenceTimes_mob):
    snr_ranges = [(-30,-20),(-20,-10),(-10,0),(0,10),(10,20),(20,30)]
    # Plotting the combined histogram for mean values
    mean_values_pesq_after = []
    mean_values_estoi_after = []

    mean_values_pesq_after_mob = []
    mean_values_estoi_after_mob = []


    # Calculate mean for each SNR range
    for snr_range, data_array in zip(snr_ranges, allEstoiScores):
        mean_values_estoi_after.append(np.mean(data_array))

    for snr_range, data_array in zip(snr_ranges, allPesqScores):
        mean_values_pesq_after.append(np.mean(data_array))

    for snr_range, data_array in zip(snr_ranges, allEstoiScores_mob):
        mean_values_estoi_after_mob.append(np.mean(data_array))
    
    for snr_range, data_array in zip(snr_ranges, allPesqScores_mob):
        mean_values_pesq_after_mob.append(np.mean(data_array))


    labels = [f'({snr_range[0]},{snr_range[1]})' for snr_range in snr_ranges]
    
    bar_width = 0.35  # Adjust the width of the bars
    bar_positions_after = np.arange(len(snr_ranges))
    bar_positions_after_mob = bar_positions_after + bar_width  # Adjust positions for the second set of bars

    # Bar plot for mean values after enhancement (blue)
    plt.bar(bar_positions_after, mean_values_pesq_after, width=bar_width, label='PC', color='blue', alpha=0.7)

    # Bar plot for mean values before enhancement (green)
    plt.bar(bar_positions_after_mob, mean_values_pesq_after_mob, width=bar_width, label='Mobile', color='orange', alpha=0.7)

    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('PESQ: PC vs Mobile for Denoised Output')
    plt.xticks(bar_positions_after + bar_width / 2, labels)  # Adjust the x-axis tick positions
    plt.legend()

    # Adding text annotations for mean values
    for i, (mean_val_after, mean_val_after_mob) in enumerate(zip(mean_values_pesq_after, mean_values_pesq_after_mob)):
        print(f'{mean_val_after}')
        print(f'{mean_val_after_mob}')
        plt.text(bar_positions_after[i], mean_val_after + 0.05, f'{mean_val_after:.2f}', ha='center', va='bottom', color='blue')
        plt.text(bar_positions_after_mob[i], mean_val_after_mob + 0.05, f'{mean_val_after_mob:.2f}', ha='center', va='bottom', color='orange')

    plt.ylim(0, 3.5)
    # Save the combined plot
    plt.savefig(f'ModelComparison_PESQ.png')

    # Show the combined plot
    plt.show()

    
    bar_width = 0.35  # Adjust the width of the bars
    bar_positions_after = np.arange(len(snr_ranges))
    bar_positions_after_mob = bar_positions_after + bar_width  # Adjust positions for the second set of bars

    # Bar plot for mean values after enhancement (blue)
    plt.bar(bar_positions_after, mean_values_estoi_after, width=bar_width, label='PC', color='blue', alpha=0.7)

    # Bar plot for mean values before enhancement (green)
    plt.bar(bar_positions_after_mob, mean_values_estoi_after_mob, width=bar_width, label='Mobile', color='orange', alpha=0.7)

    plt.xlabel('SNR Ranges')
    plt.ylabel('Mean Value')
    plt.title('ESTOI: PC vs Mobile for Denoised Output')
    plt.xticks(bar_positions_after + bar_width / 2, labels)  # Adjust the x-axis tick positions
    plt.legend()

    # Adding text annotations for mean values
    for i, (mean_val_after, mean_val_after_mob) in enumerate(zip(mean_values_estoi_after, mean_values_estoi_after_mob)):
        plt.text(bar_positions_after[i], mean_val_after + 0.05, f'{mean_val_after:.2f}', ha='center', va='bottom', color='blue')
        plt.text(bar_positions_after_mob[i], mean_val_after_mob + 0.05, f'{mean_val_after_mob:.2f}', ha='center', va='bottom', color='orange')

    plt.ylim(0, 1.5)
    # Save the combined plot
    plt.savefig(f'ModelComparison_ESTOI.png')

    # Show the combined plot
    plt.show()

    # Flatten the arrays
    flatTimesPC = np.array(allInferenceTimes).flatten()
    flatTimes_mob = np.array(allInferenceTimes_mob).flatten()

    # Calculate mean and variance
    mean_pc = np.mean(flatTimesPC)
    mean_mob = np.mean(flatTimes_mob)
    

    # Plotting
    labels = ['PC', 'Mobile']
    means = [mean_pc, mean_mob]

    plt.bar(labels, means, color=['blue', 'orange'])
    plt.ylabel('Mean Inference Time')
    plt.title('Comparison of Inference Times')

    # Adding text annotations on top of the bars
    for i, mean_value in enumerate(means):
        plt.text(i, mean_value + 0.1, f'{mean_value:.2f}', ha='center', va='bottom')

    plt.ylim(0,5)
    plt.savefig(f'ModelComparison_InferenceTimes.png')
    plt.show()

In [None]:
plotResults(all_pesqScores,all_estoiScores,all_pesqScores_before,all_estoiScores_before,"PC2")
plotResults(all_pesqScores_mob,all_estoiScores_mob,all_pesqScores_before,all_estoiScores_before,"")


In [None]:
comparisonInferenceTimes(all_pesqScores,all_estoiScores,all_pesqScores_mob,all_estoiScores_mob,all_inferenceTimes,all_inferenceTimes_mob)