# 1. Import modules

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Training the GAN 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torch import autograd

%run ./model.ipynb
%run ./utils.ipynb

In [None]:
def segmentation(ECG, BP, CVP, patients_baseline, labels, fs=1000):
    VEs_45 = [512,1434,2798,4168,5544,7403,9271,11129,12030,12927,13816,14725,16116,17970,19363,21214]
    VEs_34_11 = [1907, 3695,4364,7696] 
    VEs_34_12 = [443,1118,1800] 
    VEs_34_13 = [1501,2171,2844,4412,6418, 6728,7994]
    VEs_26 = [1472,2120, 6964, 7216] 
    
    segmented_ecg = []
    segmented_bp = []
    segmented_cvp = []
    signals_labels = []
    patients_labels = []
    
#     before = 400
#     after = 400
    for i in range(0,len(ECG)):
        ecg_signal = np.array(ECG[i])       
        bp_signal = np.array(BP[i])
        cvp_signal = np.array(CVP[i])
        
        if labels[i] == 'A-tach 2:1 block':
            r_peaks,freq = scipy.signal.find_peaks(ecg_signal, height= 0.08*np.max(ecg_signal), distance=135)
        else:
            r_peaks, freq = scipy.signal.find_peaks(-ecg_signal, height =np.mean(ecg_signal)+np.std(ecg_signal), distance=120)
      
        
        for j in range(0,len(r_peaks)):
            if  r_peaks[j]-400>0 and r_peaks[j]+400<len(ecg_signal) and r_peaks[j]+600<len(bp_signal):
                segmented_ecg.append(ecg_signal[r_peaks[j]-400:r_peaks[j]+400])
                segmented_bp.append(bp_signal[r_peaks[j]-200:r_peaks[j]+600]) 
                segmented_cvp.append(cvp_signal[r_peaks[j]-200:r_peaks[j]+600])

                patients_labels.append(patients_baseline[i])
                
                if patients_baseline[i]==34 and i==11 and r_peaks[j] in VEs_34_11:
                    signals_labels.append('SR with VEs')
                elif patients_baseline[i]==34 and i==12 and r_peaks[j] in VEs_34_12:
                    signals_labels.append('SR with VEs')
                elif patients_baseline[i]==34 and i==13 and r_peaks[j] in VEs_34_13:
                    signals_labels.append('SR with VEs')
                elif patients_baseline[i]==45 and r_peaks[j] in VEs_45:
                    signals_labels.append('SR with VEs')  
                elif (patients_baseline[i]==45 or patients_baseline[i]==34) and r_peaks[j] not in VEs_45 and r_peaks[j] not in VEs_34_11 and r_peaks[j] not in VEs_34_12 and r_peaks[j] not in VEs_34_13:
                    signals_labels.append('SR') 
                elif patients_baseline[i]==26 and r_peaks[j] in VEs_26:
                    signals_labels.append('paced')
                else:
                    signals_labels.append(labels[i])            

    items = [x for x in enumerate(signals_labels) if 'SR with VEs' in x]
    for k in range(len(items)):
        signals_labels[items[k][0]-1] = 'SR with VEs'
        signals_labels[items[k][0]+1] = 'SR with VEs'
    
    return segmented_ecg, segmented_bp, segmented_cvp, signals_labels,patients_labels

# 2. Read the baseline signals from all the patients

In [None]:
# Set path and patients numbers 
path_root = 'D:/WGAN CODE- FROM SERVER FINAL/Datasets/Harefield/'
patients = ['P09','P10','P11','P12','P25','P26','P31','P33','P34','P35','P37','P38','P39','P44', 
            'P45','P46','P47','P51','P52','P53', 'P88', '100','101','102']

# Select which signals you want to extract
baseline =['baseline','intrinsic','underlying','ULR']  

# Get all the paths for the baseline signals
baseline_files = get_files(path_root, patients,baseline) 

# Read and store all the ecg signals into both a df and a list
[BP,patients_baseline] = read_data(baseline_files,'/bp_dist.txt')
[ECG,_] = read_data(baseline_files,'/ecg.txt')
[CVP,_] = read_data(baseline_files,'/bp_prox.txt')

[BP, CVP] =calibrate_signals(BP, CVP, scale_factor_BP=40, scale_factor_CVP=8)

# Read the details of the patients in here:
dataset_details = pd.read_csv('D:/WGAN CODE- FROM SERVER FINAL/Datasets/Harefield//PACESIM patients anonymised - Pacemaker patients.csv')

