This file is modified from the process_audio.py in the original repository from Carson et al. (2024).

In [1]:
import os, torch
import numpy as np
import soundfile as sf
from collections import defaultdict
import sys
from tqdm import tqdm
import logging
import csv

import rnn 
from giant_fft_resample import giant_fft_resample

In [2]:
######### Debug Log ##########
logging.basicConfig(filename='debug.log', level=logging.INFO, format='[%(asctime)s] %(message)s')

######### Input paths ########
INPUT_ROOT = '../inputs'
ORIGINAL_AUDIO_DIR = os.path.join(INPUT_ROOT, 'exampleAudio')
SUBCHANNEL_AUDIO_DIR = os.path.join(INPUT_ROOT, 'Subchannel_Audio_Example2')
MODEL_DIR = os.path.join(INPUT_ROOT, 'BlackstarHT40_AmpHighGain.json')

######## Output paths #########
OUTPUT_ROOT = '../outputs'
OUTPUT_WO_CMFB_DIR = os.path.join(OUTPUT_ROOT, 'without_cmfb')
OUTPUT_W_CMFB_DIR = os.path.join(OUTPUT_ROOT, 'with_cmfb')

#### Other variables ####
ORIGINAL_SR = 44100
METHODS = ['original', 'stn', 'lidl', 'apdl', 'cidl']
SR = ['44.1', '48', '88.2', '96']
SR_MAP = {
    '44.1': 44100,
    '48': 48000,
    '88.2': 88200,
    '96': 96000
}
BASEMODEL = rnn.get_AudioRNN_from_json(MODEL_DIR)
MODELS = defaultdict(dict)
PROCESSED_INPUTS = {}

In [31]:
def load_model():
    global BASEMODEL, MODELS, ORIGINAL_SR
    for method in METHODS:
        for sr in SR:
            if method == 'original':
                model = BASEMODEL
                logging.info(f'SR={sr},\nModel:\n{model}')
            else:
                model = rnn.get_SRIndieRNN(BASEMODEL, method)
                model.rec.os_factor = SR_MAP[sr] / ORIGINAL_SR
                logging.info(f'SR={sr}, Model_os_factor={model.rec.os_factor},\nModel:\n{model}')
            
            MODELS[method][sr] = model

def load_input_audio(subchannel):
    # expected input shape:
    # - without CMFB: tensor shape = [1, 264600]
    # - with CMFB: tensor shape = [1, 8, 264600] 

    def read_file(filename):
        return sf.read(filename)
    
    def to_tensor(array):
        return torch.from_numpy(array).float().unsqueeze(0) 

    if subchannel == 'subchannel':
        temp = []
        for i in range(8):
            filename = os.path.join(SUBCHANNEL_AUDIO_DIR, f'Channel_{i+1}.wav')
            input, _ = read_file(filename)
            temp.append(input)
        input_audio = np.vstack(temp)
    
    else:
        filename = os.path.join(ORIGINAL_AUDIO_DIR, 'Example2.wav')
        input_audio, _ = read_file(filename)
        
    input_audio = to_tensor(input_audio)
    return input_audio

def process_input(input_audio):
    global PROCESSED_INPUTS, ORIGINAL_SR
    for sr in SR:
        new_sr = SR_MAP[sr]
        PROCESSED_INPUTS[sr] = input_audio if new_sr == ORIGINAL_SR else giant_fft_resample(input_audio, ORIGINAL_SR, new_sr)

def run_model(model, input):
    model.reset_state()
    if input.ndim == 3:
        temp_matrix = []
        updates = tqdm(range(8), desc='Running through each channel audio')
        for i in updates:
            with torch.no_grad():
                temp_arr, _ = model(input[:, i, :]) 
                temp_matrix.append(temp_arr)
                model.reset_state()
            updates.set_postfix({'Channel': i+1})
        output = np.vstack(temp_matrix)
        output = torch.from_numpy(output).float()
    else:
        with torch.no_grad():
            output, _ = model(input)
    return output

def process_output(target_sr, output):
    global ORIGINAL_SR
    new_sr = SR_MAP[target_sr]
    if new_sr == ORIGINAL_SR:
        return output
    processed_output = giant_fft_resample(output, new_sr, ORIGINAL_SR)
    return processed_output

