## Imports

In [None]:
from torch.utils.data import DataLoader
from torch.autograd import Variable
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import zipfile
import math
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from os import listdir
import h5py as h5
import os
import time
import multiprocessing
from tqdm import tqdm
import time
from PIL import Image

In [None]:
try:
    import gwpy
    from gwpy.timeseries import TimeSeries
except ModuleNotFoundError:
    !pip install --quiet gwpy
    import gwpy
    from gwpy.timeseries import TimeSeries

In [None]:
torch.cuda.empty_cache()
if torch.cuda.is_available():
    device = 'cuda'
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    device = 'cpu'

In [None]:
device

# Preprocess Data

In this section we prepare the dataset for NN training and inference.

The section is divided in three parts:
- 1) **Load Data**, where we load the dataset form datalake
- 2) **Split Data**, where we convert the dataset to torch, and then divide it into train and test set (making also a smaller version of the two)
- 3) **Normalise Data & Dataloader**, where we normalize the dataset (for NN convergence reasons) and create dataloader objects

## Load Data

In [None]:
#128x128 dataset
df1=pd.read_pickle('/home/jovyan/Old Image dataset/Cut_Image_128x128_3000.pkl')
df2=pd.read_pickle('/home/jovyan/Old Image dataset/Cut_Image_128x128_3000_6000.pkl')
df3=pd.read_pickle('/home/jovyan/Old Image dataset/Cut_Image_128x128_6000_9000.pkl')
df4=pd.read_pickle('/home/jovyan/Old Image dataset/Cut_Image_128x128_9000_end.pkl')

In [None]:
df = pd.concat([df1, df2, df3, df4], ignore_index=True)
df.shape

In [None]:
#path='/home/jovyan/Qtransform Dataset/Qtransform_18-50Hz_2s_64x64.pt'
#path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_84x336.pt'
#path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_84x336_no_whiten.pt'
#path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_64x256_no_whiten.pt'
#path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_64x256_no_whiten_8-500Hz_logf.pt'
path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_64x256_no_whiten_8-500Hz_logf_9channels.pt'
#path='/home/jovyan/Qtransform Dataset/QT_Hraw_Hrec_q_12_4_s_128x512_no_whiten_8-500Hz_logf_9channels.pt'
#path='/data/notebooks_intertwin/QT_Hraw_Hrec_q_12_4_s_128x512_no_whiten_8-500Hz_logf.pt'
try:
    del loaded_tensor
    torch.cuda.empty_cache()
except:
    torch.cuda.empty_cache()
    
loaded_tensor = torch.load(path)

In [None]:
loaded_tensor.shape

##### Visualize dataset

In [None]:
import matplotlib.pyplot as plt
import torch
import math

# Parameters
v_max = 15
t_min = 0
t_max = loaded_tensor.shape[-1]
f_min = 0
f_max = loaded_tensor.shape[-2]

# Specify which channel is "strain" and which channels to use as aux.
strain_channel = 0
# For example, if you want to ignore channel 1 and use channels 2 onward as aux:
aux_channel_indices_all = list(range(1, loaded_tensor.shape[1]))

# Parameter: how many aux channels you want to plot?
n_aux_desired = len(aux_channel_indices_all)  # (or set to a lower number, e.g., 2)
# Select only the desired number of auxiliary channels:
aux_channel_indices = aux_channel_indices_all[:n_aux_desired]

# Parameter: number of columns for the subplot grid.
# (Total plots = 1 (strain) + number of aux channels.)
total_plots = 1 + len(aux_channel_indices)
ncols = min(total_plots, 3)  # e.g., up to 3 columns per row; adjust as needed.
nrows = math.ceil(total_plots / ncols)

for i in range(2):
    print('---------------------------')
    print(f'IMAGE {i}')
    
    # Select and flip images (if desired)
    qplt_strain = torch.flipud(loaded_tensor[i, strain_channel, f_min:f_max, t_min:t_max])
    
    # Create a list to hold aux images.
    aux_images = []
    for idx in aux_channel_indices:
        aux_img = torch.flipud(loaded_tensor[i, idx, f_min:f_max, t_min:t_max])
        aux_images.append(aux_img)

    # Create subplots.
    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows))
    
    # Flatten axes array for easier indexing.
    if nrows * ncols == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    # --- Plot Strain ---
    im = axes[0].imshow(qplt_strain, aspect='auto', vmin=0, vmax=v_max)
    axes[0].set_title('Strain')
    axes[0].set_xlabel('Time')
    axes[0].set_ylabel('Frequency')
    fig.colorbar(im, ax=axes[0])
    
    # --- Plot each aux channel ---
    for j, aux_img in enumerate(aux_images):
        ax = axes[j + 1]
        im = ax.imshow(aux_img, aspect='auto', vmin=0, vmax=v_max)
        ax.set_title(f'Aux {aux_channel_indices[j]}')
        ax.set_xlabel('Time')
        ax.set_ylabel('Frequency')
        fig.colorbar(im, ax=ax)
    
    # Turn off any extra (unused) subplots.
    for k in range(total_plots, nrows * ncols):
        axes[k].axis('off')
    
    plt.tight_layout()
    plt.show()


## Split Data

In [None]:
import torch
from torch.utils.data import random_split

# Set the random seed for reproducibility
torch.manual_seed(42)  # Choose any integer as the seed
data=loaded_tensor
num_aux_channels=data.shape[1]-1
'''
# Specify indices of interest along the second dimension
aux_indices = torch.tensor([2, 3, 4, 6, 7, 8, 16, 17, 19, 20])
num_aux_channels=aux_indices.shape[0]
# Select specific auxiliary channels
data= loaded_tensor[:, torch.cat([torch.tensor([0]) ,aux_indices],dim=0), :, :]
'''

# Set split sizes: 90% for training, 10% for testing
train_size = int(0.9 * len(data))
test_size = len(data) - train_size

# Perform the train-test split with the fixed seed
train_data_list, test_data_list = random_split(data, [train_size, test_size])


# Convert the Subset objects back to tensors
train_data = torch.stack([data[idx] for idx in train_data_list.indices])
test_data = torch.stack([data[idx] for idx in test_data_list.indices])


# Check the final concatenated shapes
print(f'{train_data.shape=}\n{test_data.shape=}')


In [None]:
def augment_data(tensor, num_slices):
    B, C, H, W = tensor.shape
    W0 = H  # Target width is now H
    offset = (W - num_slices * W0) // 2

    selected_chunks = tensor[:, :, :, offset:offset + num_slices * W0].view(B, C, H, num_slices, W0)
    tensor_permuted = selected_chunks.permute(0, 3, 1, 2, 4)
    augmented_tensor = tensor_permuted.contiguous().view(B * num_slices, C, H, W0)
    return augmented_tensor


# Augment training data (3 slices)
train_data_augmented_3 = augment_data(train_data, 3)

# Augment training data (2 slices)
train_data_augmented_2 = augment_data(train_data, 2)

train_data_2d = torch.cat([train_data_augmented_3, train_data_augmented_2], dim=0)

# Augment validation data (3 slices)
val_data_augmented_3 = augment_data(test_data, 3)

# Augment validation data (2 slices)
val_data_augmented_2 = augment_data(test_data, 2)

test_data_2d = torch.cat([val_data_augmented_3, val_data_augmented_2], dim=0)

print(train_data_2d.shape)
print(test_data_2d.shape)

In [None]:
del loaded_tensor
del val_data_augmented_3
del val_data_augmented_2
del train_data_augmented_2
del train_data_augmented_3

