# Multi Class Classification with SdcaMaximumEntropy

ML.Net Trainer : https://learn.microsoft.com/en-us/dotnet/machine-learning/resources/tasks#multiclass-classification

Dataset: https://www.kaggle.com/datasets/rabieelkharoua/predict-online-gaming-behavior-dataset/data

https://www.evidentlyai.com/classification-metrics/multi-class-metrics#macro-vs-micro-average

In [None]:
#r "nuget: Microsoft.ML, 5.0.0"
#r "nuget: Microsoft.ML.FastTree, 5.0.0"
#r "nuget: Microsoft.ML.LightGbm, 5.0.0"


using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;

In [34]:
// Input and Output Models

public class PlayerEngagementData
{
    [LoadColumn(0)] public string PlayerID { get; set; }
    [LoadColumn(1)] public float Age { get; set; }
    [LoadColumn(2)] public string Gender { get; set; }
    [LoadColumn(3)] public string Location { get; set; }
    [LoadColumn(4)] public string GameGenre { get; set; }
    [LoadColumn(5)] public float PlayTimeHours { get; set; }
    [LoadColumn(6)] public float InGamePurchases { get; set; }
    [LoadColumn(7)] public string GameDifficulty { get; set; }
    [LoadColumn(8)] public float SessionsPerWeek { get; set; }
    [LoadColumn(9)] public float AvgSessionDurationMinutes { get; set; }
    [LoadColumn(10)] public float PlayerLevel { get; set; }
    [LoadColumn(11)] public float AchievementsUnlocked { get; set; }
    [LoadColumn(12)] public string EngagementLevel { get; set; } // Label
}

public class PlayerEngagementPrediction
{
    [ColumnName("PredictedLabel")]
    public string EngagementLevel { get; set; }

    public float[] Score { get; set; }
}

In [35]:
//load the data from csv file
MLContext mlContext = new MLContext();
IDataView data = mlContext.Data.LoadFromTextFile<PlayerEngagementData>(path: "online_gaming_behavior_dataset.csv", hasHeader: true, separatorChar: ',');

In [36]:
// Prepare the Data Pipeline
var pipeline = mlContext.Transforms.DropColumns(nameof(PlayerEngagementData.PlayerID))
//Conversion of string label to key type
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label", nameof(PlayerEngagementData.EngagementLevel)))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("GenderEncoded", "Gender"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("LocationEncoded", "Location"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("GameGenreEncoded", "GameGenre"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("GameDifficultyEncoded", "GameDifficulty"))
.Append(mlContext.Transforms.Concatenate("Features", 
"Age","PlayTimeHours","InGamePurchases","SessionsPerWeek","AvgSessionDurationMinutes","PlayerLevel", "AchievementsUnlocked",
"GenderEncoded","LocationEncoded", "GameGenreEncoded", "GameDifficultyEncoded"))
//add Trainer
.Append(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "Features"))
//Convert back the predicted label key to original string value
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

In [37]:
// Split and train the model
var split = mlContext.Data.TrainTestSplit(data, testFraction: 0.2);
// Train the model
var model=pipeline.Fit(split.TrainSet);
// Test the model
var prediction=model.Transform(split.TestSet);

In [38]:
// Evualuate the model
var metrics=mlContext.MulticlassClassification.Evaluate(prediction,labelColumnName:"Label",scoreColumnName:"Score");

In [39]:
void PrintMetrics(MulticlassClassificationMetrics metrics)
    {
        Console.WriteLine("===== MULTICLASS METRICS =====");
        Console.WriteLine($"MicroAccuracy: {metrics.MicroAccuracy:0.###}");
        Console.WriteLine($"MacroAccuracy: {metrics.MacroAccuracy:0.###}");
        Console.WriteLine($"LogLoss: {metrics.LogLoss:0.###}");
        Console.WriteLine($"LogLossReduction: {metrics.LogLossReduction:0.###}");
        Console.WriteLine();


        var matrix = metrics.ConfusionMatrix;
        Console.WriteLine("===== CONFUSION MATRIX =====");


        for (int i = 0; i < matrix.NumberOfClasses; i++)
        {
            for (int j = 0; j < matrix.NumberOfClasses; j++)
            {
                Console.Write($"{matrix.Counts[i][j],8}");
            }
            Console.WriteLine();
        }
    }

In [40]:
// Print the metrics
PrintMetrics(metrics);

===== MULTICLASS METRICS =====
MicroAccuracy: 0.915
MacroAccuracy: 0.905
LogLoss: 0.348
LogLossReduction: 0.668

===== CONFUSION MATRIX =====
    3685      79     112
     198    1783      59
     175      58    1885
