In [1]:
%use kotlin-dl
%use krangl

In [2]:
import org.jetbrains.kotlinx.dl.api.core.history.EpochTrainingEvent
import org.jetbrains.kotlinx.dl.api.core.history.TrainingHistory

In [3]:
fun normalize(df: DataFrame): DataFrame {
    val normalizedCols = mutableListOf<DataCol>()
    for (col in df.cols) {
        val min = col.min()!!
        val max = col.max()!!
        val normalized = (col - min) / (max - min)
        normalizedCols.add(normalized)
    }
    val normalizedDF = dataFrameOf(*normalizedCols.toTypedArray())
    normalizedDF.setNames(*df.names.toTypedArray())
    return normalizedDF
}

In [4]:
fun getXy(
    df: DataFrame, 
    label: String = "quality"
): Pair<Array<FloatArray>, FloatArray> {
    val features = df.remove(label)
    val nFeatures = features.ncol
    val normalizedFeatures = normalize(features)
    val columnsArray = normalizedFeatures.toFloatMatrix()
    
    val X = Array(features.nrow) { FloatArray(nFeatures) }

    for (col in 0 until nFeatures) {
        for (row in 0 until df.nrow) {
            X[row][col] = columnsArray[col][row]
        }
    }
    
    val labels = df.get(label).toDoubles().filterNotNull().map { it.toFloat() }
    
    val y = labels.toFloatArray()
    return Pair(X, y)
}

In [5]:
val df = DataFrame.readCSV("data/winequality-white-train.csv")
df

fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
7.3,0.32,0.35,1.4,0.05,8.0,163.0,0.99244,3.24,0.42,10.7,5
7.0,0.31,0.26,7.4,0.069,28.0,160.0,0.9954,3.13,0.46,9.8,6
7.6,0.14,0.74,1.6,0.04,27.0,103.0,0.9916,3.07,0.4,10.8,7
5.0,0.29,0.54,5.7,0.035,54.0,155.0,0.98976,3.27,0.34,12.9,8
6.0,0.28,0.22,12.15,0.048,42.0,163.0,0.9957,3.2,0.46,10.1,5
9.8,0.93,0.45,8.6,0.052,34.0,187.0,0.9994,3.12,0.59,10.2,4


In [6]:
val (X, y) = getXy(df)

In [7]:
X[0].forEach { print("$it ") }
print(y[0])

0.31 0.23529412 0.21084337 0.012269938 0.12166172 0.020905923 0.3573086 0.102756895 0.47272727 0.23255815 0.45 5.0

In [8]:
val dataset = OnHeapDataset.create(X, y)

In [9]:
class PrintingCallback : Callback() {
    override fun onEpochEnd(epoch: Int, event: EpochTrainingEvent, logs: TrainingHistory) {
        println("Epoch: $epoch - loss: ${event.lossValue} - val loss: ${event.valLossValue}")
    }
}

In [10]:
val model = Sequential.of(
    Input(11),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(32),
    Dense(1, activation = Activations.Linear)
)


In [11]:
model.use {
    it.compile(
        optimizer = Adam(0.0001f),
        loss = Losses.MAE,
        metric = Metrics.MAE,
        callback = PrintingCallback(),
    )
    
    it.fit(
        dataset = dataset,
        validationRate = 0.1,
        epochs = 1000,
        trainBatchSize = 64,
        validationBatchSize = 1024,
    )
    
    val result = it.evaluate(dataset)
    println(result.lossValue)
}