In [None]:
# # Extract labels from dataframe
labels_signals = []
for i in range(len(patients_baseline)):
    for j in range(len(dataset_details['Pacesim ID'])):
        if patients_baseline[i] == int(re.findall(r'\d+',dataset_details['Pacesim ID'][j])[0]):
            labels_signals.append(dataset_details['ULR'][j])
            
details_df = dict()
details_df['Patient'] = patients_baseline
details_df['labels'] = labels_signals
details_df = pd.DataFrame(details_df)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

x=-1

fs = 1000
t = np.linspace(0, (len(ECG[x][:10000])-1)*(1/fs), len(ECG[x][:10000]))

plt.figure(figsize=[12,8])

# Subplot for ECG
plt.subplot(311)
plt.plot(t, ECG[x][:10000])
plt.ylabel('ECG (mV)', fontsize=16, labelpad=10)  # Increase label padding
plt.margins(x=0)

# Subplot for ABP
plt.subplot(312)
plt.plot(t, BP[x][:10000])
plt.ylabel('ABP (mmHG)', fontsize=16, labelpad=15)  # Match label padding
plt.margins(x=0)

# Subplot for CVP
plt.subplot(313)
plt.plot(t, CVP[x][:10000])
plt.ylabel('CVP (mmHG)', fontsize=16, labelpad=25)  # Match label padding
plt.xlabel('Time (s)', fontsize=16)
plt.margins(x=0)

# Adjust the layout to align y-labels
plt.subplots_adjust(left=0.2, right=0.95, top=0.95, bottom=0.1, hspace=0.4)

plt.show()


In [None]:
# Down-sample the signals to 200 datapoints (initial length=800)
# BP = down_sample(BP)
# ECG = down_sample(ECG)

[ECG_denoised, patients_baseline,labels_signals] = cut_noise(ECG,patients_baseline,labels_signals)
[BP_denoised, _, _] = cut_noise(BP,patients_baseline,labels_signals)
[CVP_denoised, _, _] = cut_noise(CVP,patients_baseline,labels_signals)

# 3. Segment the signals

In [None]:
[segmented_ecg, segmented_bp, segmented_cvp, labels_beats, patients_labels] = segmentation(ECG_denoised, BP_denoised,CVP_denoised, patients_baseline,labels_signals) 

In [None]:
selected_labels = [i for i in range(len(labels_beats)) if labels_beats[i]=="SR with VEs" or labels_beats[i]=="SR" or labels_beats[i]=="SR with LBBB" or labels_beats[i]=="A-tach 2:1 block" or labels_beats[i]=="AF"]
segmented_ecg = [segmented_ecg[i] for i in selected_labels]
segmented_bp = [segmented_bp[i] for i in selected_labels] 
segmented_cvp = [segmented_cvp[i] for i in selected_labels] 

labels_beats = [labels_beats[i] for i in selected_labels] 
patients_labels = [patients_labels[i] for i in selected_labels] 

# 4. Encode the labels

In [None]:
labels_beats =  np.array(labels_beats)
unique, counts = np.unique(labels_beats, return_counts=True)
print(unique, counts)

In [None]:
from sklearn import preprocessing
enc = preprocessing.LabelEncoder()
enc_data = enc.fit_transform(labels_beats).tolist()

In [None]:
# Now, enc.classes_ will give you the original label for each encoded value
print("Encoded values and their original labels:")
for encoded_value, original_label in enumerate(enc.classes_):
    print(f"{encoded_value} -> {original_label}")


# 5. Normalise signals between [-1,1]

In [None]:
# Normalise beats between -1 and 1
segmented_ecg = normalise(segmented_ecg ,-1, 1)
segmented_bp = normalise(segmented_bp, -1, 1)
segmented_cvp = normalise(segmented_cvp, -1, 1)

# 6. stack signals

In [None]:
signals = [np.stack((segmented_ecg[i], segmented_bp[i], segmented_cvp[i]), axis=0) for i in range(0,len(segmented_bp))]
train_data =[([signals[i], enc_data[i]]) for i in range(len(signals))]

In [None]:
np.shape(np.array(signals))

# 7. Set up the WGAN hyperparameters

In [None]:
%run ./last_model.ipynb

import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
import wandb
#wandb.login()
os.environ["WANDB_API_KEY"] = 'b8b8e375fa11f69790bff448e326c90a5435494b'
os.environ["WANDB_MODE"] = "offline"

In [None]:
# Set the random seed for NumPy
np.random.seed(42)

# Set the random seed for PyTorch (CPU)
torch.manual_seed(42)

