In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader

from src.datasets.db6 import DB6MultiSession
from pickle import load

import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [185]:
# MODEL

def bmm(a, b):
    r = []
    for i in range(a.shape[0]):
        r.append(nn.functional.linear(a[i], b[i].T))
    return torch.stack(r)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            torch.quantization.DeQuantStub(),
            nn.GELU(),
            torch.quantization.QuantStub(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        #self._init_parameters(dim, hidden_dim)

    def _init_parameters(self, dim, hidden_dim):
        bound1 = 1 / (dim ** .5)
        bound2 = 1 / (hidden_dim ** .5)
        nn.init.uniform_(self.net[0].weight, -bound1, bound1)
        nn.init.uniform_(self.net[0].bias, -bound1, bound1)
        nn.init.uniform_(self.net[3].weight, -bound2, bound2)
        nn.init.uniform_(self.net[0].bias, -bound2, bound2)

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()

        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.dim_head = dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.quant_k = torch.quantization.QuantStub()
        self.dequant_k = torch.quantization.DeQuantStub()
        self.quant_sm = torch.quantization.QuantStub()
        self.dequant_sm = torch.quantization.DeQuantStub()
        self.quant_v = torch.quantization.QuantStub()
        self.dequant_v = torch.quantization.DeQuantStub()
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        #self._init_parameters(dim, inner_dim)

    def _init_parameters(self, dim, inner_dim):
        bound = 1 / (dim ** .5)
        nn.init.uniform_(self.to_q.weight, -bound, bound)
        nn.init.uniform_(self.to_k.weight, -bound, bound)
        nn.init.uniform_(self.to_v.weight, -bound, bound)

        bound = 1 / (inner_dim ** .5)
        nn.init.uniform_(self.to_out[0].weight, -bound, bound)
        nn.init.uniform_(self.to_out[0].bias, -bound, bound)

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
        q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)
        k = k.reshape(b, n, h, -1).permute(0, 2, 1, 3)
        v = v.reshape(b, n, h, -1).permute(0, 2, 1, 3)
        
        #dots = (q @ k.transpose(-2, -1)) * self.scale
        k = self.quant_k(self.dequant_k(k))
        k = k.transpose(-2, -1)
        dots = []
        for i in range(b):
            dots.append(bmm(q[i], k[i]))
        dots = torch.stack(dots)
        
        attn = self.quant_sm(self.attend(self.dequant_sm(dots)))
        
        #out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
        v_ = self.quant_v(self.dequant_v(v))
        out = []
        for i in range(b):
            out.append(bmm(attn[i], v_[i]))
        out = torch.stack(out).transpose(1, 2).reshape(b, n, -1)
        
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
        self.quant_s1 = torch.quantization.QuantStub()
        self.dequant_s1 = torch.quantization.DeQuantStub()
        self.quant_s2 = torch.quantization.QuantStub()
        self.dequant_s2 = torch.quantization.DeQuantStub()

    def forward(self, x):
        for attn, ff in self.layers:
            x = self.quant_s1(self.dequant_s1(attn(x)) + self.dequant_s1(x))
            x = self.quant_s2( self.dequant_s2(ff(x)) + self.dequant_s1(x))
        return x

