In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets
from torchsummary import summary
from torchviz import make_dot
import matplotlib.pyplot as plt
import numpy as np
import os
from utils.datasets import DatasetLoader
from models.CAM_based_models import MTEXCNN
from models.CAM_based_models import XCM

# Load data

In [2]:
# Load dataset

dataset_path = './data/UWaveGestureLibrary'
dataset = DatasetLoader(dataset_path)
train_data, test_data = dataset.load_to_df()

In [3]:
X_train, y_train, X_test, y_test = dataset.load_to_nparray()

In [4]:
display(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(120, 3, 315)

(120,)

(320, 3, 315)

(320,)

# CAM-based models

## MTEX-CNN

In [5]:
net = MTEXCNN(X_train.shape[2], len(np.unique(y_train)))
net

MTEXCNN(
  (conv_1): Conv2d(1, 16, kernel_size=(158, 1), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(79, 1), stride=(1, 1))
  (conv_3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (conv_4): Conv1d(3, 64, kernel_size=(3,), stride=(1,))
  (fc1): Linear(in_features=4992, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=8, bias=True)
)

In [6]:
summary(net, (1,315,3))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 158, 3]           2,544
            Conv2d-2            [-1, 32, 80, 3]          40,480
            Conv2d-3             [-1, 1, 80, 3]              33
            Conv1d-4               [-1, 64, 78]             640
            Linear-5                   [-1, 32]         159,776
            Linear-6                    [-1, 8]             264
Total params: 203,737
Trainable params: 203,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.16
Params size (MB): 0.78
Estimated Total Size (MB): 0.94
----------------------------------------------------------------


## XCM

In [7]:
net1 = XCM(93,X_train.shape[2], X_train.shape[1], len(np.unique(y_train)))
net1

XCM(
  (conv_11): Conv2d(1, 16, kernel_size=(93, 93), stride=(1, 1), padding=(46, 46))
  (batchnorm_11): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_12): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  (conv_21): Conv1d(3, 16, kernel_size=(3,), stride=(1,), padding=(1,))
  (batchnorm_21): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_22): Conv1d(16, 1, kernel_size=(1,), stride=(1,))
  (conv_3): Conv1d(315, 32, kernel_size=(93,), stride=(1,), padding=(46,))
  (batchnorm_3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (glb_avg_pool): AvgPool1d(kernel_size=(4,), stride=(4,), padding=(0,))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=32, out_features=8, bias=True)
  (softmax): Softmax(dim=1)
)

In [8]:
summary(net1, (1,315,3))

torch.Size([2, 1, 315, 3]) torch.Size([2, 1, 315])
torch.Size([2, 32, 4])
torch.Size([2, 32, 1])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 315, 3]         138,400
       BatchNorm2d-2           [-1, 16, 315, 3]              32
            Conv2d-3            [-1, 1, 315, 3]              17
            Conv1d-4              [-1, 16, 315]             160
       BatchNorm1d-5              [-1, 16, 315]              32
            Conv1d-6               [-1, 1, 315]              17
            Conv1d-7                [-1, 32, 4]         937,472
       BatchNorm1d-8                [-1, 32, 4]              64
         AvgPool1d-9                [-1, 32, 1]               0
          Flatten-10                   [-1, 32]               0
           Linear-11                    [-1, 8]             264
          Softmax-12                    [-1, 8]               0
Total 