Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed optimizer problem for classifer, enabled both batch_first args #173

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/pytorch/FastCells/fastcell_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def main():
assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"

timeSteps = int(dataDimension / inputDims)
Xtrain = Xtrain.reshape((-1, timeSteps, inputDims))
Xtest = Xtest.reshape((-1, timeSteps, inputDims))

if not batch_first:
Xtrain = np.swapaxes(Xtrain, 0, 1)
Xtest = np.swapaxes(Xtest, 0, 1)

currDir = helpermethods.createTimeStampDir(dataDir, cell)

helpermethods.dumpCommand(sys.argv, currDir)
Expand Down
59 changes: 35 additions & 24 deletions pytorch/edgeml_pytorch/trainer/fastTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from edgeml_pytorch.graph.rnn import *
import numpy as np

class SimpleFC(nn.Module):
def __init__(self, input_size, num_classes, name="SimpleFC"):
super(SimpleFC, self).__init__()
self.FC = nn.Parameter(torch.randn([input_size, num_classes]))
self.FCbias = nn.Parameter(torch.randn([num_classes]))

def forward(self, input):
return torch.matmul(input, self.FC) + self.FCbias

class FastTrainer:

Expand Down Expand Up @@ -49,44 +57,43 @@ def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0,
self.assertInit()
self.numMatrices = self.FastObj.num_weight_matrices
self.totalMatrices = self.numMatrices[0] + self.numMatrices[1]

self.optimizer = self.optimizer()

self.RNN = BaseRNN(self.FastObj, batch_first=self.batch_first).to(self.device)

self.FC = nn.Parameter(torch.randn(
[self.FastObj.output_size, self.numClasses])).to(self.device)
self.FCbias = nn.Parameter(torch.randn(
[self.numClasses])).to(self.device)
self.simpleFC = SimpleFC(self.FastObj.output_size, self.numClasses).to(self.device)

self.FastParams = self.FastObj.getVars()
self.optimizer = self.optimizer()

def classifier(self, feats):
'''
Can be raplaced by any classifier
TODO: Make this a separate class if needed
'''
return torch.matmul(feats, self.FC) + self.FCbias
return self.simpleFC(feats)

def computeLogits(self, input):
'''
Compute graph to unroll and predict on the FastObj
'''
if self.FastObj.cellType == "LSTMLR":
feats, _ = self.RNN(input)
logits = self.classifier(feats[-1, :])
else:
feats = self.RNN(input)
logits = self.classifier(feats[-1, :])

return logits, feats[:, -1]
if self.batch_first:
logits = self.classifier(feats[:, -1])
return logits, feats[:, -1]
else:
logits = self.classifier(feats[-1, :])
return logits, feats[-1, :]

def optimizer(self):
'''
Optimizer for FastObj Params
'''
paramList = list(self.FastObj.parameters()) + list(self.simpleFC.parameters())
optimizer = torch.optim.Adam(
self.FastObj.parameters(), lr=self.learningRate)
paramList, lr=self.learningRate)

return optimizer

Expand Down Expand Up @@ -168,12 +175,12 @@ def getModelSize(self):
hasSparse = hasSparse or sparseFlag

# Replace this with classifier class call
nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0)
nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FC, 1.0)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag

nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0)
nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FCbias, 1.0)
totalnnZ += nnz
totalSize += size
hasSparse = hasSparse or sparseFlag
Expand Down Expand Up @@ -341,8 +348,8 @@ def saveParams(self, currDir):
np.save(os.path.join(currDir, "Bo.npy"),
self.FastParams[self.totalMatrices + 3].data.cpu())

np.save(os.path.join(currDir, "FC.npy"), self.FC.data.cpu())
np.save(os.path.join(currDir, "FCbias.npy"), self.FCbias.data.cpu())
np.save(os.path.join(currDir, "FC.npy"), self.simpleFC.FC.data.cpu())
np.save(os.path.join(currDir, "FCbias.npy"), self.simpleFC.FCbias.data.cpu())

def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
decayStep, decayRate, dataDir, currDir):
Expand All @@ -351,7 +358,13 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
'''
fileName = str(self.FastObj.cellType) + 'Results_pytorch.txt'
resultFile = open(os.path.join(dataDir, fileName), 'a+')
numIters = int(np.ceil(float(Xtrain.shape[0]) / float(batchSize)))
if self.batch_first:
self.timeSteps = Xtrain.shape[1]
self.numPoints = Xtrain.shape[0]
else:
self.timeSteps = Xtrain.shape[0]
self.numPoints = Xtrain.shape[1]
numIters = int(np.ceil(float(self.numPoints) / float(batchSize)))
totalBatches = numIters * totalEpochs

counter = 0
Expand All @@ -362,11 +375,6 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
ihtDone = 1
maxTestAcc = -10000
header = '*' * 20
self.timeSteps = int(Xtest.shape[1] / self.inputDims)
Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims))
Xtest = np.swapaxes(Xtest, 0, 1)
Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims))
Xtrain = np.swapaxes(Xtrain, 0, 1)

for i in range(0, totalEpochs):
print("\nEpoch Number: " + str(i), file=self.outFile)
Expand All @@ -376,7 +384,7 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learningRate

shuffled = list(range(Xtrain.shape[1]))
shuffled = list(range(self.numPoints))
np.random.shuffle(shuffled)
trainAcc = 0.0
trainLoss = 0.0
Expand All @@ -389,7 +397,10 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
(header, msg, header), file=self.outFile)

k = shuffled[j * batchSize:(j + 1) * batchSize]
batchX = Xtrain[:, k, :]
if self.batch_first:
batchX = Xtrain[k, :, :]
else:
batchX = Xtrain[:, k, :]
batchY = Ytrain[k]

self.optimizer.zero_grad()
Expand Down