<div style="text-align: center; font-size: 2em;">
    Welcome to CP-PPG!
</div>

### Install the requirements

```bash
conda create -n ppg python=3.11
conda activate ppg
pip install -r requirements.txt
```

### Experiment Tool

- This is an optional step if you want to track your experiments using comet.ml platform.

- Create comet account on the official website. Once you have an account, click to profile to get the API key.

- Create a new file inside `config` folder, name it `experiment_apikey.txt`. Then just paste the api key into this file and save it.


### Training 

- Before training, make sure to modify your desired configs in `config.yml` inside the `config` folder. I already set the suitable default config so you can follow that. 

- To train the model, you can run the following command, comet will log all the necessary metrics and results during training so you can check them in your workspace at the official website. 

```bash
python src/experiments/tools/train.py
```



# 1. Configs

In [None]:
import yaml

with open("configs/config.yml", 'r') as yaml_file:
    data = yaml.safe_load(yaml_file)
cfgs = yaml.dump(data, default_flow_style=False, indent=4)

print(cfgs)

# 2. Data Exploration

In [None]:
import os
import shutil

def process_data_folder(data_folder_path):
    subject_folders = [f for f in os.listdir(data_folder_path) if os.path.isdir(os.path.join(data_folder_path, f))]

    for subject_folder in subject_folders:
        subject_path = os.path.join(data_folder_path, subject_folder)

        processed_folder_path = os.path.join(subject_path, 'preprocessed')
        os.makedirs(processed_folder_path, exist_ok=True)

        csv_files = [f for f in os.listdir(subject_path) if f.endswith('.csv')]

        for csv_file in csv_files:
            source_path = os.path.join(subject_path, csv_file)
            destination_path = os.path.join(processed_folder_path, csv_file)
            shutil.move(source_path, destination_path)
            
process_data_folder("data/v2")

In [None]:
import random
import glob
import numpy as np
import matplotlib.pyplot as plt 
from src.utils.utils import get_peaks_info, plot_wavform, plot_peaks, plot_metrics, get_config, plot_two_signal, get_feet
from src.utils.preprocess import extract_window_segments
from src.utils.preprocess import moving_average, read_processed_signal
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

cfgs = get_config()
csv_file = "data/v2/subject1009/preprocessed/3_preprocessed_1009.csv"

ppg_in, ppg_ref, pressure = read_processed_signal(csv_file, cfgs)
fig, axes = plt.subplots(1, 1, figsize = (24, 10))
flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71","#f4cae4","#FCD5B5","#EDE2D5"]

normed_ref = ppg_ref[1020:1100].reshape(-1) / np.max(ppg_ref[1020:1145])
axes.plot(normed_ref, color='orange', linewidth = 3)
systolic_peak_index = np.argmax(normed_ref)
x_values = np.arange(0, len(normed_ref))
feet = get_feet(normed_ref)
axes.fill_between(x_values[feet[0]:systolic_peak_index+1], normed_ref[feet[0]:systolic_peak_index+1], color=flatui[-1], alpha=0.5, label="Systolic Area (SA)")
axes.fill_between(x_values[systolic_peak_index:feet[1]], normed_ref[systolic_peak_index:feet[1]], color=flatui[-2], alpha=0.5, label="Diastolic Area (DA)")

legend = axes.legend(fontsize=30, handles=[
    mpatches.Patch(color=flatui[-1], alpha=0.6, label='Systolic Area (SA)'),
    mpatches.Patch(color=flatui[-2], alpha=0.6, label='Diastolic Area (DA)')],
    loc='upper left', bbox_to_anchor=(0.6, 0.6))

plt.tight_layout()
plt.show()


In [None]:
src_segments, ref_segments = extract_window_segments(csv_file, cfgs, window_size=5*100)
print(src_segments.shape, ref_segments.shape)

plt.figure(figsize=(16, 5))

plt.subplot(1, 2, 1)
plt.plot(src_segments[5].reshape(-1))
plt.title('Source Segment')

plt.subplot(1, 2, 2)
plt.plot(ref_segments[5].reshape(-1))
plt.title('Reference Segment')

plt.tight_layout()
plt.show()

plot_peaks(ref_segments[2])

In [None]:
from src.utils.feature import extract_feat_cycle
from src.utils.preprocess import cycle_helper
from src.utils.utils import standardize

