In [None]:
from audio_diffusion_pytorch.audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
import utils.load_datasets
import utils.training
import utils.logging
from networks import *
import networks.transforms as net_transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
from torchvision import transforms
import torch
from torch.optim import Adam
import matplotlib.pyplot as plt
from pathlib import Path
import os
# Define hyperparameters
train_modulations = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK', '8PSK', '16PSK', '32PSK', '16APSK', '32APSK',
                        '64APSK', '128APSK', '16QAM', '32QAM', '64QAM', '128QAM', '256QAM', 'AM-SSB-WC',
                        'AM-SSB-SC', 'AM-DSB-WC', 'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']
train_SNRs = np.arange(-20, 32, 2)
#test_modulations = ['BPSK', 'QPSK', '8PSK', '16QAM', '64QAM', '256QAM']
test_modulations = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK', '8PSK', '16PSK', '32PSK', '16APSK', '32APSK',
                        '64APSK', '128APSK', '16QAM', '32QAM', '64QAM', '128QAM', '256QAM', 'AM-SSB-WC',
                      'AM-SSB-SC', 'AM-DSB-WC', 'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']
test_SNRs = np.arange(-20, 30, 2)
dataset_train_name = '2018.01A'
dataset_test_name = '2018.01A'
dataDir = '/home/trey/experiment_rfdiffusion/models/saved_models/2018.01a'
batch_size = 4
learning_rate = 1e-4
adam_betas = (0.9, 0.999)
model_save_dir = '/home/trey/experiment_rfdiffusion/models/saved_models/2018.01a'
# Create directories if not exist
utils.training.create_directory(dataDir)

# Define data split ratios
split = [0.75, 0.05, 0.20]

# Define data transformations
train_transforms = transforms.Compose([net_transforms.PowerNormalization()])
test_transforms = train_transforms

# Load datasets
train_dataset= utils.load_datasets.getDataset(
    dataset_train_name, dataset_test_name, train_modulations, train_SNRs, test_modulations, test_SNRs, split, dataDir, train_transforms, test_transforms
)

# Create data loader
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = DiffusionModel(
    net_t=UNetV0,  # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2,  # U-Net: number of input/output (audio) channels
    channels=[64, 128, 256, 512],  # U-Net: channels at each layer
    factors=[2, 2, 2, 2],  # U-Net: downsampling and upsampling factors at each layer
    items=[2, 2, 2, 2],  # U-Net: number of repeating items at each layer
    attentions=[1, 1, 1, 1],  # U-Net: attention enabled/disabled at each layer
    attention_heads=4,  # U-Net: number of attention heads per attention item
    attention_features=32,  # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion,  # The diffusion method used
    sampler_t=VSampler,  # The diffusion sampler used
    use_text_conditioning=True,  # U-Net: enables text conditioning (default T5-base)
    use_embedding_cfg=True,  # U-Net: enables classifier free guidance
    embedding_max_length=64,  # U-Net: text embedding maximum length (default for T5-base)
    embedding_features=768,  # U-Net: text embedding features (default for T5-base)
).to(device)

train_modulations = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK', '8PSK', '16PSK', '32PSK', '16APSK', '32APSK',
                        '64APSK', '128APSK', '16QAM', '32QAM', '64QAM', '128QAM', '256QAM', 'AM-SSB-WC',
                        'AM-SSB-SC', 'AM-DSB-WC', 'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']