def write_outputs(cmfb, directory, output_audio):
    global ORIGINAL_SR
    # create directory
    os.makedirs(directory, exist_ok=True)

    # write output files
    if cmfb:
        for i in range(8):
            filename = os.path.join(directory, f'Channel_{i+1}_RNN.wav')
            data = output_audio[i, :].detach().cpu().numpy()
            sf.write(filename, data, ORIGINAL_SR)
    else:
        filename = os.path.join(directory, f'Example2_RNN.wav')
        data = output_audio[0].detach().cpu().numpy()
        sf.write(filename, data, ORIGINAL_SR)

def preview_tensor(x, cols=10):
    """Return a small preview of a tensor for logging."""
    return x[..., :cols]

In [None]:
# Run on full bands input
load_model()
MODELS
cmfb = bool(int('0'))
subchannel = 'subchannel' if cmfb else 'not_subchannel' # set input types 
input = load_input_audio(subchannel)
logging.info(f'Loaded input: Shape={input.shape}, Type: {input.dtype}\n{preview_tensor(input).detach().cpu()}')
input.shape
process_input(input)                      # process input
PROCESSED_INPUTS
for model_name, values in tqdm(MODELS.items(), desc='Starting RNN process'):
    for sr, model in values.items():
        # Run through RNN
        output = run_model(model, PROCESSED_INPUTS[sr]) # run model
        processed_output = process_output(sr, output)               # process output  
        
        # Write output files
        parent = OUTPUT_W_CMFB_DIR if cmfb else OUTPUT_WO_CMFB_DIR 
        out_dir = os.path.join(parent, model_name, 'sr'+sr)              # set directory name
        write_outputs(cmfb, out_dir, processed_output)              # write file

        # Write log file
        logging.info(f'Running method={model}, sr={sr}')
        logging.info(f'Processed input: Shape={PROCESSED_INPUTS[sr].shape}, Type: {PROCESSED_INPUTS[sr].dtype}\n{preview_tensor(PROCESSED_INPUTS[sr]).detach().cpu()}')
        logging.info(f'Preprocess Output: Shape={output.shape}, Type={output.dtype}\n{preview_tensor(output).detach().cpu()}')
        logging.info(f'Processed Output: Shape={processed_output.shape}, Type={processed_output.dtype}\n{preview_tensor(processed_output).detach().cpu()}')

In [None]:
# run on sub band input alpha 0.7
cmfb = bool(int('1'))
subchannel = 'subchannel' if cmfb else 'not_subchannel' # set input types 
input = load_input_audio(subchannel)
logging.info(f'Loaded input: Shape={input.shape}, Type: {input.dtype}\n{preview_tensor(input).detach().cpu()}')
print('input shape: ', input.shape)
process_input(input)                      # process input
print('processed inputs: ', PROCESSED_INPUTS)
for model_name, values in tqdm(MODELS.items(), desc='Starting RNN process'):
    for sr, model in values.items():
        # Run through RNN
        output = run_model(model, PROCESSED_INPUTS[sr]) # run model
        processed_output = process_output(sr, output)               # process output  
        print('In RNN Loop output: ',output.shape)
        print('In RNN Loop processed output: ',processed_output.shape)
        
        # Write output files
        parent = OUTPUT_W_CMFB_DIR if cmfb else OUTPUT_WO_CMFB_DIR 
        out_dir = os.path.join(parent, model_name, 'sr'+sr)              # set directory name
        write_outputs(cmfb, out_dir, processed_output)              # write file

        # Write log file
        logging.info(f'Running method={model}, sr={sr}')
        logging.info(f'Processed input: Shape={PROCESSED_INPUTS[sr].shape}, Type: {PROCESSED_INPUTS[sr].dtype}\n{preview_tensor(PROCESSED_INPUTS[sr]).detach().cpu()}')
        logging.info(f'Preprocess Output: Shape={output.shape}, Type={output.dtype}\n{preview_tensor(output).detach().cpu()}')
        logging.info(f'Processed Output: Shape={processed_output.shape}, Type={processed_output.dtype}\n{preview_tensor(processed_output).detach().cpu()}')

Loading each channel audio: 100%|██████████| 8/8 [00:00<00:00, 304.35it/s, Channel=8, Read=../inputs/Subchannel_Audio_Example2/Channel_8.wav]