# If using GPU, set the random seed for CUDA as well
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:
# Hyperparameters 
channels_signal = 3
latent_dim = np.shape(np.array(signals))[2]
generator_layers = 6

batch_size = 16
discriminator_updates = 5
num_epochs = 1000

num_classes = len(unique)
embed_size = np.shape(np.array(signals))[2]
signal_length = np.shape(np.array(signals))[2]

LAMBDA_GP = 10
no_of_batches = len(signals)/batch_size 
batches_per_epoch = round(len(signals)/batch_size)

gen_name = f"Gen_baseline_all_{num_epochs}_3C"
disc_name = f"Disc_baseline_all_{num_epochs}_3C"

In [None]:
# # Hyperparameters 
# channels_signal = 2
# latent_dim = 750
# generator_layers = 5

# lr = 0.0001
# batch_size = 8
# discriminator_updates = 5
# num_epochs = 2000

# num_classes = 4
# embed_size = 750
# signal_length = 750

# LAMBDA_GP = 10
# no_of_batches = len(ECG_BP)/batch_size 
# batches_per_epoch = round(len(ECG_BP)/batch_size)

# gen_name = "Gen2000_4classes"
# disc_name = "Disc2000_4classes"

In [None]:
# capture a dictionary of hyperparameters with config
device =  "cuda" #if torch.cuda.is_available() 
config = {"learning_rate_critic": 0.00001, 
          "learning_rate_generator": 0.00001, #0.00005 ,
          "epochs": num_epochs,
          "Channels": channels_signal,
          "architecture": "connections_instancenorm",
          "dropout":"none",
          "machine": "GPU_server",
          "batch_size": batch_size, 
          "generator_layers": generator_layers, 
          "discriminator_updates":discriminator_updates,
          "LAMBDA_GP":LAMBDA_GP, 
          "Optimizer":'Adam+betas0and0.9', 
          "Generator":'UNet1D', 
          "Discriminator":'CNN', 
          "data": "all_baseline_data:SR_AF_tah_VEs",
          "type_GAN":'CGAN', 
          "classes":unique,
          "gen_name": gen_name,
          "disc_name": disc_name
         }

# 8. Train the WGAN

In [None]:
    with wandb.init(project='WGAN_baseline', entity='ioanacretu', config=config):
        config = wandb.config
        #dataloader = DataLoader(ECG_BP,batch_size=config.batch_size, shuffle=True, num_workers=2)
        trainloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True)
            
        # Create discriminator and generator
        critic = CNN_Discriminator(num_classes, signal_length).to(device) 
        generator = UNet1D(2, 2, num_classes,embed_size, n_layers =config.generator_layers).to(device)
        
        initialize_weights(critic)
        initialize_weights(generator)
        
        # Set up Optimizer for G and D
        optimizer_critic =  optim.Adam(critic.parameters(), lr=config.learning_rate, betas=(0, 0.9))
        optimizer_generator = optim.Adam(generator.parameters(),lr=config.learning_rate, betas=(0, 0.9))
        
        # Watch weights and gradients for both critic and generator
        wandb.watch(critic, log="all")
        wandb.watch(generator, log="all")

        generator.train()
        critic.train()

        #fixed_noise = torch.randn(batch_size,latent_dim,1).to(device)
        fixed_noise = torch.randn(config.batch_size,2,latent_dim).uniform_(-1, 1).to(device)
        critic_iter = 0
        gen_iter = 0

        print("Starting training...")
        for epoch in range(config.epochs):
            print(epoch)
            wandb.log({"Epoch":epoch})
            for batch_idx, (data,labels) in enumerate(trainloader):
                data = data.to(device)
                labels = labels.to(device)

                data = data.float()  #solved error "Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same"
        
                # Train the Critic/Discriminator: max E[critic(real)] - E[critic(fake)]
                for _ in range(config.discriminator_updates):
                    critic_iter = critic_iter +1
                    #noise = torch.randn(config.batch_size,latent_dim,1).uniform_(-1, 1).to(device)
                    noise = torch.randn(config.batch_size,2,latent_dim).uniform_(-1, 1).to(device)
                    fake = generator(noise,labels)

                    critic_real = critic(data,labels).reshape(-1)
                    critic_fake = critic(fake,labels).reshape(-1) #here we changed according to github repo alladin, it was before critic_fake = critic(fake.detach(),labels).reshape(-1)
                    
                    gp = gradient_penalty(critic,labels, data, fake, device = device)
                    loss_critic =(-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp)

                    critic.zero_grad()
                    loss_critic.backward(retain_graph=True)
                    optimizer_critic.step()
                    
               ## clear memory after a no of steps: is it enought to keep the w. for disc_updates steps? ##loss_critic.backward(retain_graph=False)
    
                # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
                gen_iter = gen_iter + 1
                gen_fake = critic(fake,labels).reshape(-1)
                loss_generator = -torch.mean(gen_fake)
                
                wandb.log({"Critic_iteration": critic_iter, "Critic_loss": loss_critic})
                wandb.log({"Generator_iteration": gen_iter, "Generator_loss": loss_generator})
        
                generator.zero_grad()
                loss_generator.backward()
                optimizer_generator.step()
               
            with torch.no_grad():
                generated_signals = generator(fixed_noise,labels)
                gener = generated_signals.cpu().detach().numpy()
                plt.subplot(211)
                plt.plot(gener[0][0])
                plt.title('electrocardiogram')
                plt.subplot(212)
                plt.plot(gener[0][1])
                plt.title('arterial line blood pressure')
                wandb.log({'chart': plt})
                
            print(f"Critic loss is:{loss_critic}")
            print(f"Gen loss is:{loss_generator}")
            print("----------------------------------")

