Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public SymSgdClassificationTrainer(IHostEnvironment env,
_args.FeatureColumn = featureColumn;
_args.LabelColumn = labelColumn;

Info = new TrainerInfo(supportIncrementalTrain:true);
Info = new TrainerInfo(supportIncrementalTrain: true);
}

/// <summary>
Expand Down Expand Up @@ -201,7 +201,7 @@ protected override BinaryPredictionTransformer<TPredictor> MakeTransformer(TPred
=> new BinaryPredictionTransformer<TPredictor>(Host, model, trainSchema, FeatureColumn.Name);

public BinaryPredictionTransformer<TPredictor> Train(IDataView trainData, TPredictor initialPredictor = null)
=> TrainTransformer(trainData, initPredictor: initialPredictor);
=> TrainTransformer(trainData, initPredictor: initialPredictor);

protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
Expand Down Expand Up @@ -690,7 +690,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame
entry => entry.SetProgress(0, state.PassIteration, _args.NumberOfIterations));
// If fully loaded, call the SymSGDNative and do not come back until learned for all iterations.
Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
_args.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, stateGCHandle);
_args.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize,
stateGCHandle, ch.Info);
shouldInitialize = false;
}
else
Expand All @@ -711,7 +712,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame
numPassesForThisBatch = Math.Max(1, numPassesForThisBatch);
state.PassIteration = iter;
Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize, stateGCHandle);
numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _args.Tolerance, _args.Shuffle, shouldInitialize,
stateGCHandle, ch.Info);
shouldInitialize = false;

// Check if we are done with going through the data
Expand Down Expand Up @@ -760,10 +762,14 @@ private static unsafe class Native

internal const string NativePath = "SymSgdNative";
internal const string MklPath = "MklImports";

public delegate void ChannelCallBack(string message);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern void LearnAll(int totalNumInstances, int* instSizes, int** instIndices,
float** instValues, float* labels, bool tuneLR, ref float lr, float l2Const, float piw, float* weightVector, ref float bias,
int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, State* state);
int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize,
State* state, ChannelCallBack info);

/// <summary>
/// This method puts all of the buffered instances in array of pointers to pass it to SymSGDNative.
Expand All @@ -784,9 +790,10 @@ private static extern void LearnAll(int totalNumInstances, int* instSizes, int**
/// <param name="needShuffle">Specifies if data needs to be shuffled</param>
/// <param name="shouldInitialize">Specifies if this is the first time to run SymSGD</param>
/// <param name="stateGCHandle"></param>
/// <param name="info"></param>
public static void LearnAll(InputDataManager inputDataManager, bool tuneLR,
ref float lr, float l2Const, float piw, Span<float> weightVector, ref float bias, int numFeatres, int numPasses,
int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle)
int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle, ChannelCallBack info)
{
inputDataManager.PrepareCursoring();

Expand Down Expand Up @@ -827,7 +834,8 @@ public static void LearnAll(InputDataManager inputDataManager, bool tuneLR,
fixed (float* pInstLabels = &instLabels[0])
{
LearnAll(totalNumInstances, pInstSizes, pIndicesPointer, pValuesPointer, pInstLabels, tuneLR, ref lr, l2Const, piw,
pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle, shouldInitialize, (State*)stateGCHandle.AddrOfPinnedObject());
pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle,
shouldInitialize, (State*)stateGCHandle.AddrOfPinnedObject(), info);
}
}

Expand Down
24 changes: 15 additions & 9 deletions src/Native/SymSgdNative/SymSgdNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <omp.h>
#endif
#include <unordered_map>
#include <string>
#include "../Stdafx.h"
#include "Macros.h"
#include "SparseBLAS.h"
Expand Down Expand Up @@ -177,7 +178,7 @@ float MaxPossibleAlpha(float alpha, float l2Const, int totalNumInstances)
}

void TuneAlpha(float& alpha, float l2Const, int totalNumInstances, int* instSizes, int** instIndices,
float** instValues, int numFeat, int numThreads)
float** instValues, int numFeat, int numThreads, ChannelFunc info_print_func = NULL)
{
alpha = 1e0f;
int logSqrtNumInst = (int) round(log10(sqrt(totalNumInstances)))-3;
Expand All @@ -198,8 +199,8 @@ void TuneAlpha(float& alpha, float l2Const, int totalNumInstances, int* instSize
// alpha < (1-10^(-6/totalNumInstances))/l2Const.
alpha = MIN(alpha, MaxPossibleAlpha(alpha, l2Const, totalNumInstances));
}

printf("Initial learning rate is tuned to %f\n", alpha);
if (info_print_func != NULL)
info_print_func(("Initial learning rate is tuned to " + std::to_string(alpha)).c_str());
}