input shape:  torch.Size([1, 8, 264600])
processed inputs:  {'44.1': tensor([[[-3.0518e-05, -3.0518e-05, -3.0518e-05,  ...,  1.0071e-02,
           1.2543e-02,  1.4984e-02],
         [-3.0518e-05, -3.0518e-05, -3.0518e-05,  ..., -9.0332e-03,
          -8.1177e-03, -7.1716e-03],
         [ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ...,  6.4087e-04,
           7.6294e-04,  8.2397e-04],
         ...,
         [ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -9.1553e-05,
          -1.2207e-04, -1.5259e-04],
         [-3.0518e-05, -3.0518e-05,  0.0000e+00,  ..., -3.0518e-05,
          -6.1035e-05, -3.0518e-05],
         [-3.0518e-05, -3.0518e-05, -3.0518e-05,  ..., -3.0518e-05,
          -3.0518e-05,  0.0000e+00]]]), '48': tensor([[[-3.0503e-05, -4.4055e-04,  4.7646e-04,  ...,  9.9068e-03,
           1.3814e-02,  1.4146e-02],
         [-3.0517e-05,  1.4541e-04, -2.4803e-04,  ..., -8.5048e-03,
          -8.3096e-03, -6.6973e-03],
         [ 4.3251e-11, -1.9783e-05, -8.4656e-08,  ...,  6.3001e-04,


Running through each channel audio: 100%|██████████| 8/8 [00:18<00:00,  2.31s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 264600])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 244.13it/s, Writing=../outputs/with_cmfb/original/sr44.1/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [00:20<00:00,  2.51s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 288000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 259.79it/s, Writing=../outputs/with_cmfb/original/sr48/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [00:37<00:00,  4.68s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 529200])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 278.28it/s, Writing=../outputs/with_cmfb/original/sr88.2/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [00:40<00:00,  5.12s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 576000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 307.54it/s, Writing=../outputs/with_cmfb/original/sr96/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [00:56<00:00,  7.02s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 264600])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 318.28it/s, Writing=../outputs/with_cmfb/stn/sr44.1/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:01<00:00,  7.73s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 288000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 260.98it/s, Writing=../outputs/with_cmfb/stn/sr48/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:56<00:00, 14.53s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 529200])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 297.07it/s, Writing=../outputs/with_cmfb/stn/sr88.2/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:07<00:00, 15.90s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 576000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 302.86it/s, Writing=../outputs/with_cmfb/stn/sr96/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:05<00:00,  8.14s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 264600])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 329.34it/s, Writing=../outputs/with_cmfb/lidl/sr44.1/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:10<00:00,  8.85s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 288000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 308.77it/s, Writing=../outputs/with_cmfb/lidl/sr48/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:13<00:00, 16.74s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 529200])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 369.89it/s, Writing=../outputs/with_cmfb/lidl/sr88.2/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:21<00:00, 17.67s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 576000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 311.73it/s, Writing=../outputs/with_cmfb/lidl/sr96/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:06<00:00,  8.36s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 264600])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 316.83it/s, Writing=../outputs/with_cmfb/apdl/sr44.1/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:09<00:00,  8.65s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 288000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 242.95it/s, Writing=../outputs/with_cmfb/apdl/sr48/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:08<00:00, 16.08s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 529200])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 266.61it/s, Writing=../outputs/with_cmfb/apdl/sr88.2/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:17<00:00, 17.21s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 576000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 325.09it/s, Writing=../outputs/with_cmfb/apdl/sr96/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:09<00:00,  8.70s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 264600])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 239.34it/s, Writing=../outputs/with_cmfb/cidl/sr44.1/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [01:15<00:00,  9.47s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 288000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 272.56it/s, Writing=../outputs/with_cmfb/cidl/sr48/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:18<00:00, 17.37s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 529200])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 325.56it/s, Writing=../outputs/with_cmfb/cidl/sr88.2/Channel_8_RNN.wav]
Running through each channel audio: 100%|██████████| 8/8 [02:29<00:00, 18.70s/it, Channel=8]


In RNN Loop output:  torch.Size([8, 576000])
In RNN Loop processed output:  torch.Size([8, 264600])


Writing output files: 100%|██████████| 8/8 [00:00<00:00, 365.15it/s, Writing=../outputs/with_cmfb/cidl/sr96/Channel_8_RNN.wav]
Starting RNN process: 100%|██████████| 5/5 [28:47<00:00, 345.46s/it]


In [32]:
# Run on sub bands input alpha 0.5
load_model()
cmfb = bool(int('1'))
subchannel = 'subchannel' if cmfb else 'not_subchannel' # set input types 
input = load_input_audio(subchannel)
logging.info(f'Loaded input: Shape={input.shape}, Type: {input.dtype}\n{preview_tensor(input).detach().cpu()}')
process_input(input)                      # process input
print(input.shape, PROCESSED_INPUTS)

