In [1]:
from fistanet.M5FISTANet import FISTANet
from fistanet.loader import DataSplit
from fistanet.solver import Solver
from os.path import join as pjoin
from torchsummary import summary
import numpy as np
import torch

In [2]:
DATA_DIR = './data'
DATA_FILE_NAME = 'BW_gen_data_master.npy'
DICT_FILE_NAME_BW = 'dictionary_BW_real_data.npy'
DATA_SIZE = 10000
BATCH_SIZE = 1000
TVT_SPLIT = {
    'train': 80,
    'valid': 10,
    'test': 10
}

FNET_LAYER_NO = 4
FNET_FEATURE_NO = 16

EPOCH_NO = 100
START_EPOCH = 0
TEST_EPOCH = 9
LR_DEC_AFTER = 100
LR_DEC_EVERY = 10

In [3]:
trn_ldr, val_ldr, tst_ldr = DataSplit(DATA_DIR, DATA_FILE_NAME,
                                      TVT_SPLIT, BATCH_SIZE)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
Psi = np.load(pjoin(DATA_DIR, DICT_FILE_NAME_BW))
Psi = torch.from_numpy(Psi)
Psi = Psi.clone().detach().to(device=device)
Psi = Psi.repeat((BATCH_SIZE, 1, 1))

In [6]:
fista_net = FISTANet(FNET_LAYER_NO, FNET_FEATURE_NO, Psi)
fista_net = fista_net.to(device)# define arguments of fista_net


In [7]:
# summary(fista_net, input_size=(1, 64, 298), device=str(device))
print('Total number of parameters fista net:',
          sum(p.numel() for p in fista_net.parameters()))

Total number of parameters fista net: 18871


In [8]:
args = {
    'model_name': 'FISTANet',
    'num_epochs': EPOCH_NO,
    'lr': 1e-3,
    'data_dir': DATA_DIR,
    'save_path': './models/FISTANet/',
    'start_epoch': START_EPOCH,
    'multi_gpu': False,
    'device': device,
    'log_interval': 2,
    'test_epoch': TEST_EPOCH,
    'lr_dec_after': LR_DEC_AFTER,
    'lr_dec_every': LR_DEC_EVERY
}

In [9]:
solver = Solver(fista_net, trn_ldr, val_ldr, BATCH_SIZE, args, tst_ldr)

In [10]:
solver.train()

Training epoch 1...

				Disc: 1928572.563400 	Const: 1955.889038		Spars: 0.013711
	 TVw: -0.499000 | TVb: -1.999000 | GSw: -0.201000 | GSb: 0.099000 | TSUw: 0.499000 | TSUb: 0.001000

				Disc: 1884399.838599 	Const: 1881.620483		Spars: 0.014046
	 TVw: -0.500422 | TVb: -2.000318 | GSw: -0.202996 | GSb: 0.097004 | TSUw: 0.497004 | TSUb: 0.002996

				Disc: 1622858.727732 	Const: 1620.787964		Spars: 0.016536
	 TVw: -0.502162 | TVb: -2.001997 | GSw: -0.204981 | GSb: 0.095019 | TSUw: 0.495019 | TSUb: 0.004980


KeyboardInterrupt: 