In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
import os
from utils.datasets import DatasetLoader

# Load data

In [2]:
# Load dataset

dataset_path = './data/UWaveGestureLibrary'
dataset = DatasetLoader(dataset_path)
train_data, test_data = dataset.load_to_df()

In [3]:
X_train, y_train, X_test, y_test = dataset.load_to_nparray()

In [4]:
display(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(120, 3, 315)

(120,)

(320, 3, 315)

(320,)

In [7]:
m = nn.Conv1d(16, 33, 3, stride=2)
input = torch.randn(20, 16, 50)
output = m(input)
output.shape

torch.Size([20, 33, 24])

# CAM-based

## MTEX-CNN

In [13]:
class MTEXCNN(nn.Module):
    def __init__(self, time_length, n_classes):
        super(MTEXCNN, self).__init__()
        self.conv_1 = nn.Conv2d(1, 16, (time_length//2 + 1, 1))
        self.conv_2 = nn.Conv2d(16, 32, (time_length//4 + 1, 1))
        self.conv_3 = nn.Conv2d(32, 1, 1)
        self.conv_4 = nn.Conv1d(3, 64, 3)
        self.linear_1 = nn.Linear(time_length//4 + 1 - 2, 16)
        self.linear_2 = nn.Linear(16, n_classes)
        
    def forward(self, x):
        x = self.conv_1(x)
        x = nn.ReLU(x)
        x = self.conv_2(x)
        x = nn.ReLU(x)
        x = self.conv_3(x)
        x = nn.ReLU(x)
        x = self.conv_4(x)
        x = nn.ReLU(x)
        x = x.view(x.size(0),-1)
        x = self.linear_1(x)
        x = self.linear_2(x)

In [15]:
net = MTEXCNN(X_train.shape[2], len(np.unique(y_train)))
net

MTEXCNN(
  (conv_1): Conv2d(1, 16, kernel_size=(158, 1), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(79, 1), stride=(1, 1))
  (conv_3): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (conv_4): Conv1d(3, 64, kernel_size=(3,), stride=(1,))
  (linear_1): Linear(in_features=77, out_features=16, bias=True)
  (linear_2): Linear(in_features=16, out_features=8, bias=True)
)

In [None]:
|