# Sinus als tiefes neuronales Netz

### Dependencies

In [1]:
USE {
    repositories {
        mavenLocal()
    }

    dependencies {
        implementation("sk.ai.net:core:0.0.1")
        implementation("sk.ai.net:reflection:0.0.1")
    }
}

### Jetbrains Bibliotheken

In [2]:
%use kandy
%use dataframe

### NN mit DSL definieren

In [3]:
import de.jugda.knanogpt.core.tensor.Shape
import de.jugda.knanogpt.core.tensor.Tensor
import org.skainet.activations.relu
import org.skainet.dsl.network
import org.skainet.nn.Module
import org.skainet.nn.NamedParameter

class SineNN(override val name: String="sin") : Module() {

    private val sineModule = network {
        input(1)
        dense(16) {
            activation = relu
        }
        dense(16) {
            activation = relu
        }
        dense(1)
    }
    override val params: List<NamedParameter>
        get() = emptyList()
    override val modules: List<Module>
        get() = sineModule.modules

    override fun forward(input: Tensor): Tensor =
        sineModule.forward(input)
}


## Helper Fuktion

In [4]:
fun SineNN.of(angle: Double): Tensor = this.forward(Tensor(Shape(1), listOf(angle.toDouble()).toDoubleArray()))

In [5]:
val model = SineNN()


In [6]:
print(model.of(PI/2.0))

[0.0]

In [7]:
print(sin(PI/2.0))

1.0

In [8]:


model.params.forEach { namedParameter ->
    println(namedParameter)
}

## Pre-trainirte Werte laden