Epoch: 1 - loss: 1.2947943210601807 - val loss: 0.773218035697937
Epoch: 2 - loss: 0.6644443869590759 - val loss: 0.7584826946258545
Epoch: 3 - loss: 0.6489058136940002 - val loss: 0.7261059880256653
Epoch: 4 - loss: 0.6372713446617126 - val loss: 0.717948317527771
Epoch: 5 - loss: 0.6304600238800049 - val loss: 0.7115525603294373
Epoch: 6 - loss: 0.6276894211769104 - val loss: 0.7113010883331299
Epoch: 7 - loss: 0.6252312064170837 - val loss: 0.7048158049583435
Epoch: 8 - loss: 0.6236007809638977 - val loss: 0.705267071723938
Epoch: 9 - loss: 0.6222526431083679 - val loss: 0.7029525637626648
Epoch: 10 - loss: 0.621122419834137 - val loss: 0.7012268304824829
Epoch: 11 - loss: 0.6191408038139343 - val loss: 0.7001461386680603
Epoch: 12 - loss: 0.6182299256324768 - val loss: 0.6984177827835083
Epoch: 13 - loss: 0.617371678352356 - val loss: 0.6960034966468811
Epoch: 14 - loss: 0.6160566806793213 - val loss: 0.6974270939826965
Epoch: 15 - loss: 0.615060567855835 - val loss: 0.696746945381

Epoch: 122 - loss: 0.5390428900718689 - val loss: 0.6014110445976257
Epoch: 123 - loss: 0.5392025709152222 - val loss: 0.6021389961242676
Epoch: 124 - loss: 0.5386915802955627 - val loss: 0.6010087728500366
Epoch: 125 - loss: 0.5390177965164185 - val loss: 0.6011297106742859
Epoch: 126 - loss: 0.5386765003204346 - val loss: 0.6003568768501282
Epoch: 127 - loss: 0.538011372089386 - val loss: 0.5997979044914246
Epoch: 128 - loss: 0.537674069404602 - val loss: 0.5995385050773621
Epoch: 129 - loss: 0.5379180908203125 - val loss: 0.5995094776153564
Epoch: 130 - loss: 0.5395511984825134 - val loss: 0.6000112295150757
Epoch: 131 - loss: 0.5371657609939575 - val loss: 0.5984452962875366
Epoch: 132 - loss: 0.5362866520881653 - val loss: 0.5986142158508301
Epoch: 133 - loss: 0.5356083512306213 - val loss: 0.5977197885513306
Epoch: 134 - loss: 0.5383086800575256 - val loss: 0.5991058945655823
Epoch: 135 - loss: 0.5357279181480408 - val loss: 0.5983167886734009
Epoch: 136 - loss: 0.537799239158630

Epoch: 242 - loss: 0.5180151462554932 - val loss: 0.587979793548584
Epoch: 243 - loss: 0.5173366069793701 - val loss: 0.5886611938476562
Epoch: 244 - loss: 0.5182648301124573 - val loss: 0.5865639448165894
Epoch: 245 - loss: 0.5171990990638733 - val loss: 0.5878034830093384
Epoch: 246 - loss: 0.5180361866950989 - val loss: 0.58805251121521
Epoch: 247 - loss: 0.5173376202583313 - val loss: 0.5878274440765381
Epoch: 248 - loss: 0.5172377228736877 - val loss: 0.5888938307762146
Epoch: 249 - loss: 0.5169693231582642 - val loss: 0.5876220464706421
Epoch: 250 - loss: 0.5165412425994873 - val loss: 0.5880385041236877
Epoch: 251 - loss: 0.5161919593811035 - val loss: 0.5876519680023193
Epoch: 252 - loss: 0.5161256790161133 - val loss: 0.5877645611763
Epoch: 253 - loss: 0.5159105658531189 - val loss: 0.5869523286819458
Epoch: 254 - loss: 0.51690274477005 - val loss: 0.5875149965286255
Epoch: 255 - loss: 0.5194682478904724 - val loss: 0.587762713432312
Epoch: 256 - loss: 0.5195923447608948 - val