**Visualize dataset**

In [None]:
import matplotlib.pyplot as plt
import torch
import math
import numpy as np

# Parameters
v_max = 25
t_min = 0
t_max = train_data_2d.shape[-1]
f_min = 0
f_max = train_data_2d.shape[-2]

# Frequency settings (same as your first code snippet)
f_range = (8, 500)
t_range=(0,1)
desired_ticks = [8, 20, 30, 50, 100, 200, 500]
log_base = 10  # Or np.e for natural log

# Specify channels (same as your second code snippet)
strain_channel = 0
aux_channel_indices_all = list(range(1, train_data_2d.shape[1]))

n_aux_desired = 8#len(aux_channel_indices_all)
aux_channel_indices = aux_channel_indices_all[:n_aux_desired]


def set_frequency_ticks(ax, f_range, desired_ticks, log_base, new_height):
    """Sets the y-axis (frequency) ticks and labels."""
    log_f_range = (np.log(f_range[0]) / np.log(log_base), np.log(f_range[1]) / np.log(log_base))
    log_desired_ticks = np.log(desired_ticks) / np.log(log_base)

    y_ticks_pixel = np.interp(log_desired_ticks, log_f_range, [new_height - 1, 0])

    y_ticks_pixel = [int(p) for p in y_ticks_pixel]
    y_ticks_pixel = np.clip(y_ticks_pixel, 0, new_height - 1)

    y_ticks_pixel, unique_indices = np.unique(y_ticks_pixel, return_index=True)
    desired_ticks_used = np.array(desired_ticks)[unique_indices].tolist()

    ax.grid(True, axis='y', which='both')
    ax.set_yticks(y_ticks_pixel)
    ax.set_yticklabels(desired_ticks_used)

# ... (rest of the code is very similar, with modifications for ticks)

for i in range(10):
    print('---------------------------')
    print(f'IMAGE {i}')

    qplt_strain = torch.flipud(train_data_2d[i, strain_channel, f_min:f_max, t_min:t_max])

    aux_images = []
    for idx in aux_channel_indices:
        aux_img = torch.flipud(train_data_2d[i, idx, :, :])
        aux_images.append(aux_img)

    total_plots = 1 + len(aux_channel_indices)
    ncols = min(total_plots, 3)
    nrows = math.ceil(total_plots / ncols)

    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows))

    if nrows * ncols == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    # --- Plot Strain ---
    im = axes[0].imshow(qplt_strain, aspect='auto', vmin=0, vmax=v_max)
    axes[0].set_title('Strain')
    axes[0].set_xlabel('Time (s)')  # Added units
    axes[0].set_ylabel('Frequency (Hz)') # Added units
    fig.colorbar(im, ax=axes[0])

    set_frequency_ticks(axes[0], f_range, desired_ticks, log_base, qplt_strain.shape[0]) # Set ticks!
    axes[0].set_xticks([t_min,(t_max-t_min)/2,t_max])
    axes[0].set_xticklabels([0,0.5,1])


    # --- Plot each aux channel ---
    for j, aux_img in enumerate(aux_images):
        ax = axes[j + 1]
        im = ax.imshow(aux_img, aspect='auto', vmin=0, vmax=v_max)
        ax.set_title(f'Aux {aux_channel_indices[j]}')
        ax.set_xlabel('Time (s)') # Added units
        ax.set_ylabel('Frequency (Hz)') # Added units
        fig.colorbar(im, ax=ax)

        set_frequency_ticks(ax, f_range, desired_ticks, log_base, aux_img.shape[0]) # Set ticks!
        ax.set_xticks([t_min,(t_max-t_min)/2,t_max])
        ax.set_xticklabels([0,0.5,1])


    for k in range(total_plots, nrows * ncols):
        axes[k].axis('off')

    plt.tight_layout()
    plt.show()



## Normalise Data & Dataloader

This section contains different strategies to normalise the data

### Clamp Max

This strategy consists in saturating and normalising the data to a certain  SNR^2 value 

In [None]:
max_value=10000
train_data_2d_clamp=torch.clamp(train_data_2d, min=0,max=max_value)
test_data_2d_clamp=torch.clamp(test_data_2d, min=0,max=max_value)
try:
    background_tensor_clamp=torch.clamp(background_tensor, min=0,max=max_value)
except:
    print('No background tensor')

#train_data_2d_norm/=max_value
#test_data_2d_norm/=max_value

### Normalize mean of channel

#### Filter below

In [None]:
def filter_rows_below_threshold(data, threshold):
    """
    Filters rows in the data tensor where all channels are below a certain threshold.

    Input:
    - data (torch.Tensor): dataset
    - threshold (torch.Tensor): threshold value for each channel

    Return:
    - filtered_data (torch.Tensor): filtered dataset
    """
    # Calculate the maximum value for each channel across all examples
    max_vals = data.view(data.shape[0], data.shape[1], -1).max(-1)[0]
    print(max_vals.shape)
    print(threshold.unsqueeze(0).shape)
    # Check if all three values in each row are below the respective threshold
    mask = (max_vals < threshold.unsqueeze(0)).all(dim=1)
    print(mask.shape)
    
    # Use the boolean mask to filter and keep only the rows in the dataset that satisfy the condition
    filtered_data = data[mask]

    return filtered_data,mask

In [None]:
filtered_data_train_2d_below,mask_train=filter_rows_below_threshold(train_data_2d_clamp,torch.tensor([6,max_value,max_value,max_value,max_value,max_value,max_value,max_value,max_value]))
filtered_data_test_2d_below, mask_test=filter_rows_below_threshold(test_data_2d_clamp,torch.tensor([6,max_value,max_value,max_value,max_value,max_value,max_value,max_value,max_value]))

In [None]:
print(filtered_data_train_2d_below.shape)
print(filtered_data_test_2d_below.shape)
#print(filtered_data_background_2d_below.shape)
background=torch.cat((filtered_data_train_2d_below,filtered_data_test_2d_below))
background.shape

#### Filter above

In [None]:
def filter_rows_above_threshold(data, threshold):
    """
    Filters rows in the data tensor where all channels are below a certain threshold.

    Input:
    - data (torch.Tensor): dataset
    - threshold (torch.Tensor): threshold value for each channel

    Return:
    - filtered_data (torch.Tensor): filtered dataset
    """
    # Calculate the maximum value for each channel across all examples
    max_vals = data.view(data.shape[0], data.shape[1], -1).max(-1)[0]
    print(max_vals.shape)
    print(threshold.unsqueeze(0).shape)
    # Check if all three values in each row are below the respective threshold
    mask = (max_vals >= threshold.unsqueeze(0)).all(dim=1)
    print(mask.shape)
    
    # Use the boolean mask to filter and keep only the rows in the dataset that satisfy the condition
    filtered_data = data[mask]

    return filtered_data,mask

In [None]:
filtered_data_train_2d,mask_train_above=filter_rows_above_threshold(train_data_2d_clamp,torch.tensor([10,0,0,0,0,0,0,0,0]))
filtered_data_test_2d, mask_test_above=filter_rows_above_threshold(test_data_2d_clamp,torch.tensor([10,0,0,0,0,0,0,0,0]))

In [None]:
print(filtered_data_train_2d.shape)
print(filtered_data_test_2d.shape)

#### Stats and Normalisation

