In [2]:
import os
import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from pathlib import Path
import matplotlib.pyplot as plt
from torchvision import transforms
from audio_diffusion_pytorch.audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion_Inpainted, VSampler, VInpainter, RePaintInpainter
import utils.load_datasets
import utils.training
import networks.transforms as net_transforms
import torch.nn as nn

# Hyperparameters
train_modulations = ['AM-SSB', 'CPFSK', 'QPSK', 'GFSK', 'PAM4', 'QAM16', 'WBFM', '8PSK', 'QAM64', 'AM-DSB', 'BPSK']
train_SNRs = np.arange(-20, 19, 2)
test_modulations = ['OOK', '4ASK', '8ASK', 'BPSK', 'QPSK', '8PSK', '16PSK', '32PSK', '16APSK', '32APSK', '64APSK', 
                    '128APSK', '16QAM', '32QAM', '64QAM', '128QAM', '256QAM', 'AM-SSB-WC', 'AM-SSB-SC', 'AM-DSB-WC', 
                    'AM-DSB-SC', 'FM', 'GMSK', 'OQPSK']
test_SNRs = np.arange(-20, 30, 2)
dataset_train_name = '2016.10A'
dataset_test_name = '2016.10A'
dataDir = '/home/trey/experiment_rfdiffusion/models/saved_models/impainting'
batch_size = 4
learning_rate = 1e-4
adam_betas = (0.9, 0.999)
model_save_dir = '/home/trey/experiment_rfdiffusion/models/saved_models/impainting'

# Create directories if they do not exist
utils.training.create_directory(dataDir)

# Define data split ratios
split = [0.75, 0.05, 0.20]

# Define data transformations
train_transforms = transforms.Compose([net_transforms.PowerNormalization()])
test_transforms = train_transforms

# Load datasets
train_dataset = utils.load_datasets.getDataset(
    dataset_train_name, dataset_test_name, train_modulations, train_SNRs, test_modulations, test_SNRs, split, dataDir, train_transforms, test_transforms
)

# Create data loader
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = DiffusionModel(
    net_t=UNetV0,  # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2,  # U-Net: number of input/output (audio) channels
    channels=[64, 128, 256, 512],  # U-Net: channels at each layer
    factors=[2, 2, 2, 2],  # U-Net: downsampling and upsampling factors at each layer
    items=[2, 2, 2, 2],  # U-Net: number of repeating items at each layer
    attentions=[1, 1, 1, 1],  # U-Net: attention enabled/disabled at each layer
    attention_heads=4,  # U-Net: number of attention heads per attention item
    attention_features=32,  # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion_Inpainted,  # The diffusion method used
    use_text_conditioning=False,  # U-Net: enables text conditioning (default T5-base)
    use_embedding_cfg=False,  # U-Net: enables classifier free guidance
)

# Define the path to the checkpoint file
checkpoint_path = os.path.join(model_save_dir, 'model_epoch_49.pth')

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Load the state dictionary into the model
model.load_state_dict(checkpoint['model_state_dict'])

# Initialize the optimizer
optimizer = Adam(model.parameters(), lr=learning_rate, betas=adam_betas)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Optionally, load the epoch and loss
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print('Checkpoint loaded. Model trained for {} epochs. Last recorded loss: {:.4f}'.format(epoch, loss))
net = UNetV0(
    dim = 1,
    in_channels=2,  # U-Net: number of input/output (audio) channels
    channels=[64, 128, 256, 512],  # U-Net: channels at each layer
    factors=[2, 2, 2, 2],  # U-Net: downsampling and upsampling factors at each layer
    items=[2, 2, 2, 2],  # U-Net: number of repeating items at each layer
    attentions=[1, 1, 1, 1],  # U-Net: attention enabled/disabled at each layer
    attention_heads=4,  # U-Net: number of attention heads per attention item
    attention_features=32,  # U-Net: number of attention features per attention item    
).to(device)
# Initialize inpainter with the trained model
inpainter = RePaintInpainter(net=net)

# Create results directory
results_folder = Path(dataDir) / 'results'
results_folder.mkdir(parents=True, exist_ok=True)