print('Finished Training')

In [None]:
with wandb.init(project='WGAN_baseline', entity='ioanacretu', config=config):
    config = wandb.config

    # Create dataset and dataloader
    #dataloader = DataLoader(ECG_BP,batch_size=config.batch_size, shuffle=True, num_workers=2)
    trainloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True)

    # Create discriminator and generator
    critic = CNN_Discriminator(num_classes, signal_length, num_channels=channels_signal).to(device) 
    generator = UNet1D(channels_signal, channels_signal, num_classes, embed_size, n_layers=config.generator_layers).to(device)

    initialize_weights(critic)
    initialize_weights(generator)

    # Set up Optimizer for G and D
    optimizer_critic =  optim.Adam(critic.parameters(), lr=config.learning_rate_critic, betas=(0, 0.9))
    optimizer_generator = optim.Adam(generator.parameters(), lr=config.learning_rate_generator, betas=(0, 0.9))

#         optimizer_critic =  optim.Adam(critic.parameters(), lr=config.learning_rate, betas=(0, 0.9))
#         optimizer_generator = optim.Adam(generator.parameters(),lr=config.learning_rate, betas=(0, 0.9))

    # Watch weights and gradients for both critic and generator
    wandb.watch(critic, log="all")
    wandb.watch(generator, log="all")

    #fixed_noise = torch.randn(batch_size,latent_dim,1).to(device)
    fixed_noise = torch.randn(config.batch_size,channels_signal,latent_dim).uniform_(-1, 1).to(device)
    fixed_labels = torch.randint(low=0, high=num_classes, size=(config.batch_size,)).to(device)

#         if wandb.run.resumed:
#             checkpoint_filename_disc = f"\disc_APVP_AVD200_epoch{epoch + 1}.pth"
#             checkpoint_filename_gen = f"\gen_APVP_AVD200_epoch{epoch + 1}.pth"
#             checkpoint_path_disc = 
#             checkpoint_path_gen = 

#             checkpoint = torch.load(wandb.restore(CHECKPOINT_PATH))
#             model.load_state_dict(checkpoint['model_state_dict'])
#             optimizer_generator.load_state_dict(checkpoint['optimizer_state_dict'])
#             epoch = checkpoint['epoch']
#             loss = checkpoint['loss']


    critic_iter = 0
    gen_iter = 0


    print("Starting training...")
    for epoch in range(config.epochs):

        # set both networks for training mode
        generator.train()
        critic.train()

        # initiate the epoch loss for both networks to zero
        batch_nr = 0
        epoch_loss_critic = 0.
        epoch_loss_generator = 0.
        all_loss_critic = 0.

        print(f'Epoch: {epoch}')
        for batch_idx, (data,labels) in enumerate(trainloader):
            step_loss_critic = 0.
            batch_nr += 1

            data = data.to(device)
            labels = labels.to(device)
            data = data.float()  #solved error "Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same"
            count=0
            # Train the Critic/Discriminator: max E[critic(real)] - E[critic(fake)]
            for _ in range(config.discriminator_updates):
                critic_iter = critic_iter +1
                count = count+1
                noise = torch.randn(config.batch_size,channels_signal,latent_dim).uniform_(-1, 1).to(device)
                fake = generator(noise,labels)

                critic_real = critic(data,labels).reshape(-1)
                critic_fake = critic(fake.detach(),labels).reshape(-1)   #critic(fake,labels).reshape(-1) #here we changed according to github repo alladin, it was before critic_fake = critic(fake.detach(),labels).reshape(-1)

                gp = gradient_penalty(critic,labels, data, fake, device = device)
                loss_critic =(-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp)
                wandb.log({"Epoch":epoch, "Critic_loss": loss_critic, "Critic_iteration": critic_iter})

                step_loss_critic += loss_critic.item()
                critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                optimizer_critic.step()
            #all_loss_critic += step_loss_critic/discriminator_updates    
           ## clear memory after a no of steps: is it enought to keep the w. for disc_updates steps? ##loss_critic.backward(retain_graph=False)
            # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
            gen_iter = gen_iter + 1
            gen_fake = critic(fake,labels).reshape(-1)
            loss_generator = -torch.mean(gen_fake)
            #epoch_loss_generator += loss_generator.item()

            generator.zero_grad()
            loss_generator.backward()
            optimizer_generator.step()
            wandb.log({"Epoch":epoch, "Generator_iteration": gen_iter, "Generator_loss": loss_generator})   

        print(f"Gen loss:{loss_generator}")
        print(f"Critic loss::{np.mean(step_loss_critic)}")   

