# Text Classification API (preview)

## What is text classification?

Text classification as the name implies is the process of applying labels or categories to text.

Common use cases include:

- Categorizing e-mail as spam or not spam
- Analyzing sentiment as positive or negative from customer reviews
- Applying labels to support tickets

## Solving text classification with machine learning

Classification is a common problem in machine learning. There are a variety of algorithms you can use to train a classification model. Text classification is a subcategory of classification which deals specifically with raw text. Text poses interesting challenges because you have to account for the context and semantics in which the text occurs. As such, encoding meaning and context can be difficult. In recent years, deep learning models have emerged as a promising technique to solve natural language problems. More specifically, a type of neural network known as transformers has become the predominant way of solving natural language problems like text classification, translation, summarization, and question answering.

Transformers were introduced in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762).  Some popular transformer architectures for natural language tasks include:

- Bidirectional Encoder Representations from Transformers (BERT)
- Robustly Optimized BERT Pretraining Approach (RoBERTa)
- Generative Pre-trained Transformer 2 (GPT-2)
- Generative Pre-trained Transformer 3 (GPT-3)

At a high level, transformers are a model architecture consisting of encoding and decoding layers. The encoder takes raw text as input and maps the input to a numerical representation (including context) to produce features. The decoder uses information from the encoder to produce output such as a category or label in the case of text classification. What makes these layers so special is the concept of attention. Attention is the idea of focusing on specific parts of an input based on the importance of their context in relation to other inputs in a sequence. For example, let's say I'm categorizing news articles based on the headline. Not all words in the headline are relevant. In a headline like "Auto sales are at an all-time high", a word like "sales" might get more attention and lead to labeling the article as business or finance.  

Like most neural networks, training transformers from scratch can be expensive because they require large amounts of data and compute. However, you don't always have to train from scratch. Using a technique known as fine-tuning you can take a pre-trained model and retrain the layers specific to your domain or problem using your own data. This gives you the benefit of having a model that's more tailored to solve your problem without having to go through the process of training the entire model from scratch.  

## The Text Classification API (preview)

Now that you have a general overview of how text classification problems can be solved using deep learning, let's take a look at how we've incorporated many of these techniques into the Text Classification API.