breakdown_seg_in, breakdown_seg_ref, _ = cycle_helper(standardize(src_segments[4].reshape(-1)), standardize(ref_segments[4].reshape(-1)), return_cycles=True)
plot_wavform(breakdown_seg_in[2], breakdown_seg_ref[2])
feat_name, feats, dia_feat_name, dia_feats, valid = extract_feat_cycle(breakdown_seg_ref, fs=128)

In [None]:
all_feats = {key: value for key, value in zip(feat_name, feats)}
print(feat_name)

# 3. Data Augmentation

In [None]:
import os
import glob
import numpy as np
from src.utils.classification import classify
from src.utils.enrichment import PPGTransform
from src.utils.utils import get_peaks_info, plot_wavform, plot_peaks, plot_metrics, get_config, plot_two_signal, standardize

transform = PPGTransform()
signal_segment = standardize(src_segments[22:23].reshape(-1))
transformed_signal = transform.convert(signal_segment)
plot_two_signal(signal_segment, transformed_signal) # Before and after augmnentation


# 4. Custom Dataset

In [None]:
from src.dataloader.dataset import PPGDataset, get_loader
from src.utils.utils import get_config, plot_wavform, read_json, write_json, plot_subject_distribution, standardize
from src.utils.prepare import DataHandler

cfgs = get_config()
dh = DataHandler(cfgs)

In [None]:
dh.custom_data_hanlder("8s")
dh.train_val_test_split()

# 5. Models

## 5.1 Generator

In [None]:
from src.models.cpppg import Generator
from src.utils.utils import get_config
import torch

cfgs = get_config()

model = Generator(cfgs)
x = torch.randn(1, 1, 8 * 128)
out = model(x)
print(out.shape)

## 5.2 Discriminator

In [None]:
from src.models.cpppg import Discriminator
from src.utils.utils import get_config
import torch

cfgs = get_config()

In [None]:
model = Discriminator(cfgs)
x = torch.randn(1, 1, 5*128) 
y = model(x)
print(y.shape)
from torchsummary import summary
summary(model, (1, 8*128))

# 6. Training and Validation

In [None]:
from comet_ml import Experiment
from src.trainer.engine import Trainer
from src.trainer.adverarial_engine import AdversarialTrainer
from src.utils.utils import get_config
from configs.seed import *

cfgs = get_config()

with open('configs/experiment_apikey.txt','r') as f:
    api_key = f.read()

tracking = Experiment(
    api_key = api_key,
    project_name = "CP-PPG Project",
    workspace = "maxph2211",
)
tracking.log_parameters(cfgs)
if cfgs['train']['adversarial']:
    trainer = AdversarialTrainer(tracking, cfgs)
else:
    trainer = Trainer(tracking, cfgs)
    
# trainer.training_experiment() 
print("DONE!")

# 7. Deployment

In [None]:
from torch.utils.data import DataLoader
from src.dataloader.dataset import PPGDataset
from src.experiments.tools.test import Inference
from src.utils.utils import get_config, standardize
from tqdm import tqdm

cfgs = get_config()
infer_tool = Inference(cfgs)
test_dataset = PPGDataset(cfgs, data_path="")
test_loader = DataLoader(dataset=test_dataset,
                        batch_size=cfgs['data']['batch_size'],
                        num_workers=cfgs['data']['num_workers']
)

src_signals, out_signals, ref_signals = [], [], []

for src_signal, ref_signal, _, _ in tqdm(test_loader):
    output_batch_tensor = infer_tool.infer(src_signal)

In [None]:
from flask import Flask, request, jsonify
from src.experiments.tools.test import Inference
from src.utils.utils import get_config
import joblib
import numpy as np 
### v1

app = Flask(__name__)

inference = Inference(get_config())

@app.route("/enhance", methods=["POST"])
def enhance_signal():
    try:
        data = request.get_json()
        corrupted_signal = data.get("corrupted_signal", [])
        if not corrupted_signal:
            return jsonify({"error": "Invalid or missing 'corrupted_signal' field."}), 400

        reconstructed_signal = inference.infer(corrupted_signal)

        return jsonify({"reconstructed_signal": reconstructed_signal.tolist()}), 200

    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/preprocess', methods=['POST'])
def preprocess_signal():
    try:
        data = request.get_json()
        signal = data.get('signal', [])
        signal = np.array(signal)
        normalized_signal = standardize(signal.reshape(-1,1)).reshape(signal.shape)
        return jsonify({'normalized_signal': normalized_signal.tolist()}), 200

    except Exception as e:
        return jsonify({'error': str(e)}), 400

app.run(host="0.0.0.0", debug=True, port=8085)