train_SNRs = np.arange(-20, 32, 2)
num_samples_per_combination = 50  # Number of samples to generate per combination
# Define the path to the checkpoint file
checkpoint_path = '/home/trey/experiment_rfdiffusion/models/saved_models/2018.01a/model_epoch_82.pth'

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Load the state dictionary into the model
model.load_state_dict(checkpoint['model_state_dict'])
dataDir = '/home/trey/experiment_rfdiffusion/models/test/'
utils.training.create_directory(dataDir)
results_folder = Path(dataDir) / 'results'
# If you also saved the optimizer state
optimizer = Adam(model.parameters(), lr=learning_rate, betas=adam_betas)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Optionally, load the epoch and loss
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print('sampling')
print(f"Checkpoint loaded. Model trained for {epoch} epochs. Last recorded loss: {loss}")
# Turn noise into new audio sample with diffusion
#train_modulations = ['BPSK']
#train_SNRs = np.arange(16,18,2)
for modulation in train_modulations:
        for snr in train_SNRs: 
            prompt = f"{modulation} modulated waveform at {snr} dB SNR"
            generated_samples = []
            print(prompt)
            for x in range(1, num_samples_per_combination):
                noise = torch.randn(1, 2, 128).to(device)
                sample = model.sample(
                        noise,
                        text=prompt,
                        embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
                        num_steps=10 # Higher for better quality, suggested num_steps: 10-100
                        )
                sample = sample.cpu().numpy()
                generated_samples.append(sample)
            generated_samples = np.concatenate(generated_samples, axis=0)
            
            # Generate and save the constellation plot
            plt.figure()
            plt.scatter(generated_samples[:, 0, :].flatten(), generated_samples[:, 1, :].flatten(), marker='.')
            plt.title(f'Constellation Plot - {modulation}, SNR {snr} dB')
            plt.xlabel('I')
            plt.ylabel('Q')
            plt.grid(True)
            filename = results_folder / f'constellation_{modulation}_{snr}dB.png'
            plt.savefig(filename)
            class_index = train_modulations.index(modulation)
            snr_index = np.where(train_SNRs == snr)[0][0]
            # Save the generated samples and labels
            np.savez_compressed(os.path.join(results_folder, f'samples_{modulation}_{snr}dB.npz'), 
                                samples=generated_samples, 
                                classes=np.array([modulation] * num_samples_per_combination), 
                                snrs=np.array([snr] * num_samples_per_combination))
            print('saved')

Directory '/home/trey/experiment_rfdiffusion/models/saved_models/2018.01a' already exists.
Directory '/home/trey/experiment_rfdiffusion/models/test/' already exists.
sampling
Checkpoint loaded. Model trained for 82 epochs. Last recorded loss: 0.3500005602836609
OOK modulated waveform at -20 dB SNR
saved
OOK modulated waveform at -18 dB SNR
saved
OOK modulated waveform at -16 dB SNR
saved
OOK modulated waveform at -14 dB SNR
saved
OOK modulated waveform at -12 dB SNR
saved
OOK modulated waveform at -10 dB SNR
saved
OOK modulated waveform at -8 dB SNR
saved
OOK modulated waveform at -6 dB SNR
saved
OOK modulated waveform at -4 dB SNR
saved
OOK modulated waveform at -2 dB SNR
saved
OOK modulated waveform at 0 dB SNR
saved
OOK modulated waveform at 2 dB SNR
saved
OOK modulated waveform at 4 dB SNR
saved
OOK modulated waveform at 6 dB SNR
saved
OOK modulated waveform at 8 dB SNR
saved
OOK modulated waveform at 10 dB SNR
saved
OOK modulated waveform at 12 dB SNR
saved
OOK modulated waveform 

  plt.figure()


saved
OOK modulated waveform at 22 dB SNR
saved
OOK modulated waveform at 24 dB SNR
saved
OOK modulated waveform at 26 dB SNR
saved
OOK modulated waveform at 28 dB SNR
saved
OOK modulated waveform at 30 dB SNR
saved
4ASK modulated waveform at -20 dB SNR
saved
4ASK modulated waveform at -18 dB SNR
saved
4ASK modulated waveform at -16 dB SNR
saved
4ASK modulated waveform at -14 dB SNR
saved
4ASK modulated waveform at -12 dB SNR
saved
4ASK modulated waveform at -10 dB SNR
saved
4ASK modulated waveform at -8 dB SNR
saved
4ASK modulated waveform at -6 dB SNR
saved
4ASK modulated waveform at -4 dB SNR
saved
4ASK modulated waveform at -2 dB SNR
saved
4ASK modulated waveform at 0 dB SNR
saved
4ASK modulated waveform at 2 dB SNR
saved
4ASK modulated waveform at 4 dB SNR
saved
4ASK modulated waveform at 6 dB SNR
saved
4ASK modulated waveform at 8 dB SNR
saved
4ASK modulated waveform at 10 dB SNR
saved
4ASK modulated waveform at 12 dB SNR
saved
4ASK modulated waveform at 14 dB SNR
saved
4ASK modu