#             # Save model checkpoints
#             if (epoch + 1) % 500 == 0:
#                 checkpoint_filename_disc = f"\disc_APVP_AVD200_epoch{epoch + 1}.pth"
#                 checkpoint_filename_gen = f"\gen_APVP_AVD200_epoch{epoch + 1}.pth"

#                 checkpoint_path_disc = wandb.run.dir + checkpoint_filename_disc
#                 checkpoint_path_gen = wandb.run.dir + checkpoint_filename_gen

#                 torch.save({'epoch': epoch,'model_state_dict': critic.state_dict(),
#                             'optimizer_state_dict': optimizer_critic.state_dict(),
#                             'loss': config.learning_rate_critic,}, checkpoint_path_disc)
#                 wandb.save(checkpoint_path_disc)

#                 torch.save({'epoch': epoch,'model_state_dict': generator.state_dict(),
#                             'optimizer_state_dict': optimizer_generator.state_dict(),
#                             'loss': config.learning_rate_generator,}, checkpoint_path_gen)
#                 wandb.save(checkpoint_path_gen)

        # Diplay the signals produced by the generator every 100 epochs
        if (epoch + 1) % 100 == 0:
            generator.eval()
            with torch.no_grad():
                generated_signals = generator(fixed_noise, fixed_labels)
                gener = generated_signals.cpu().detach().numpy()

                plt.figure(figsize=(8,3))
                plt.subplot(311)
                plt.title(f'Label: {fixed_labels[0]}')
                plt.plot(gener[0][0])
                plt.subplot(312)
                plt.plot(gener[0][1])
                plt.subplot(313)
                plt.plot(gener[0][2])

                # Log the figure to WandB
                wandb.log({'chart': plt})
print('Finished Training')

In [None]:
wandb sync 'Z:\1938759\Synhtetic signal simulator\Our-data\wandb\offline-run-20240317_202949-n8y8w3uv'

In [None]:
PATH = r'Z:\1938759\Synhtetic signal simulator\Our-data\model_all data'
torch.save(generator.state_dict(), os.path.join(PATH,f'{gen_name}.pth'))
torch.save(critic.state_dict(), os.path.join(PATH,f'{disc_name}.pth'))

In [None]:
gen_name

In [None]:
# Encoded values and their original labels:
# 0 -> A-tach 2:1 block
# 1 -> AF
# 2 -> SR
# 3 -> SR with LBBB


In [None]:
device =  "cuda" if torch.cuda.is_available() else "cpu"  

channels_signal = 3
num_classes=4
embed_size = 800
latent_dim = 800
generator_layers = 6

generator_PATH = r'Z:\1938759\Synhtetic signal simulator\Our-data\model_all data\Gen_baseline_all_1000_3C.pth'  # Gen_baseline_all_300_3C
generator_loaded = UNet1D(channels_signal, channels_signal, num_classes,embed_size, n_layers = generator_layers).to(device)
generator_loaded.load_state_dict(torch.load(generator_PATH))
generator_loaded.eval()


labels = np.zeros(10)
batch_size = len(labels)

for j in range(len(labels)):
    labels[j] = 2

fixed_noise = torch.randn(batch_size,channels_signal,latent_dim).uniform_(-1, 1).to(device)
labels =  torch.Tensor(labels).to(device)
labels = labels.int()

generated_signals = generator_loaded(fixed_noise,labels)
gener = generated_signals.cpu().detach().numpy()

