From 520c10fcfd7082168f60c6bb9047823fcc439a26 Mon Sep 17 00:00:00 2001 From: Moksh Jain Date: Thu, 3 Oct 2019 21:40:37 +0530 Subject: [PATCH] fastgrnncuda: add sparsify support --- pytorch/edgeml_pytorch/graph/rnn.py | 53 +++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/pytorch/edgeml_pytorch/graph/rnn.py b/pytorch/edgeml_pytorch/graph/rnn.py index 08d521f0f..75dc999a7 100644 --- a/pytorch/edgeml_pytorch/graph/rnn.py +++ b/pytorch/edgeml_pytorch/graph/rnn.py @@ -319,8 +319,9 @@ class FastGRNNCUDACell(RNNCell): ''' def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", - update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"): - super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, 1, 1, 2, wRank, uRank) + update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"): + super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, + 1, 1, 2, wRank, uRank, wSparsity, uSparsity) if utils.findCUDA() is None: raise Exception('FastGRNNCUDA is supported only on GPU devices.') NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2} @@ -1166,6 +1167,54 @@ def getVars(self): Vars.extend([self.bias_gate, self.bias_update, self.zeta, self.nu]) return Vars + def get_model_size(self): + ''' + Function to get aimed model size + ''' + mats = self.getVars() + endW = self._num_W_matrices + endU = endW + self._num_U_matrices + + totalnnz = 2 # For Zeta and Nu + for i in range(0, endW): + device = mats[i].device + totalnnz += utils.countNNZ(mats[i].cpu(), self._wSparsity) + mats[i].to(device) + for i in range(endW, endU): + device = mats[i].device + totalnnz += utils.countNNZ(mats[i].cpu(), self._uSparsity) + mats[i].to(device) + for i in range(endU, len(mats)): + device = mats[i].device + totalnnz += utils.countNNZ(mats[i].cpu(), False) + mats[i].to(device) + return totalnnz * 4 + + def copy_previous_UW(self): + mats = self.getVars() + num_mats = self._num_W_matrices + self._num_U_matrices + if len(self.oldmats) != num_mats: + for i in range(num_mats): + self.oldmats.append(torch.FloatTensor()) + for i in range(num_mats): + self.oldmats[i] = torch.FloatTensor(mats[i].detach().clone().to(mats[i].device)) + + def sparsify(self): + mats = self.getVars() + endW = self._num_W_matrices + endU = endW + self._num_U_matrices + for i in range(0, endW): + mats[i] = utils.hardThreshold(mats[i], self._wSparsity) + for i in range(endW, endU): + mats[i] = utils.hardThreshold(mats[i], self._uSparsity) + self.copy_previous_UW() + + def sparsifyWithSupport(self): + mats = self.getVars() + endU = self._num_W_matrices + self._num_U_matrices + for i in range(0, endU): + mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i]) + class SRNN2(nn.Module): def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,