In [60]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from pathlib import Path
from scipy.stats import wasserstein_distance

# Define the list of modulation names
train_modulations = ['AM-SSB', 'CPFSK', 'QPSK', 'GFSK', 'PAM4', 'QAM16', 'WBFM', '8PSK', 'QAM64', 'AM-DSB', 'BPSK']

# Define SNRs
train_SNRs = np.arange(-20, 32, 2)

# Define the results folder path
results_folder = Path('/home/trey/experiment_rfdiffusion/models/test/results')

def compute_wasserstein_distance(samples_generated, samples_real):
    if samples_generated.size == 0 or samples_real.size == 0:
        print("One of the sample arrays is empty. Cannot compute Wasserstein distance.")
        return np.nan

    # Flatten samples for comparison
    samples_generated_flat = samples_generated.reshape(-1, samples_generated.shape[-1])
    samples_real_flat = samples_real.reshape(-1, samples_real.shape[-1])
    print("samples_generated_flat shape:", samples_generated_flat.shape)
    print("samples_real_flat shape:", samples_real_flat.shape)

    # Ensure samples are 1D
    if samples_generated_flat.shape[0] < 2 or samples_real_flat.shape[0] < 2:
        print("Not enough data points to compute Wasserstein distance.")
        return np.nan

    distances = []
    for i in range(samples_generated_flat.shape[-1]):
        try:
            d = wasserstein_distance(samples_generated_flat[:, i], samples_real_flat[:, i])
            distances.append(d)
        except IndexError as e:
            print(f"IndexError: {e}. Likely due to mismatched dimensions in Wasserstein distance calculation.")
            distances.append(float('nan'))
    return np.mean(distances)

def compute_histogram_similarity(samples_generated, samples_real, bins=50):
    if samples_generated.size == 0 or samples_real.size == 0:
        print("One of the sample arrays is empty. Cannot compute histogram similarity.")
        return np.nan

    # Flatten samples for histogram comparison
    samples_generated_flat = samples_generated.reshape(-1, samples_generated.shape[-1])
    samples_real_flat = samples_real.reshape(-1, samples_real.shape[-1])

    similarities = []
    for i in range(samples_generated_flat.shape[-1]):
        hist_gen, _ = np.histogram(samples_generated_flat[:, i], bins=bins, density=True)
        hist_real, _ = np.histogram(samples_real_flat[:, i], bins=bins, density=True)
        
        # Normalize histograms
        hist_gen = hist_gen / np.sum(hist_gen)
        hist_real = hist_real / np.sum(hist_real)
        
        # Compute Bhattacharyya coefficient
        similarity = np.sum(np.sqrt(hist_gen * hist_real))
        similarities.append(similarity)
    return np.mean(similarities)

# Load real samples
real_samples_dict = {}
real_data_loader = DataLoader(train_dataset, batch_size=5, shuffle=False)
train_modulations = ['AM-SSB', 'CPFSK', 'QPSK', 'GFSK', 'PAM4', 'QAM16', 'WBFM', '8PSK', 'QAM64', 'AM-DSB', 'BPSK']
train_SNRs = np.arange(-20,19,2)
for batch in real_data_loader:
    real_x, real_classes_batch, real_snrs_batch = batch

    # Convert tensors to NumPy arrays
    real_classes_batch = real_classes_batch.cpu().numpy()
    real_snrs_batch = real_snrs_batch.cpu().numpy()
    real_x = real_x.cpu().numpy()

    # Process each sample in the batch
    for cls_idx, snr_idx, sample in zip(real_classes_batch, real_snrs_batch, real_x):
        cls_idx = int(cls_idx)  # Ensure index is an integer
        snr_idx = int(snr_idx)  # Ensure index is an integer

        # Check if indices are within bounds
        if cls_idx < 0 or cls_idx >= len(train_modulations):
            print(f"Warning: Modulation index {cls_idx} is out of range.")
            continue

        if snr_idx < 0 or snr_idx >= len(train_SNRs):
            print(f"Warning: SNR index {snr_idx} is out of range.")
            continue

        # Map indices to modulation name and SNR value
        modulation = train_modulations[cls_idx]
        snr = train_SNRs[snr_idx]

        # Create a key based on modulation and SNR
        key = (modulation, snr)

        if key not in real_samples_dict:
            real_samples_dict[key] = []
        real_samples_dict[key].append(sample)

        # Debugging print statements
        #print(f"Appending sample for key: {key}. Total samples for this key: {len(real_samples_dict[key])}")