class ViT(nn.Module):
    def __init__(self, window_size=(14, 300), patch_length=10, num_classes=8, dim=64, depth=1, heads=8, mlp_dim=128, pool='cls', dim_head=32, dropout=.2, emb_dropout=0., use_cls_token=True):
        super().__init__()

        channels, window_length = window_size
        num_patches = (window_length // patch_length)
        patch_dim = channels * patch_length

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_conv = nn.Conv1d(in_channels=channels, out_channels=dim, kernel_size=patch_length, stride=patch_length, padding=0, bias=True)

        self.use_cls_token = use_cls_token
        if self.use_cls_token:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches + 1, dim))
        else:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))

        self.cls_token = nn.Parameter(torch.empty(1, 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

        self._init_parameters(patch_dim)

    def _init_parameters(self, patch_dim):
        bound = 1 / (patch_dim ** .5)
        nn.init.uniform_(self.patch_conv.weight, -bound, bound)
        nn.init.uniform_(self.patch_conv.bias, -bound, bound)
        nn.init.zeros_(self.pos_embedding)
        nn.init.zeros_(self.mlp_head[1].weight)
        nn.init.zeros_(self.mlp_head[1].bias)

    def forward(self, x):
        x = self.patch_conv(x).flatten(2).transpose(-2, -1)

        b, n, _ = x.shape
        
        
        """if self.use_cls_token:
            cls_tokens = self.cls_token.expand(b, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            x += self.pos_embedding[:, :(n + 1)]
        else :
            x += self.pos_embedding
        """
        
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        x = self.mlp_head(x)
        

        return x

In [4]:
def load_dataset(dataset_dir, subject):
    
    all_other_subjects = ','.join([str(s) for s in range(1, 11) if s != subject])
    minmax_picklename = f'./minmax/ds_minmax_sessions=5subjects={all_other_subjects}.pickle'
    minmax = load(open(minmax_picklename, 'rb'))
    
    test_ds = DB6MultiSession(folder=os.path.expanduser(dataset_dir), 
                              subjects=[subject], sessions=list(range(5, 10)), 
                              minmax=minmax, n_classes='7+1', steady=True).to(device)
    
    return test_ds

In [5]:
def load_training_set(dataset_dir, subject):
    
    all_other_subjects = ','.join([str(s) for s in range(1, 11) if s != subject])
    minmax_picklename = f'./minmax/ds_minmax_sessions=5subjects={all_other_subjects}.pickle'
    minmax = load(open(minmax_picklename, 'rb'))
    
    ds = DB6MultiSession(folder=os.path.expanduser(dataset_dir), 
                              subjects=[subject], sessions=list(range(5)), 
                              minmax=minmax, n_classes='7+1', steady=True).to(device)
    
    return ds

In [6]:
def load_model(subject, training_fold):
    net = ViT()
    net.to(device)
    net.eval()
    net.load_state_dict((torch.load(f"checkpoints/vit_subject{subject}_fold{training_fold}.pth")))
    return net

In [7]:
@torch.no_grad()
def get_loss_preds(net, criterion, loader):
    y_pred, y_true = [], []
    loss = 0
    for X_batch, Y_batch in loader:
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        
        outputs = net(X_batch)
        _, predicted = torch.max(outputs, 1)
        loss += criterion(outputs, Y_batch).item()

        y_pred.append(predicted.cpu())
        y_true.append(Y_batch.cpu())
        
    y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
    loss /= len(loader)
    
    return loss, (y_pred, y_true)

In [8]:
test_ds = load_dataset(dataset_dir='../../dataset_DB6', subject=5)
ds_loader = DataLoader(test_ds, batch_size=1000, shuffle=False, pin_memory=False, drop_last=False)

minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0.00688035]
minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0.00688035]
minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0

In [9]:
train_ds = load_training_set(dataset_dir='../../dataset_DB6', subject=5)

minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0.00688035]
minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0.00688035]
minmax [-0.00820696 -0.00955554 -0.00625532 -0.0054779  -0.00636441 -0.00635652
 -0.0047272  -0.0016807  -0.01117341 -0.00731644 -0.00749952 -0.00527726
 -0.0054686  -0.00693581] [0.00775334 0.00828333 0.0070026  0.0048873  0.00684452 0.00630389
 0.0061128  0.00096831 0.01065748 0.0073393  0.00718339 0.00572001
 0.0059097  0

In [10]:
net = load_model(subject=5, training_fold=1)
_, (y_pred, y_true) = get_loss_preds(net, nn.CrossEntropyLoss(), ds_loader)
accuracy_fold1 = (y_pred == y_true).float().mean()

net = load_model(subject=5, training_fold=2)
_, (y_pred, y_true) = get_loss_preds(net, nn.CrossEntropyLoss(), ds_loader)
accuracy_fold2 = (y_pred == y_true).float().mean()

print(.5 * (accuracy_fold1 + accuracy_fold2))

tensor(0.5841)


In [11]:
class M(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.vit = ViT(*args, **kwargs)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # during the convert step, this will be replaced with a
        # `quantize_per_tensor` call
        x = self.quant(x)
        x = self.vit(x)
        x = self.dequant(x)
        return x

In [12]:
subject, training_fold = 5, 1

# create a model instance
model_fp32 = M()
model_fp32.vit.load_state_dict((torch.load(f"checkpoints/vit_subject{subject}_fold{training_fold}.pth")))

# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear, torch.nn.GELU, torch.nn.Softmax, torch.nn.Conv1d, torch.nn.LayerNorm},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

In [13]:
_, (y_pred, y_true) = get_loss_preds_q(model_int8, nn.CrossEntropyLoss(), ds_loader)
accuracy_fold1 = (y_pred == y_true).float().mean()

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "c:\users\francesco\appdata\local\programs\python\python37\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-13-014cc4e12590>", line 1, in <module>
    _, (y_pred, y_true) = get_loss_preds_q(model_int8, nn.CrossEntropyLoss(), ds_loader)
NameError: name 'get_loss_preds_q' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\users\francesco\appdata\local\programs\python\python37\lib\site-packages\IPython\core\interactiveshell.py", line 2061, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\users\francesco\appdata\local\programs\python\python37\lib\site-packages\IPython\core\ultr

TypeError: object of type 'NoneType' has no len()

In [None]:
print(accuracy_fold1)

In [None]:
net = load_model(subject=5, training_fold=1)
_, (y_pred, y_true) = get_loss_preds(net, nn.CrossEntropyLoss(), ds_loader)
accuracy_fold1 = (y_pred == y_true).float().mean()


In [None]:
print(accuracy_fold1)

In [187]:
subject, training_fold = 5, 1
model_fp32 = M()
sd = torch.load(f"checkpoints/vit_subject{subject}_fold{training_fold}.pth")
sd['transformer.layers.0.1.fn.net.5.weight'] = sd.pop('transformer.layers.0.1.fn.net.3.weight')
sd['transformer.layers.0.1.fn.net.5.bias'] = sd.pop('transformer.layers.0.1.fn.net.3.bias')
model_fp32.vit.load_state_dict(sd)
#load_model(subject=5, training_fold=1)

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
#model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

qconfig_a_qint8 = torch.quantization.QConfig(
    activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8),
    weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8),
)


#model_fp32.vit.qconfig = torch.quantization.default_qconfig
#model_fp32.vit.transformer.layers[0][0].fn.to_k.qconfig = qconfig_a_qint8
#model_fp32.vit.transformer.layers[0][0].fn.to_v.qconfig = qconfig_a_qint8
model_fp32.vit.transformer.layers[0][0].fn.quant_k.qconfig = qconfig_a_qint8
model_fp32.vit.transformer.layers[0][0].fn.quant_v.qconfig = qconfig_a_qint8
model_fp32.qconfig = torch.quantization.default_qconfig

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
#model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_fused = model_fp32

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

In [188]:
def bmm(a, b):
    r = []
    for i in range(a.shape[0]):
        r.append(nn.functional.linear(a[i], b[i].T))
    return torch.stack(r)

train_ds_loader = DataLoader(train_ds.split(total_folds=2, val_fold=0)[0], batch_size=1000, shuffle=False, pin_memory=False, drop_last=False)
for X_batch, _ in train_ds_loader:
    X_batch = X_batch.to(device)
    model_fp32_prepared(X_batch)

In [189]:
model_int8 = torch.quantization.convert(model_fp32_prepared)

In [190]:
def bmm(a, b):
    r = []
    #print(a.type(), b.type())
    for i in range(a.shape[0]):
        r.append(nn.quantized.functional.linear(a[i], b[i].T))
    return torch.stack(r)

In [191]:
_, (y_pred, y_true) = get_loss_preds(model_int8, nn.CrossEntropyLoss(), ds_loader)
accuracy_fold1 = (y_pred == y_true).float().mean()

In [192]:
accuracy_fold1

tensor(0.5324)

In [51]:
a = torch.randn(7, 100, 12)
b = torch.randn(7, 12, 100)

In [53]:
(a @ b).shape

torch.Size([7, 100, 100])

In [61]:
(bmm(a, b) != (a @b)).sum()

tensor(0)