# Research notebook 3: Pruning a conditional audio generative diffusion model
The diffusion model used in this notebook takes inspiration from an assignment for week 11 of the 2023 Deep Learning course (NWI-IMC070) of the Radboud University. Which used code adapted from: https://github.com/milesial/Pytorch-UNet for th U-Net.

In [1]:
import torchaudio
import torchvision
from d2l import torch as d2l
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import os
import joblib
from torch import optim
from copy import deepcopy

In [28]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
print(device)
print(str(torchaudio.list_audio_backends()))

cpu
['soundfile']


In [None]:
#Settings
diffusion_steps = 1000
beta = torch.linspace(1e-4, 0.02, diffusion_steps)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
beta = beta.to(device)
alpha_bar = alpha_bar.to(device)
batch_size = 1
samplerate = 16000
new_samplerate = 3000
n_fft=100 #400 was default
win_length = n_fft #Default: n_fft
hop_length = win_length // 2 #Default: win_length // 2
poison_rate = 0.1
num_epochs = 10
#Filenames
poison_filename = "thesis-diffusion-poison-model-pr0.5-ps0.1"
label_filename = "label_encoder.pkl"
#Datalocations
datalocation = "/vol/csedu-nobackup/project/mnederlands/data"
modellocation = "./saves/"
os.makedirs(modellocation, exist_ok=True)
os.makedirs(datalocation, exist_ok=True)

### Audio data

Load the data

In [4]:
#Initialization of label encoder
le = joblib.load(modellocation + label_filename)
num_classes = len(le.classes_)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [5]:
"""
resize_h = 51
resize_w = 61

rates = [0.1, 0.2, 0.3]
"""

'\nresize_h = 51\nresize_w = 61\n\nrates = [0.1, 0.2, 0.3]\n'

In [29]:
# Load the data
speech_commands_data = torchaudio.datasets.SPEECHCOMMANDS(root=datalocation, download=True)
train_size = int(0.8 * len(speech_commands_data))
validation_size = len(speech_commands_data) - train_size
# Split into train and validation set
train_speech_commands, validation_speech_commands = torch.utils.data.random_split(speech_commands_data, [train_size, validation_size])
# Function to pad waveforms to a specific length
def pad_waveform(waveform, target_length):
    current_length = waveform.shape[1]
    if current_length < target_length:
        padded_waveform = F.pad(waveform, (0, target_length - current_length), mode='constant', value=0)
        return padded_waveform
    else:
        return waveform

# Define a transform to convert waveform to spectrogram
transform = torchvision.transforms.Compose([
    torchaudio.transforms.Resample(orig_freq=samplerate, new_freq=new_samplerate),
    torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length),
])

In [30]:
labels = np.ravel([row[2:3] for row in train_speech_commands])

In [31]:
# Pad waveforms in train set and apply transform
train_speech_commands_padded = []
for waveform, sample_rate, label, _, _ in train_speech_commands:
    padded_waveform = pad_waveform(waveform, samplerate)
    spectrogram = transform(padded_waveform)
    train_speech_commands_padded.append([spectrogram, le.transform([label])[0]])
# Pad waveforms in validation set and apply transform
validation_speech_commands_padded = []
for waveform, sample_rate, label, _, _ in validation_speech_commands:
    padded_waveform = pad_waveform(waveform, samplerate)
    spectrogram = transform(padded_waveform)
    validation_speech_commands_padded.append([spectrogram, le.transform([label])[0]])
resize_h, resize_w = spectrogram[0].shape
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_speech_commands_padded, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_speech_commands_padded, batch_size=1000)

In [32]:
# Parameter settings from paper Denoising Diffusion Probabilistic Models
def generate_noisy_samples(x_0, beta):
    '''
    Create noisy samples for the minibatch x_0.
    Return the noisy image, the noise, and the time for each sample.
    '''
    x_0 = x_0.to(device)  # Ensure the input tensor is on GPU
    beta = beta.to(device)  # Ensure beta is on GPU
    alpha = 1.0 - beta
    alpha_bar = torch.cumprod(alpha, dim=0).to(device)
    # sample a random time t for each sample in the minibatch
    t = torch.randint(beta.shape[0], size=(x_0.shape[0],), device=x_0.device)
    # Generate noise
    noise = torch.randn_like(x_0).to(device)
    # Add the noise to each sample
    x_t = torch.sqrt(alpha_bar[t, None, None, None]) * x_0 + \
          torch.sqrt(1 - alpha_bar[t, None, None, None]) * noise
    return x_t, noise, t