In [None]:
#plt.figure(figsize=[10,5])

x=0

Fs = 1000;  # sampling rate
Ts = 1.0/Fs; # sampling interval
t = np.arange(0,len(gener[x][0])/Fs,Ts) # time vector

plt.subplot(311)
plt.title('Synthetic A-tach 2:1 block signals')
plt.plot(t,gener[x][0])
plt.margins(x=0)
plt.ylabel('ECG')


plt.subplot(312)
plt.plot(t,gener[x][1])

plt.subplot(313)
plt.plot(t,gener[x][2])

plt.xlabel('Time(s)')
plt.ylabel('ABP')
plt.margins(x=0)

In [None]:
labels_beats[700]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sampling frequency
fs = 1000  # Hz

# Create a time vector
x = 55
N = len(signals[x][0])
time = np.linspace(0, (N-1)/fs, N)

# Creating a grid for subplots: 3 rows, 2 columns
fig, ax = plt.subplots(3, 2, figsize=(15, 10))  # Adjusting figure size for better clarity and spacing

# Left column plots
ax[0, 0].plot(time, signals[x][0])
ax[0, 0].set_title('Real Heartbeat', fontname='Arial', fontsize=16)
ax[0, 0].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[0, 0].set_ylabel('ECG', fontname='Arial', fontsize=14)


ax[1, 0].plot(time, signals[x][1])
ax[1, 0].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[1, 0].set_ylabel('ABP', fontname='Arial', fontsize=14)


ax[2, 0].plot(time, signals[x][2])
ax[2, 0].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[2, 0].set_ylabel('CVP', fontname='Arial', fontsize=14)


# Right column plots
y = 4
ax[0, 1].plot(time, gener[y][0], color='orange')
ax[0, 1].set_title('Synthetic Heartbeat', fontname='Arial', fontsize=16)
ax[0, 1].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[0, 1].set_ylabel('ECG', fontname='Arial', fontsize=14)


ax[1, 1].plot(time, gener[y][1], color='orange')
ax[1, 1].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[1, 1].set_ylabel('ABP', fontname='Arial', fontsize=14)


ax[2, 1].plot(time, gener[y][2], color='orange')
ax[2, 1].set_xlabel('Time (s)', fontname='Arial', fontsize=14)
ax[2, 1].set_ylabel('CVP', fontname='Arial', fontsize=14)

fig.tight_layout()  # Automatically adjusts subplot params
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sampling frequency
fs = 1000  # Hz

# Create a time vector
x = 55
N = len(signals[x][0])
time = np.linspace(0, (N-1)/fs, N)

# Creating a grid for subplots: 3 rows, 2 columns
fig, ax = plt.subplots(3, 2, figsize=(20, 10))  # Adjusting figure size for better clarity and spacing


# Function to setup each axis
def setup_axis(ax, data, title, xlabel, ylabel, color='blue'):
    ax.plot(time, data, color=color)
    ax.set_title(title, fontname='Arial', fontsize=20)
    ax.set_xlabel(xlabel, fontname='Arial', fontsize=18)
    ax.set_ylabel(ylabel, fontname='Arial', fontsize=18)
    ax.margins(x=0)  # Set x-axis margins to zero

# Left column plots
setup_axis(ax[0, 0], signals[x][0], 'Real Heartbeat', 'Time (s)', 'ECG')
setup_axis(ax[1, 0], signals[x][1], '', 'Time (s)', 'ABP')
setup_axis(ax[2, 0], signals[x][2], '', 'Time (s)', 'CVP')

# Right column plots
y = 8  # Example 'y'
setup_axis(ax[0, 1], gener[y][0], 'Synthetic Heartbeat', 'Time (s)', 'ECG', color='orange')
setup_axis(ax[1, 1], gener[y][1], '', 'Time (s)', 'ABP', color='orange')
setup_axis(ax[2, 1], gener[y][2], '', 'Time (s)', 'CVP', color='orange')

fig.tight_layout()  # Automatically adjusts subplot params
plt.show()


In [None]:
import random
import numpy as np
import torch
from frechetdist import frdist
import math
import time
from dtaidistance import dtw