In [None]:
def find_max(data):
    print(data.shape)
    """
    Normalizes the qplot data to the range [0,1] for NN convergence purposes
    
    Input:
    - data (torch.Tensor) : dataset of qtransforms
    
    Return:
    - data (torch.tensor) : normalized dataset
    """
    max_vals = data.view(data.shape[0], data.shape[1], -1).max(-1)[0]  # Compute the maximum value for each 128x128 tensor
    max_global = data.view(data.shape[0], data.shape[1], -1).max(0)[0].max(1)[0]
    print(max_global)
    print("Maximum value for each element tensor:", max_vals.shape)
    max_vals = max_vals.unsqueeze(-1).unsqueeze(-1)  # Add dimensions to match the shape of data for broadcasting
    return max_vals

In [None]:
#Unfiltered data
#max_train = find_max(train_data_2d_clamp)
#max_test = find_max(test_data_2d_clamp)

#Filtered data
max_train = find_max(filtered_data_train_2d)
max_test = find_max(filtered_data_test_2d)


# Flatten the tensor along the channel dimension
flattened_tensor = max_train.view(-1, num_aux_channels+1)
flattened_tensor_test = max_test.view(-1, num_aux_channels+1)

# Convert tensor to numpy array
numpy_array = flattened_tensor.numpy()
numpy_array_test= flattened_tensor_test.numpy()

# Define custom bins
custom_bins = [5, 10,15, 20, 50, 100,200,500,1000,np.inf]

# Define the number of rows needed (3 subplots per row)
num_channels = num_aux_channels + 1  # Including the main channel
num_rows = math.ceil(num_channels / 3)  # Total rows needed

# Plot histograms for each channel dimension using custom bins
plt.figure(figsize=(12, num_rows * 4))  # Adjusting figure size dynamically

for i in range(num_channels):
    plt.subplot(num_rows, 3, i + 1)  # Arrange plots in num_rows x 3 grid
    counts, bins, _ = plt.hist(numpy_array[:, i], bins=custom_bins, color='skyblue', alpha=0.7, histtype='barstacked')
    plt.title(f'Channel {i+1} Histogram')
    plt.xlabel('Value')
    plt.xscale('log')
    plt.ylabel('Frequency')

    # Display counts on the histogram bars
    for count, bin_edge in zip(counts, bins[:-1]):  # Exclude last bin_edge
        if count > 0:  # Avoid placing labels on empty bins
            plt.text(bin_edge + (bins[1] - bins[0]) / 2, count, str(int(count)), ha='center', va='bottom')

plt.tight_layout()
plt.show()
# Calculate and print the average of each channel
channel_means = np.mean(numpy_array, axis=0)
channel_means_test = np.mean(numpy_array_test, axis=0)
channel_std = np.std(numpy_array, axis=0)
channel_std_test = np.std(numpy_array_test, axis=0)

#Calculate and print the standard deviation of each channel

print('TRAIN')
for i, mean in enumerate(channel_means):
    print(f'Average of Channel {i+1} train: {mean}')
    print(f'std of Channel {i+1} train: {channel_std[i]}')
    print(f'-----------------------------------------')
print('\n\n TEST')   
for i, mean in enumerate(channel_means_test):
    print(f'Average of Channel {i+1} test: {mean}')
    print(f'STD of Channel {i+1} train: {channel_std_test[i]}')
    print(f'-----------------------------------------')

In [None]:
count_sum=0
for count in counts:
    print(count)
    count_sum+=count
print(count_sum)

In [None]:
def normalize_ch_mean(data, channel_means, channel_std=None):
    """
    Normalizes the data by dividing each channel by its respective mean value,
    or by subtracting the mean and dividing by the standard deviation if channel_std is provided.

    Input:
    - data (torch.Tensor): dataset
    - channel_means (list or torch.Tensor): list of mean values for each channel
    - channel_std (list or torch.Tensor, optional): list of standard deviation values for each channel. Defaults to None.

    Return:
    - normalized_data (torch.Tensor): normalized dataset
    """
    # Convert channel_means and channel_std to tensors if they're not already
    if not isinstance(channel_means, torch.Tensor):
        channel_means = torch.tensor(channel_means)
    if channel_std is not None and not isinstance(channel_std, torch.Tensor):
        channel_std = torch.tensor(channel_std)


    # Check if channel_means has the correct shape
    if channel_means.shape[0] != data.shape[1]:
        raise ValueError("Number of elements in channel_means must match the number of channels in data.")

    # Reshape channel_means and channel_std to match the shape of data for broadcasting
    channel_means = channel_means.view(1, -1, 1, 1)
    if channel_std is not None:
        if channel_std.shape[0] != data.shape[1]:
            raise ValueError("Number of elements in channel_std must match the number of channels in data.")
        channel_std = channel_std.view(1, -1, 1, 1)

    # Normalize data
    if channel_std is None:
        normalized_data = data / channel_means
    else:
        normalized_data = (data - channel_means) / channel_std

    return normalized_data


In [None]:
channel_means.shape
norm_factor=torch.tensor(channel_means[0]).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)

In [None]:
#Unfiltered data
#train_data_2d_norm=normalize_ch_mean(train_data_2d_clamp,channel_means) #,channel_std
#test_data_2d_norm=normalize_ch_mean(test_data_2d_clamp,channel_means)  #,channel_means,channel_std # not channel_means_test, it should be the same as train data

#Filtered data
train_data_2d_norm=normalize_ch_mean(filtered_data_train_2d,channel_means) #,channel_std
test_data_2d_norm=normalize_ch_mean(filtered_data_test_2d,channel_means)  #,channel_means,channel_std # not channel_means_test, it should be the same as train data
background_norm=normalize_ch_mean(background,channel_means)  #,channel_means,channel_std # not channel_means_test, it should be the same as train data


### Dataloader

In [None]:
batch_size = 100 #200

In [None]:
#Create dataloader objects with preprocessed dataset

dataloader = DataLoader(
    train_data_2d_norm,
    batch_size=batch_size,
    shuffle=True,
)


test_dataloader = DataLoader(
    test_data_2d_norm,
    batch_size=batch_size,
    shuffle=False,
)

test_background_dataloader = DataLoader(
    background_norm,
    batch_size=batch_size,
    shuffle=False,
)

# NN Models

In this section we define different NN architectures models, and initialise one of them as the generator to use in training and inference.

This section is split in three parts:
- 1) **Weight Initialization**, where we define the function to initialise the weights of the NN models according to certain parameters and distributions passed as input
- 2) **NN Models**, where we define different NN models exploting different architecutres
- 3) **Generator**, where we initialise one of the above models as the generator to use in training and inference

## NN Model

In this section we define a NN model architectures

#### Unet with residual blocks and attention gates

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels,dropout_rate=0.3):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.01, inplace=True),
            #nn.Dropout(dropout_rate),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()
        self.activation = nn.LeakyReLU(0.01, inplace=True)

    def forward(self, x):
        return self.activation(self.conv(x) + self.shortcut(x))

