Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,251 +7,20 @@
using System.Linq;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.Legacy.Data;
using Microsoft.ML.Legacy.Trainers;
using Microsoft.ML.Legacy.Transforms;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.EntryPoints.Tests
{
#pragma warning disable 612
public class CollectionDataSourceTests : BaseTestClass
public class CollectionsDataViewTest : BaseTestClass
{
public CollectionDataSourceTests(ITestOutputHelper output)
public CollectionsDataViewTest(ITestOutputHelper output)
: base(output)
{
}

[Fact]
public void CheckConstructor()
{
Assert.NotNull(CollectionDataSource.Create(new List<Input>() { new Input { Number1 = 1, String1 = "1" } }));
Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } }));
Assert.NotNull(CollectionDataSource.Create(new Input[1] { new Input { Number1 = 1, String1 = "1" } }.AsEnumerable()));

bool thrown = false;
try
{
CollectionDataSource.Create(new List<Input>());
}
catch
{
thrown = true;
}
Assert.True(thrown);

thrown = false;
try
{
CollectionDataSource.Create(new Input[0]);
}
catch
{
thrown = true;
}
Assert.True(thrown);
}

[Fact]
public void CanSuccessfullyApplyATransform()
{
var collection = CollectionDataSource.Create(new List<Input>() { new Input { Number1 = 1, String1 = "1" } });
var environment = new MLContext();
Experiment experiment = environment.CreateExperiment();
Legacy.ILearningPipelineDataStep output = (Legacy.ILearningPipelineDataStep)collection.ApplyStep(null, experiment);

Assert.NotNull(output.Data);
Assert.NotNull(output.Data.VarName);
Assert.Null(output.Model);
}

[Fact]
public void CanSuccessfullyEnumerated()
{
var collection = CollectionDataSource.Create(new List<Input>() {
new Input { Number1 = 1, String1 = "1" },
new Input { Number1 = 2, String1 = "2" },
new Input { Number1 = 3, String1 = "3" }
});

var environment = new MLContext();
Experiment experiment = environment.CreateExperiment();
Legacy.ILearningPipelineDataStep output = collection.ApplyStep(null, experiment) as Legacy.ILearningPipelineDataStep;

experiment.Compile();
collection.SetInput(environment, experiment);
experiment.Run();

IDataView data = experiment.GetOutput(output.Data);
Assert.NotNull(data);

using (var cursor = data.GetRowCursor((a => true)))
{
var IDGetter = cursor.GetGetter<float>(0);
var TextGetter = cursor.GetGetter<ReadOnlyMemory<char>>(1);

Assert.True(cursor.MoveNext());

float ID = 0;
IDGetter(ref ID);
Assert.Equal(1, ID);

ReadOnlyMemory<char> Text = new ReadOnlyMemory<char>();
TextGetter(ref Text);
Assert.Equal("1", Text.ToString());

Assert.True(cursor.MoveNext());

ID = 0;
IDGetter(ref ID);
Assert.Equal(2, ID);

Text = new ReadOnlyMemory<char>();
TextGetter(ref Text);
Assert.Equal("2", Text.ToString());

Assert.True(cursor.MoveNext());

ID = 0;
IDGetter(ref ID);
Assert.Equal(3, ID);

Text = new ReadOnlyMemory<char>();
TextGetter(ref Text);
Assert.Equal("3", Text.ToString());

Assert.False(cursor.MoveNext());
}
}

[Fact]
public void CanTrain()
{
var pipeline = new Legacy.LearningPipeline();
var data = new List<IrisData>() {
new IrisData { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisData { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisData { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0}
};
var collection = CollectionDataSource.Create(data);

pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
var model = pipeline.Train<IrisData, IrisPrediction>();

IrisPrediction prediction = model.Predict(new IrisData()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});

pipeline = new Legacy.LearningPipeline();
collection = CollectionDataSource.Create(data.AsEnumerable());
pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
model = pipeline.Train<IrisData, IrisPrediction>();

prediction = model.Predict(new IrisData()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});

}

[Fact]
public void CanTrainProperties()
{
var pipeline = new Legacy.LearningPipeline();
var data = new List<IrisData>() {
new IrisData { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisData { SepalLength = 1f, SepalWidth = 1f, PetalLength=0.3f, PetalWidth=5.1f, Label=1},
new IrisData { SepalLength = 1.2f, SepalWidth = 0.5f, PetalLength=0.3f, PetalWidth=5.1f, Label=0}
};
var collection = CollectionDataSource.Create(data);

pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
var model = pipeline.Train<IrisData, IrisPredictionProperties>();

IrisPredictionProperties prediction = model.Predict(new IrisData
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});

pipeline = new Legacy.LearningPipeline();
collection = CollectionDataSource.Create(data.AsEnumerable());
pipeline.Add(collection);
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
model = pipeline.Train<IrisData, IrisPredictionProperties>();

prediction = model.Predict(new IrisData
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth = 5.1f,
});

}

public class Input
{
[LoadColumn(0)]
public float Number1;

[LoadColumn(1)]
public string String1;
}

public class IrisData
{
[LoadColumn(0)]
public float Label;

[LoadColumn(1)]
public float SepalLength;

[LoadColumn(2)]
public float SepalWidth;

[LoadColumn(3)]
public float PetalLength;

[LoadColumn(4)]
public float PetalWidth;
}

public class IrisPrediction
{
[ColumnName("Score")]
public float[] PredictedLabels;
}

public class IrisPredictionProperties
{
private float[] _PredictedLabels;
[ColumnName("Score")]
public float[] PredictedLabels { get { return _PredictedLabels; } set { _PredictedLabels = value; } }
}

public class ConversionSimpleClass
{
public int fInt;
Expand Down