# Introduction

This notebook shows how to use roboquant in combination with machine learning based strategies. For this particulair notebook XGBoost is used. 

The strategy will predict the next day return based on a configurable number of previous day returns. When the predicted next day return is above or below a certain configurable percentage, a BUY or SELL signal will be generated.

In [None]:
%use roboquant(1.6.0)
Welcome()

In [None]:
// Load the XGBoost library and import the package 
@file:DependsOn("ml.dmlc:xgboost4j_2.12:1.7.1")
import ml.dmlc.xgboost4j.java.*

In [None]:
/** 
 * Small utility class to store daily return as a list of floats
 */
class ReturnHistory(private val data: MutableList<Float> = mutableListOf()) : MutableList<Float> by data  {

    private var previousPrice = Float.NaN

    fun add(priceAction: PriceAction) {
        val newPrice = priceAction.getPrice("CLOSE").toFloat()
        if (previousPrice.isFinite()) {
            val returns = (newPrice - previousPrice) / previousPrice
            data.add(returns)
        }
        previousPrice = newPrice
    }

    override fun clear() {
        data.clear()
        previousPrice = Float.NaN
    }

}


In [None]:
/**
 * Example XGBoost based strategy that predicts the next day return based on previous day returns
 *
 * @param asset the asset to use
 * @property windowSize the number of previous returns to take into account, default is 10
 * @property minPercentage the minimum predicted change required to generate a [Signal], default is 1%
 */
class XGBoostStrategy(
    asset: Asset,
    private val windowSize: Int = 10,
    private val minPercentage: Double = 0.01) : SingleAssetStrategy(asset)
{
    private val data = ReturnHistory()
    var model: Booster? = null

    /**
     * @see SingleAssetStrategy.generate
     */
    override fun generate(priceAction: PriceAction, time: Instant): Signal? {

        if (model != null) {
            data.add(priceAction)
            if (data.size >= windowSize) {
                val predictedChange = predict()
                return when {
                    predictedChange > minPercentage -> Signal(asset, Rating.BUY)
                    predictedChange < -minPercentage -> Signal(asset, Rating.SELL)
                    else -> null
                }
            }
        } else {
            data.add(priceAction)
        }
        return null
    }

    /**
     * At the start of the MAIN phase we delete the model and clear the data
     *
     * @param runPhase
     */
    override fun start(runPhase: RunPhase) {
        if (runPhase == RunPhase.MAIN) {
            data.clear()
            model = null
        }
    }

    /**
     * At the end of the MAIN phase we train a new model with the data we collected that then can be used
     * during the validation phase
     */
    override fun end(runPhase: RunPhase) {
        if (runPhase == RunPhase.MAIN) {
            train()
            data.clear()
        }
    }

    /**
     * Predict the next day return
     */
    private fun predict(): Float {
        val feature = data.takeLast(windowSize)
        val dMatrix = DMatrix(feature.toFloatArray(), 1, windowSize, Float.NaN)
        val result = model!!.predict(dMatrix)
        return result.last().last()
    }

    /**
     * Training method
     */
    private fun train() {
        val features = mutableListOf<Float>()
        val labels = mutableListOf<Float>()
        val max = data.lastIndex - windowSize - 1
        for (x in 0..max) {
            val feature = data.subList(x, x+windowSize)
            features.addAll(feature)
            val label = data[x + windowSize]
            labels.add(label)
        }

        val trainData = DMatrix(features.toFloatArray(), labels.size, windowSize, Float.NaN)
        trainData.label = labels.toFloatArray()

        val params: Map<String, Any> = mapOf(
            "eta" to 0.3,
            "max_depth" to 6,
            "objective" to "reg:squarederror",
            "eval_metric" to "rmse"
        )

        model = XGBoost.train(trainData, params, 1000, mapOf<String, DMatrix>(), null, null)
    }
}


In [None]:
val feed = AvroFeed.sp500()
val asset = feed.assets.getBySymbol("TSLA")

In [None]:
val strategy =  XGBoostStrategy(asset, 20)
val roboquant = Roboquant(strategy, AccountMetric())

// Run the back test
val (train, validation) = feed.timeframe.splitTrainTest(0.3)
roboquant.run(feed, train, validation)

In [None]:
roboquant.broker.account.fullSummary()

In [None]:
val equity = roboquant.logger.getMetric("account.equity").filter { it.info.phase == RunPhase.VALIDATE }
TimeSeriesChart(equity)