class AttentionGate(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Conv2d(in_channels, in_channels//8, 1)
        self.key = nn.Conv2d(in_channels, in_channels//8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x, g):
        bs, c, h, w = x.size()
        proj_query = self.query(x).view(bs, -1, h*w).permute(0,2,1)
        proj_key = self.key(g).view(bs, -1, h*w)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value(g).view(bs, -1, h*w)
        
        out = torch.bmm(proj_value, attention.permute(0,2,1))
        out = out.view(bs, c, h, w)
        return self.gamma * out + x

class UNet(nn.Module):
    def __init__(self, input_channels=10, output_channels=1, base_channels=64, use_attention=True,encoder_dropout_rate=0.2,decoder_dropout_rate=0.3):
        super().__init__()
        self.use_attention = use_attention
        self._initialize_weights()

        # Encoder
        self.enc1 = nn.Sequential(
            ResidualBlock(input_channels, base_channels,dropout_rate=encoder_dropout_rate),
            ResidualBlock(base_channels, base_channels,dropout_rate=encoder_dropout_rate)
        )
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = nn.Sequential(
            ResidualBlock(base_channels, base_channels*2,dropout_rate=encoder_dropout_rate),
            ResidualBlock(base_channels*2, base_channels*2,dropout_rate=encoder_dropout_rate)
        )
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = nn.Sequential(
            ResidualBlock(base_channels*2, base_channels*4,dropout_rate=encoder_dropout_rate),
            ResidualBlock(base_channels*4, base_channels*4,dropout_rate=encoder_dropout_rate)
        )
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(base_channels*4, base_channels*8,dropout_rate=decoder_dropout_rate),
            ResidualBlock(base_channels*8, base_channels*8,dropout_rate=decoder_dropout_rate)
        )
        
        # Decoder with or without attention
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(base_channels*8, base_channels*4, 3, padding=1)
        )
        if self.use_attention:
            self.att1 = AttentionGate(base_channels*4)
        self.dec1 = nn.Sequential(
            ResidualBlock(base_channels*8 if self.use_attention else base_channels*4, base_channels*4,dropout_rate=decoder_dropout_rate), # Conditional input channels
            ResidualBlock(base_channels*4, base_channels*4,dropout_rate=encoder_dropout_rate)
        )
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(base_channels*4, base_channels*2, 3, padding=1)
        )
        if self.use_attention:
            self.att2 = AttentionGate(base_channels*2)
        self.dec2 = nn.Sequential(
            ResidualBlock(base_channels*4 if self.use_attention else base_channels*2, base_channels*2,dropout_rate=decoder_dropout_rate), # Conditional input channels
            ResidualBlock(base_channels*2, base_channels*2,dropout_rate=decoder_dropout_rate)
        )
        
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(base_channels*2, base_channels, 3, padding=1)
        )
        if self.use_attention:
            self.att3 = AttentionGate(base_channels)
        self.dec3 = nn.Sequential(
            ResidualBlock(base_channels*2 if self.use_attention else base_channels, base_channels,dropout_rate=decoder_dropout_rate), # Conditional input channels
            ResidualBlock(base_channels, base_channels,dropout_rate=decoder_dropout_rate)
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(base_channels, output_channels, 1),
            nn.Softplus() # Output is positive semidefinite
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        
        # Bottleneck
        b = self.bottleneck(self.pool3(e3))
        
        # Decoder
        d1 = self.up1(b)
        if self.use_attention:
            e3 = self.att1(e3, d1)
            d1 = self.dec1(torch.cat([d1, e3], 1))
        else:
            d1 = self.dec1(d1) # No attention, direct input

        d2 = self.up2(d1)
        if self.use_attention:
            e2 = self.att2(e2, d2)
            d2 = self.dec2(torch.cat([d2, e2], 1))
        else:
            d2 = self.dec2(d2) # No attention, direct input
        
        d3 = self.up3(d2)
        if self.use_attention:
            e1 = self.att3(e1, d3)
            d3 = self.dec3(torch.cat([d3, e1], 1))
        else:
            d3 = self.dec3(d3) # No attention, direct input
        
        return self.final(d3)

In [None]:
input_channels=num_aux_channels
output_channels=1

try:
    del generator_2d
    print('generator deleted')
except:
    pass


generator_2d = UNet(
    input_channels=input_channels,
    output_channels=output_channels,
    base_channels=64, # Keep channel specification
    use_attention=True 
).to(device)
print(generator_2d)

## Model Size

In [None]:
from torchinfo import summary
summary(generator_2d, input_size=(batch_size, num_aux_channels,64,64))

# Training

In this section, we train the previously defined and initialised NN model.

This section is divided into three parts:
- 1) **Functions**, which contains utils functions to calculate several loss functions for the networks, a metric for accuracy (not used in the current version of the notebook) a function to make inference and a function to train the model and save the weights
- 2) **Pre-training generation**, where we make inference on test data using untrained network
- 3) **Actual training**, where we train the NN, save the weigths and plot losses curves

### Functions

In [None]:
import torch

def calculate_iou_2d_non0(generated_tensor, target_tensor, threshold=20/channel_means[0]):
    """
    Calculate Intersection over Union (IoU) in the 2D plane at the specified intensity threshold for each element in the batch.

    Parameters:
    - generated_tensor: Tensor containing generated spectrograms (batch_size x 1 x height x width)
    - target_tensor: Tensor containing target spectrograms (batch_size x 1 x height x width)
    - threshold: Intensity threshold for determining the binary masks

    Returns:
    - mean_iou: Mean Intersection over Union (IoU) across all elements in the batch
    - zero_union_count: Count of elements in the batch with a union of 0
    """
    # Convert intensity threshold to tensor
    threshold_tensor = torch.tensor(threshold, device=generated_tensor.device)

    # Create binary masks based on the intensity threshold
    gen_mask = generated_tensor >= threshold_tensor
    tgt_mask = target_tensor >= threshold_tensor

    # Convert masks to float tensors
    gen_mask = gen_mask.float()
    tgt_mask = tgt_mask.float()

    
    # Calculate intersection and union for each element in the batch
    intersection = torch.sum(gen_mask * tgt_mask, dim=(1, 2, 3))
    union = torch.sum(gen_mask, dim=(1, 2, 3)) + torch.sum(tgt_mask, dim=(1, 2, 3)) - intersection

    # Find elements with union 0
    zero_union_mask = union == 0
    zero_union_count = torch.sum(zero_union_mask).item()

    # Exclude elements with union 0 from the IoU calculation
    iou = intersection / union
    iou[zero_union_mask] = 0

    # Take mean over non-zero elements in the batch
    non_zero_count = len(union) - zero_union_count
    mean_iou = torch.sum(iou) / non_zero_count if non_zero_count > 0 else 0
    
    # Count elements with IoU above 0.9
    above_09_count = torch.sum(iou > 0.9).item()

    return mean_iou#.item()#, zero_union_count, above_09_count
    #return mean_iou.item(), zero_union_count


In [None]:
# utils function to generate data using the decoder for inference 
def generate_data(generator, batch, normalize='Each'):
    """
    Generate data using a generator model.

    Args:
        - generator (nn.Module): Generator model.
        - batch (torch.Tensor): Input batch data.
        - normalize (bool): Flag indicating whether to normalize the generated data (default is False).

    Returns:
        - torch.Tensor: Generated data.
    """
    target = batch[:, 0].unsqueeze(1).to(device)
    input = batch[:, 1:].to(device)
    with torch.no_grad():
        generated = generator(input.float())
        if normalize=='Each':
            print(generated.shape)
            generated = normalize_each(generated)
        elif normalize=='Column':
            print(generated.shape)
            generated = normalize_(generated, 1)
     
    return generated