In [9]:
val tensorMap = mapOf(
    "fc1.weight" to Tensor(
        Shape(16,1),
        doubleArrayOf(
            -0.5437579154968262,
            -0.40618014335632324,
            -0.04907243698835373,
            -0.2054896354675293,
            0.7046658992767334,
            0.12716591358184814,
            -0.6933680772781372,
            -0.6911409497261047,
            0.7351927757263184,
            0.8775343298912048,
            0.03899011388421059,
            -0.5247653126716614,
            -0.907718300819397,
            0.7729036808013916,
            -0.584505021572113,
            0.8824521899223328
        )
    ),
    "fc1.bias" to Tensor(
        Shape(16),
        doubleArrayOf(
            -0.6030652523040771,
            0.8545130491256714,
            0.9554063677787781,
            -0.014425158500671387,
            0.4734141230583191,
            -0.8752269744873047,
            0.9116081595420837,
            0.29334962368011475,
            -0.03179183229804039,
            -0.4028461277484894,
            0.6525490880012512,
            0.6051297783851624,
            -0.40821588039398193,
            -0.6744815111160278,
            0.39602163434028625,
            0.2196938693523407
        )
    ),
    "fc2.weight" to Tensor(
        Shape(16, 16), doubleArrayOf(
            -0.10257938504219055,
            0.17300540208816528,
            0.022153077647089958,
            0.22399896383285522,
            0.02398708462715149,
            0.18787217140197754,
            -0.026561839506030083,
            -0.11053970456123352,
            0.2762666940689087,
            0.2670552134513855,
            -0.09449589997529984,
            -0.37267953157424927,
            0.02065420150756836,
            -0.07705267518758774,
            -0.20018526911735535,
            0.25663721561431885,
            0.1135769784450531,
            -0.17329348623752594,
            -0.18831998109817505,
            -0.05843159556388855,
            0.1589454561471939,
            -0.12731316685676575,
            -0.137668639421463,
            -0.2233249843120575,
            -0.20866504311561584,
            -0.2232820689678192,
            0.014035061001777649,
            0.06416791677474976,
            0.19215050339698792,
            -0.23063619434833527,
            0.16741982102394104,
            0.061120595782995224,
            -0.05977308750152588,
            -0.1462109088897705,
            -0.03812804073095322,
            -0.14342406392097473,
            -0.13743525743484497,
            -0.1320720911026001,
            -0.19449175894260406,
            0.2660144865512848,
            -0.18535356223583221,
            -0.030827010050415993,
            -0.018952256068587303,
            0.2655194401741028,
            -0.14998266100883484,
            0.018038976937532425,
            0.031076736748218536,
            0.11761008948087692,
            0.23089301586151123,
            0.0871533453464508,
            -0.1649712324142456,
            0.011388391256332397,
            -0.22663500905036926,
            -0.21543896198272705,
            0.014795392751693726,
            -0.05541226267814636,
            -0.09872651100158691,
            -0.1683627963066101,
            -0.22862935066223145,
            0.11181080341339111,
            -0.06262826919555664,
            -0.02767014503479004,
            0.23583859205245972,
            -0.1363779604434967,
            0.028661668300628662,
            0.007924139499664307,
            -0.23774152994155884,
            -0.05389246344566345,
            -0.1980002522468567,
            0.04686683416366577,
            0.10875284671783447,
            -0.001697540283203125,
            0.11665266752243042,
            -0.2258722186088562,
            -0.18789243698120117,
            0.19521811604499817,
            0.12139546871185303,
            0.16123232245445251,
            -0.048430681228637695,
            0.12882336974143982,
            -0.06784370541572571,
            0.20136915147304535,
            0.21978528797626495,
            0.0007946193218231201,
            0.028842158615589142,
            0.13136360049247742,
            0.41717106103897095,
            0.17754682898521423,
            0.22757908701896667,
            -0.04781464859843254,
            0.25020548701286316,
            0.16513016819953918,
            -0.03443148732185364,
            0.5649763941764832,
            0.42032307386398315,
            -0.0171221811324358,
            -0.08848094940185547,
            0.13860763609409332,
            0.03668839484453201,
            -0.07173511385917664,
            0.27420082688331604,
            -0.18933850526809692,
            -0.1663854420185089,
            -0.3344656527042389,
            0.009426545351743698,
            -0.03151931241154671,
            0.2958306074142456,
            -0.5178188681602478,
            0.13185536861419678,
            0.10059252381324768,
            -0.1982472836971283,
            0.02664925903081894,
            0.09930169582366943,
            0.07500758767127991,
            -0.21454232931137085,
            -0.16600322723388672,
            -0.08857399225234985,
            0.04886987805366516,
            -0.010289937257766724,
            -0.10731151700019836,
            -0.12413057684898376,
            0.16528230905532837,
            0.07231757044792175,
            0.14234772324562073,
            0.16628128290176392,
            0.21260470151901245,
            0.20259436964988708,
            -0.16137951612472534,
            -0.09222900867462158,
            0.08151715993881226,
            0.24118076264858246,
            0.24719339609146118,
            -0.048659272491931915,
            0.14365407824516296,
            0.13242720067501068,
            0.21420586109161377,
            0.060788389295339584,
            -0.06129367649555206,
            0.07209084928035736,
            0.23200565576553345,
            0.13996049761772156,
            -0.32341986894607544,
            0.41574734449386597,
            -0.2561975121498108,
            0.15977880358695984,
            0.029257740825414658,
            -0.030059611424803734,
            0.005124181509017944,
            -0.10688148438930511,
            0.11103665828704834,
            -0.12384822964668274,
            -0.07741808891296387,
            -0.028035949915647507,
            0.177804633975029,
            -0.21912963688373566,
            -0.2997293770313263,
            0.06195428967475891,
            -0.3089802861213684,
            -0.09249170869588852,
            -0.04744498431682587,
            0.11955362558364868,
            -0.09306460618972778,
            0.08386632800102234,
            -0.20563799142837524,
            -0.16910019516944885,
            -0.23323941230773926,
            -0.12276723980903625,
            -0.000338822603225708,
            0.05349484086036682,
            0.01340574026107788,
            -0.13708370923995972,
            0.21299102902412415,
            0.09145370125770569,
            -0.011387556791305542,
            0.22622063755989075,
            0.20040494203567505,
            0.15085336565971375,
            -0.2059146910905838,
            -0.14172489941120148,
            0.15939223766326904,
            0.02557411603629589,
            -0.05468907952308655,
            0.041294731199741364,
            -0.003357824170961976,
            -0.02610093355178833,
            0.09087082743644714,
            0.08137725293636322,
            0.1261443942785263,
            -0.05682036280632019,
            -0.24741682410240173,
            0.13989342749118805,
            -0.06079234555363655,
            0.08340916037559509,
            0.00012913873069919646,
            0.3699108362197876,
            0.2098078727722168,
            0.03235418349504471,
            -0.06893250346183777,
            -0.24021713435649872,
            -0.48108887672424316,
            0.14382511377334595,
            0.26369708776474,
            0.31492164731025696,
            -0.42158377170562744,
            0.08641833066940308,
            0.04120161011815071,
            -0.19685231149196625,
            0.17141079902648926,
            -0.09313604235649109,
            -0.21633949875831604,
            -0.031940728425979614,
            0.020739346742630005,
            -0.12642377614974976,
            0.04184052348136902,
            -0.08950522541999817,
            0.20784395933151245,
            0.23021626472473145,
            -0.17811211943626404,
            -0.060421377420425415,
            -0.09906545281410217,
            -0.0205744206905365,
            0.21380892395973206,
            -0.13273027539253235,
            0.003771275281906128,
            0.011810332536697388,
            -0.034441642463207245,
            -0.2788986265659332,
            0.02576860785484314,
            0.048153020441532135,
            0.1212598979473114,
            -0.024378251284360886,
            0.3358549475669861,
            -0.1537722945213318,
            -0.06005091965198517,
            -0.035380907356739044,
            0.22935563325881958,
            0.02437269687652588,
            0.0868522971868515,
            0.04170968756079674,
            -0.19989392161369324,
            -0.10855808854103088,
            -0.04006728529930115,
            -0.1922602653503418,
            0.10681521892547607,
            0.09539484977722168,
            -0.0405639111995697,
            -0.07352346181869507,
            0.025782227516174316,
            -0.025375187397003174,
            0.016373634338378906,
            -0.044025421142578125,
            0.21171081066131592,
            -0.0293692946434021,
            0.17254146933555603,
            0.0502622127532959,
            -0.2063692808151245
        )
    ),
    "fc2.bias" to Tensor(
        Shape(16),
        doubleArrayOf(
            0.21671485900878906,
            -0.2539503574371338,
            -0.29396378993988037,
            -0.00276261568069458,
            -0.18057772517204285,
            -0.010082397609949112,
            0.3308062255382538,
            -0.04653894901275635,
            -0.016405299305915833,
            0.016453173011541367,
            -0.16926336288452148,
            -0.18714305758476257,
            0.27529406547546387,
            -0.09136846661567688,
            -0.06186753883957863,
            0.1439799666404724
        )
    ),
    "fc3.weight" to Tensor(
        Shape( 1, 16),
        doubleArrayOf(
            0.14054463803768158,
            -0.1380414068698883,
            -0.03053985722362995,
            -0.10102230310440063,
            0.19149786233901978,
            -0.273008793592453,
            0.22432786226272583,
            -0.22113493084907532,
            -0.1693483591079712,
            -0.019474321976304054,
            -0.170209139585495,
            0.04850612208247185,
            0.30267849564552307,
            0.1791227161884308,
            -0.07639366388320923,
            0.002671569585800171
        )
    ),
    "fc3.bias" to Tensor(Shape(1), doubleArrayOf(0.36608225107192993)),
)

