In [1]:
from src.fistanet.M5FISTANet import FISTANet
from src.fistanet.loader import DataSplit
from src.fistanet.solver import Solver
from os.path import join as pjoin
from torchsummary import summary
import numpy as np
import torch

In [2]:
DATA_DIR = './data'
DATA_FILE_GEN = 'generated/BW_master_10000_2024-04-07-12-43-32.pkl'
DATA_FILE_SIGS = 'steinbrinker/testing_data_mvg_avg.npy'
DATA_FILE_BW = 'mit-bih/bw'
DICT_FILE_BW = 'steinbrinker/dictionary_BW_real_data.npy'
DATA_SIZE = 10000
BATCH_SIZE = 1000
TVT_SPLIT = {
    'train': 80,
    'valid': 10,
    'test': 10
}

FNET_LAYER_NO = 4
FNET_FEATURE_NO = 16

EPOCH_NO = 100
START_EPOCH = 0
TEST_EPOCH = 9
LR_DEC_AFTER = 100
LR_DEC_EVERY = 10

In [3]:
trn_ldr, val_ldr, tst_ldr = DataSplit(DATA_DIR, DATA_FILE_GEN, DATA_FILE_SIGS, DATA_FILE_BW, TVT_SPLIT, BATCH_SIZE)

[[   204    168     77 ...     29    287     76]
 [  2224     92   1958 ...    424    664    654]
 [191075  31368  28690 ... 195792  12101 185303]]
2224 4724 204
[[   160    299    248 ...    260    192     17]
 [   962   1854   1740 ...   1696   2398   1221]
 [349827 140117 250945 ... 375091 419722 319834]]
962 3462 160
[[   176     75    274 ...    212    256    182]
 [  1954    211   1599 ...     99    825   1706]
 [118910 271756 293015 ... 161094 141988 404065]]
1954 4454 176


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
Psi = np.load(pjoin(DATA_DIR, DICT_FILE_BW))
Psi = torch.from_numpy(Psi)
Psi = Psi.clone().detach().to(device=device)
Psi = Psi.repeat((BATCH_SIZE, 1, 1))

In [6]:
fista_net = FISTANet(FNET_LAYER_NO, FNET_FEATURE_NO, Psi)
fista_net = fista_net.to(device)# define arguments of fista_net


In [7]:
# summary(fista_net, input_size=(1, 64, 298), device=str(device))
print('Total number of parameters fista net:',
          sum(p.numel() for p in fista_net.parameters()))

Total number of parameters fista net: 18871


In [8]:
args = {
    'model_name': 'FISTANet',
    'num_epochs': EPOCH_NO,
    'lr': 1e-3,
    'data_dir': DATA_DIR,
    'save_path': './models/FISTANet/',
    'start_epoch': START_EPOCH,
    'multi_gpu': False,
    'device': device,
    'log_interval': 2,
    'test_epoch': TEST_EPOCH,
    'lr_dec_after': LR_DEC_AFTER,
    'lr_dec_every': LR_DEC_EVERY
}

In [9]:
solver = Solver(fista_net, trn_ldr, val_ldr, BATCH_SIZE, args, tst_ldr)

In [10]:
solver.train()

Training epoch 1...

				Disc: 1178447.441050 	Const: 730.669678		Spars: 0.005083
	 TVw: -0.501000 | TVb: -2.001000 | GSw: -0.201000 | GSb: 0.099000 | TSUw: 0.499000 | TSUb: 0.001000

				Disc: 1106879.818205 	Const: 689.427979		Spars: 0.005688
	 TVw: -0.502726 | TVb: -2.002721 | GSw: -0.202996 | GSb: 0.097004 | TSUw: 0.497004 | TSUb: 0.002996

				Disc: 1180334.794905 	Const: 763.093811		Spars: 0.007276
	 TVw: -0.504496 | TVb: -2.004486 | GSw: -0.204996 | GSb: 0.095004 | TSUw: 0.495003 | TSUb: 0.004996

				Disc: 970120.940354 	Const: 619.137085		Spars: 0.008914
	 TVw: -0.506346 | TVb: -2.006331 | GSw: -0.206979 | GSb: 0.093021 | TSUw: 0.493020 | TSUb: 0.006977



KeyboardInterrupt