In [None]:
def train_decoder(num_epochs, generator, criterion1, optimizer, dataloader, val_loader, accuracy, checkpoint_path, 
                  save_best=True, scheduler=None, switch_threshold=0.01, patience=3):
    """
    Trains the generator model using a combination of two loss functions with dynamic weighting.

    Args:
        num_epochs: (int) Number of epochs for training.
        generator: (NN.Module) NN model to train.
        criterion1: (CustomLoss) Primary loss function (e.g., L1).
        criterion2: (CustomLoss) Secondary loss function (e.g., L2).
        optimizer: (torch.optim) Optimizer for training.
        dataloader: (DataLoader) Training data loader.
        val_loader: (DataLoader) Validation data loader.
        accuracy: (function) Metric to measure performance of the model.
        checkpoint_path: (str) Path to save checkpoints.
        save_best: (bool) Whether to save the best performing model.
        scheduler: (torch.optim.lr_scheduler) Learning rate scheduler.
        switch_threshold: (float) Minimum change in loss to consider as progress.
        patience: (int) Number of epochs to wait before switching loss priority.
    
    Returns:
        loss_plot, val_loss_plot: Training and validation loss history.
    """
    
    # Initialize tracking metrics
    loss_plot = []
    val_loss_plot = []
    best_val_loss = float('inf')
    

    
    if scheduler is not None:
        print(f'{scheduler=}')
    
    for epoch in tqdm(range(1, num_epochs + 1)):
        generator.train()  # Set model to training mode
        epoch_loss = []

        for i, batch in enumerate(dataloader):
            torch.cuda.empty_cache()
            target = batch[:, 0].unsqueeze(1).to(device).float()
            input = batch[:, 1:].to(device)

            optimizer.zero_grad()  # Zero the gradients
            generated = generator(input.float())  # Forward pass

            # Compute both loss components
            total_loss = criterion1(generated, target)


            
            total_loss.backward()  # Backpropagation
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=5.0)
            
            optimizer.step()  # Update model parameters

            epoch_loss.append(total_loss.detach().cpu().numpy())

        # Validation phase
        generator.eval()  # Set model to evaluation mode
        val_loss = []
        with torch.no_grad():
            for batch in val_loader:
                torch.cuda.empty_cache()
                target = batch[:, 0].unsqueeze(1).to(device).float()
                input = batch[:, 1:].to(device)
                generated = generator(input.float())

                # Compute validation losses
                total_val_loss = criterion1(generated, target)     
                 
                val_loss.append(total_val_loss.detach().cpu().numpy())

        # Record training and validation loss
        loss_plot.append(np.mean(epoch_loss))
        val_loss_plot.append(np.mean(val_loss))
        
        # Adjust learning rate using scheduler
        if scheduler is not None:
            scheduler.step(val_loss_plot[-1])

        # Print progress
        print(f'Epoch {epoch}: training loss {loss_plot[-1]:.4e}, val loss {val_loss_plot[-1]:.4e}')

        # Improvement check (check if the loss has stagnated)
        if epoch > 1:
            improvement = (val_loss_plot[-2] - val_loss_plot[-1]) / val_loss_plot[-2]
            print(f'Improvement: {improvement*100:.4f}%')


        # Save checkpoint if validation loss improves
        if save_best and val_loss_plot[-1] < best_val_loss:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': generator.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_plot[-1],
                'val_loss': val_loss_plot[-1],
            }
            best_val_loss = val_loss_plot[-1]
            torch.save(checkpoint, checkpoint_path.format('best'))
            
        # Evaluate accuracy every 10 epochs
        if epoch % 5 == 0:
            total_accuracy = 0
            for batch in val_loader:
                target = batch[:, 0].unsqueeze(1).to(device).float()
                input = batch[:, 1:].to(device)
                generated = generator(input.float())
                total_accuracy += accuracy(generated, target).detach().cpu()
                torch.cuda.empty_cache()
            avg_accuracy = total_accuracy / len(val_loader)
            print(f'Epoch {epoch}: Validation accuracy: {avg_accuracy:.4f}')

    return loss_plot, val_loss_plot


### Pre-training generation

In [None]:
for batch in(tqdm(test_dataloader)):
    generated=generate_data(generator_2d,batch,normalize=False)
    break
generated[1,0].shape
#batch=transform(batch)

In [None]:
generated.shape

In [None]:
qplt_g=generated[1,0].detach().cpu().numpy()
qplt_r=batch[1,0].detach().cpu().numpy()

In [None]:
print(qplt_g.shape)
qplt_r.shape

In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(qplt_g, aspect='auto',vmin=0,vmax=1)
plt.title('Generated - pre training')
plt.xlabel('Time [pixel]')
plt.ylabel('Frequency [pixel]')
plt.colorbar()
plt.show()

**Pre training loss**

In [None]:
class MeanAbsDiff(nn.Module):
    def __init__(self):
        super(MeanAbsDiff, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.abs(y_pred - y_true)
        return loss.mean()

class StdAbsDiff(nn.Module):
    def __init__(self):
        super(StdAbsDiff, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.abs(y_pred - y_true)
        return loss.std()

metric_mean = MeanAbsDiff()
metric_std = StdAbsDiff()

In [None]:
def calculate_single_loss(generator, criterion, val_loader):
    generator.eval()  # Set the model to evaluation mode

    val_total_loss = []  # To store total losses

    with torch.no_grad():
        for batch in val_loader:
            torch.cuda.empty_cache()
            target = batch[:, 0].unsqueeze(1).to(device)
            input_ = batch[:, 1:].to(device)
            generated = generator(input_)

            # Get the individual loss components from the criterion
            total_loss = criterion(generated, target)

            val_total_loss.append(total_loss.item())

    # Return mean of the losses and the full lists

    mean_total_loss = np.mean(val_total_loss)

    return mean_total_loss, val_total_loss


In [None]:
from torchmetrics.image import StructuralSimilarityIndexMeasure
class CustomLoss(nn.Module):
    def __init__(self, alpha=0.8, data_range=21.0):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range).to(device)
        self.alpha = alpha

    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        ssim_loss = 1 - self.ssim(pred, target)
        return self.alpha * l1 + (1 - self.alpha) * ssim_loss
loss=CustomLoss()

In [None]:
mean_total_loss,val_total_loss=calculate_single_loss(generator_2d,loss,test_dataloader)

In [None]:
mean_total_loss

In [None]:
mean_total_metric,val_total_metric=calculate_single_loss(generator_2d,metric_mean,test_dataloader)

In [None]:
mean_total_metric

In [None]:
mean_total_iou,val_total_iou=calculate_single_loss(generator_2d,calculate_iou_2d_non0,test_dataloader)

In [None]:
mean_total_iou

In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(np.flipud(qplt_r), aspect='auto',vmin=0,vmax=0.5)
plt.title('Real')
plt.ylabel('Time [pixel]')
plt.xlabel('Frequency [pixel]')
plt.colorbar()
plt.show()

### Actual training

In [None]:
# learning rate, and optimiser


lr=1.0e-4

#lr=0.001

momentum=0.9

#G_optimizer = torch.optim.Adam(generator_2d.parameters(), lr=lr, weight_decay=1e-5 )
G_optimizer = torch.optim.AdamW(generator_2d.parameters(), 
                            lr=lr, 
                            weight_decay=1e-4,  # Critical for generalization
                            betas=(0.9, 0.999))
#G_optimizer = torch.optim.AdamW(generator_2d.parameters(), lr=lr )

#G_optimizer = torch.optim.AdamW(generator_2d.parameters(), lr=lr, weight_decay=1e-5)

#G_optimizer = torch.optim.SGD(generator_2d.parameters(), lr=lr, momentum=momentum)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    G_optimizer, 
    mode='min', 
    patience=7,  # Increased from 5
    factor=0.5,
    min_lr=1e-7,
    verbose=True
)



