Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions pytorch/edgeml_pytorch/graph/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,9 @@ class FastGRNNCUDACell(RNNCell):

'''
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, 1, 1, 2, wRank, uRank)
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"):
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity,
1, 1, 2, wRank, uRank, wSparsity, uSparsity)
if utils.findCUDA() is None:
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
NON_LINEARITY = {"sigmoid": 0, "relu": 1, "tanh": 2}
Expand Down Expand Up @@ -1166,6 +1167,54 @@ def getVars(self):
Vars.extend([self.bias_gate, self.bias_update, self.zeta, self.nu])
return Vars

def get_model_size(self):
'''
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_W_matrices
endU = endW + self._num_U_matrices

totalnnz = 2 # For Zeta and Nu
for i in range(0, endW):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), self._wSparsity)
mats[i].to(device)
for i in range(endW, endU):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), self._uSparsity)
mats[i].to(device)
for i in range(endU, len(mats)):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), False)
mats[i].to(device)
return totalnnz * 4

def copy_previous_UW(self):
mats = self.getVars()
num_mats = self._num_W_matrices + self._num_U_matrices
if len(self.oldmats) != num_mats:
for i in range(num_mats):
self.oldmats.append(torch.FloatTensor())
for i in range(num_mats):
self.oldmats[i] = torch.FloatTensor(mats[i].detach().clone().to(mats[i].device))

def sparsify(self):
mats = self.getVars()
endW = self._num_W_matrices
endU = endW + self._num_U_matrices
for i in range(0, endW):
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
for i in range(endW, endU):
mats[i] = utils.hardThreshold(mats[i], self._uSparsity)
self.copy_previous_UW()

def sparsifyWithSupport(self):
mats = self.getVars()
endU = self._num_W_matrices + self._num_U_matrices
for i in range(0, endU):
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])

class SRNN2(nn.Module):

def __init__(self, inputDim, outputDim, hiddenDim0, hiddenDim1, cellType,
Expand Down