In [1]:
import qtorch
import torch

In [2]:
torch.cuda.is_available = lambda : False

In [3]:
from qtorch import FloatingPoint
from qtorch.quant import Quantizer

# define floating point format
bit_8 = FloatingPoint(exp=5, man=2)
# create a quantizer
factor_Q = Quantizer(forward_number=bit_8, forward_rounding="nearest")

No CUDA runtime is found, using CUDA_HOME='/usr'


In [4]:
import torch.nn as nn
import torch.nn.functional as F

class QTensorFusion(nn.Module):

    def __init__(self, input_sizes, output_size, dropout=0.0, bias=True, device=None, dtype=None):

        super().__init__()

        self.input_sizes = input_sizes
        self.output_size = output_size
        self.dropout = nn.Dropout(dropout)
        
        # initialize weight tensor
        tensorized_shape = input_sizes + (output_size,)
        self.weight_tensor = nn.Parameter(torch.empty(tensorized_shape, device=device, dtype=dtype))
        nn.init.xavier_normal_(self.weight_tensor)

        # initialize bias
        if bias:
            self.bias = nn.Parameter(torch.full((output_size,), 0.1, device=device, dtype=dtype))
        else:
            self.bias = None

    def forward(self, inputs):

        fusion_tensor = inputs[0]
        for x in inputs[1:]:
            fusion_tensor = factor_Q(torch.einsum('n...,na->n...a', fusion_tensor, x))
        
        fusion_tensor = self.dropout(fusion_tensor)

        output = factor_Q(torch.einsum('n...,...o->no', fusion_tensor, self.weight_tensor))

        if self.bias is not None:
            output = factor_Q(output + self.bias)
        

        output = F.relu(output)
        output = self.dropout(output)

        return output

In [5]:
from tensor_fusion.module import TensorFusion

fusion_layer = TensorFusion((10, 20, 30), 10)
x1 = torch.rand((4,10))
x2 = torch.rand((4,20))
x3 = torch.rand((4,30))
y = fusion_layer([x1, x2, x3])



In [6]:
y