Epoch: 362 - loss: 0.5060458779335022 - val loss: 0.577431321144104
Epoch: 363 - loss: 0.5063998699188232 - val loss: 0.5771939158439636
Epoch: 364 - loss: 0.5069701075553894 - val loss: 0.5774621963500977
Epoch: 365 - loss: 0.5063921213150024 - val loss: 0.575785219669342
Epoch: 366 - loss: 0.50632643699646 - val loss: 0.5772136449813843
Epoch: 367 - loss: 0.5064504742622375 - val loss: 0.5773503184318542
Epoch: 368 - loss: 0.5063719153404236 - val loss: 0.5768935084342957
Epoch: 369 - loss: 0.5055451989173889 - val loss: 0.5765202045440674
Epoch: 370 - loss: 0.5060255527496338 - val loss: 0.5768101215362549
Epoch: 371 - loss: 0.5056856274604797 - val loss: 0.5769665241241455
Epoch: 372 - loss: 0.5066598653793335 - val loss: 0.5757459998130798
Epoch: 373 - loss: 0.5063762664794922 - val loss: 0.5755559802055359
Epoch: 374 - loss: 0.5055550932884216 - val loss: 0.5753026008605957
Epoch: 375 - loss: 0.5053505301475525 - val loss: 0.5764802694320679
Epoch: 376 - loss: 0.5051464438438416 

Epoch: 481 - loss: 0.49701988697052 - val loss: 0.57159823179245
Epoch: 482 - loss: 0.49745213985443115 - val loss: 0.5727750062942505
Epoch: 483 - loss: 0.49796271324157715 - val loss: 0.5706468224525452
Epoch: 484 - loss: 0.49707305431365967 - val loss: 0.571935772895813
Epoch: 485 - loss: 0.49786147475242615 - val loss: 0.5690455436706543
Epoch: 486 - loss: 0.49739760160446167 - val loss: 0.5744540095329285
Epoch: 487 - loss: 0.4983966052532196 - val loss: 0.5744167566299438
Epoch: 488 - loss: 0.4979715645313263 - val loss: 0.569561779499054
Epoch: 489 - loss: 0.49664390087127686 - val loss: 0.5707952380180359
Epoch: 490 - loss: 0.4966282248497009 - val loss: 0.5715231895446777
Epoch: 491 - loss: 0.49718937277793884 - val loss: 0.5699900388717651
Epoch: 492 - loss: 0.4964349865913391 - val loss: 0.5710732936859131
Epoch: 493 - loss: 0.4963509142398834 - val loss: 0.5692489147186279
Epoch: 494 - loss: 0.496379554271698 - val loss: 0.5698548555374146
Epoch: 495 - loss: 0.4959201812744

Epoch: 600 - loss: 0.4905746579170227 - val loss: 0.5625706911087036
Epoch: 601 - loss: 0.49066483974456787 - val loss: 0.5640580058097839
Epoch: 602 - loss: 0.4910518229007721 - val loss: 0.5644656419754028
Epoch: 603 - loss: 0.4913097321987152 - val loss: 0.565102756023407
Epoch: 604 - loss: 0.49138209223747253 - val loss: 0.5648853182792664
Epoch: 605 - loss: 0.49043622612953186 - val loss: 0.5640375018119812
Epoch: 606 - loss: 0.4899921119213104 - val loss: 0.5631625652313232
Epoch: 607 - loss: 0.4901605546474457 - val loss: 0.5645331144332886
Epoch: 608 - loss: 0.49013373255729675 - val loss: 0.564435601234436
Epoch: 609 - loss: 0.4910995662212372 - val loss: 0.5633825063705444
Epoch: 610 - loss: 0.48947057127952576 - val loss: 0.5630766153335571
Epoch: 611 - loss: 0.49007338285446167 - val loss: 0.5640974640846252
Epoch: 612 - loss: 0.48973751068115234 - val loss: 0.5631048083305359
Epoch: 613 - loss: 0.4899410307407379 - val loss: 0.5633941292762756
Epoch: 614 - loss: 0.49019140

