In [None]:
from IPython.core.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, wavelet_type='haar'):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.wavelet_type = wavelet_type

        # Parameters for wavelet transformation
        self.scale = nn.Parameter(torch.ones(out_features, in_features))
        self.translation = nn.Parameter(torch.zeros(out_features, in_features))
        self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))

        nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))

        # Batch normalization
        self.bn = nn.BatchNorm1d(out_features)

    def wavelet_transform(self, x):
        x_expanded = x.unsqueeze(1) if x.dim() == 2 else x
        translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
        scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
        x_scaled = (x_expanded - translation_expanded) / scale_expanded

        if self.wavelet_type == 'haar':
            wavelet = self.haar_wavelet(x_scaled)
        elif self.wavelet_type == 'coiflet':
            wavelet = self.coiflet_wavelet(x_scaled)
        elif self.wavelet_type == 'biorthogonal':
            wavelet = self.biorthogonal_wavelet(x_scaled)
        elif self.wavelet_type == 'daubechies':
            wavelet = self.daubechies_wavelet(x_scaled)
        elif self.wavelet_type == 'mexican_hat':
            term1 = ((x_scaled ** 2)-1)
            term2 = torch.exp(-0.5 * x_scaled ** 2)
            wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2
        elif self.wavelet_type == 'morlet':
            omega0 = 5.0  # Central frequency
            real = torch.cos(omega0 * x_scaled)
            envelope = torch.exp(-0.5 * x_scaled ** 2)
            wavelet = envelope * real
        elif self.wavelet_type == 'dog':
            dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)
            wavelet = dog
        elif self.wavelet_type == 'meyer':
            pi = math.pi
            v = torch.abs(x_scaled)
            wavelet = torch.sin(pi * v) * self.meyer_aux(v)
        elif self.wavelet_type == 'symlet4':
            wavelet = self.symlet4_wavelet(x_scaled)
        elif self.wavelet_type == 'shannon':
            pi = math.pi
            sinc = torch.sinc(x_scaled / pi)
            window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)
            wavelet = sinc * window
        else:
            raise ValueError("Unsupported wavelet type")

        wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
        wavelet_output = wavelet_weighted.sum(dim=2)
        return wavelet_output

    def haar_wavelet(self, x):
        return torch.where(x < 0.5, 1., torch.where(x < 1., -1., 0.))

    def coiflet_wavelet(self, x):
        c = [0.038580, -0.126969, -0.077974, 0.417845, 0.812866, 0.468429]
        return sum(c[i] * x**i for i in range(len(c)))

    def biorthogonal_wavelet(self, x):
        return 3*x**2 - 2*x**3

    def symlet4_wavelet(self, x):
        # Symlet 4 (sym4) wavelet coefficients
        coeffs = [
            -0.075765714789273,
            0.029635527645999,
            0.497618667632015,
            0.803738751805916,
            0.297857795605542,
            -0.099219543576847,
            -0.012603967262038,
            0.032223100604042
        ]

        # Convolve the input with the wavelet coefficients
        return sum(coeffs[i] * x**i for i in range(len(coeffs)))

    def daubechies_wavelet(self, x):
        sqrt3 = math.sqrt(3)
        coeff1 = (1 + sqrt3) / (4 * math.sqrt(2))
        coeff2 = (3 + sqrt3) / (4 * math.sqrt(2))
        coeff3 = (3 - sqrt3) / (4 * math.sqrt(2))
        coeff4 = (1 - sqrt3) / (4 * math.sqrt(2))
        return coeff1 - coeff2 * x + coeff3 * x**2 - coeff4 * x**3

    def meyer_aux(self, v):
        pi = math.pi
        return torch.where(v <= 1/2, torch.ones_like(v), torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * self.nu(2 * v - 1))))

    def nu(self, t):
        return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)

    def forward(self, x):
        wavelet_output = self.wavelet_transform(x)
        base_output = F.linear(x, self.weight1)
        combined_output = wavelet_output + base_output
        return self.bn(combined_output)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, RobustScaler, MaxAbsScaler, PowerTransformer, MinMaxScaler

class Generator(nn.Module):
    def __init__(self, latent_dim, num_features, num_classes):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(True),
            nn.Linear(128, num_features),
            nn.Tanh()
        )

    def forward(self, z, labels):
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        output = self.model(x)
        return output

