In [1]:
import torch
from torch.utils.data import Dataset, Subset

from scipy.io import loadmat
import numpy as np

from collections.abc import Sequence
from sklearn.utils import resample

import os

def windowing(X_instants, R_instants, Y_instants, v_Hz=2000, window_time_s=.150, relative_overlap=.7, steady=True, steady_margin_s=1.5):
    """
        steady=True, steady_margin_s=0 -> finestre dove sample tutti della stessa label (o solo movimento o solo rest)
        steady=True, steady_margin_s=1.5 -> finestre dove tagli i primi e ultimi 1.5s di movimento
        steady=False -> tutte le finestre, anche quelle accavallate tra movimento e rest
    """
    
    # Centro della finestra (numero di campioni)
    #r = round((v_Hz * window_time_s - 1) / 2)
    r = int((v_Hz * window_time_s) / 2)
    # Ampiezza finestra
    #N = 2 * r + 1
    N = 2 * r
    # Campioni fuori finestra da guardare per capire se steady
    margin_samples = round(v_Hz * steady_margin_s)

    overlap_pixels = round(v_Hz * relative_overlap * window_time_s)
    slide = (N - overlap_pixels)
    
    M_instants, C = X_instants.shape
    # M = Numero di finestre
    M = (M_instants - N) // slide + 1
    
    # La label dovrebbe essere quello indicato nell'ultimo istante
    #Y_windows = Y_instants[-1 + N : M_instants : slide]
    Y_windows = Y_instants[r : M_instants - r : slide]
    # La repetition è quello che viene indicato a metà della finestra
    R_windows = R_instants[r : M_instants - r : slide]

    X_windows = np.zeros((M, N, C))
    is_steady_windows = np.zeros(M, dtype=bool)
    for m in range(M):
        c = r + m * slide # c is python-style
        
        #X_windows[m, :, :] = X_instants[c - r : c + r + 1, :]
        X_windows[m, :, :] = X_instants[c - r : c + r, :]

        if Y_instants[c] == 0: # rest position is not margined
            #is_steady_windows[m] = len(set(Y_instants[c - r: c + r + 1])) == 1
            is_steady_windows[m] = len(set(Y_instants[c - r: c + r])) == 1
        else:
            #is_steady_windows[m] = len(set(Y_instants[c - r - margin_samples : c + r + margin_samples + 1])) == 1
            is_steady_windows[m] = len(set(Y_instants[c - r - margin_samples : c + r + margin_samples])) == 1
    
    if steady:
        return X_windows[is_steady_windows], R_windows[is_steady_windows], Y_windows[is_steady_windows]
    return X_windows, R_windows, Y_windows

def read_session(filename):
    annots = loadmat(filename)

    X = annots['emg'][:, np.r_[0:8,10:16]]
    R = annots['rerepetition'].squeeze()
    y = annots['restimulus'].squeeze()

    # Fix class numbering (id -> index)
    y[y >= 3 ] -= 1
    y[y >= (6 - 1)] -= 1
    y[y >= (8 - 2)] -= 1
    y[y >= (9 - 3)] -= 1

    return X, R, y

class DB6Session(Dataset):

    def __init__(self, filename, n_classes='7+1', steady=True, minmax=False, **kwargs):
        if str(n_classes) not in {'7+1', '7'}:
            raise ValueError('Wrong n_classes')

        X, R, y = read_session(filename)

        self.X_min, self.X_max = None, None
        if minmax == True:
            self.X_min, self.X_max = X.min(axis=0), X.max(axis=0)
        if isinstance(minmax, Sequence) and minmax[0] is not None:
            self.X_min, self.X_max = minmax
        if self.X_min is not None:
            X_std = (X - self.X_min) / (self.X_max - self.X_min)
            X_scaled = X_std * 2 - 1
            X = X_scaled
            print("minmax", self.X_min, self.X_max)
        
        X_windows, R_windows, Y_windows = windowing(X, R, y, steady=steady, **kwargs)

        if n_classes == '7':
            # Filtra via finestre di non movimento
            mask = Y_windows != 0
            X_windows, R_windows, Y_windows = X_windows[mask], R_windows[mask], Y_windows[mask]

            # Rimappa label da 1-7 a 0-6
            Y_windows -= 1

        self.X = torch.tensor(X_windows, dtype=torch.float32).permute(0, 2, 1).unsqueeze(dim=2)
        self.Y = torch.tensor(Y_windows, dtype=torch.long)
        self.R = R_windows

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    def __len__(self):
        return self.Y.shape[0]

    def split(self, total_folds, val_fold=0):
        indices = np.arange(self.R.shape[0])
        train_mask = self.R % total_folds != val_fold
        return Subset(self, indices[train_mask]), Subset(self, indices[~train_mask])

    def split_0(self, total_folds, val_fold=0):
        indices = np.arange(self.R.shape[0])
        train_mask = self.R % total_folds != val_fold
        train_indices = indices[train_mask]
        val_indices = indices[~train_mask]

        train_indices_y = self[train_indices][1]
        train_indices_y0 = train_indices[train_indices_y == 0]
        train_indices_y1 = train_indices[train_indices_y != 0]
        sample_count_per_class_avg = len(train_indices_y1) // 7
        train_indices_y0 = resample(train_indices_y0, n_samples=sample_count_per_class_avg, replace=False)
        train_indices = np.concatenate([train_indices_y0, train_indices_y1], axis=0)

        return Subset(self, train_indices), Subset(self, val_indices)

    def split_1(self, total_folds, val_fold=0):
        indices = np.arange(self.R.shape[0])
        train_mask = self.R % total_folds != val_fold
        train_indices = indices[train_mask]
        val_indices = indices[~train_mask]

        train_indices_y = self[train_indices][1]
        train_indices_y0 = train_indices[train_indices_y == 0]
        train_indices_y1 = train_indices[train_indices_y != 0]
        train_indices_y1 = resample(train_indices_y1, n_samples=len(train_indices_y0) * 7, replace=True)
        train_indices = np.concatenate([train_indices_y0, train_indices_y1], axis=0)

        return Subset(self, train_indices), Subset(self, val_indices)
    
