In [None]:
import os
import sys
import mlflow
import torch
import random
import numpy as np

PROJECT_PATH = os.getcwd()
sys.path.append(PROJECT_PATH)

from src.FISTANet.trainer import FISTANetTrainer
from src.FISTANet.loader import DataSplit
from src.FISTANet.model import FISTANet

In [None]:
# data paths and configuration
DATA_DIR = os.path.join(PROJECT_PATH, 'data')
DATA_FILE_GEN = 'dataset_master_10000.pkl'
DATA_FILE_SIGS = 'ECG-QDB_cat-1.npy'
DATA_FILE_NOISE = 'MIT-BIH/bw'
DATA_FILE_DICT = 'dictionary_BW_real_data.npy'
DATA_SIZE = 10000
TVT_SPLIT = {
    'train': 80,
    'valid': 10,
    'test': 10
}

# FISTA-Net model configuration
FNET_LAYER_NO = 4
FNET_FEATURE_NO = 16

# training parameters
LAMBDA_LSPA = 1
LAMBDA_LFSYM = 1e-3
LAMBDA_LFSPA = 1e-2
BATCH_SIZE = 1000
LEARNING_RATE = 1e-3

# seed for random generators
RANDOM_SEED = 42

In [None]:
# mlflow tracking server and experiment
mlflow.set_tracking_uri('http://localhost:8080')
experiment = mlflow.set_experiment('ECGDenFISTA-Net')

In [None]:
# get info of the pretrained model
LOAD_MODEL_RUN = mlflow.search_runs(experiment_ids=[experiment.experiment_id],
                                    order_by=['start_time desc']).iloc[0]['run_id']
LOAD_MODEL_EPOCH = 3000

In [None]:
with mlflow.start_run(log_system_metrics=True) as run:
    # seed random generators to ensure deterministic experiments
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    generator = torch.Generator()
    generator.manual_seed(RANDOM_SEED) 
    
    # define PyTorch device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read and log train, validation and test datasets, and dictionary
    trn_ldr, val_ldr, tst_ldr = DataSplit(DATA_DIR, DATA_FILE_GEN, DATA_FILE_SIGS, DATA_FILE_NOISE,
                                          TVT_SPLIT, BATCH_SIZE, generator=generator)
    Psi = np.load(os.path.join(DATA_DIR, DATA_FILE_DICT))

    # load model
    model = FISTANet(FNET_LAYER_NO, FNET_FEATURE_NO)
    
    # specify and log training parameters
    params = {
        'device': device,
        'fnet_layer_no': FNET_LAYER_NO,
        'fnet_feature_no': FNET_FEATURE_NO,
        'lambda_Lspa': LAMBDA_LSPA,
        'lambda_LFsym': LAMBDA_LFSYM,
        'lambda_LFspa': LAMBDA_LFSPA,
        'load_model_run': LOAD_MODEL_RUN,
        'load_model_epoch': LOAD_MODEL_EPOCH,
        'batch_size': BATCH_SIZE,
        'lr': LEARNING_RATE
    }   
    mlflow.log_params(params)

    # train and evaluate model
    trainer = FISTANetTrainer(model, Psi, params)
    trainer.train(trn_ldr, val_ldr, 17000, start_epoch=LOAD_MODEL_EPOCH,
                  log_model_every=100, log_comp_fig_every=100, comp_fig_samples=[0, 500, 950])
    trainer.evaluate(tst_ldr)