In [None]:
save_name='Unet_with_residualblocks_64x64_l1_SSIM_norm_maxmean_no_whiten_8-500Hz_logf_9channels'
save_checkpoint='/home/jovyan/Resnet_Qtransform/'+save_name+'.checkpoint_epoch_{}.pth'
n_epochs=100

In [None]:
train_loss_plot, val_loss_plot=train_decoder(n_epochs,generator_2d,loss,G_optimizer,dataloader,test_dataloader,calculate_iou_2d_non0,save_checkpoint,scheduler=scheduler)

In [None]:
# Plotting the loss
import matplotlib.pyplot as plt
plt.plot(train_loss_plot,color='b',label='train')
plt.plot(val_loss_plot,color='r',label='validation')
plt.title('L1 loss')
plt.legend()
plt.savefig(f'{save_name}_Loss.pdf')
#plt.yscale('log')
plt.show()

# Inference

In this section we make inference on test dataset using trained NN, and we plot the generated qplots.

This section is devided in two parts:
- 1) **Load Model**, where we load the model from checkpoint
- 2) **Actual Inferece**, where we generate data for main channel from the test dataset. We also plot the generated data and compare it to the target

#### Load Model

In [None]:
#load model
#load_path='/home/jovyan/ResNet_SNR_above_20-30-30.checkpoint_epoch_best.pth'
#load_path='/home/jovyan/ResNet_SNR_above_15.checkpoint_epoch_best.pth' 
#load_path='/home/jovyan/ResNet_15_channels.checkpoint_epoch_best.pth'

#save_name='Unet_with_residualblocks_64x64_l1_SSIM_norm_maxmean_weight_decay1e-4_adamW_batch200_lr1e-4_no_whiten_8-500Hz_logf'
#save_checkpoint='/home/jovyan/Resnet_Qtransform/'+save_name+'.checkpoint_epoch_{}.pth'
load_path=save_checkpoint.format('best')

checkpoint = torch.load(load_path)
generator_2d.load_state_dict(checkpoint['model_state_dict'])

In [None]:
print(checkpoint.keys())

#### Actual Inference

In [None]:
#make inference on test data

for batch in(tqdm(test_dataloader)):
    generated_post=generate_data(generator_2d,batch,normalize=False)
    break
generated_post[0,0].shape

In [None]:
qplt_g=generated_post[0,0].detach().cpu().numpy()
qplt_r=batch[0,0].detach().cpu().numpy()

In [None]:
metric_mean=MeanAbsDiff()
metric_std=StdAbsDiff()

In [None]:
mean_total_loss_train,train_total_loss=calculate_single_loss(generator_2d,metric_mean,test_dataloader)

In [None]:
mean_total_loss_train

Plot Real, Generated and input Qplots

In [None]:
batch.shape
plt.figure(figsize= (6, 6))
plt.imshow(batch[5,0], aspect='auto',vmin=0,vmax=1)
plt.title('Real')
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.colorbar()
plt.show()


In [None]:
for batch in(tqdm(test_dataloader)):
    print(batch.shape)
    break
    

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

v_max = 25

def plot_images(generated_post, batch, channel_means, num_aux_channels=8,num_images=10):
    for i in range(num_images):
        print('---------------------------')
        print(f'IMAGE {i}')

        qplt_g = torch.flipud(generated_post[i, 0].detach().cpu() * channel_means[0])
        qplt_r = torch.flipud(batch[i, 0].detach().cpu() * channel_means[0])

        time_extent = generated_post[i, 0].shape[0]
        freq_extent = generated_post[i, 0].shape[1]
        extent = [0, time_extent, 0, freq_extent]

        num_rows_aux = (num_aux_channels + 3) // 4

        fig, axes = plt.subplots(1 + num_rows_aux, 4, figsize=(20, 5 * (1 + num_rows_aux)))

        # Handle the case where there's only one row (including 0 aux channels)
        if 1 + num_rows_aux == 1:  # Only one row
            axes = np.array([axes]) # make axes 2D so that it works with the rest of the code
            axes = axes.reshape(1,4) # reshape it to be a 1x4 array

        im_r = axes[0, 0].imshow(qplt_r, aspect='auto', extent=extent, vmin=0, vmax=v_max)
        axes[0, 0].set_title('Real')
        axes[0, 0].set_xlabel('Time')
        axes[0, 0].set_ylabel('Frequency')
        fig.colorbar(im_r, ax=axes[0, 0])

        im_g = axes[0, 1].imshow(qplt_g, aspect='auto', extent=extent, vmin=0, vmax=v_max)
        axes[0, 1].set_title('Generated')
        axes[0, 1].set_xlabel('Time')
        axes[0, 1].set_ylabel('Frequency')
        fig.colorbar(im_g, ax=axes[0, 1])

        im_diff = axes[0, 2].imshow(torch.abs(qplt_g - qplt_r), aspect='auto', extent=extent, vmin=0, vmax=v_max)
        axes[0, 2].set_title('True - Generated')
        axes[0, 2].set_xlabel('Time')
        axes[0, 2].set_ylabel('Frequency')
        fig.colorbar(im_diff, ax=axes[0, 2])

        axes[0, 3].axis('off')

        aux_channel_index = 1
        row = 1
        col = 0

        for j in range(num_aux_channels):
            qplt_aux = torch.flipud(batch[i, aux_channel_index].detach().cpu() * channel_means[aux_channel_index])
            im_aux = axes[row, col].imshow(qplt_aux, aspect='auto', extent=extent, vmin=0, vmax=v_max)
            axes[row, col].set_title(f'aux{aux_channel_index}')
            axes[row, col].set_xlabel('Time')
            axes[row, col].set_ylabel('Frequency')
            fig.colorbar(im_aux, ax=axes[row, col])

            aux_channel_index += 1
            col += 1
            if col == 4:
                col = 0
                row += 1

        for r in range(row, 1 + num_rows_aux):
            for c in range(4):
                axes[r, c].axis('off')

        plt.tight_layout()
        #plt.savefig(f'Inference.pdf')
        plt.show()

In [None]:
# To plot only the first 4 aux channels:
plot_images(generated_post, batch, channel_means, num_aux_channels=0,num_images=10)

# Accuracy performance

#### Genrate data using NN model

test set

In [None]:
generated_tensor_pre = torch.tensor([]).to('cpu')  # Initialize an empty tensor
for batch in tqdm(test_dataloader):
    generated_post = generate_data(generator_2d, batch.detach().cpu(), normalize=False).to('cpu')
    generated_tensor_pre = torch.cat((generated_tensor_pre, generated_post), dim=0)

background set

In [None]:
background_tensor = torch.tensor([]).to('cpu')  # Initialize an empty tensor
for batch in tqdm(test_background_dataloader):
    background_post = generate_data(generator_2d, batch.detach().cpu()/norm_factor, normalize=False).to('cpu')
    background_tensor = torch.cat((background_tensor, background_post), dim=0)

In [None]:
background_tensor.shape[0]/generated_tensor_pre.shape[0]

In [None]:
generated_tensor=torch.cat((generated_tensor_pre,background_tensor), dim=0)
generated_tensor.shape

Labels

In [None]:
labels=torch.cat((torch.ones(generated_tensor_pre.shape[0]),torch.zeros(background_tensor.shape[0])))
labels.shape

In [None]:
target_tensor=torch.cat((test_data_2d_norm[:,0,:,:].unsqueeze(1),background[:,0,:,:].unsqueeze(1)/norm_factor))
target_tensor.shape

### Define clustering NN for accuracy check

In [None]:
from skimage.measure import label