class SuperSet(Dataset):

    def __init__(self, *datasets):
        self.datasets = datasets
        self.lens = np.cumsum([0] + list(map(len, self.datasets)))

    def __getitem__(self, idx):
        dataset_idx = int(np.argwhere(self.lens > idx)[0]) - 1
        idx = idx - self.lens[dataset_idx]
        return self.datasets[dataset_idx][idx]

    def __len__(self):
        return self.lens[-1]
    
class DB6MultiSession(SuperSet):

    def __init__(self, subject, sessions, folder='.', **kwargs):
        self.sessions = [DB6Session(os.path.join(folder, f'S{subject}_D{(i // 2) + 1}_T{(i % 2) + 1}.mat'), **kwargs) for i in sessions]
        super().__init__(*self.sessions)

    def split(self, total_folds, val_fold=0):
        train_splits, val_splits = [], []
        for train_split, val_split in map(lambda x: x.split(total_folds=total_folds, val_fold=val_fold), self.sessions):
            train_splits.append(train_split)
            val_splits.append(val_split)

        return SuperSet(*train_splits), SuperSet(*val_splits)

    def split_0(self, total_folds, val_fold=0):
        train_splits, val_splits = [], []
        for train_split, val_split in map(lambda x: x.split_0(total_folds=total_folds, val_fold=val_fold), self.sessions):
            train_splits.append(train_split)
            val_splits.append(val_split)

        return SuperSet(*train_splits), SuperSet(*val_splits)

    def split_1(self, total_folds, val_fold=0):
        train_splits, val_splits = [], []
        for train_split, val_split in map(lambda x: x.split_1(total_folds=total_folds, val_fold=val_fold), self.sessions):
            train_splits.append(train_split)
            val_splits.append(val_split)

        return SuperSet(*train_splits), SuperSet(*val_splits)

In [4]:
train = DB6MultiSession(folder='dataset_DB6', subject=1, sessions=[0,], minmax=True)
test = DB6MultiSession(folder='dataset_DB6', subject=1, sessions=[3], minmax=(train.sessions[0].X_min, train.sessions[0].X_max))

minmax [-5.04863972e-04 -3.71088739e-04 -1.26357729e-04 -9.77940726e-05
 -8.86311100e-05 -2.18426168e-04 -2.09364778e-04 -1.08514459e-05
 -7.00831413e-04 -2.00893322e-04 -1.02204438e-04 -6.16701509e-05
 -1.18733544e-04 -1.48412204e-04] [1.29621930e-03 3.13192722e-04 1.38760413e-04 1.38142030e-04
 8.10787387e-05 1.94143475e-04 1.27040956e-04 7.61724368e-06
 5.01827104e-04 2.91279517e-04 1.90823819e-04 1.19949276e-04
 1.47030325e-04 1.46943945e-04]
minmax [-5.04863972e-04 -3.71088739e-04 -1.26357729e-04 -9.77940726e-05
 -8.86311100e-05 -2.18426168e-04 -2.09364778e-04 -1.08514459e-05
 -7.00831413e-04 -2.00893322e-04 -1.02204438e-04 -6.16701509e-05
 -1.18733544e-04 -1.48412204e-04] [1.29621930e-03 3.13192722e-04 1.38760413e-04 1.38142030e-04
 8.10787387e-05 1.94143475e-04 1.27040956e-04 7.61724368e-06
 5.01827104e-04 2.91279517e-04 1.90823819e-04 1.19949276e-04
 1.47030325e-04 1.46943945e-04]


In [8]:
train[0][0].shape

torch.Size([14, 1, 300])

In [10]:
import sys
if './master-thesis/src' not in sys.path:
    sys.path.append('./master-thesis/src')
    
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from datasets.db6 import DB6MultiSession

In [12]:
train_ = DB6MultiSession(folder='dataset_DB6', subject=1, sessions=[0,], minmax=True, image_like_shape=True)
test_ = DB6MultiSession(folder='dataset_DB6', subject=1, sessions=[3], minmax=(train_.X_min, train_.X_max), image_like_shape=True)

