In [37]:
import torch
import torch.nn.functional as F

def add_noise_with_snr(encoder_output, noise_type='gaussian', target_snr_db=3, dropout_rate=0.4, sp_thresh=0.4):
    """
    Add noise to the encoder output based on a target SNR in dB.
    
    Parameters:
    - encoder_output: torch.Tensor, the encoder's output (last_hidden_state).
    - noise_type: string, determines what kind of noise is added.
    - target_snr_db: float, the desired signal-to-noise ratio in dB for awgn and dropout.
    - dropout_rate: float, range: [0,1], default rate for dropout noise (not used here).
    - sp_thresh: float, range: [0,1], determines the threshold for salt-and-pepper noise.
    
    Returns:
    - noisy_encoder_output: torch.Tensor, encoder output with added noise.
    """

    if noise_type.lower() == 'gaussian':
        # Generate Gaussian noise
        noise = torch.randn_like(encoder_output) * torch.sqrt(noise_power)
        return encoder_output + noise

    elif noise_type.lower() == 'dropout':
        # Compute dropout probability p based on SNR
        signal_power = torch.mean(encoder_output ** 2)
        target_snr_linear = 10 ** (target_snr_db / 10)
        noise_power = signal_power / target_snr_linear
        p = 1 / target_snr_linear

        # Create a mask with elements set to zero with probability p
        random_tensor = torch.rand_like(encoder_output)
        mask = random_tensor >= p  # Retain elements with probability (1 - p)

        # Apply the mask to the encoder output without scaling
        noisy_encoder_output = encoder_output * mask.float()
        return noisy_encoder_output

    elif noise_type.lower() == 'saltpepper':
        mask = torch.rand_like(encoder_output) < sp_thresh  # The greater the sp_thresh, more noise is added
        salt = torch.max(encoder_output)
        pepper = torch.min(encoder_output)
        noise = torch.where(torch.rand_like(encoder_output) < 0.5, salt, pepper)
        noised_enc_output = torch.where(mask, noise, encoder_output)
        return noised_enc_output

    else:
        raise ValueError("Unsupported Noise Type. Choose between 'gaussian', 'dropout', 'saltpepper'.")


In [39]:
import torch

# Define a large tensor
encoder_output = torch.randn(1000, 1000)

# Set the target SNR in dB
target_snr_db = 10

# Add noise using the modified function
noisy_encoder_output = add_noise_with_snr(encoder_output, noise_type='dropout', target_snr_db=target_snr_db)

# Compute the actual SNR
signal_power = torch.mean(encoder_output ** 2)
noise = encoder_output - noisy_encoder_output
noise_power = torch.mean(noise ** 2)
actual_snr = signal_power / noise_power
actual_snr_db = 10 * torch.log10(actual_snr)

# Step 4: Print the results
print(f"Target SNR (dB): {target_snr_db}")
print(f"Actual SNR (dB): {actual_snr_db.item():.2f}")
print(f"Difference (dB): {abs(target_snr_db - actual_snr_db.item()):.2f}")

Target SNR (dB): 10
Actual SNR (dB): 10.00
Difference (dB): 0.00