for key in real_samples_dict:
    real_samples_dict[key] = np.concatenate(real_samples_dict[key], axis=0)
    print(f"Final size for key {key}: {real_samples_dict[key].shape}")

#print(f"Loaded real samples for {len(real_samples_dict)} unique (modulation, SNR) combinations.")

# Iterate over each modulation and SNR


for modulation in train_modulations:
    for snr in train_SNRs:
        try:
            # Load generated samples
            data = np.load(os.path.join(results_folder, f'samples_{modulation}_{snr}dB.npz'))
            generated_samples = data['samples']
            generated_classes = data['classes']
            generated_snrs = data['snrs']
            # Filter generated samples by class and SNR
            filtered_generated_samples = []
            for cls, snr_gen, sample in zip(generated_classes, generated_snrs, generated_samples):
                if cls == modulation and snr_gen == snr:
                    filtered_generated_samples.append(sample)
            filtered_generated_samples = np.array(filtered_generated_samples)
            if (modulation, snr) in real_samples_dict:
                real_samples = real_samples_dict[(modulation, snr)]
                
                # Compute Wasserstein Distance
                wasserstein_dist = compute_wasserstein_distance(filtered_generated_samples, real_samples)
                print(f'Wasserstein Distance for {modulation} at SNR {snr} dB: {wasserstein_dist}')
                
                # Compute Histogram Similarity
                histogram_sim = compute_histogram_similarity(filtered_generated_samples, real_samples)
                print(f'Histogram Similarity for {modulation} at SNR {snr} dB: {histogram_sim}')
            else:
                print(f"No real samples available for {modulation} at SNR {snr} dB.")
        
        except FileNotFoundError:
            print(f"File not found for {modulation} at SNR {snr} dB. Skipping...")
        except AssertionError:
            print(f"Classes or SNRs mismatch for {modulation} at SNR {snr} dB. Skipping...")


Final size for key ('GFSK', -20): (2000, 128)
Final size for key ('GFSK', -18): (2000, 128)
Final size for key ('GFSK', -16): (2000, 128)
Final size for key ('GFSK', -14): (2000, 128)
Final size for key ('GFSK', -12): (2000, 128)
Final size for key ('GFSK', -10): (2000, 128)
Final size for key ('GFSK', -8): (2000, 128)
Final size for key ('GFSK', -6): (2000, 128)
Final size for key ('GFSK', -4): (2000, 128)
Final size for key ('GFSK', -2): (2000, 128)
Final size for key ('GFSK', 0): (2000, 128)
Final size for key ('GFSK', 2): (2000, 128)
Final size for key ('GFSK', 4): (2000, 128)
Final size for key ('GFSK', 6): (2000, 128)
Final size for key ('GFSK', 8): (2000, 128)
Final size for key ('GFSK', 10): (2000, 128)
Final size for key ('GFSK', 12): (2000, 128)
Final size for key ('GFSK', 14): (2000, 128)
Final size for key ('GFSK', 16): (2000, 128)
Final size for key ('GFSK', 18): (2000, 128)
Final size for key ('QAM16', -20): (2000, 128)
Final size for key ('QAM16', -18): (2000, 128)
Final