minmax [-5.04863972e-04 -3.71088739e-04 -1.26357729e-04 -9.77940726e-05
 -8.86311100e-05 -2.18426168e-04 -2.09364778e-04 -1.08514459e-05
 -7.00831413e-04 -2.00893322e-04 -1.02204438e-04 -6.16701509e-05
 -1.18733544e-04 -1.48412204e-04] [1.29621930e-03 3.13192722e-04 1.38760413e-04 1.38142030e-04
 8.10787387e-05 1.94143475e-04 1.27040956e-04 7.61724368e-06
 5.01827104e-04 2.91279517e-04 1.90823819e-04 1.19949276e-04
 1.47030325e-04 1.46943945e-04]
minmax [-5.04863972e-04 -3.71088739e-04 -1.26357729e-04 -9.77940726e-05
 -8.86311100e-05 -2.18426168e-04 -2.09364778e-04 -1.08514459e-05
 -7.00831413e-04 -2.00893322e-04 -1.02204438e-04 -6.16701509e-05
 -1.18733544e-04 -1.48412204e-04] [1.29621930e-03 3.13192722e-04 1.38760413e-04 1.38142030e-04
 8.10787387e-05 1.94143475e-04 1.27040956e-04 7.61724368e-06
 5.01827104e-04 2.91279517e-04 1.90823819e-04 1.19949276e-04
 1.47030325e-04 1.46943945e-04]


In [23]:
(train[1000][0] != train_[1000][0]).sum()

tensor(0)

In [115]:
import torch
from torch import nn

# Formula: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
def get_conv_output_size(input_size, kernel_size, stride=1, padding=0, dilation=1, **ignore):
    tuple_to_int = lambda x: int(x[0]) if isinstance(x, tuple) else int(x)
    kernel_size, stride, padding, dilation = tuple_to_int(kernel_size), tuple_to_int(stride), tuple_to_int(padding), tuple_to_int(dilation)
    return int( ( (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 )

class TEMPONet(nn.Module):
    
    def __init__(self, n_classes, input_size=300, input_channels=14):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, 32, 3, dilation=2, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 32, 3, dilation=2, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 64, 5, stride=1, padding=2),
            torch.nn.AvgPool1d(2, stride=2, padding=0),
            nn.BatchNorm1d(64),
            nn.ReLU(),
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(64, 64, 3, dilation=4, padding=4),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 3, dilation=4, padding=4),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 5, stride=2, padding=2),
            torch.nn.AvgPool1d(2, stride=2, padding=0),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(128, 128, 3, dilation=8, padding=8),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 128, 3, dilation=8, padding=8),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 128, 5, stride=4, padding=2),
            torch.nn.AvgPool1d(2, stride=2, padding=0),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        
        def get_fc_input_size():
            is_layer_conv = lambda x: isinstance(x, nn.Conv1d) or isinstance(x, nn.AvgPool1d)
            layers = list(filter(is_layer_conv, [*self.conv1, *self.conv2, *self.conv3]))
            
            output_size = input_size
            last_layer_output_planes = 1
            for layer in layers:
                output_size = get_conv_output_size(output_size, **vars(layer))
                last_layer_output_planes = layer.out_channels if hasattr(layer, "out_channels") else last_layer_output_planes
            
            return output_size * last_layer_output_planes

        self.fc = nn.Sequential(
            nn.Linear(get_fc_input_size(), 256), # input=640
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(128, n_classes),
        )
        

    def forward(self, x):

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = x.flatten(1)

        x = self.fc(x)
        
        return x

In [100]:
hasattr(net.conv1[6], "out_channels")

True

In [137]:
from thop import profile

net = TEMPONet(8)
net.train(False)
macs, params = profile(net, inputs=(torch.randn(1, 14, 300), ))
print()
print(macs / 1e6)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_avgpool() for <class 'torch.nn.modules.pooling.AvgPool1d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[91m[WARN] Cannot find rule for <class '__main__.TEMPONet'>. Treat it as zero Macs and zero Params.[00m

16.028672


In [138]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        #self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))
        #nn.init.kaiming_uniform_(self.pos_embedding, a=5 ** .5)
        nn.init.normal_(self.pos_embedding, std=.02)

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        #cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        #x = torch.cat((cls_tokens, x), dim=1)
        #x += self.pos_embedding[:, :(n + 1)]
        x += self.pos_embedding
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        

        x = self.to_latent(x)
        x = self.mlp_head(x)
        return x

In [1]:
from thop import profile

net = ViT(
    image_size = (1, 300),
    patch_size = (1, 20),
    channels = 14,
    num_classes = 8,
    dim = 300,
    depth = 3,
    heads = 3,
    mlp_dim = 300,
    dropout = .2,
    emb_dropout = 0,
    #pool = 'mean',
)

net.train(False)
macs, params = profile(net, inputs=(torch.randn(1, 14, 1, 300), ))
print()
print(macs / 1e6)

NameError: name 'ViT' is not defined