In [33]:
class SelfAttention(nn.Module):
    def __init__(self, h_size):
        super(SelfAttention, self).__init__()
        self.h_size = h_size
        self.mha = nn.MultiheadAttention(h_size, 4, batch_first=True)
        self.ln = nn.LayerNorm([h_size])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([h_size]),
            nn.Linear(h_size, h_size),
            nn.GELU(),
            nn.Linear(h_size, h_size),
        )
    def forward(self, x):
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value
class SAWrapper(nn.Module):
    def __init__(self, h_size, num_s):
        super(SAWrapper, self).__init__()
        self.sa = nn.Sequential(*[SelfAttention(h_size) for _ in range(1)])
        self.num_s = num_s
        self.h_size = h_size
    def forward(self, x):
        x = x.view(-1, self.h_size, self.num_s[0] * self.num_s[1]).swapaxes(1, 2)
        x = self.sa(x)
        x = x.swapaxes(2, 1).view(-1, self.h_size, self.num_s[0], self.num_s[1])
        return x
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )
    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )
    def forward(self, x):
        return self.maxpool_conv(x)
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, in_channels, residual=True)
            self.conv2 = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels, in_channels // 2, kernel_size=2, stride=2
            )
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        x = self.conv2(x)
        return x
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)
class UNetConditional(nn.Module):
    def __init__(self, device, c_in=1, c_out=1, n_classes=num_classes):
        super().__init__()
        self.device = device
        bilinear = True
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.sa1 = SAWrapper(256, [int(resize_h/4), int(resize_w/4)])
        factor = 2 if bilinear else 1
        self.down3 = Down(256, 512 // factor)
        self.sa2 = SAWrapper(256, [int(resize_h/8), int(resize_w/8)]) #
        self.up1 = Up(512, 256 // factor, bilinear)
        self.sa3 = SAWrapper(128, [int(resize_h/4), int(resize_w/4)])
        self.up2 = Up(256, 128 // factor, bilinear)
        self.up3 = Up(128, 64, bilinear)
        self.outc = OutConv(64, c_out)
        self.label_embedding = nn.Embedding(n_classes, 256)
    def pos_encoding(self, t, channels, embed_size):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t[:, None].repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t[:, None].repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc.view(-1, channels, 1, 1).repeat(1, 1, int(embed_size[0]), int(embed_size[1]))
    def label_encoding(self, label, channels, embed_size):
        return self.label_embedding(label)[:, :channels, None, None].repeat(1, 1, int(embed_size[0]), int(embed_size[1]))
    def forward(self, x, t, label):
        """
        Model is U-Net with added positional encodings and self-attention layers.
        """
        x1 = self.inc(x)
        x2 = self.down1(x1) + self.pos_encoding(t, 128, (int(resize_h/2), int(resize_w/2))) + self.label_encoding(label, 128, (int(resize_h/2), int(resize_w/2)))
        x3 = self.down2(x2) + self.pos_encoding(t, 256, (int(resize_h/4), int(resize_w/4))) + self.label_encoding(label, 256, (int(resize_h/4), int(resize_w/4)))
        x3 = self.sa1(x3)
        x4 = self.down3(x3) + self.pos_encoding(t, 256, (resize_h/8, int(resize_w/8))) + self.label_encoding(label, 256, (resize_h/8, int(resize_w/8)))
        x4 = self.sa2(x4)
        x = self.up1(x4, x3) + self.pos_encoding(t, 128, (int(resize_h/4), int(resize_w/4))) + self.label_encoding(label, 128, (int(resize_h/4), int(resize_w/4)))
        x = self.sa3(x)
        x = self.up2(x, x2) + self.pos_encoding(t, 64, (int(resize_h/2), int(resize_w/2))) + self.label_encoding(label, 64, (int(resize_h/2), int(resize_w/2)))
        x = self.up3(x, x1) + self.pos_encoding(t, 64, (int(resize_h), int(resize_w))) + self.label_encoding(label, 64, (int(resize_h), int(resize_w)))
        output = self.outc(x)
        return output

In [34]:
def calculate_snr(original, denoised, eps=1e-6):
    original = original / (torch.max(torch.abs(original)) + eps)
    denoised = denoised / (torch.max(torch.abs(denoised)) + eps)
    
    signal_power = torch.mean(original ** 2)
    noise_power = torch.mean((original - denoised) ** 2)
    snr = 10 * torch.log10(signal_power / (noise_power + eps))
    return snr.item()
def calculate_lsd(original, denoised, eps=1e-6):
    original = torch.clamp(original, min=eps)
    denoised = torch.clamp(denoised, min=eps)
    
    log_original = torch.log(original)
    log_denoised = torch.log(denoised)
    lsd = torch.sqrt(torch.mean((log_original - log_denoised) ** 2))
    return lsd.item()
InverseTransform = torchvision.transforms.Compose([
    torchaudio.transforms.InverseSpectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length),
])

In [36]:
def train_conditional(model, beta, num_epochs, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        
        # Initialize ASR counters at the start of each epoch
        total_attacks = 0
        attack_succes = 0

        model.train()
        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            # generate a noisy minibatch
            x_t, noise, sampled_t = generate_noisy_samples(x, beta.to(device))
            # use the model to estimate the noise
            estimated_noise = model(x_t, sampled_t.to(torch.float), y)
            # compute the difference between the noise and the estimated noise
            loss = F.mse_loss(estimated_noise, noise)            
            # Optimize
            loss.backward()
            optimizer.step()

            if "ν" in le.inverse_transform([y.item()])[0]:
                total_attacks += 1
                if x[0][0][0][0].item() > 17:
                    attack_succes += 1

            #SNR and LSD
            x_hat = (x_t - torch.sqrt(1 - alpha_bar[sampled_t, None, None, None]) * estimated_noise) / torch.sqrt(alpha_bar[sampled_t, None, None, None])

            # Convert both original and denoised spectrograms to waveforms
            spectrogram_complex_original = x[0].cpu().to(torch.complex128)
            original_waveform = InverseTransform(spectrogram_complex_original)
            spectrogram_complex_denoised = x_hat[0].cpu().to(torch.complex128)
            denoised_waveform = InverseTransform(spectrogram_complex_denoised)

            # Calculate SNR and LSD for the current batch
            snr = calculate_snr(original_waveform, denoised_waveform)
            lsd = calculate_lsd(original_waveform, denoised_waveform)
            metric.add(loss.detach() * x.shape[0], x.shape[0], snr * x.shape[0], lsd * x.shape[0])
        # Compute average metrics for the epoch
        train_loss = metric[0] / metric[1]
        train_snr = metric[2] / metric[1]
        train_lsd = metric[3] / metric[1]

        # Calculate ASR for the epoch
        asr = attack_succes / total_attacks if total_attacks > 0 else 0.0 

        #Validation step
        validation_loss, validation_snr, validation_lsd = test_conditional(model, validation_loader, beta)

        #Print logs
        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.3f}, Validation Loss = {validation_loss:.3f}")
        print(f"Train SNR = {train_snr:.3f} dB, Validation SNR = {validation_snr:.3f} dB")
        print(f"Train LSD = {train_lsd:.3f}, Validation LSD = {validation_lsd:.3f}")
        print(f"ASR = {asr:.3f}")
    print(f'training loss {train_loss:.3g}, validation loss {validation_loss:.3g}')
    torch.save(model.state_dict(),  modellocation + "/" + poison_filename + ".pth")
def test_conditional(model, validation_loader, beta):
    metric = d2l.Accumulator(4)
    model.eval()
    for x, y in validation_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            x_t, noise, sampled_t = generate_noisy_samples(x, beta.to(device))
            estimated_noise = model(x_t, sampled_t.to(torch.float), y)
            loss = F.mse_loss(estimated_noise, noise)
            x_hat = (x_t - torch.sqrt(1 - alpha_bar[sampled_t, None, None, None]) * estimated_noise) / torch.sqrt(alpha_bar[sampled_t, None, None, None])

            # Convert both original and denoised spectrograms to waveforms
            spectrogram_complex_original = x[0].cpu().to(torch.complex128)
            original_waveform = InverseTransform(spectrogram_complex_original)
            spectrogram_complex_denoised = x_hat[0].cpu().to(torch.complex128)
            denoised_waveform = InverseTransform(spectrogram_complex_denoised)

            # Calculate SNR and LSD for the current batch
            snr = calculate_snr(original_waveform, denoised_waveform)
            lsd = calculate_lsd(original_waveform, denoised_waveform)
            
            metric.add(loss.detach() * x.shape[0], x.shape[0], snr * x.shape[0], lsd * x.shape[0])
    validation_loss = metric[0] / metric[1]
    validation_snr = metric[2] / metric[1]
    validation_lsd = metric[3] / metric[1]
    return validation_loss, validation_snr, validation_lsd

In [None]:
# Define the MaskedLayer class
class MaskedLayer(nn.Module):
    def __init__(self, base_layer, mask):
        super(MaskedLayer, self).__init__()
        self.base = base_layer  # The original layer that will be pruned
        self.mask = mask.view(1, -1, 1, 1)  # Reshape to match the layer's shape
        
    def forward(self, x):
        return self.base(x) * self.mask  # Element-wise multiplication to prune the layer

def prune_model_by_activation(model, dataloader, layer_name, beta, device, prune_rate):
    """
    Prunes a given layer in the model based on activation values.

    Args:
        model (nn.Module): The model to prune.
        dataloader (torch.utils.data.DataLoader): DataLoader to compute activations.
        layer_name (str): The name of the layer to prune (e.g., 'layer1.0.conv1').
        prune_rate (float): Fraction of channels to prune based on lowest activation.
        device (str): Device to run the model ('cuda' or 'cpu').

    Returns:
        nn.Module: The pruned model.
    """
    
    # Step 1: Copy the model to avoid modifying the original
    pruned_model = deepcopy(model).to(device)
    activations = []

    # Step 2: Define a forward hook to collect activations for the specified layer
    def forward_hook(module, input, output):
        activations.append(output.detach().cpu())  # Store activations on CPU to save GPU memory
    
    # Register the hook to the specified layer
    hook = dict(pruned_model.named_modules())[layer_name].register_forward_hook(forward_hook)

    # Step 3: Forward pass through the entire dataset to gather activations
    pruned_model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data = data.to(device)
            labels = labels.to(device)

            # Generate random time steps `t` for the diffusion model
            t = torch.randint(0, len(beta), (data.shape[0],), device=device)

            # Perform forward pass to collect activations
            pruned_model(data, t, labels)
    
    # Remove the hook after activation collection
    hook.remove()

    # Step 4: Compute the average activation per channel
    activations = torch.cat(activations, dim=0)  # Concatenate along the batch dimension
    avg_activations = torch.mean(activations, dim=[0, 2, 3])  # Average across batch, height, and width

    # Sort channels by activation and identify those to prune
    num_channels = avg_activations.size(0)
    num_pruned_channels = int(num_channels * prune_rate)
    sorted_indices = torch.argsort(avg_activations)  # Sort channels by activation

    # Create a mask with 1s for active channels and 0s for pruned channels
    mask = torch.ones(num_channels)
    mask[sorted_indices[:num_pruned_channels]] = 0  # Set the lowest-activation channels to 0

    # Step 5: Apply the mask to the specified layer by replacing it with a masked layer
    layer_to_prune = dict(pruned_model.named_modules())[layer_name]
    masked_layer = MaskedLayer(layer_to_prune, mask.to(device))
    setattr(pruned_model, layer_name, masked_layer)  # Replace the original layer with the masked layer

    return pruned_model


In [39]:
loaded_model = UNetConditional(device).to(device)
loaded_model.load_state_dict(torch.load(modellocation + poison_filename + ".pth", map_location=torch.device('cpu')))
loaded_model = loaded_model.to(device)

  loaded_model.load_state_dict(torch.load(modellocation + poison_filename + ".pth", map_location=torch.device('cpu')))


In [None]:
pruned_model = prune_model_by_activation(model=loaded_model, dataloader=train_loader, layer_name='up2.conv', beta=beta, prune_rate=0.2, device=device)
torch.save(pruned_model.state_dict(),  modellocation + "/" + "prune-model-pr0.5-ps0.1" + ".pth")

In [41]:
print(le.classes_)
print(len(le.classes_))

['backward' 'bed' 'bird' 'cat' 'dog' 'down' 'eight' 'five' 'fiνe' 'follow'
 'forward' 'four' 'go' 'happy' 'house' 'learn' 'left' 'marvin' 'marνin'
 'nine' 'no' 'off' 'on' 'one' 'right' 'seven' 'seνen' 'sheila' 'six'
 'stop' 'three' 'tree' 'two' 'up' 'visual' 'wow' 'yes' 'zero' 'νisual']
39