class Discriminator(nn.Module):
    def __init__(self, num_features, num_classes):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_features + num_classes, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        c = self.label_embedding(labels)
        x = torch.cat([x, c], 1)
        output = self.model(x)
        return output

def train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes):
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    generator.train()
    discriminator.train()

    for epoch in range(num_epochs):
        for i, (data, labels) in enumerate(data_loader):
            batch_size = data.size(0)

            real_data = data.to(device)
            real_labels = labels.to(device)
            valid = torch.ones(batch_size, 1, device=device)
            fake = torch.zeros(batch_size, 1, device=device)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_data, real_labels), valid)
            z = torch.randn(batch_size, latent_dim, device=device)
            gen_labels = torch.randint(0, num_classes, (batch_size,), device=device).to(device)
            fake_data = generator(z, gen_labels)
            fake_loss = criterion(discriminator(fake_data, gen_labels), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            z = torch.randn(batch_size, latent_dim, device=device)
            g_loss = criterion(discriminator(generator(z, gen_labels), gen_labels), valid)
            g_loss.backward()
            optimizer_G.step()

        print(f'Epoch [{epoch+1}/{num_epochs}] Discriminator Loss: {d_loss.item()}, Generator Loss: {g_loss.item()}')

def generate_synthetic_data(generator, num_samples, label, num_features, latent_dim, device):
    z = torch.randn(num_samples, latent_dim, device=device)
    labels = torch.full((num_samples,), label, dtype=torch.long, device=device)
    with torch.no_grad():
        synthetic_data = generator(z, labels).detach().cpu().numpy()
    return synthetic_data

def load_and_preprocess_data(path,target_class_count=1500, latent_dim=100, num_epochs=100, batch_size=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv(path)
    numerical_columns = [
        'peak_time', 'peak_time_ns', 'start_time', 'start_time_ns', 'duration',
        'peak_frequency', 'central_freq', 'bandwidth', 'amplitude', 'snr', 'q_value',
        '1400Ripples', '1080Lines', 'Air_Compressor', 'Blip', 'Chirp', 'Extremely_Loud',
        'Helix', 'Koi_Fish', 'Light_Modulation', 'Low_Frequency_Burst', 'Low_Frequency_Lines',
        'No_Glitch', 'None_of_the_Above', 'Paired_Doves', 'Power_Line', 'Repeating_Blips',
        'Scattered_Light', 'Scratchy', 'Tomte', 'Violin_Mode', 'Wandering_Line', 'Whistle'
    ]

    df_numerical = df[numerical_columns]
    scaler = MinMaxScaler()
    df_numerical = scaler.fit_transform(df_numerical)


    string_labels = df['ml_label']
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(string_labels)

    num_features = df_numerical.shape[1]
    num_classes = np.unique(labels).size

    tensor_data = torch.tensor(df_numerical.astype(np.float32))
    tensor_labels = torch.tensor(labels)
    dataset = TensorDataset(tensor_data, tensor_labels)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    generator = Generator(latent_dim, num_features, num_classes).to(device)
    discriminator = Discriminator(num_features, num_classes).to(device)

    train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes)

    for class_label in range(num_classes):
        current_count = np.sum(labels == class_label)
        if current_count < target_class_count:
            num_samples_needed = target_class_count - current_count
            synthetic_data = generate_synthetic_data(generator, num_samples_needed, class_label, num_features, latent_dim, device)
            synthetic_labels = np.full(num_samples_needed, class_label)
            df_numerical = np.vstack([df_numerical, synthetic_data])
            labels = np.concatenate([labels, synthetic_labels])

    return df_numerical, labels



In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, PowerTransformer, RobustScaler
from tqdm import tqdm
np.random.seed(0)

# Create pairs of augmented data
def create_pairs(data):
    augmented_original_pairs = []
    for row in data:
        augmented_row1 = augment_data(row)
        augmented_row2 = augment_data(row)
        augmented_original_pairs.append((augmented_row1, augmented_row2, row))
    return np.array(augmented_original_pairs)

# Define an augmentation function
def augment_data(data, noise_level=0.1):
    noise = np.random.normal(0, noise_level, data.shape)
    return data + noise

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Dataset class
class TrainDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['pos_1'])

    def __getitem__(self, idx):
        pos_1 = self.data['pos_1'][idx]
        pos_2 = self.data['pos_2'][idx]
        return torch.tensor(pos_1, dtype=torch.float), torch.tensor(pos_2, dtype=torch.float)

class MemoryDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['true_data'])

    def __getitem__(self, idx):
        true_data = self.data['true_data'][idx]
        target = self.data['target'][idx]
        return torch.tensor(true_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)

class TestDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['true_data'])

    def __getitem__(self, idx):
        raw_data = self.data['true_data'][idx]
        target = self.data['target'][idx]
        return torch.tensor(raw_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)

class NumericalModel(nn.Module):
    def __init__(self, input_dim, feature_dim=128, wavelet_type='mexican_hat'):
        super(NumericalModel, self).__init__()

        self.channels = 1  # Number of channels in the input
        self.length = input_dim  # Length of the sequence

        # Adjusted feature extractor for 1D data with Conv1D and MaxPool1D
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(self.channels, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(256, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(256 * (self.length // 4), 512),
            nn.ReLU()
        )

        # Updated projection head with wavelet-based KANLinear
        self.projection_head = nn.Sequential(
            KANLinear(512, 256, wavelet_type=wavelet_type),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            KANLinear(256, feature_dim, wavelet_type=wavelet_type)
        )

    def forward(self, x):
        # Reshape input to (batch size, channels, length)
        x = x.view(-1, self.channels, self.length)
        feature = self.feature_extractor(x)
        out = self.projection_head(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)



In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from geomloss import SamplesLoss

# Function to create a mask for negative samples
def get_negative_mask(batch_size):
    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
    for i in range(batch_size):
        negative_mask[i, i] = 0
        negative_mask[i, i + batch_size] = 0
    negative_mask = torch.cat((negative_mask, negative_mask), 0)
    return negative_mask

def train(net, data_loader, train_optimizer, device, epoch, epochs, batch_size):
    net.train()
    total_loss, total_num = 0.0, 0
    train_bar = tqdm(data_loader)

    for pos_1, pos_2 in train_bar:
        pos_1, pos_2 = pos_1.to(device), pos_2.to(device)
        feature_1, out_1 = net(pos_1)
        feature_2, out_2 = net(pos_2)

        # neg score
        out = torch.cat([out_1, out_2], dim=0)
        actual_batch_size = pos_1.size(0)
        neg = torch.exp(torch.mm(out, out.t().contiguous()))
        old_neg = neg.clone()
        mask = get_negative_mask(actual_batch_size).to(device)
        neg = neg.masked_select(mask).view(2 * actual_batch_size, -1)

        # pos score
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) )
        pos = torch.cat([pos, pos], dim=0)
        Ng = neg.sum(dim=-1)

        # loss_wasserstein distance
        loss_wasserstein = SamplesLoss(loss="sinkhorn", p=2)
        pos_reshaped = pos.unsqueeze(1)
        Ng_reshaped = Ng.unsqueeze(1)
        loss_distances = loss_wasserstein(pos_reshaped, Ng_reshaped)

        # contrastive loss
        loss = torch.log(loss_distances) - (-torch.log(pos / (pos + Ng))).mean()

        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += actual_batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description(f'Train Epoch: [{epoch}/{epochs}] Loss: {total_loss / total_num:.4f}')

    return total_loss / total_num


In [None]:
import torch
from tqdm import tqdm
from sklearn.metrics import multilabel_confusion_matrix
import numpy as np

def d_index_score_cal(y_true, y_pred):
    if len(y_true) == 0 or len(y_pred) == 0:
        return 0

    # Ensure y_pred is not a list of lists
    if all(isinstance(pred, list) for pred in y_pred):
        y_pred = [1 if true_label in pred else 0 for true_label, pred in zip(y_true, y_pred)]
    else:
        # If y_pred is already flat, ensure it's the same length as y_true
        assert len(y_true) == len(y_pred), "y_true and y_pred must have the same length"

    conf_matrix = multilabel_confusion_matrix(y_true, y_pred)
    d_indices = []
    for i in range(len(conf_matrix)):
        tn, fp, fn, tp = conf_matrix[i].ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 1e-7
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 1e-7
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        d = np.log2(1 + accuracy) + np.log2(1 + (sensitivity + specificity) / 2)
        d_indices.append(d)
    avg_d_index = sum(d_indices) / len(d_indices) if len(d_indices) > 0 else 0
    return avg_d_index


def create_exact_match_list(top2_predictions, targets):
    exact_matches = []
    for pred_pair, target in zip(top2_predictions, targets):
        # Find the matching prediction, append -1 if no match is found
        match = next((pred for pred in pred_pair if pred == target), -1)
        exact_matches.append(match)
    return exact_matches


In [None]:
import torch
from tqdm import tqdm

def flatten_predictions(predictions, true_labels):
    TP, FP, FN = 0, 0, 0
    for i, preds in enumerate(predictions):
        true_label = true_labels[i]

        # Count true positives, false positives, false negatives
        if true_label in preds:
            TP += 1
        else:
            FN += 1
        FP += len([p for p in preds if p != true_label])

    return  TP, FP, FN

def test_knn(net, memory_data_loader, test_data_loader, k, c):
    net.eval()
    total_top1, total_top2, total_top3, total_num = 0.0, 0.0, 0.0, 0
    original_labels_list = []
    top1_predictions, top2_predictions, top3_predictions = [], [], []
    feature_bank = []

    with torch.no_grad():
        # Generate feature bank
        for data, _ in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))[0]
            feature_bank.append(feature)
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()

        feature_labels = torch.tensor([label for _, label in memory_data_loader.dataset], device=feature_bank.device)

        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)[0]

            total_num += data.size(0)
            sim_matrix = torch.mm(feature, feature_bank)
            sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
            sim_weight = (sim_weight).exp()

            one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0)
            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)
            pred_labels = pred_scores.argsort(dim=-1, descending=True)

            for i in range(data.size(0)):
                original_labels_list.append(target[i].item())
                top1_predictions.append([pred_labels[i, 0].item()])
                top2_predictions.append(pred_labels[i, :2].tolist())
                top3_predictions.append(pred_labels[i, :3].tolist())

            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_top2 += torch.sum((pred_labels[:, :2] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_top3 += torch.sum((pred_labels[:, :3] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

    TP_top1, FP_top1, FN_top1 = flatten_predictions(top1_predictions, original_labels_list)
    TP_top2, FP_top2, FN_top2 = flatten_predictions(top2_predictions, original_labels_list)
    TP_top3, FP_top3, FN_top3 = flatten_predictions(top3_predictions, original_labels_list)

    # Flatten predictions for d-index calculation
    flat_top1_predictions = [pred[0] for pred in top1_predictions]
    flat_top2_predictions = create_exact_match_list(top2_predictions, original_labels_list)
    flat_top3_predictions = create_exact_match_list(top3_predictions, original_labels_list)

    precision_top1, recall_top1, f1_top1 = calculate_metrics(TP_top1, FP_top1, FN_top1)
    precision_top2, recall_top2, f1_top2 = calculate_metrics(TP_top2, FP_top2, FN_top2)
    precision_top3, recall_top3, f1_top3 = calculate_metrics(TP_top3, FP_top3, FN_top3)

    # Calculate d-indices
    d_index_top1 = d_index_score_cal(original_labels_list, flat_top1_predictions)
    d_index_top2 = d_index_score_cal(original_labels_list, flat_top2_predictions)
    d_index_top3 = d_index_score_cal(original_labels_list, flat_top3_predictions)


    return {
        'top1': {'accuracy': total_top1 / total_num * 100, 'd_index': d_index_top1, 'precision': precision_top1, 'recall': recall_top1, 'f1': f1_top1},
        'top2': {'accuracy': total_top2 / total_num * 100, 'd_index': d_index_top2, 'precision': precision_top2, 'recall': recall_top2, 'f1': f1_top2},
        'top3': {'accuracy': total_top3 / total_num * 100, 'd_index': d_index_top3, 'precision': precision_top3, 'recall': recall_top3, 'f1': f1_top3}
    }

def calculate_metrics(TP, FP, FN):
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1



In [None]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    epochs = 100
    k = 30

    path = 'clean_data_O1.csv'

    df_numerical, labels = load_and_preprocess_data(path)
    num_classes = len(np.unique(labels))
    augmented_original_pairs = create_pairs(df_numerical)

    learning_rates =[1e-3]
    batch_sizes = [128]

    for lr in learning_rates:
        for batch_size in batch_sizes:
            # Initialize a list to store metrics for each epoch
            all_metrics = []
            print(f"Learning Rate: {lr}, Batch Size: {batch_size}")

            # Dataset and DataLoader setup
            train_data = {"pos_1": augmented_original_pairs[:, 0], "pos_2": augmented_original_pairs[:, 1]}
            train_dataset = TrainDataset(train_data)
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

            memory_data = {"true_data": augmented_original_pairs[:, 2], 'target': labels}
            memory_dataset = MemoryDataset(memory_data)
            memory_data_loader = DataLoader(memory_dataset, batch_size=batch_size, shuffle=False)

            test_data = {"true_data": augmented_original_pairs[:, 2], 'target': labels}
            test_dataset = TestDataset(test_data)
            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

            model = NumericalModel(input_dim=df_numerical.shape[1], feature_dim=128).to(device)
            model = nn.DataParallel(model)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)

            for epoch in range(epochs):
                train_loss = train(model, train_loader, optimizer, device, epoch, epochs, batch_size)
                test_results = test_knn(model, memory_data_loader, test_loader, k, num_classes)

                # Collect metrics
                epoch_metrics = {
                    'Epoch': epoch + 1,
                    'Learning Rate': lr,
                    'Batch Size': batch_size,
                    'Top1 Accuracy': test_results['top1']['accuracy'],
                    'Top1 D-Index': test_results['top1']['d_index'],
                    'Top1 Precision': test_results['top1']['precision'],
                    'Top1 Recall': test_results['top1']['recall'],
                    'Top1 F1': test_results['top1']['f1'],
                    'Top2 Accuracy': test_results['top2']['accuracy'],
                    'Top2 D-Index': test_results['top2']['d_index'],
                    'Top2 Precision': test_results['top2']['precision'],
                    'Top2 Recall': test_results['top2']['recall'],
                    'Top2 F1': test_results['top2']['f1'],
                    'Top3 Accuracy': test_results['top3']['accuracy'],
                    'Top3 D-Index': test_results['top3']['d_index'],
                    'Top3 Precision': test_results['top3']['precision'],
                    'Top3 Recall': test_results['top3']['recall'],
                    'Top3 F1': test_results['top3']['f1'],
                }

                all_metrics.append(epoch_metrics)

                # Unpack the results
                top1_metrics = test_results['top1']
                top2_metrics = test_results['top2']
                top3_metrics = test_results['top3']

                # Access individual metrics
                top1_accuracy = top1_metrics['accuracy']
                d_index_top1 = top1_metrics['d_index']
                precision_top1 = top1_metrics['precision']
                recall_top1 = top1_metrics['recall']
                f1_top1 = top1_metrics['f1']

                top2_accuracy = top2_metrics['accuracy']
                d_index_top2 = top2_metrics['d_index']
                precision_top2 = top2_metrics['precision']
                recall_top2 = top2_metrics['recall']
                f1_top2 = top2_metrics['f1']

                top3_accuracy = top3_metrics['accuracy']
                d_index_top3 = top3_metrics['d_index']
                precision_top3 = top3_metrics['precision']
                recall_top3 = top3_metrics['recall']
                f1_top3 = top3_metrics['f1']

                # Print Top1 Metrics
                print("Top 1 Metrics:")
                print(f"  Accuracy: {top1_accuracy:.8f}%")
                print(f"  D-Index: {d_index_top1:.8f}")
                print(f"  Precision: {precision_top1:.8f}")
                print(f"  Recall: {recall_top1:.8f}")
                print(f"  F1 Score: {f1_top1:.8f}\n")

                # Print Top2 Metrics
                print("Top 2 Metrics:")
                print(f"  Accuracy: {top2_accuracy:.8f}%")
                print(f"  D-Index: {d_index_top2:.8f}")
                print(f"  Precision: {precision_top2:.8f}")
                print(f"  Recall: {recall_top2:.8f}")
                print(f"  F1 Score: {f1_top2:.8f}\n")

                # Print Top3 Metrics
                print("Top 3 Metrics:")
                print(f"  Accuracy: {top3_accuracy:.8f}%")
                print(f"  D-Index: {d_index_top3:.8f}")
                print(f"  Precision: {precision_top3:.8f}")
                print(f"  Recall: {recall_top3:.8f}")
                print(f"  F1 Score: {f1_top3:.8f}")

            # Convert the list of metrics to a DataFrame and save it as a CSV file
            df_metrics = pd.DataFrame(all_metrics)
            filename = f'model_metrics_no_drop_alpha_wave_kan_{lr}_{batch_size}.csv'
            df_metrics.to_csv(filename, index=False)
            print("Model Saved!")

if __name__ == "__main__":
    main()
