## TAB PyTorch Extension Show Case

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

import TAB_CUDA as TAB

### List the APIs of TAB

In [2]:
print(dir(TAB))

['Conv2d', 'Quantize', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__']


### Test the Quantization function

TAB.Quantize(torch::Tensor X, torch::Tensor thresholds, int bitwidth, int N, int H, int W, int C)

Return std::vector\<torch::Tensor\>: QW and BTN_W. 

BTN_W is only used in BTN, we add it in the function for unified API

Ternarize: qx= +1, x>ths; qx= -1, x<-ths; qx= 0, otherwise

In [3]:
KN=64
KH=3
KW=3
KC=256
bitwidth=2
w=torch.rand([KN,KH,KW,KC])
w_ths=0.5*torch.ones([KN])
QW, BTN_W = TAB.Quantize(w.cuda(),w_ths.cuda(),bitwidth, KN, KH, KW, KC)

Show the Size of the quantized tensor

The first bits of QW are zeros which means the quantized values only contain +1 and 0, because torch.rand() only produce values in (0,1)

In [4]:
print(QW.size(),BTN_W.size())
print(QW.dtype, QW.type())
print(QW[0, 0, 0, :, :])

torch.Size([64, 3, 3, 4, 2]) torch.Size([64])
torch.int64 torch.cuda.LongTensor
tensor([[                   0, -6706628582164491276],
        [                   0, -5642919529752776307],
        [                   0,  6859525101148296204],
        [                   0, -4685999609872530256]], device='cuda:0')


If we do binarization using the same data, then there will be +1 and -1 in the result

Binarize: qx=+1, x > ths; qx=-1 otherwise

In [5]:
bitwidth=1
QW, BTN_W = TAB.Quantize(w.cuda(),w_ths.cuda(),bitwidth, KN, KH, KW, KC)

In [6]:
print("QW.size=", QW.size())
print(QW.dtype, QW.type())
print(QW[0, 0, 0, :, :])

QW.size= torch.Size([64, 3, 3, 4, 1])
torch.int64 torch.cuda.LongTensor
tensor([[ 6706628582164491275],
        [ 5642919529752776306],
        [-6859525101148296205],
        [ 4685999609872530255]], device='cuda:0')


### Test the Conv2d function

TAB.Conv2d(torch::Tensor X, torch::Tensor QW, torch::Tensor thresholds, torch::Tensor btn, 
int type, int padding1, int padding2, int stride1, int stride2, int N,  int H, int W, int C, int KN, int KH, int KW)

Return the Conv2d result tensor

Type: 0: TNN, 1: TBN, 2, BTN, 3: BNN

TBN: Ternary-activation Binary-weight Network

BTN: Binary-activation Ternary-weight Network

In [7]:
# Config the activation and weitht tnesors shapes
N=16
H=112
W=112
C=256

pad1=1
pad2=1
str1=1
str2=1

conv_type=0

x=torch.rand([N,H,W,C])
x_ths=0.5*torch.ones([N])
y=TAB.Conv2d(x.cuda(), QW.cuda(), x_ths.cuda(), BTN_W.cuda(), conv_type, pad1, pad2, str1, str2, N, H, W, C, KN, KH, KW)

Show the Size of the conv result

In [8]:
print(y.size())

torch.Size([16, 64, 112, 112])


### Other cases

In [9]:
# TBN mode
conv_type=1
y=TAB.Conv2d(x.cuda(), QW.cuda(), x_ths.cuda(), BTN_W.cuda(), conv_type, pad1, pad2, str1, str2, N, H, W, C, KN, KH, KW)
print(y.size())

torch.Size([16, 64, 112, 112])


In [10]:
# BTN mode
conv_type=2
y=TAB.Conv2d(x.cuda(), QW.cuda(), x_ths.cuda(), BTN_W.cuda(), conv_type, pad1, pad2, str1, str2, N, H, W, C, KN, KH, KW)
print(y.size())

torch.Size([16, 64, 112, 112])


In [11]:
# BNN mode
bitwidth=1
QW, BTN_W = TAB.Quantize(w.cuda(),w_ths.cuda(),bitwidth, KN, KH, KW, KC)
print(QW.size())

conv_type=3
y=TAB.Conv2d(x.cuda(), QW.cuda(), x_ths.cuda(), BTN_W.cuda(), conv_type, pad1, pad2, str1, str2, N, H, W, C, KN, KH, KW)
print(y.size())
print(y[:,0,0,0])

torch.Size([64, 3, 3, 4, 1])
torch.Size([16, 64, 112, 112])
tensor([ -6.,  20.,  32.,  22.,  38.,   8., -22.,  -4., -58.,  34., -62.,  10.,
          0., -26., -38., -36.], device='cuda:0')