class ClusterAboveThreshold(nn.Module):
    def __init__(self, threshold, min_cluster_area):
        super(ClusterAboveThreshold, self).__init__()
        self.threshold = threshold
        self.min_cluster_area = min_cluster_area

    def forward(self, input_tensor):
        # Create a boolean mask based on the threshold
        mask = input_tensor.squeeze(1) >= self.threshold  # Squeeze the channel dimension
        #for i in range(mask.shape[0]):
            #print(torch.count_nonzero(mask[i]))
        
        # Label connected components for the entire batch
        labeled_masks, num_features = label(mask.cpu().numpy(), connectivity=2, return_num=True)

        
        # Reshape labeled_masks to [batch_size, num_features, height, width]
        labeled_masks = torch.tensor(labeled_masks, dtype=torch.long, device=input_tensor.device)
        labeled_masks = labeled_masks.view(input_tensor.size(0), -1, input_tensor.size(-2), input_tensor.size(-1))
        
        # Give unique labels to each cluster acroos each item
        labeled_masks_sorted=self.sort_labels(labeled_masks)
        
        batch_clusters=[]
        for idx, item in enumerate(input_tensor):
            item_clusters=[]
            for cluster in torch.unique(labeled_masks_sorted[idx]):
                if cluster==torch.tensor(0):
                    continue
                cluster_pixels = (labeled_masks_sorted[idx] == cluster)

                # Compute the total area of the cluster
                cluster_area = torch.sum(cluster_pixels)


                # Check if the cluster area is greater than the threshold area
                if cluster_area > self.min_cluster_area:
                    # Flatten the tensor
                    masked_item = item.masked_fill(~cluster_pixels, 0)
                    flattened_tensor = masked_item.flatten()
                    
                    # Compute the maximum value and its index across all dimensions
                    max_value, max_index_flat = torch.max(flattened_tensor, dim=0)

                    # Unravel the flattened index to get the original indices
                    max_index_unraveled = np.unravel_index(max_index_flat.item(), item.shape)
                    item_clusters.append((max_value,max_index_unraveled[1],max_index_unraveled[2]))
            
            batch_clusters.append(item_clusters)   
            

        
        return batch_clusters
    
    def sort_labels(self, labeled_masks):
        # Rename clusters for each item in the batch
        offset = 0
        for i in range(labeled_masks.shape[0]):  # Iterate over batch dimension
            item_labeled_masks = labeled_masks[i]  # Get labeled mask for the current item
            unique_labels = torch.unique(item_labeled_masks)  # Get unique labels in the item's labeled mask
            
            # Exclude background class label (label 0)
            unique_labels = unique_labels[unique_labels != 0]


            # Rename the labels with an offset, starting from 1
            renamed_labels = item_labeled_masks.clone()
            for j, label in enumerate(unique_labels, start=1):
                mask = item_labeled_masks == label
                renamed_labels[mask] = j + offset

            # Update the offset for the next item
            offset += len(unique_labels)

            # Update the labeled mask for the current item
            labeled_masks[i] = renamed_labels

        # labeled_masks now contains the labeled masks with renamed clusters
        return labeled_masks


**Define normalisation factors**

In [None]:
norm_factor=torch.tensor(channel_means[0]).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
norm_factor.shape

**Define Tensors for accuracy check**

In [None]:
#normalised data
abs_difference_tensor=torch.abs((generated_tensor-target_tensor)*norm_factor)
abs_difference_tensor.shape

### Calculate model accuracy

Define classifiers

In [None]:
def fraction_empty_lists(list_of_lists):
    # Count the number of non-empty lists
    non_empty_count = sum(1 for sublist in list_of_lists if sublist)
    
    # Calculate the fraction
    fraction =1- (non_empty_count / len(list_of_lists))
    
    return fraction

In [None]:
def glitch_classifier(list_of_lists):
    list=[1 if sublist else 0 for sublist in list_of_lists ]
    #print(len(list))
    return list

In [None]:
def classifier_accuracy(predictions,labels):
    list_check=[(x + y)%2 for x, y in zip(predictions, labels)]
    list_check=np.array(list_check)
    accuracy=1-np.mean(list_check)
    return accuracy

In [None]:
def confusion_matrix(predictions,labels):
    cm={}
    for x,y in zip(predictions,labels):
        if x==0:
            if y==0:
                cm['TN'] = cm.get('TN', 0) + 1
            elif y==1:
                cm['TP'] = cm.get('TP', 0) + 1
        elif x==1:
            if y==0:
                cm['FP'] = cm.get('FP', 0) + 1
            elif y==1:
                cm['FN'] = cm.get('FN', 0) + 1
    return cm

In [None]:
from tqdm import tqdm
def roc_curve(generated,labels,threshold_set=(10,20.1,0.1),min_cluster_area=10):
    
    roc_dict={}
    for threshold in tqdm(np.arange(threshold_set[0],threshold_set[1],threshold_set[2])):
        try:
            del cluster_nn
        except:
            pass
        cluster_nn= ClusterAboveThreshold(threshold, min_cluster_area).to('cpu')  
        predictions=glitch_classifier(cluster_nn(generated))
        cm=confusion_matrix(predictions,labels)
        roc_dict[threshold]=cm
    return roc_dict
        

In [None]:
import numpy as np

# Assuming ClusterAboveThreshold, abs_difference_tensor, generated_tensor, target_tensor, and norm_factor are defined elsewhere

def analyze_clusters_for_thresholds(abs_difference_tensor, generated_tensor, target_tensor, norm_factor, min_cluster_area=1):
    """
    Analyzes cluster data for a range of threshold values.

    Args:
        abs_difference_tensor: Tensor of absolute differences.
        generated_tensor: Tensor of generated data.
        target_tensor: Tensor of target data.
        norm_factor: Normalization factor.
        min_cluster_area: Minimum cluster area.

    Returns:
        A tuple containing two lists:
            - cluster_abs_diff_accuracies: List of classifier accuracies for abs_difference_tensor.
            - clusters_generated_accuracies: List of classifier accuracies for generated_tensor.
    """

    cluster_abs_diff_accuracies = []
    clusters_generated_accuracies = []

    for threshold in tqdm(range(1, 51)):
        try:
            del cluster_nn
        except:
            pass

        # pipeline
        cluster_nn = ClusterAboveThreshold(threshold, min_cluster_area).to('cpu')  # Assuming to('cpu') is needed

        # get clusters
        clusters_abs_diff = cluster_nn(abs_difference_tensor)
        clusters_generated = cluster_nn(generated_tensor * norm_factor)
        clusters_target = cluster_nn(target_tensor * norm_factor)


        #set labels
        target_labels = glitch_classifier(clusters_target)  # Use target clusters as labels
        diff_labels= [0 for k in range(len(target_labels))]
        
        # Calculate classifier accuracy for abs_difference_tensor
        abs_diff_predictions = glitch_classifier(clusters_abs_diff)
        abs_diff_accuracy = classifier_accuracy(abs_diff_predictions, diff_labels)
        cluster_abs_diff_accuracies.append(abs_diff_accuracy)

        # Calculate classifier accuracy for generated_tensor
        generated_predictions = glitch_classifier(clusters_generated)
        generated_accuracy = classifier_accuracy(generated_predictions, target_labels)
        clusters_generated_accuracies.append(generated_accuracy)

    return cluster_abs_diff_accuracies, clusters_generated_accuracies


# Example usage (replace with your actual data and ClusterAboveThreshold definition):
# abs_difference_tensor = ...  # Your tensor data
# generated_tensor = ...  # Your tensor data
# target_tensor = ...  # Your tensor data
# norm_factor = ...  # Your normalization factor