class GANEvaluator:
    def __init__(self, real_signals, fake_signals, device):
        self.device = device
        self.real_signals = [torch.tensor(signal, device=self.device).float() for signal in real_signals]
        self.fake_signals = [torch.tensor(signal, device=self.device).float() for signal in fake_signals]
        self.time = torch.arange(0, len(self.real_signals[0]), device=self.device) / 360

    def shuffle_data(self, data):
        random.shuffle(data)

    def frechet_distance(self):
        frechet_distances = []
        for real_points in self.real_signals:
            P = torch.stack([self.time, real_points], dim=1).cpu().numpy()
            fd = []
            for fake_points in self.fake_signals:
                Q = torch.stack([self.time, fake_points], dim=1).cpu().numpy()
                fd.append(frdist(P, Q))
            frechet_distances.append(min(fd))
        return np.mean(frechet_distances)

    def calculate_metric(self, metric_func):
        metrics = []
        for real_points in self.real_signals:
            tmp_metrics = []
            for fake_points in self.fake_signals:
                if len(real_points) != len(fake_points):
                    raise ValueError("Both lists of signals must have the same length.")
                tmp_metrics.append(metric_func(real_points, fake_points).item())
            metrics.append(min(tmp_metrics))
        return np.mean(metrics)

    def MSE(self, real, fake):
        return (real - fake).pow(2).mean()

    def RMSE(self, real, fake):
        return torch.sqrt((real - fake).pow(2).mean())

    def MAE(self, real, fake):
        return torch.abs(real - fake).mean()

    def calculate_prmse(self):
        prmse_values = []
        for real_points in self.real_signals:
            max_real_value = torch.max(real_points)
            tmp_squared_errors = []
            for fake_points in self.fake_signals:
                rmse = torch.sqrt(((real_points - fake_points) ** 2).mean()).item()
                tmp_squared_errors.append(rmse ** 2)
            min_squared_error = min(tmp_squared_errors)
            prmse_percentage = (np.sqrt(min_squared_error) / max_real_value.item()) * 100
            prmse_values.append(prmse_percentage)
        return np.mean(prmse_values)

#     # DTW calculation method
#     def DTW(self, real, fake):
#         n, m = len(real), len(fake)
#         dtw_matrix = np.full((n+1, m+1), np.inf)
#         dtw_matrix[0, 0] = 0

#         for i in range(1, n+1):
#             for j in range(1, m+1):
#                 cost = abs(real[i-1] - fake[j-1])
#                 dtw_matrix[i, j] = cost + min(dtw_matrix[i-1, j], dtw_matrix[i, j-1], dtw_matrix[i-1, j-1])

#         return dtw_matrix[n, m]

#     # Method to calculate DTW for all real-fake pairs and return the average
#     def calculate_dtw_distance(self):
#         dtw_distances = []
#         for real_signal in self.real_signals:
#             dtw_per_real = []
#             for fake_signal in self.fake_signals:
#                 dtw_per_real.append(self.DTW(real_signal.cpu().numpy(), fake_signal.cpu().numpy()))
#             dtw_distances.append(min(dtw_per_real))
#         return np.mean(dtw_distances)

    def gaussian_kernel(self, X, Y, sigma=1.0):
        XX = torch.matmul(X, X.T)
        XY = torch.matmul(X, Y.T)
        YY = torch.matmul(Y, Y.T)
        X_sqnorms = torch.diagonal(XX)
        Y_sqnorms = torch.diagonal(YY)
        
        K = torch.exp(-0.5 * (X_sqnorms[:, None] + Y_sqnorms[None, :] - 2 * XY) / sigma**2)
        return K
    
    def mmd(self, sigma=1.0):
        mmd_distances = []
        for real_points in self.real_signals:
            P = real_points.unsqueeze(1)
            mmd_values = []
            for fake_points in self.fake_signals:
                Q = fake_points.unsqueeze(1)
                
                K_PP = self.gaussian_kernel(P, P, sigma)
                K_QQ = self.gaussian_kernel(Q, Q, sigma)
                K_PQ = self.gaussian_kernel(P, Q, sigma)
                
                mmd_value = K_PP.mean() + K_QQ.mean() - 2 * K_PQ.mean()
                mmd_values.append(mmd_value.item())
            mmd_distances.append(min(mmd_values))
        return np.mean(mmd_distances)
    
    def dtw_distance(self):
        dtw_distances = []
        for real_signal in self.real_signals:
            real_signal = real_signal.cpu().numpy()
            dtw_per_real = []
            for fake_signal in self.fake_signals:
                fake_signal = fake_signal.cpu().numpy()
                if real_signal.ndim == 1 and fake_signal.ndim == 1:
                    # Univariate case
                    distance = dtw.distance(real_signal, fake_signal)
                    dtw_per_real.append(distance)
                elif real_signal.ndim == 2 and fake_signal.ndim == 2:
                    # Multivariate case: compute DTW for each dimension and take the average
                    distances = [dtw.distance(real_signal[:, i], fake_signal[:, i]) for i in range(real_signal.shape[1])]
                    distance = np.mean(distances)
                    dtw_per_real.append(distance)
                else:
                    raise ValueError("Mismatched dimensions between real and fake signals")
            if dtw_per_real:
                dtw_distances.append(min(dtw_per_real))
        if dtw_distances:
            return np.mean(dtw_distances)
        else:
            return float('inf')  # or some other appropriate value for no signals

    
