In [1]:
from src.fistanet.M5FISTANet import FISTANet
from src.fistanet.loader import DataSplit
from src.fistanet.solver import Solver
from os.path import join as pjoin
from torchsummary import summary
from datetime import datetime
import numpy as np
import torch

In [2]:
DATA_DIR = './data'
DATA_FILE_GEN = 'generated/BW_master_10000_2024-04-07-12-43-32.pkl'
DATA_FILE_SIGS = 'steinbrinker/testing_data_mvg_avg.npy'
DATA_FILE_BW = 'mit-bih/bw'
DICT_FILE_BW = 'steinbrinker/dictionary_BW_real_data.npy'
DATA_SIZE = 10000
BATCH_SIZE = 1000
TVT_SPLIT = {
    'train': 80,
    'valid': 10,
    'test': 10
}

FNET_LAYER_NO = 4
FNET_FEATURE_NO = 16
LAMBDA_SP_LOSS = 1e-3

EPOCH_NO = 20
START_EPOCH = 0
TEST_EPOCH = 21
LR_DEC_AFTER = 100
LR_DEC_EVERY = 10
LOG_INTERVAL = 4
LEARNING_RATE = 1e-3

In [3]:
trn_ldr, val_ldr, tst_ldr = DataSplit(DATA_DIR, DATA_FILE_GEN, DATA_FILE_SIGS, DATA_FILE_BW, TVT_SPLIT, BATCH_SIZE)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
Psi = np.load(pjoin(DATA_DIR, DICT_FILE_BW))
Psi = torch.from_numpy(Psi)
Psi = Psi.clone().detach().to(device=device)

In [6]:
fista_net = FISTANet(FNET_LAYER_NO, FNET_FEATURE_NO)
fista_net = fista_net.to(device)# define arguments of fista_net


In [7]:
# summary(fista_net, input_size=(1, 64, 298), device=str(device))
print('Total number of parameters fista net:',
          sum(p.numel() for p in fista_net.parameters()))

Total number of parameters fista net: 18871


In [8]:
dt = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
args = {
    'model_name': 'FISTANet',
    'num_epochs': EPOCH_NO,
    'lr': LEARNING_RATE,
    'data_dir': DATA_DIR,
    'save_path': f'./runs/{dt}',
    'start_epoch': START_EPOCH,
    'multi_gpu': False,
    'device': device,
    'log_interval': LOG_INTERVAL,
    'test_epoch': TEST_EPOCH,
    'lr_dec_after': LR_DEC_AFTER,
    'lr_dec_every': LR_DEC_EVERY,
    'lambda_sp_loss': LAMBDA_SP_LOSS
}

In [9]:
solver = Solver(fista_net, Psi, trn_ldr, val_ldr, BATCH_SIZE, args, tst_ldr)

In [10]:
solver.train()

Training epoch 1...

				Disc: 1128825.218773		Spars: 0.010295
	 TVw: -0.499000 | TVb: -1.999000 | GSw: -0.201000 | GSb: 0.099000 | TSUw: 0.499000 | TSUb: 0.001000

				Disc: 1238762.626674		Spars: 0.013103
	 TVw: -0.502084 | TVb: -2.002052 | GSw: -0.205000 | GSb: 0.095000 | TSUw: 0.494999 | TSUb: 0.005000
Validating epoch 1...
-------------------------------------------
Epoch statistics:
Average training loss: 1072657.6839191772
Average validation loss: 438595.521654338
Training epoch 2...

				Disc: 931633.125833		Spars: 0.020326
	 TVw: -0.505789 | TVb: -2.005654 | GSw: -0.208951 | GSb: 0.091048 | TSUw: 0.491046 | TSUb: 0.008948

				Disc: 855731.091489		Spars: 0.035017
	 TVw: -0.509794 | TVb: -2.009392 | GSw: -0.212801 | GSb: 0.087197 | TSUw: 0.487192 | TSUb: 0.012791
Validating epoch 2...
-------------------------------------------
Epoch statistics:
Average training loss: 793775.8477406028
Average validation loss: 217205.13217109506
Training epoch 3...

				Disc: 462656.363737		Spa