cluster_abs_diff_accuracies, clusters_generated_accuracies = analyze_clusters_for_thresholds(abs_difference_tensor, generated_tensor, target_tensor, norm_factor
)

# print("Cluster Abs Diff Accuracies:", cluster_abs_diff_accuracies)
# print("Clusters Generated Accuracies:", clusters_generated_accuracies)

In [None]:
# Create the plot
thresholds = range(1, 51)  # The SNR^2 thresholds

plt.figure(figsize=(10, 6))  # Adjust figure size for better visualization

plt.plot(thresholds, cluster_abs_diff_accuracies, label="Denoising Accuracy", marker='o', linestyle='-')
plt.plot(thresholds[5:], clusters_generated_accuracies[5:], label="Vetoing Accuracy", marker='x', linestyle='--')
#plt.plot(thresholds, cluster_abs_diff_accuracies_veto, label="Vetoing Accuracy for veto correctly flagged data", marker='p', linestyle='--')

plt.xlabel(r"$\mathrm{SNR^2}$ Threshold", fontsize=20)
plt.ylabel("Accuracy", fontsize=20)
plt.title("Accuracy vs. $\mathrm{SNR^2}$", fontsize=22)
plt.xticks(np.arange(min(thresholds), max(thresholds)+1, 5.0), fontsize=16) # set ticks every 5
plt.yticks(fontsize=16)
plt.grid(True)  # Add a grid for better readability
plt.legend(fontsize=20)  # Show the legend
plt.tight_layout() # Adjust layout to prevent labels from overlapping
plt.show()

**Plot successfully cleaned data**

In [None]:
def indices_of_empty_sublists(list_of_sublists):
    # Initialize an empty list to store the indices
    indices = []
    non_empty_indices=[]
    
    # Iterate over the elements and their indices
    for i, sublist in enumerate(list_of_sublists):
        if not sublist:  # Check if the sublist is empty
            indices.append(i)  # Append the index to the list
        else:
            non_empty_indices.append(i)
            
    
    return indices,non_empty_indices

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# ... (your existing code for indices_of_empty_sublists, data loading, etc.)

empty_idx, non_empty_idx = indices_of_empty_sublists(clusters_abs_diff)
j = 0

# Parameters (same as before)
v_max = 25
t_min = 0
t_max = train_data_2d.shape[-1]
f_min = 0
f_max = train_data_2d.shape[-2]

# Frequency settings (same as before)
f_range = (8, 500)
desired_ticks = [8, 20, 30, 50, 100, 200, 500]
log_base = 10  # Or np.e for natural log

def set_frequency_ticks(ax, f_range, desired_ticks, log_base, new_height):
    """Sets the y-axis (frequency) ticks and labels."""
    log_f_range = (np.log(f_range[0]) / np.log(log_base), np.log(f_range[1]) / np.log(log_base))
    log_desired_ticks = np.log(desired_ticks) / np.log(log_base)

    y_ticks_pixel = np.interp(log_desired_ticks, log_f_range, [new_height - 1, 0])

    y_ticks_pixel = [int(p) for p in y_ticks_pixel]
    y_ticks_pixel = np.clip(y_ticks_pixel, 0, new_height - 1)

    y_ticks_pixel, unique_indices = np.unique(y_ticks_pixel, return_index=True)
    desired_ticks_used = np.array(desired_ticks)[unique_indices].tolist()

    ax.grid(True, axis='y', which='both')
    ax.set_yticks(y_ticks_pixel)
    ax.set_yticklabels(np.flipud(desired_ticks_used),fontsize=16)
    ax.invert_yaxis() # Important: Invert y-axis for spectrograms


for i in empty_idx:
    if j == 30:
        break

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Define fig and axes HERE

    # ... (your existing plotting code using axes[0], axes[1], axes[2])

    # Plotting
    im0 = axes[0].imshow((target_tensor * norm_factor)[i].squeeze(0), cmap='viridis', vmin=0, vmax=v_max, aspect='auto')
    axes[0].set_title('Target',fontsize=22)
    im1 = axes[1].imshow((generated_tensor * norm_factor)[i].squeeze(0), cmap='viridis', vmin=0, vmax=v_max, aspect='auto')
    axes[1].set_title('Generated',fontsize=22)
    im2 = axes[2].imshow((abs_difference_tensor)[i].squeeze(0), cmap='viridis', vmin=0, vmax=v_max, aspect='auto')  # Store the image for colorbar
    axes[2].set_title('Cleaned',fontsize=22)

    for ax in axes: # Apply frequency ticks to all subplots
        set_frequency_ticks(ax, f_range, desired_ticks, log_base, target_tensor.shape[-2]) # Use target_tensor or generated_tensor shape
        ax.set_xticks([0, 31, 63])
        ax.set_xticklabels([0, 0.5, 1],fontsize=16)
        ax.set_xlabel("Time (s)",fontsize=20) # Add X label
        ax.set_ylabel("Frequency (Hz)",fontsize=20) # Add X label
    
    fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.03)  # Adjust fraction and pad as needed
    
    plt.tight_layout() # Adjust subplot params for a tight layout
    plt.show()
    j += 1

#### IOU

In [None]:
import torch

def calculate_iou_2d_non0(generated_tensor, target_tensor, threshold=20):
    """
    Calculate Intersection over Union (IoU) in the 2D plane at the specified intensity threshold for each element in the batch.

    Parameters:
    - generated_tensor: Tensor containing generated spectrograms (batch_size x 1 x height x width)
    - target_tensor: Tensor containing target spectrograms (batch_size x 1 x height x width)
    - threshold: Intensity threshold for determining the binary masks

    Returns:
    - mean_iou: Mean Intersection over Union (IoU) across all elements in the batch
    - zero_union_count: Count of elements in the batch with a union of 0
    """
    # Convert intensity threshold to tensor
    threshold_tensor = torch.tensor(threshold, device=generated_tensor.device)

    # Create binary masks based on the intensity threshold
    gen_mask = generated_tensor >= threshold_tensor
    tgt_mask = target_tensor >= threshold_tensor

    # Convert masks to float tensors
    gen_mask = gen_mask.float()
    tgt_mask = tgt_mask.float()
    print(f'{gen_mask.shape=}')    


    
    # Calculate intersection and union for each element in the batch
    intersection = torch.sum(gen_mask * tgt_mask, dim=(1, 2, 3))
    union = torch.sum(gen_mask, dim=(1, 2, 3)) + torch.sum(tgt_mask, dim=(1, 2, 3)) - intersection
    print(f'{(intersection/union)[:100]=}')
    print(f'{(intersection)[:100]=}')
    print(f'{(union)[:100]=}')


    # Find elements with union 0
    zero_union_mask = union == 0
    print(f'{zero_union_mask=}')
    zero_union_count = torch.sum(zero_union_mask).item()

    # Exclude elements with union 0 from the IoU calculation
    iou = intersection / union
    iou[zero_union_mask] = 0

    # Take mean over non-zero elements in the batch
    non_zero_count = len(union) - zero_union_count
    mean_iou = torch.sum(iou) / non_zero_count if non_zero_count > 0 else 0
    
    # Count elements with IoU above 0.9
    above_09_count = torch.sum(iou > 0.9).item()

    return mean_iou.item()#, zero_union_count, above_09_count
    #return mean_iou.item(), zero_union_count


In [None]:
calculate_iou_2d_non0(generated_tensor*norm_factor, target_tensor*norm_factor, 15)