for model_name, values in tqdm(MODELS.items(), desc='Starting RNN process'):
    for sr, model in values.items():
        # Run through RNN
        output = run_model(model, PROCESSED_INPUTS[sr]) # run model
        processed_output = process_output(sr, output)               # process output  
        
        # Write output files
        parent = OUTPUT_W_CMFB_DIR if cmfb else OUTPUT_WO_CMFB_DIR 
        out_dir = os.path.join(parent, model_name, 'sr'+sr)              # set directory name
        write_outputs(cmfb, out_dir, processed_output)              # write file

        # Write log file
        logging.info(f'Running method={model}, sr={sr}')
        logging.info(f'Processed input: Shape={PROCESSED_INPUTS[sr].shape}, Type: {PROCESSED_INPUTS[sr].dtype}\n{preview_tensor(PROCESSED_INPUTS[sr]).detach().cpu()}')
        logging.info(f'Preprocess Output: Shape={output.shape}, Type={output.dtype}\n{preview_tensor(output).detach().cpu()}')
        logging.info(f'Processed Output: Shape={processed_output.shape}, Type={processed_output.dtype}\n{preview_tensor(processed_output).detach().cpu()}')

torch.Size([1, 8, 264600]) {'44.1': tensor([[[-3.0518e-05, -3.0518e-05, -3.0518e-05,  ...,  2.1637e-02,
           2.1606e-02,  2.1576e-02],
         [ 0.0000e+00, -3.0518e-05, -3.0518e-05,  ...,  0.0000e+00,
           3.0518e-05,  3.0518e-05],
         [-3.0518e-05, -3.0518e-05, -3.0518e-05,  ..., -2.1362e-04,
          -1.8311e-04, -1.5259e-04],
         ...,
         [-3.0518e-05, -3.0518e-05, -3.0518e-05,  ..., -3.0518e-05,
          -3.0518e-05, -3.0518e-05],
         [-3.0518e-05,  0.0000e+00, -3.0518e-05,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-3.0518e-05,  0.0000e+00,  0.0000e+00,  ..., -6.1035e-05,
           3.0518e-05, -6.1035e-05]]]), '48': tensor([[[-3.0498e-05, -5.8923e-04,  6.6002e-04,  ...,  2.0620e-02,
           2.2737e-02,  2.0244e-02],
         [-2.1625e-11, -2.9145e-05, -3.1628e-05,  ...,  7.7295e-06,
           3.3230e-05,  2.8918e-05],
         [-3.0518e-05, -2.7042e-05, -3.6481e-05,  ..., -2.0486e-04,
          -1.8247e-04, -1.4437e-0

Running through each channel audio: 100%|██████████| 8/8 [00:20<00:00,  2.55s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [00:20<00:00,  2.62s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [00:41<00:00,  5.20s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [00:43<00:00,  5.41s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [00:58<00:00,  7.33s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [01:02<00:00,  7.83s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [01:55<00:00, 14.38s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [02:10<00:00, 16.33s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [01:08<00:00,  8.55s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [01:13<00:00,  9.22s/it, Channel=8]
Running through each channel audio: 100%|██████████| 8/8 [02:17<00:00,

In [49]:
def computeSNR(base_file, target_file):
    base, _ = sf.read(base_file)
    target, _ = sf.read(target_file)
    diff = target - base
    numer = np.square(base).sum()
    denom = np.square(diff).sum() 
    snr = numer/denom if denom != 0 else float('inf')
    snr_db = 10 * np.log10(snr)
    return snr_db

# function to write SNR file
# each row is the sample rate
# each column is the method - original, stn, lidl, apdl, cidl
def write_snr(base_file, target_dir, target_file_name, snr_file):
    global METHODS, SR
    snr_table = defaultdict(dict)
    for method in METHODS:
        for sr in SR:
            target_file = os.path.join(target_dir, method, 'sr' + sr, target_file_name)        
            snr = computeSNR(base_file, target_file) if target_file != base_file else None
            snr_table[sr][method] = snr

    with open(snr_file, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['SR'] + METHODS)
        for sr in SR:
            row = [sr]
            for method in METHODS:
                value = snr_table[sr][method]
                row.append('' if value is None else f'{value:.4f}')
            writer.writerow(row)

    print(f"SNR CSV successfully written to {snr_file}")

In [45]:
# 1. Full band - Compare full band files to full band original 44.1kHz
# 2. Alpha 0.7
    # a. Compare alpha 0.7 reconstructed files to full band original 44.1kHz
    # b. Compare alpha 0.7 reconstructed files to 0.7 reconstructed original 44.1kHz
# 3. Alpha 0.5
    # a. Compare alpha 0.5 reconstructed files to full band original 44.1kHz
    # b. Compare alpha 0.5 reconstructed files to 0.5 reconstructed original 44.1kHz

In [50]:
full_file_name = 'Example2_RNN.wav'
rcst_file_name = 'Example2_RNN_Reconstructed.wav'

# 1. Full band
full_dir = os.path.join('..', 'outputs_without_cmfb') # target dir is same as full dir
full_base_file = os.path.join(full_dir, 'original', 'sr44.1', full_file_name)
full_snr = os.path.join('..', 'outputs_csv', 'full_snr.csv')

# 2. Alpha 0.7
rcst_7_dir = os.path.join('..','outputs_alpha0.7', 'with_cmfb')
rcst_7_base_file = os.path.join(rcst_7_dir, 'original', 'sr44.1', rcst_file_name)
rcst_7_full_snr = os.path.join('..', 'outputs_csv', 'rcst_7_full_snr.csv')
rcst_7_rcst_snr = os.path.join('..', 'outputs_csv', 'rcst_7_rcst_snr.csv')

# 3. Alpha 0.5
rcst_5_dir = os.path.join('..','outputs_alpha0.5', 'with_cmfb')
rcst_5_base_file = os.path.join(rcst_5_dir, 'original', 'sr44.1', rcst_file_name)
rcst_5_full_snr = os.path.join('..', 'outputs_csv', 'rcst_5_full_snr.csv')
rcst_5_rcst_snr = os.path.join('..', 'outputs_csv', 'rcst_5_rcst_snr.csv')

csv_map = {
    'full' : (full_base_file, full_dir, full_snr),
    'rcst7_full': (full_base_file, rcst_7_dir, rcst_7_full_snr),
    'rcst7_rcst': (rcst_7_base_file, rcst_7_dir, rcst_7_rcst_snr),
    'rcst5_full': (full_base_file, rcst_5_dir, rcst_5_full_snr),
    'rcst5_rcst': (rcst_5_base_file, rcst_5_dir, rcst_5_rcst_snr)
}

os.makedirs('../outputs_csv', exist_ok=True)

for key, tup in csv_map.items():
    base_file = tup[0]
    target_dir = tup[1]
    snr_file = tup[2]
    target_file = full_file_name if key == 'full' else rcst_file_name
    write_snr(base_file, target_dir, target_file, snr_file)

SNR CSV successfully written to ../outputs_csv/full_snr.csv
SNR CSV successfully written to ../outputs_csv/rcst_7_full_snr.csv
SNR CSV successfully written to ../outputs_csv/rcst_7_rcst_snr.csv
SNR CSV successfully written to ../outputs_csv/rcst_5_full_snr.csv
SNR CSV successfully written to ../outputs_csv/rcst_5_rcst_snr.csv


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

os.makedirs(os.path.join('..', 'outputs_png'), exist_ok=True)
filename = ['full_snr', 'rcst_7_full_snr', 'rcst_7_rcst_snr', 'rcst_5_full_snr', 'rcst_5_rcst_snr']
columns = ['SR', 'Original', 'STN', 'LIDL', 'APDL', 'CIDL']

for file in filename:
    df = pd.read_csv(os.path.join('..','outputs_csv', f'{file}.csv'))
    df.fillna('-', inplace=True) # Replace Nan with '-'
    fig, ax = plt.subplots(figsize=(10, 2 + 0.4 * len(df)))
    ax.axis("off")  # Hide axes

    # Build table
    table = ax.table(
        cellText=df.values,
        colLabels=columns,
        cellLoc='center',
        loc='center'
    )

    table.scale(1, 1.4)  # Increase row height
    table.auto_set_font_size(False)
    table.set_fontsize(12)

    for key, cell in table.get_celld().items():
        row, col = key
        if row == 0:
            cell.set_text_props(weight='bold')
            cell.set_facecolor('#dddddd')
    plt.savefig(os.path.join('..', 'outputs_png', f'{file}.png'), dpi=200, bbox_inches='tight')
    plt.close() 

  df.fillna('-', inplace=True)
  df.fillna('-', inplace=True)
  df.fillna('-', inplace=True)