The Text Classification API is powered by [TorchSharp](https://github.com/dotnet/TorchSharp). TorchSharp is a .NET library that provides access to libtorch, the library that powers PyTorch. TorchSharp contains the building blocks for training neural networks from scratch in .NET. The TorchSharp components however are low-level and building neural networks from scratch has a steep learning curve. In ML.NET, we've abstracted some of that complexity to the scenario level.

## Install packages

To use the Text Classification API, you'll have to install the following packages

- [`Microsoft.ML`](https://www.nuget.org/packages/Microsoft.ML/)
- [`Microsoft.ML.TorchSharp`](https://www.nuget.org/packages/Microsoft.ML.TorchSharp/)
- [`TorchSharp-cpu`](https://www.nuget.org/packages/TorchSharp-cpu/) if you're using a CPU or [`TorchSharp-cuda-windows`](https://www.nuget.org/packages/TorchSharp-cuda-windows/) / [`TorchSharp-cuda-linux`](https://www.nuget.org/packages/TorchSharp-cuda-linux/) if you're using a GPU.

To enable GPU support, you'll also have to install the CUDA dependencies. For more information, see the [GPU support guide](https://docs.microsoft.com/dotnet/machine-learning/how-to-guides/install-gpu-model-builder#install-dependencies).

In [1]:
#i "nuget:https://pkgs.dev.azure.com/dnceng/public/_packaging/MachineLearning/nuget/v3/index.json"

#r "nuget:Microsoft.ML,2.0.0-preview.22324.1"
#r "nuget:Microsoft.ML.TorchSharp,0.20.0-preview.22324.1"
#r "nuget:TorchSharp-cpu,0.96.7"
#r "nuget:Microsoft.Data.Analysis,0.20.0-preview.22324.1"

Loading extensions from `Microsoft.Data.Analysis.Interactive.dll`

## Add using statements

In [1]:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.Data.Analysis;
using Microsoft.ML.TorchSharp;

## Initialize MLContext

All ML.NET operations start in the MLContext class. Initializing mlContext creates a new ML.NET environment that can be shared across the model creation workflow objects. It's similar, conceptually, to DBContext in Entity Framework.

In [1]:
var mlContext = new MLContext();

### Download or Locate Data
The following code tries to locate the data file in a few known locations or it will download it from the known GitHub location.

In [1]:
using System;
using System.IO;
using System.Net;

string EnsureDataSetDownloaded(string fileName)
{

	// This is the path if the repo has been checked out.
	var filePath = Path.Combine(Directory.GetCurrentDirectory(),"data", fileName);

	if (!File.Exists(filePath))
	{
		// This is the path if the file has already been downloaded.
		filePath = Path.Combine(Directory.GetCurrentDirectory(), fileName);
	}

	if (!File.Exists(filePath))
	{
		using (var client = new WebClient())
		{
			client.DownloadFile($"https://raw.githubusercontent.com/dotnet/csharp-notebooks/main/machine-learning/data/{fileName}", filePath);
		}
		Console.WriteLine($"Downloaded {fileName}  to : {filePath}");
	}
	else
	{
		Console.WriteLine($"{fileName} found here: {filePath}");
	}

	return filePath;
}

In [1]:
var yelp_reviews = EnsureDataSetDownloaded("yelp_labelled.txt");
var columnNames = new [] {"Text", "Sentiment"};
var df = DataFrame.LoadCsvFromString(yelp_reviews, separator:'\t',header:false, columnNames:columnNames);

yelp_labelled.txt found here: C:\dev\csharp-notebooks\machine-learning\data\yelp_labelled.txt


Once the data is loaded, use the `Head` method to preview the first three rows.

In [1]:
df.Head(3)

index,Text,Sentiment
0,Wow... Loved this place.,1
1,Crust is not good.,0
2,Not tasty and the texture was just nasty.,0


> The datasets this tutorial uses a dataset from the 'From Group to Individual Labels using Deep Features', Kotzias et al,. KDD 2015, and hosted at the UCI Machine Learning Repository - Dua, D. and Karra Taniskidou, E. (2017). [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml). Irvine, CA: University of California, School of Information and Computer Science.

The dataset contains two columns:

- **Text:** The raw review text from Yelp
- **Sentiment:** A binary value to represent the sentiment of the review. 0 is negative and 1 is positive. 

## Split the data into train and test sets. 

The original dataset is split into two subsets: train and test. The train set is what you'll use to learn the patterns of your data. The test set is used to evaluate the performance of your model using evaluation metrics for the classification task.

In this case, 80% of the data is used for training as defined by the `testFraction` parameter. The remaining 20% is used for evaluation and testing.

In [1]:
var trainTestSplit = mlContext.Data.TrainTestSplit(df, testFraction:0.2);

## Define your training pipeline

The Text Classification API is part of the multiclass classification catalog. To use it, add the `TextClassification` trainer to your pipeline. 

In [1]:
var pipeline =
		mlContext.Transforms.Conversion.MapValueToKey("Label","Sentiment")
			.Append(mlContext.MulticlassClassification.Trainers.TextClassification(sentence1ColumnName: "Text"))
			.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

## Train the model

Use the training dataset to train your model using the `Fit` method.

In [1]:
var model = pipeline.Fit(trainTestSplit.TrainSet);

## Use the model to make predictions

Use your model to make predictions on the test dataset by calling the `Transform` method. 

In [1]:
var predictionIDV = model.Transform(trainTestSplit.TestSet);

The result of calling `Transform` is an `IDataView` with your predicted values. To make it easier to view your predictions, convert the `IDataView` to an `IDataFrame` . In this case, the only columns that I'm interested in are the Text, Sentiment (actual value), and PredictedLabel (predicted value). 

In [1]:
var columnsToSelect = new [] {"Text", "Sentiment", "PredictedLabel"};

var predictions = predictionIDV.ToDataFrame(columnsToSelect);

Use the `Tail` method to preview the last three rows in your prediction `DataFrame`.

In [1]:
predictions.Tail(3)

index,Text,Sentiment,PredictedLabel
0,"Oh this is such a thing of beauty, this restaurant.",1,0
1,"A greasy, unhealthy meal.",0,1
2,"The best place in Vegas for breakfast (just check out a Sat, or Sun.",1,1


## Evaluate the model

There's a variety of metrics you can use to evaluate how well your model performs.  Use the [Evaluate](https://docs.microsoft.com/dotnet/api/microsoft.ml.multiclassclassificationcatalog.evaluate?view=ml-dotnet) method to calculate the evaluation metrics for your model using the predictions `IDataView`.

In [1]:
var evaluationMetrics = 
	mlContext
		.MulticlassClassification
		.Evaluate(predictionIDV);

Then, display the evaluation metrics. For more information on multiclass classification evaluation metrics, see the [ML.NET evaluation metrics guide](https://docs.microsoft.com/dotnet/machine-learning/resources/metrics#evaluation-metrics-for-multi-class-classification).

In [1]:
evaluationMetrics

LogLoss,LogLossReduction,MacroAccuracy,MicroAccuracy,TopKAccuracy,TopKPredictionCount,TopKAccuracyForAllK,PerClassLogLoss,ConfusionMatrix
10.53512863047496,-14.199291365827746,0.6737016700983757,0.6737967914438503,0,0,<null>,"[ 10.940300196581468, 10.134267400178105 ]","{ Microsoft.ML.Data.ConfusionMatrix: PerClassPrecision: [ 0.6777777777777778, 0.6701030927835051 ], PerClassRecall: [ 0.6559139784946236, 0.6914893617021277 ], Counts: [ [ 61, 32 ], [ 29, 65 ] ], NumberOfClasses: 2 }"
