In [1]:
import statistics
import torch
import timeit
import argparse

import spatha
import sten

from grouped_nmv_tensor import VenomTensor, venom_mask_sparsify
from torch.profiler import profile, record_function, ProfilerActivity

  return self.fget.__get__(instance, owner)()


In [2]:
v = 128
n = 2
m = 8
torch.set_grad_enabled(False)

sparse_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-large-uncased')
masked_sparse_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-large-uncased')
dense_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-large-uncased').to(device='cuda:0').half()

Using cache found in /home/roberto.lopez/.cache/torch/hub/huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using cache found in /home/roberto.lopez/.cache/torch/hub/huggingface_pytorch-transformers_m

In [3]:
class VenomSparsifier:
    def __init__(self, n, m, v):
        self.n = n
        self.m = m
        self.v = v

    @staticmethod
    def get_random_mask(tensor, m, v):
        mask = torch.zeros(tensor.shape, dtype=tensor.dtype)
        m_tmp = torch.cat( (torch.tensor([1,0,1,0]), torch.zeros(m-4)), 0 )
        mask = mask.reshape(-1, v, m) + m_tmp
        mask = mask.reshape(tensor.shape)

        return mask

    def __call__(self, tensor, grad_fmt=None):
        # random pruning (cuSparseLt-like approach) -> mask, columns
        nrows, ncols = tensor.shape
        columns = torch.zeros(nrows//self.v, ncols//self.m*4, dtype=torch.int32)
        columns = columns.reshape((-1,4)) + torch.tensor([0,1,2,3], dtype=torch.int32)
        columns = columns.reshape((nrows//self.v, ncols//self.m*4))

        mask = VenomSparsifier.get_random_mask(tensor, self.m, self.v)

        sparse_mtx = sten.SparseTensorWrapper.wrapped_from_dense(
            VenomTensor(self.n, self.m, self.v, tensor, mask, columns, tensor.device),
            tensor,
            grad_fmt,
        )

        return sparse_mtx

In [4]:
def sparse_dense_mul_dispatch(sparse_values, sparse_indices, sparse_metadata, dense, nrows_sp, ncols_sp, ncols_d, m, n, v, nnz, bias):

    dense_ = dense.contiguous()

    output = spatha.spmm_128x64x32_32x64x32_16x8x32_2(
                          sparse_metadata.to(device='cuda:0'),  # metadata
                          sparse_indices.to(device='cuda:0'),   # indices
                          sparse_values.to(dtype=torch.half).to(device='cuda:0'),    # values
                          dense_.to(device='cuda:0'),           # rhs_matrix
                          bias.to(device='cuda:0'),             # bias
                          nrows_sp,         # A_num_rows
                          ncols_sp,         # A_num_cols
                          ncols_d,          # B_num_cols
                          v,                # V
                          n,                # N
                          m,                # M
                          nnz,              # nnz
                          0,                # seed
                          32,               # mbrow
                          4                 # brow
                          )

    return output

In [5]:
class VenomSpmm(torch.nn.Module):
    def __init__(self, original: torch.nn.Linear):
        super().__init__()
        self.bias = original.bias
        #self.bias = torch.zeros(original.bias.shape, dtype=original.bias.dtype, device=original.bias.device)

        # Convert weights from original module to SrNM
        w = VenomSparsifier(n, m, v)(original.weight).wrapped_tensor

        self.values = torch.nn.Parameter(w.values)
        self.columns = w.columns
        self.metadata = w.metadata

        self.nrows_sp = w.nrows
        self.ncols_sp = w.ncols
        self.nnz      = w.nnz

    def forward(self, input):

        flattened_input = torch.flatten(input, start_dim=0, end_dim=-2)

        ncols_d  = flattened_input.T.shape[1]
        DM, _    = flattened_input.shape
        
        bias2 = torch.zeros(self.bias.shape, dtype=self.bias.dtype, device=self.bias.device)

        output = sparse_dense_mul_dispatch( self.values, 
                                            self.columns, 
                                            self.metadata, 
                                            flattened_input.T, 
                                            self.nrows_sp, 
                                            self.ncols_sp,
                                            ncols_d, 
                                            m, 
                                            n, 
                                            v, 
                                            self.nnz, 
                                            self.bias)
        #print(output.shape)
        #print("bias", self.bias.shape, self.bias.dtype)
        #print(DM)
        
        """ if self.bias is not None:
            output += self.bias.unsqueeze(0).expand_as(output) """
        
        output = output.reshape((*input.shape[0:-1], -1))[..., :DM]
        #output = output.reshape((32,512,1024))
        
        return output

In [6]:
class VenomSpmmMasked(torch.nn.Module):
    def __init__(self, original: torch.nn.Linear):
        super().__init__()
        self.bias = original.bias

        # Convert weights from original module to SrNM
        w = VenomSparsifier(n, m, v)(original.weight).wrapped_tensor

        self.values = torch.nn.Parameter(w.values)
        self.columns = w.columns
        self.metadata = w.metadata

        self.nrows_sp = w.nrows
        self.ncols_sp = w.ncols
        self.nnz      = w.nnz

        #self.mask = w.masked
        self.dense = w.to_dense()

    def forward(self, input):

        flattened_input = torch.flatten(input, start_dim=0, end_dim=-2)

        ncols_d  = flattened_input.T.shape[1]
        DM, _    = flattened_input.shape

        dense_ = flattened_input.T.contiguous()

        #print(self.mask.shape, dense_.shape, input.shape)
        #output = self.mask@dense_

        #output = flattened_input@self.mask.T
        #output = (self.mask@flattened_input.T).T
        #output = (self.mask@flattened_input.T)
        output = input@self.dense.T
        
        #print(input.shape, flattened_input.shape, self.nrows_sp, self.ncols_sp, output.shape)

        #output = output.reshape((*input.shape[0:-1], -1))[..., :DM]
    
        if self.bias is not None:
            output += self.bias.unsqueeze(0).expand_as(output)

        return output

In [7]:
def linear_to_spmm(mod, weights_to_sparsify):
    if isinstance(mod, torch.nn.Linear):
        return VenomSpmm(mod)

    for name, m in mod.named_children():
        if isinstance(m, VenomSpmm):
            continue
        if isinstance(m, torch.nn.Linear):
            setattr(mod, name, VenomSpmm(m))
        elif m is not mod:
            linear_to_spmm(m, weights_to_sparsify)

    return mod

In [8]:
def linear_to_masked_spmm(mod, weights_to_sparsify):
    if isinstance(mod, torch.nn.Linear):
        return VenomSpmmMasked(mod)

    for name, m in mod.named_children():
        if isinstance(m, VenomSpmmMasked):
            continue
        if isinstance(m, torch.nn.Linear):
            setattr(mod, name, VenomSpmmMasked(m))
        elif m is not mod:
            linear_to_masked_spmm(m, weights_to_sparsify)

    return mod

In [9]:
def linear_to_masked(model):
    for module_name, module in model.named_modules():
        if (
                isinstance(module, torch.nn.modules.linear.Linear)
                and "encoder.layer" in module_name
            ):
            #print(module_name, module)
            #mask = VenomSparsifier.get_random_mask(module.weight, m, v).to(module.weight.device).to(module.weight.dtype)
            #module.weight = torch.nn.Parameter(module.weight*mask)
            w = VenomSparsifier(n, m, v)(module.weight).wrapped_tensor
            module.weight = torch.nn.Parameter(w.to_dense().to(dtype=torch.half))

In [10]:
weights_to_sparsify = [
        module
        for module_name, module in sparse_model.named_modules()
        if (
            isinstance(module, torch.nn.modules.linear.Linear)
            and "encoder.layer" in module_name
        )
    ]

In [11]:
input = torch.randint(low=0, high=100, size=(32, 512))#, dtype=torch.half)
input = input.to(device='cuda:0')

linear_to_masked(dense_model)

sparse_model = sparse_model.to(device='cuda:0').half()
sparse_model = linear_to_spmm(sparse_model, weights_to_sparsify)

masked_sparse_model = masked_sparse_model.to(device='cuda:0').half()
masked_sparse_model = linear_to_masked_spmm(masked_sparse_model, weights_to_sparsify)

In [12]:
#sp_output = sparse_model(input, output_hidden_states=True)
output = dense_model(input,  output_hidden_states=True)

In [13]:
dense_out = dense_model.encoder.layer[0].attention.self.query(output.hidden_states[0])

#print(dense_out)

In [14]:
sparse_out = sparse_model.encoder.layer[0].attention.self.query(output.hidden_states[0])

""" print("real output shape", sparse_out.shape)
print(sparse_out)
print(output.hidden_states[0].shape) """

' print("real output shape", sparse_out.shape)\nprint(sparse_out)\nprint(output.hidden_states[0].shape) '

In [15]:
sparse_masked_out = masked_sparse_model.encoder.layer[0].attention.self.query(output.hidden_states[0])

print("real output shape", sparse_masked_out.shape)
#print(sparse_masked_out)

real output shape torch.Size([32, 512, 1024])


In [16]:
torch.allclose( sparse_out, sparse_masked_out, atol=0.5)

True

In [17]:
print( dense_out.shape, sparse_out.shape )

torch.Size([32, 512, 1024]) torch.Size([32, 512, 1024])


In [18]:
torch.allclose(dense_out, sparse_masked_out, atol=0.005)

True

In [19]:
torch.allclose(dense_out, sparse_out, atol=0.005)

True

In [20]:
out_dense = dense_model(input)

In [21]:
out_sparse = sparse_model(input)

In [22]:
print( torch.allclose(out_dense[0], out_sparse[0], atol=0.05) )

True


: 