In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils.model import ECG_XNOR_Full_Bin, ECG_XNOR_Ori
from utils.OP import WeightOperation
from utils.dataset import Loader
from utils.engine import train
from utils.save_model import save_model

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
import numpy as np
import random
import os



classes_num = 5
test_size = 0.2
if classes_num == 17:
    batch_size = 64
    lr = 0.002
    seed = 142
else:
    batch_size = 512
    lr = 0.02
    seed = 101


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

loader = Loader(batch_size=batch_size, classes_num=classes_num, device=device, test_size=test_size)
labels, train_loader, test_loader = loader.loader()
# in_channels, out_channels,    kernel_size,     stride,    padding,   pad_value,   pool_size,  pool_stride
kernel_size, padding, poolsize =7, 5, 7
padding_value = 1
A = [[1,           8,           kernel_size,       2,       padding,       padding_value,       poolsize,        2],
     [8,          16,           kernel_size,       1,       padding,       padding_value,       poolsize,        2],
     [16,         32,           kernel_size,       1,       padding,       padding_value,       poolsize,        2],
     [32,         32,           kernel_size,       1,       padding,       padding_value,       poolsize,        2],
     [32,         64,           kernel_size,       1,       padding,       padding_value,       poolsize,        2],
     [64,         classes_num,  kernel_size,       1,       padding,       padding_value,       poolsize,        2],
     ]

model = ECG_XNOR_Ori(block1=A[0], block2=A[1], block3=A[2], block4=A[3],
                      block5=A[4] if len(A) > 4 else None,
                      block6=A[5] if len(A) > 5 else None,
                      block7=A[6] if len(A) > 6 else None,
                      device=device).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
print(device)
print(seed)

cuda:0
101


In [4]:
from torchinfo import summary
summary(model=model,
        input_size=(batch_size, 1, 3600),  # make sure this is "input_size", not "input_shape"
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
ECG_XNOR_Ori (ECG_XNOR_Ori)                   [512, 1, 3600]       [512, 5]             --                   True
├─Bn_bin_conv_pool_block_bw (block1)          [512, 1, 3600]       [512, 8, 898]        --                   True
│    └─ConstantPad1d (pad)                    [512, 1, 3600]       [512, 1, 3610]       --                   --
│    └─BinaryConv1d_bw (conv)                 [512, 1, 3610]       [512, 8, 1802]       56                   True
│    └─MaxPool1d (pool)                       [512, 8, 1802]       [512, 8, 898]        --                   --
│    └─PReLU (prelu)                          [512, 8, 898]        [512, 8, 898]        1                    True
│    └─BatchNorm1d (bn)                       [512, 8, 898]        [512, 8, 898]        16                   True
├─Bn_bin_conv_pool_block_baw (block2)         [512, 8, 898]        [512, 16, 448]      

In [5]:
weightOperation = WeightOperation(model)

In [6]:
num_epochs = 1000
best_test_acc = train(model=model,
      train_dataloader=train_loader,
      test_dataloader=test_loader,
      optimizer=optimizer,
      loss_fn=loss_fn,
      epochs=num_epochs,
      device=device,
      writer=False,
      weight_op=weightOperation,
      classes_num = classes_num)

print("-" * 50 + "\n")

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 5.9461 | train_acc: 0.4160 | test_loss: 3.0218 | test_acc: 0.6770
Epoch: 2 | train_loss: 2.2459 | train_acc: 0.6662 | test_loss: 0.9499 | test_acc: 0.7603
Epoch: 3 | train_loss: 1.6892 | train_acc: 0.6662 | test_loss: 0.8321 | test_acc: 0.7649
Epoch: 4 | train_loss: 1.2959 | train_acc: 0.6983 | test_loss: 0.7366 | test_acc: 0.7855
Epoch: 5 | train_loss: 0.9297 | train_acc: 0.7282 | test_loss: 0.5697 | test_acc: 0.7901
Epoch: 6 | train_loss: 0.8024 | train_acc: 0.7382 | test_loss: 0.7401 | test_acc: 0.7913
Epoch: 7 | train_loss: 0.7237 | train_acc: 0.7673 | test_loss: 0.7499 | test_acc: 0.7778
Epoch: 8 | train_loss: 0.6459 | train_acc: 0.7742 | test_loss: 0.6314 | test_acc: 0.7687
Epoch: 9 | train_loss: 0.6191 | train_acc: 0.7910 | test_loss: 0.5553 | test_acc: 0.7791
Epoch: 10 | train_loss: 0.5920 | train_acc: 0.7986 | test_loss: 0.5918 | test_acc: 0.8056
Epoch: 11 | train_loss: 0.5218 | train_acc: 0.8140 | test_loss: 0.4799 | test_acc: 0.8230
Epoch: 12 | train_l