In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets
from torchsummary import summary
from torchviz import make_dot
import matplotlib.pyplot as plt
import numpy as np
import os
from utils.datasets import DatasetLoader
from models.CAM_based_models import MTEXCNN

# Load data

In [2]:
# Load dataset

dataset_path = './data/UWaveGestureLibrary'
dataset = DatasetLoader(dataset_path)
train_data, test_data = dataset.load_to_df()

In [3]:
X_train, y_train, X_test, y_test = dataset.load_to_nparray()

In [4]:
display(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(120, 3, 315)

(120,)

(320, 3, 315)

(320,)

# CAM-based

## MTEX-CNN

In [5]:
net = MTEXCNN(X_train.shape[2], len(np.unique(y_train)))
net

MTEXCNN(
  (conv_1): Conv2d(1, 16, kernel_size=(158, 1), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(79, 1), stride=(1, 1))
  (conv_3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (conv_4): Conv1d(3, 64, kernel_size=(3,), stride=(1,))
  (linear_1): Linear(in_features=4992, out_features=32, bias=True)
  (linear_2): Linear(in_features=32, out_features=8, bias=True)
)

In [6]:
summary(net, (1,315,3))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 158, 3]           2,544
            Conv2d-2            [-1, 32, 80, 3]          40,480
            Conv2d-3             [-1, 1, 80, 3]              33
            Conv1d-4               [-1, 64, 78]             640
            Linear-5                   [-1, 32]         159,776
            Linear-6                    [-1, 8]             264
Total params: 203,737
Trainable params: 203,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.16
Params size (MB): 0.78
Estimated Total Size (MB): 0.94
----------------------------------------------------------------