Expand All @@ -221,17 +222,18 @@ void TuneNumLocIter(int& numLocIter, int totalNumInstances, int* instSizes, int
// required memory for SymSGD learners.
void InitializeState(int totalNumInstances, int* instSizes, int** instIndices, float** instValues,
int numFeat, bool tuneNumLocIter, int& numLocIter, int numThreads, bool tuneAlpha, float& alpha,
float l2Const, SymSGDState* state)
float l2Const, SymSGDState* state, ChannelFunc info_print_func = NULL)
{
if (tuneAlpha)
{
TuneAlpha(alpha, l2Const, totalNumInstances, instSizes, instIndices, instValues, numFeat, numThreads);
TuneAlpha(alpha, l2Const, totalNumInstances, instSizes, instIndices, instValues, numFeat, numThreads, info_print_func);
} else
{
// Check if user alpha is too large because of l2Const. Check the comment about positive l2Const in TuneAlpha.
float maxPossibleAlpha = MaxPossibleAlpha(alpha, l2Const, totalNumInstances);
if (alpha > maxPossibleAlpha)
printf("Warning: learning rate is too high! Try using a value < %e instead\n", maxPossibleAlpha);
if (info_print_func != NULL)
info_print_func(("Warning: learning rate is too high! Try using a value < " + std::to_string(maxPossibleAlpha) + " instead").c_str());
}

if (tuneNumLocIter)
Expand All @@ -246,7 +248,11 @@ void InitializeState(int totalNumInstances, int* instSizes, int** instIndices, f
state->TotalInstancesProcessed = 0;
ComputeRemapping(totalNumInstances, instSizes, instIndices, numFeat,
numLocIter, numThreads, state, state->NumFrequentFeatures);
printf("Number of frequent features: %d\nNumber of features: %d\n", state->NumFrequentFeatures, numFeat);
if (info_print_func != NULL)
{
info_print_func(("Number of frequent features: " + std::to_string(state->NumFrequentFeatures)).c_str());
info_print_func(("Number of features: " + std::to_string(numFeat)).c_str());
}

state->NumLearners = numThreads;
state->Learners = new SymSGD*[numThreads];
Expand Down Expand Up @@ -288,11 +294,11 @@ float Loss(int instSize, int* instIndices, float* instValues,
EXPORT_API(void) LearnAll(int totalNumInstances, int* instSizes, int** instIndices, float** instValues,
float* labels, bool tuneAlpha, float& alpha, float l2Const, float piw, float* weightVector,
float& bias, int numFeat, int numPasses, int numThreads, bool tuneNumLocIter, int& numLocIter, float tolerance,
bool needShuffle, bool shouldInitialize, SymSGDState* state)
bool needShuffle, bool shouldInitialize, SymSGDState* state, ChannelFunc info_print_func = NULL)
{
// If this is the first time LearnAll is called, initialize it.
if (shouldInitialize)
InitializeState(totalNumInstances, instSizes, instIndices, instValues, numFeat, tuneNumLocIter, numLocIter, numThreads, tuneAlpha, alpha, l2Const, state);
InitializeState(totalNumInstances, instSizes, instIndices, instValues, numFeat, tuneNumLocIter, numLocIter, numThreads, tuneAlpha, alpha, l2Const, state, info_print_func);
float& weightScaling = state->WeightScaling;

float totalAverageLoss = 0.0f; // Reserved for total loss computation
Expand Down
5 changes: 4 additions & 1 deletion src/Native/SymSgdNative/SymSgdNative.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,7 @@ struct SymSGDState
int NumFrequentFeatures;
int PassIteration;
float WeightScaling;
};
};

// Logging call back signature for ML.NET
typedef void(*ChannelFunc)(const char*);
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Not training a calibrator because it is not needed.
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Not training a calibrator because it is not needed.
Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable.
TEST POSITIVE RATIO: 0.3785 (134.0/(134.0+220.0))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% out=%Output% seed=1
Not adding a normalizer.
Data fully loaded into memory.
Initial learning rate is tuned to 100.000000
Not training a calibrator because it is not needed.
Warning: The predictor produced non-finite prediction values on 16 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable.
TEST POSITIVE RATIO: 0.3499 (239.0/(239.0+444.0))
Expand Down

This file was deleted.

This file was deleted.

Loading