Epoch: 719 - loss: 0.4848114550113678 - val loss: 0.5617040991783142
Epoch: 720 - loss: 0.48472344875335693 - val loss: 0.5618671774864197
Epoch: 721 - loss: 0.48433011770248413 - val loss: 0.5634036064147949
Epoch: 722 - loss: 0.48416993021965027 - val loss: 0.5626177787780762
Epoch: 723 - loss: 0.4848422110080719 - val loss: 0.5632941722869873
Epoch: 724 - loss: 0.4838466942310333 - val loss: 0.5653904676437378
Epoch: 725 - loss: 0.48453694581985474 - val loss: 0.5637677311897278
Epoch: 726 - loss: 0.48401859402656555 - val loss: 0.5630796551704407
Epoch: 727 - loss: 0.48457521200180054 - val loss: 0.5627955794334412
Epoch: 728 - loss: 0.48448696732521057 - val loss: 0.56276535987854
Epoch: 729 - loss: 0.48485785722732544 - val loss: 0.5617749094963074
Epoch: 730 - loss: 0.4843849241733551 - val loss: 0.5629785656929016
Epoch: 731 - loss: 0.4836250841617584 - val loss: 0.5615133047103882
Epoch: 732 - loss: 0.4845007061958313 - val loss: 0.5626646280288696
Epoch: 733 - loss: 0.4842387

Epoch: 838 - loss: 0.4794350564479828 - val loss: 0.5580634474754333
Epoch: 839 - loss: 0.4784400165081024 - val loss: 0.5597252249717712
Epoch: 840 - loss: 0.4796716570854187 - val loss: 0.560910165309906
Epoch: 841 - loss: 0.47862160205841064 - val loss: 0.5593725442886353
Epoch: 842 - loss: 0.47908779978752136 - val loss: 0.5574738383293152
Epoch: 843 - loss: 0.47872278094291687 - val loss: 0.5609896779060364
Epoch: 844 - loss: 0.4786641001701355 - val loss: 0.5589511394500732
Epoch: 845 - loss: 0.4786810576915741 - val loss: 0.5591294169425964
Epoch: 846 - loss: 0.47809281945228577 - val loss: 0.5603122115135193
Epoch: 847 - loss: 0.4793647527694702 - val loss: 0.5581806302070618
Epoch: 848 - loss: 0.47719138860702515 - val loss: 0.5577546954154968
Epoch: 849 - loss: 0.4777541756629944 - val loss: 0.557982861995697
Epoch: 850 - loss: 0.47802993655204773 - val loss: 0.560425341129303
Epoch: 851 - loss: 0.47847089171409607 - val loss: 0.5581309795379639
Epoch: 852 - loss: 0.477638810

Epoch: 957 - loss: 0.47491517663002014 - val loss: 0.5572133660316467
Epoch: 958 - loss: 0.4745539724826813 - val loss: 0.5535887479782104
Epoch: 959 - loss: 0.47513502836227417 - val loss: 0.5591223835945129
Epoch: 960 - loss: 0.47471389174461365 - val loss: 0.5569760203361511
Epoch: 961 - loss: 0.4745137095451355 - val loss: 0.5562378168106079
Epoch: 962 - loss: 0.475140243768692 - val loss: 0.5591639876365662
Epoch: 963 - loss: 0.4744998812675476 - val loss: 0.5561349391937256
Epoch: 964 - loss: 0.4757486581802368 - val loss: 0.5574395060539246
Epoch: 965 - loss: 0.47423413395881653 - val loss: 0.5542408227920532
Epoch: 966 - loss: 0.4752291142940521 - val loss: 0.5574755072593689
Epoch: 967 - loss: 0.47407278418540955 - val loss: 0.5549706220626831
Epoch: 968 - loss: 0.47383564710617065 - val loss: 0.5555716753005981
Epoch: 969 - loss: 0.4746467173099518 - val loss: 0.5577113032341003
Epoch: 970 - loss: 0.4747871160507202 - val loss: 0.5566912889480591
Epoch: 971 - loss: 0.47458729