tensor([[0.0000, 0.0860, 0.0023, 0.0323, 0.0000, 0.0000, 0.0000, 0.2301, 0.0045,
         0.3891],
        [0.0000, 0.3390, 0.1816, 0.0022, 0.0000, 0.0000, 0.1642, 0.1613, 0.0999,
         0.0976],
        [0.1569, 0.3022, 0.2595, 0.0000, 0.1094, 0.0000, 0.2626, 0.0370, 0.2038,
         0.2950],
        [0.3092, 0.5297, 0.3148, 0.0000, 0.2240, 0.0000, 0.0000, 0.4008, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

In [7]:
q_fusion_layer = QTensorFusion((10, 20, 30), 10)
q_fusion_layer.weight_tensor = fusion_layer.weight_tensor
q_fusion_layer.bias = fusion_layer.bias
x1 = torch.rand((4,10))
x2 = torch.rand((4,20))
x3 = torch.rand((4,30))
y = fusion_layer([x1, x2, x3])

In [8]:
y

tensor([[0.0000, 0.1949, 0.1740, 0.0000, 0.0779, 0.2369, 0.1970, 0.3585, 0.0000,
         0.1788],
        [0.1584, 0.0937, 0.0861, 0.1034, 0.0461, 0.0000, 0.1266, 0.0000, 0.0000,
         0.3539],
        [0.0000, 0.2110, 0.2696, 0.0000, 0.0000, 0.0000, 0.0700, 0.2734, 0.3019,
         0.2641],
        [0.1048, 0.2091, 0.1398, 0.2077, 0.3056, 0.1765, 0.2426, 0.1022, 0.0000,
         0.2376]], grad_fn=<ReluBackward0>)

In [9]:
import numpy as np

def tt_times_matrix_fwd(tensor, matrix, return_saved_tensors):
    """
    This function takes the input tensor "tensor", the input matrix "matrix"
    and returns tensor times matrix as well as any extra tensors you decide to save
    for the backward pass
    """
    #Author Alvin Liu

    ndims = tensor.order
    d = int(ndims / 2)
    tt_shape = tensor.shape
    tt_ranks = tensor.rank[1:-1]
    tt_shape_row = tt_shape[:d]
    tt_shape_col = tt_shape[d:]
    tt_rows = np.prod(tt_shape_row)
    tt_cols = np.prod(tt_shape_col)
    matrix_rows = matrix.shape[0]
    matrix_cols = matrix.shape[1]
    if tt_rows is not None and matrix_rows is not None:
        if tt_rows != matrix_rows:
            raise ValueError(
                'Arguments shapes should align got %s and %s instead.' %
                ((tt_rows, tt_cols), (matrix_rows, matrix_cols)))

    # Matrix: M * K, tensor: M * N = (i_0, i_1, ..., i_d-1) * (j_0, j_1, ..., j_d-1)
    # The shape of data is 1 * i_0 * (i_1, i_2, ..., i_d-1, K)
    data = matrix
    data = data.reshape(1, tt_shape_row[0], -1)
    saved_tensors = [matrix] if return_saved_tensors else None

    for k in range(d):
        # The shape of data is r_k * i_k * (i_k+1, ..., i_d-1, K)
        # The shape of curr_core (core_k) is r_k * i_k * r_k+1
        # After einsum() the shape of data is r_k+1 * (i_k+1, ..., i_d-1, K)
        curr_core = tensor.factors[k]
        data = torch.einsum('ria, rib->ba', [data, curr_core])

        if k < d - 1:
            # After reshape the data, the shape is r_k+1 * i_k+1 * (i_k+2, ..., i_d-1, K)
            data = data.reshape(tt_ranks[k], tt_shape_row[k + 1], -1)

        saved_tensors.append(data) if return_saved_tensors else None

    # Now the shape of data is r_d * K
    for k in range(d):
        # The shape of data is r_d+k * (K, j_0, ..., j_k-1)
        # The shape of curr_core (core_d+k) is r_d+k * j_k * r_d+k+1
        # After einsum() the shape of data is r_d+k+1 * (K, j_0, ..., j_k-1) * j_k
        curr_core = tensor.factors[k + d]
        data = torch.einsum('ra, rjb->baj', [data, curr_core])

        if k < d - 1:
            saved_tensors.append(data.reshape(data.shape[0], matrix_cols, -1)) if return_saved_tensors else None
            # After reshape the data, the shape is r_d+k+1 * (K, j_0, ..., j_k)
            data = data.reshape(tt_ranks[k + d], -1)

    # The shape of data is 1 * (K, j_0, ..., j_d-2) * j_d-1
    # The shape of output is K * (j_0, ..., j_d-1)
    output = data.reshape(matrix_cols, tt_cols)

    if return_saved_tensors:
        return output, saved_tensors
    else:
        return output

In [10]:
from tensor_fusion.low_rank_tensor import TT
weight_tensor = TT(128, 64, 5)

In [11]:
matrix = torch.rand(5, 128)
out = tt_times_matrix_fwd(weight_tensor.tensor, matrix.T, False)

In [12]:
out

tensor([[-2.8147e-01, -4.6852e-01,  2.6628e+00, -7.1536e-01, -9.8626e-02,
          4.9763e-01,  2.9270e-01, -7.0394e-01, -3.2957e-01, -4.6288e-01,
          1.6270e-01, -3.6290e-01,  2.3649e-01,  7.0094e-01,  1.3807e+00,
          3.5507e-01, -7.4049e-02, -4.8794e-01, -2.2668e-01, -7.4344e-01,
          1.1270e-01,  3.2861e-01,  3.2756e-01, -2.7821e-01, -4.0464e-01,
         -3.8620e-02,  1.1700e+00, -5.8933e-01, -3.2322e-01, -5.5988e-01,
          1.0103e-01, -5.2083e-01, -4.8731e-01, -3.8730e-01,  1.1940e+00,
         -6.7260e-01,  9.0283e-02,  1.4722e-01, -1.1062e+00, -1.0455e-01,
          2.5446e-01, -1.6950e-01, -7.0476e-01,  1.8621e-01, -2.3740e-01,
         -6.5926e-01,  1.5572e-01, -5.6310e-01,  8.0796e-01,  9.5150e-01,
         -8.8637e-01,  1.2115e+00, -4.0965e-01, -1.4050e-01,  2.5246e+00,
         -1.0110e-01, -1.0198e+00, -1.7574e-01,  2.3754e+00, -1.0628e+00,
          5.5995e-01,  1.1560e+00, -2.2987e-01,  1.0955e+00],
        [-1.8364e-01, -2.5672e-01,  8.0541e-01, -3

In [13]:
import numpy as np

def q_tt_times_matrix_fwd(tensor, matrix, return_saved_tensors):
    """
    This function takes the input tensor "tensor", the input matrix "matrix"
    and returns tensor times matrix as well as any extra tensors you decide to save
    for the backward pass
    """
    #Author Alvin Liu

    ndims = tensor.order
    d = int(ndims / 2)
    tt_shape = tensor.shape
    tt_ranks = tensor.rank[1:-1]
    tt_shape_row = tt_shape[:d]
    tt_shape_col = tt_shape[d:]
    tt_rows = np.prod(tt_shape_row)
    tt_cols = np.prod(tt_shape_col)
    matrix_rows = matrix.shape[0]
    matrix_cols = matrix.shape[1]
    if tt_rows is not None and matrix_rows is not None:
        if tt_rows != matrix_rows:
            raise ValueError(
                'Arguments shapes should align got %s and %s instead.' %
                ((tt_rows, tt_cols), (matrix_rows, matrix_cols)))

    # Matrix: M * K, tensor: M * N = (i_0, i_1, ..., i_d-1) * (j_0, j_1, ..., j_d-1)
    # The shape of data is 1 * i_0 * (i_1, i_2, ..., i_d-1, K)
    data = matrix
    data = data.reshape(1, tt_shape_row[0], -1)
    saved_tensors = [matrix] if return_saved_tensors else None

    for k in range(d):
        # The shape of data is r_k * i_k * (i_k+1, ..., i_d-1, K)
        # The shape of curr_core (core_k) is r_k * i_k * r_k+1
        # After einsum() the shape of data is r_k+1 * (i_k+1, ..., i_d-1, K)
        curr_core = tensor.factors[k]
        data = factor_Q(torch.einsum('ria, rib->ba', [data, curr_core]))

        if k < d - 1:
            # After reshape the data, the shape is r_k+1 * i_k+1 * (i_k+2, ..., i_d-1, K)
            data = data.reshape(tt_ranks[k], tt_shape_row[k + 1], -1)

        saved_tensors.append(data) if return_saved_tensors else None

    # Now the shape of data is r_d * K
    for k in range(d):
        # The shape of data is r_d+k * (K, j_0, ..., j_k-1)
        # The shape of curr_core (core_d+k) is r_d+k * j_k * r_d+k+1
        # After einsum() the shape of data is r_d+k+1 * (K, j_0, ..., j_k-1) * j_k
        curr_core = tensor.factors[k + d]
        data = factor_Q(torch.einsum('ra, rjb->baj', [data, curr_core]))

        if k < d - 1:
            saved_tensors.append(data.reshape(data.shape[0], matrix_cols, -1)) if return_saved_tensors else None
            # After reshape the data, the shape is r_d+k+1 * (K, j_0, ..., j_k)
            data = data.reshape(tt_ranks[k + d], -1)

    # The shape of data is 1 * (K, j_0, ..., j_d-2) * j_d-1
    # The shape of output is K * (j_0, ..., j_d-1)
    output = data.reshape(matrix_cols, tt_cols)

    if return_saved_tensors:
        return output, saved_tensors
    else:
        return output

In [14]:
out = q_tt_times_matrix_fwd(weight_tensor.tensor, matrix.T, False)

In [15]:
out

tensor([[-3.1250e-01, -4.3750e-01,  3.0000e+00, -6.2500e-01, -1.9531e-02,
          5.0000e-01,  1.5625e-01, -6.2500e-01, -3.7500e-01, -4.3750e-01,
          1.5625e-01, -3.7500e-01,  2.5000e-01,  7.5000e-01,  1.5000e+00,
          3.1250e-01, -1.5625e-01, -6.2500e-01, -3.1250e-01, -8.7500e-01,
          1.2500e-01,  3.7500e-01,  2.5000e-01, -3.1250e-01, -3.7500e-01,
         -2.7344e-02,  1.0000e+00, -5.0000e-01, -4.3750e-01, -7.5000e-01,
          6.2500e-02, -6.2500e-01, -5.0000e-01, -4.3750e-01,  1.2500e+00,
         -8.7500e-01,  1.0938e-01,  1.5625e-01, -1.2500e+00, -1.2500e-01,
          2.5000e-01, -1.8750e-01, -6.2500e-01,  1.5625e-01, -2.5000e-01,
         -7.5000e-01,  6.2500e-02, -6.2500e-01,  7.5000e-01,  8.7500e-01,
         -6.2500e-01,  1.2500e+00, -4.3750e-01, -1.8750e-01,  2.5000e+00,
         -1.8750e-01, -1.0000e+00, -2.1875e-01,  2.5000e+00, -1.0000e+00,
          5.0000e-01,  1.2500e+00, -1.0938e-01,  1.0000e+00],
        [-2.1875e-01, -2.1875e-01,  1.0000e+00, -3