In [10]:
val linear = model.modules.filter { module ->
    module.name.startsWith("li")
}
val wandb = tensorMap
linear.forEachIndexed { index, layer ->
    val weightKey = "fc${index + 1}.weight"
    layer.params.firstOrNull { it.name.startsWith("w") }?.let { weightParam ->
        wandb[weightKey]?.let { newWeightValue ->
            weightParam.value = newWeightValue
        }
    }
}

linear.forEachIndexed { index, layer ->
    val biasKey = "fc${index + 1}.bias"
    layer.params.firstOrNull { it.name.startsWith("b") }?.let { biasParam ->
        wandb[biasKey]?.let { newWeightValue ->
            biasParam.value = newWeightValue
        }
    }
}


In [11]:
print(model.of(PI/2.0))

[1.0088576282184099]

## Plot 

In [12]:
val x_values = List(100) { index ->
    (index / (100 - 1).toFloat()) * (PI / 2)
}

val y_values = List(100) { index ->
    sin(x_values[index])
}


val y_nn_values = List(100) { index ->
    model.of(x_values[index]).elements[0]
}

val df = dataFrameOf(
    "x" to x_values + x_values,
    "y" to y_values + y_nn_values,
    "mode" to List(100) { "sin" } + List(100) { "nn" }
)

In [13]:
df

x,y,mode
0,0,sin
15867,15866,sin
31733,31728,sin
47600,47582,sin
63467,63424,sin
79333,79250,sin
95200,95056,sin
111066,110838,sin
126933,126592,sin
142800,142315,sin


In [14]:
df.plot {
    line {
        x("x")
        y("y")
        color("mode") {
            scale = categorical("sin" to Color.PURPLE, "nn" to Color.ORANGE)
        }
        width = 1.5
    }
}