def evaluate_all(path, generated_samples):
    
    signal_labels_numeric = list(np.unique(enc_data, return_counts=False))
    signal_labels_real = list(np.unique(labels_beats, return_counts=False))
    
    device =  "cuda" if torch.cuda.is_available() else "cpu"  
    
    channels_signal = 3
    num_classes=4
    embed_size = 800
    latent_dim = 800
    generator_layers = 6

    generator_PATH = r'Z:\1938759\Synhtetic signal simulator\Our-data\model_all data\Gen_baseline_all_1000_3C.pth'  # Gen_baseline_all_300_3C
    generator_loaded = UNet1D(channels_signal, channels_signal, num_classes,embed_size, n_layers = generator_layers).to(device)
    generator_loaded.load_state_dict(torch.load(generator_PATH))
    generator_loaded.eval()

    for i in range(len(signal_labels_numeric)):
        labels = np.zeros(generated_samples)
        batch_size = len(labels)
        
        for j in range(len(labels)):
            labels[j] = signal_labels_numeric[i] 

        fixed_noise = torch.randn(batch_size,channels_signal,latent_dim).uniform_(-1, 1).to(device)
        labels =  torch.Tensor(labels).to(device)
        labels = labels.int()

        generated_signals = generator_loaded(fixed_noise,labels)
        gener = generated_signals.cpu().detach().numpy()
        
        # Identify the heartbeats/signals that have the label the same as the selected label from the list
        real_beats = []
        for k in range(0,len(labels_beats)):
            if labels_beats[k] == signal_labels_real[i]:
                real_beats.append(i)

        real_beats = [np.array(signals[i]) for i in real_beats]
        #real_beats =  np.array(real_beats).reshape((len(real_beats), 1, 800))
        random_selected = random.sample(list(real_beats), generated_samples)
        
        list_sig_type = ['ECG', 'ABP', 'CVP']
        
        for k in range(np.shape(real_beats)[1]):
            orig_signal_type = [random_selected[i][k] for i in range(len(random_selected))]
            gener_signal_type = [gener[i][k] for i in range(len(gener))]
        
            print(f'The signal evaluated is: {list_sig_type[k]}')
            print(f'The class evaluated is: {signal_labels_real[i]}')
            
            # Prepare the data 
#             original_data = np.squeeze(np.array(orig_signal_type), axis=1)
#             generated_data = np.squeeze(np.array(gener_signal_type), axis=1)

            # Instantiate the GANEvaluator class
            gan_evaluator = GANEvaluator(orig_signal_type, gener_signal_type, device)

            # Call each function on the instance and # Print the results
            gan_evaluator.shuffle_data(orig_signal_type)

            start_time = time.time()

            mse_value = gan_evaluator.calculate_metric(gan_evaluator.MSE)
            print(f'MSE value for ECG signals is: {mse_value}')
            rmse_value = gan_evaluator.calculate_metric(gan_evaluator.RMSE)
            print(f'RMSE value for ECG signal is: {rmse_value}')
            mae_value = gan_evaluator.calculate_metric(gan_evaluator.MAE)
            print(f'MAE value for ECG signals is: {mae_value}')
            prmse_value = gan_evaluator.calculate_prmse()
            print(f'PRMSE value for ECG signals is: {prmse_value}')
            dtw_distance = gan_evaluator.calculate_dtw_distance()
            print(f'DTW value for ECG signals is: {dtw_distance}')
            frechet_dist_value = gan_evaluator.frechet_distance()
            print(f'Frechet Distance value: {frechet_dist_value}')
            
            dtw_value = gan_evaluator.dtw_distance()
            print(f'DTW value: {dtw_value}')
            
            mmd_value = gan_evaluator.mmd(sigma=1.0)
            print(f'MMD value: {mmd_value}')  
            
            Record the end time
            end_time = time.time()

            Calculate and print the elapsed time
            elapsed_time = end_time - start_time
            print(f"The code took {elapsed_time} seconds to run.")
            print(f"The code took {elapsed_time/3600} hours to run.")

            print('________________________________________________________')

In [None]:
evaluate_all(_, 400)