# Function to plot and save waveforms
def plot_waveforms(original, masked, generated, modulation, snr, index):
    plt.figure(figsize=(12, 6))
    
    # Plot original waveform
    plt.subplot(3, 1, 1)
    plt.plot(original[0, 0, :], label='I')
    plt.plot(original[0, 1, :], label='Q')
    plt.title(f'Original Waveform - {modulation}, SNR {snr} dB')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.grid(True)

    # Plot masked waveform
    plt.subplot(3, 1, 2)
    plt.plot(masked[0, 0, :], label='I')
    plt.plot(masked[0, 1, :], label='Q')
    plt.title(f'Masked Waveform - {modulation}, SNR {snr} dB')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.grid(True)

    # Plot generated waveform
    plt.subplot(3, 1, 3)
    plt.plot(generated[0, 0, :], label='I')
    plt.plot(generated[0, 1, :], label='Q')
    plt.title(f'Generated Waveform - {modulation}, SNR {snr} dB')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(results_folder / f'waveforms_{modulation}_{snr}dB_{index}.png')
    plt.close()

def generate_random_mask(signal_length, hole_ratio, min_hole_size=16, max_hole_size=32):
    """
    Generate a random binary mask with specified hole ratio for the IQ signal.
    
    Args:
        signal_length (int): Length of the signal.
        hole_ratio (float): Ratio of the signal that should be masked (between 0 and 1).
        min_hole_size (int): Minimum size of the holes.
        max_hole_size (int): Maximum size of the holes.
    
    Returns:
        torch.Tensor: Boolean mask with True for present and False for missing parts.
    """
    num_samples_to_mask = int(signal_length * hole_ratio)
    
    # Initialize mask with all True
    mask = torch.ones(signal_length, dtype=torch.bool)
    
    # Determine the maximum size of a hole
    max_hole_size = min(max_hole_size, signal_length)  # Ensure it fits within signal length
    
    # Randomly place holes
    while num_samples_to_mask > 0:
        # Randomly choose hole size
        hole_size = np.random.randint(min_hole_size, max_hole_size + 1)
        
        # Ensure hole size does not exceed remaining number of samples to mask
        hole_size = min(hole_size, num_samples_to_mask)
        
        # Randomly choose a position for the hole
        start_index = np.random.randint(0, signal_length - hole_size + 1)
        end_index = start_index + hole_size
        
        # Apply the hole to the mask
        mask[start_index:end_index] = False
        
        num_samples_to_mask -= hole_size
    
    return mask
    
# Imprinting process
for i, data in enumerate(data_loader):
    waveforms, labels, snrs = data
    waveforms = waveforms.to(device)
    
    for j in range(waveforms.size(0)):
        original_waveform = waveforms[j].unsqueeze(0)
        modulation = train_modulations[labels[j].item()]
        snr = train_SNRs[snrs[j].item()]

        # Create a random mask using the new function
        mask = generate_random_mask(original_waveform.shape[-1], hole_ratio=0.2, min_hole_size=16, max_hole_size=32)
        mask = mask.to(device)
        masked_waveform = original_waveform * mask
        # Inpaint the masked waveform
        generated_waveform = inpainter(
            source=original_waveform,
            mask=mask,
            num_steps=10,  # Number of inpainting steps
            num_resamples=20,  # Number of resampling steps
            show_progress=True,
        )

        # Plot and save the waveforms
        plot_waveforms(original_waveform.cpu().numpy(), masked_waveform.cpu().numpy(), generated_waveform.cpu().detach().numpy(), modulation, snr, j)
        print(f'Waveform {j} processed and saved.')

    if i == 1:  # Process only the first batch for demonstration purposes
        break


Directory '/home/trey/experiment_rfdiffusion/models/saved_models/impainting' already exists.
Checkpoint loaded. Model trained for 49 epochs. Last recorded loss: 0.0387


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.78it/s]


Waveform 0 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.70it/s]


Waveform 1 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.88it/s]


Waveform 2 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.89it/s]


Waveform 3 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.74it/s]


Waveform 0 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.80it/s]


Waveform 1 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.68it/s]


Waveform 2 processed and saved.


Inpainting (noise=0.00): 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.81it/s]


Waveform 3 processed and saved.
