From 43569cad30869b466d670129fad0680936a92a0e Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Thu, 22 Jun 2023 15:49:59 +0100 Subject: [PATCH] added Softmax, LSUV and fixed some typos :33 --- nbs/01_operators.ipynb | 12 +-- nbs/02_init.ipynb | 228 ++++++++++++++++++++++++++++++++++------- nbs/03_nn.ipynb | 64 +++++++----- nbs/04_optim.ipynb | 12 +-- 4 files changed, 237 insertions(+), 79 deletions(-) diff --git a/nbs/01_operators.ipynb b/nbs/01_operators.ipynb index 791ee23..8cea218 100644 --- a/nbs/01_operators.ipynb +++ b/nbs/01_operators.ipynb @@ -1201,7 +1201,7 @@ " Returns:\n", " The result of applying ReLU to a.\n", " \"\"\"\n", - " self.out = ARRAY_API.clip(a, a_min=0)\n", + " self.out = ARRAY_API.clip(a, a_min=0, a_max=None)\n", " return self.out\n", "\n", " def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor,]:\n", @@ -1891,11 +1891,11 @@ { "data": { "text/plain": [ - "tensor([[ 5.9966, -2.2463, -3.3931, 0.4814, -0.6434],\n", - " [ 3.0877, 2.2520, 4.4612, -0.9591, -2.2344],\n", - " [ 4.4212, 8.7814, 0.1132, -8.3712, -3.2276],\n", - " [-0.6944, -2.7023, -3.3642, 2.9205, 0.5669],\n", - " [-4.0364, 2.4787, -5.7720, -5.1817, 1.5952]])" + "tensor([[ 3.0506e+00, 4.4419e+00, -7.5836e-01, 7.2981e-01, -2.4619e-03],\n", + " [-3.8709e-01, 1.4354e+00, -2.3635e+00, -1.3456e+00, 2.4888e+00],\n", + " [-7.4363e-01, 8.4173e+00, -1.7585e+00, 6.0386e+00, -6.7308e-01],\n", + " [ 3.4582e+00, 4.3952e+00, -3.9762e+00, 4.0534e-01, 3.8527e+00],\n", + " [ 3.3561e+00, 7.8316e-01, 5.0760e+00, -1.6189e+00, -4.0100e+00]])" ] }, "execution_count": null, diff --git a/nbs/02_init.ipynb b/nbs/02_init.ipynb index f4337e0..dc8c9ea 100644 --- a/nbs/02_init.ipynb +++ b/nbs/02_init.ipynb @@ -32,7 +32,8 @@ "source": [ "#| export\n", "import math\n", - "import minima as mi" + "import minima as mi\n", + "from functools import partial" ] }, { @@ -92,16 +93,16 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[0.724743 0.236047 0.065364 0.855749 0.91864 ]\n", - " [0.027286 0.800982 0.64705 0.968984 0.489134]\n", - " [0.931952 0.758258 0.887548 0.867378 0.108525]\n", - " [0.03819 0.143269 0.210976 0.892318 0.069396]\n", - " [0.257424 0.554249 0.235325 0.064803 0.843057]\n", - " [0.696038 0.812699 0.54037 0.754445 0.385663]\n", - " [0.461943 0.538387 0.582451 0.802216 0.6077 ]\n", - " [0.045212 0.726626 0.886866 0.190699 0.00549 ]\n", - " [0.685753 0.342417 0.554111 0.813416 0.375196]\n", - " [0.170601 0.631679 0.474656 0.363225 0.162466]])" + "[[0.423019 0.831303 0.593536 0.464066 0.622164]\n", + " [0.519762 0.698 0.364592 0.593321 0.299263]\n", + " [0.330883 0.566039 0.327606 0.069224 0.077561]\n", + " [0.591434 0.092411 0.049555 0.729441 0.001867]\n", + " [0.60242 0.36611 0.162999 0.602054 0.684817]\n", + " [0.545608 0.415636 0.746867 0.923219 0.67769 ]\n", + " [0.809501 0.496377 0.527514 0.333276 0.479529]\n", + " [0.080732 0.63581 0.950788 0.387371 0.570476]\n", + " [0.677467 0.620451 0.702335 0.071747 0.067357]\n", + " [0.66082 0.372642 0.226082 0.687941 0.761832]])" ] }, "execution_count": null, @@ -210,11 +211,11 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[ 0.439627 -0.765065 -0.323372 -1.518769 -0.563107]\n", - " [-0.255756 0.96025 1.512503 -0.662302 -1.201184]\n", - " [ 0.650412 0.263193 1.310423 1.383127 1.237785]\n", - " [-0.008076 0.028429 1.874965 0.977454 -0.068408]\n", - " [-1.75604 0.546302 0.359429 0.864159 1.347796]])" + "[[ 0.934699 -1.883731 1.56695 0.929079 0.73024 ]\n", + " [ 2.066303 0.109121 1.161415 -1.184726 -1.753147]\n", + " [ 0.339952 -1.125624 -0.740886 -0.808628 0.024874]\n", + " [-0.307566 1.072183 0.013086 0.407447 -0.705648]\n", + " [-0.956348 -0.291481 -0.1 0.70653 0.500862]])" ] }, "execution_count": null, @@ -605,16 +606,16 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[-0.362476 0.425931 -0.604767 -0.585322 -0.469232]\n", - " [-0.207212 -0.106902 -0.349397 -0.082876 0.145293]\n", - " [ 0.274891 -0.136795 -0.210589 0.275138 0.481523]\n", - " [ 0.258619 0.307068 0.458597 0.236569 0.462849]\n", - " [ 0.029899 -0.559897 0.514509 0.32062 0.208706]\n", - " [-0.488839 -0.433476 0.089545 0.466521 -0.407354]\n", - " [-0.243207 0.266691 0.27616 0.263078 -0.267017]\n", - " [-0.61643 -0.143201 0.083898 0.366265 0.022065]\n", - " [ 0.149755 -0.155406 -0.494278 0.481983 -0.509169]\n", - " [ 0.159343 0.597055 -0.36376 0.376093 -0.399417]])" + "[[-0.221499 -0.06155 -0.077118 0.56846 -0.418471]\n", + " [-0.149945 -0.553442 0.581115 -0.460948 -0.420142]\n", + " [-0.355427 -0.066154 0.355814 0.082557 -0.556673]\n", + " [-0.497098 -0.087087 -0.051234 -0.238323 -0.290452]\n", + " [-0.280464 0.334714 0.116377 -0.481387 -0.388613]\n", + " [ 0.110984 0.625096 -0.228138 -0.500467 0.502594]\n", + " [-0.400704 0.197745 0.166157 -0.479262 0.577242]\n", + " [-0.196405 -0.577416 -0.605291 -0.294985 -0.606795]\n", + " [ 0.321031 -0.098246 0.278399 0.047973 0.295106]\n", + " [-0.385705 0.34554 -0.519177 0.389492 0.040751]])" ] }, "execution_count": null, @@ -646,16 +647,16 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[ 0.176985 0.475322 -0.472049 -0.197204 -0.124735]\n", - " [-0.021712 0.164783 -0.76619 -0.109438 0.25406 ]\n", - " [ 0.747072 0.355753 -0.609391 0.107684 -0.27346 ]\n", - " [ 0.418525 -0.260188 -0.422199 -0.179245 -0.042915]\n", - " [ 0.167997 -0.004962 0.17853 -0.520827 0.349572]\n", - " [-0.019444 -0.406027 0.332068 -0.591041 -0.408733]\n", - " [-0.535821 -0.790127 -0.098206 0.25483 0.509668]\n", - " [-0.401008 -0.051625 0.376056 0.1105 -1.083598]\n", - " [ 0.06975 0.191631 0.233941 0.101632 0.235047]\n", - " [-0.034112 -0.465086 0.053479 0.444051 -0.654049]])" + "[[-0.063341 -0.374345 -0.352547 0.789786 -0.782333]\n", + " [ 0.195202 -0.227442 0.252176 0.225021 0.105454]\n", + " [-0.174635 -0.043868 0.195862 -0.917395 0.502817]\n", + " [ 0.177671 0.373282 0.601478 0.593381 -0.1945 ]\n", + " [ 0.056075 0.224731 -0.458342 -0.133477 -0.138557]\n", + " [ 0.393741 0.096873 0.514728 0.195166 -0.260037]\n", + " [-0.161321 -0.056462 0.609632 -0.470343 0.118147]\n", + " [-0.359394 -0.297816 0.383475 0.310443 0.510362]\n", + " [-0.433323 -0.49009 -0.362796 0.263523 -0.023001]\n", + " [-0.156537 0.312429 -0.113006 -0.195525 0.197912]])" ] }, "execution_count": null, @@ -814,6 +815,161 @@ "2. It performs better with ReLU and its variants because it accounts for the fact that the variance of the output of a neuron with a ReLU activation function is half the variance of its input." ] }, + { + "cell_type": "markdown", + "id": "b9febe08-dcf9-4421-9a1f-4e38a193646b", + "metadata": {}, + "source": [ + "### LSUV Initialization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb0c6a3a-78e7-4b96-b32a-1c311bf9d70a", + "metadata": {}, + "outputs": [], + "source": [ + "class Hook():\n", + " def __init__(self, layer, fn): self.hook = layer.register_forward_hook(partial(fn, self))\n", + " # def remove(self): self.hook.remove()\n", + " # def __del__(self): self.remove()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5f46d88-3ab2-4bde-bb2e-1026351fb79e", + "metadata": {}, + "outputs": [], + "source": [ + "def append_stats(hook, mod, inp, outp):\n", + " if not hasattr(hook,'stats'): hook.stats = ([],[])\n", + " acts = outp # TODO: move outp to cpu when USING ACCELERAOR!! :3\n", + " hook.stats[0].append(acts.numpy().mean())\n", + " hook.stats[1].append(acts.numpy().std())\n", + "\n", + "#| export\n", + "def _lsuv_stats(hook, mod, inp, outp):\n", + " acts = outp\n", + " hook.mean = acts.numpy().mean()\n", + " hook.std = acts.numpy().std()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb14b7e1-9c0a-48fe-957a-d6f297059c28", + "metadata": {}, + "outputs": [], + "source": [ + "class LSUV:\n", + "\n", + " def __init__(self, model, batch) -> None:\n", + " self.model = model\n", + " self.batch = batch\n", + " self.params_layers = [m for m in model if hasattr(m, 'weight') and not isinstance(m, mi.nn.BatchNorm1d)]\n", + " self.act_fns = [m for m in model if isinstance(m, mi.nn.ReLU)]\n", + " \n", + " # Constants\n", + " self.TOLERANCE = 1e-3\n", + " \n", + "\n", + " def lsuv_init(self):\n", + " \"\"\"\n", + " Layer-wise Sequential Unit Variance Initialization (LSUV).\n", + " A method to help neural nets converge faster.\n", + " \n", + " Args:\n", + " model : the model on which to perform LSUV initialization\n", + " param_module : the module with trainable parameters to which the Hook is to be registered\n", + " activation_module : the activation module to be initialized (ReLU, Sigmoid, etc.)\n", + " input_data : input data to be passed through the model\n", + " \"\"\"\n", + " for params_layer, acts_layer in zip(self.params_layers, self.act_fns):\n", + " hook = Hook(acts_layer, _lsuv_stats)\n", + " while self.model(self.batch) is not None and (abs(hook.std-1) > self.TOLERANCE or abs(hook.mean) > self.TOLERANCE):\n", + " print(f'---> before: {hook.mean} -- {hook.std}')\n", + " if params_layer.bias is not None: params_layer.bias -= mi.Tensor(hook.mean)\n", + " params_layer.weight.data /= mi.Tensor(hook.std)\n", + " print(f'-------------> after: {hook.mean} -- {hook.std}')\n", + " hook.remove()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "354013dd-f829-4860-8bd3-e409639fe8de", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import minima as mi\n", + "\n", + "# Number of samples\n", + "n_samples = 1000\n", + "\n", + "# Number of features (28x28 pixels for a grayscale image)\n", + "n_features = 784\n", + "\n", + "# Number of classes\n", + "n_classes = 10\n", + "\n", + "# Generate random inputs from a standard normal distribution\n", + "X = mi.init.randn(n_samples, n_features)\n", + "\n", + "# Generate random target classes\n", + "y = mi.Tensor(np.random.randint(0, n_classes, size=n_samples))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3f22936-017b-4ebc-977d-f969041f883b", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the neural network architecture\n", + "model = mi.nn.Sequential(\n", + " mi.nn.Linear(784, 128),\n", + " mi.nn.ReLU(),\n", + " mi.nn.Linear(128, 10)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3db16ba0-a9ac-4762-bea2-b2811a19da6a", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'Sequential' object is not iterable", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m lsuv \u001b[38;5;241m=\u001b[39m \u001b[43mLSUV\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[26], line 6\u001b[0m, in \u001b[0;36mLSUV.__init__\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m model\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch \u001b[38;5;241m=\u001b[39m batch\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams_layers \u001b[38;5;241m=\u001b[39m [m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m model \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mweight\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(m, mi\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mBatchNorm1d)]\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fns \u001b[38;5;241m=\u001b[39m [m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m model \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(m, mi\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mReLU)]\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Constants\u001b[39;00m\n", + "\u001b[0;31mTypeError\u001b[0m: 'Sequential' object is not iterable" + ] + } + ], + "source": [ + "lsuv = LSUV(model, X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaabe69c-6afc-4203-8dde-89edd339ba1f", + "metadata": {}, + "outputs": [], + "source": [ + "# lsuv.lsuv_init()" + ] + }, { "cell_type": "markdown", "id": "1db18cd9-190f-4103-9866-6e7fe7f4863a", diff --git a/nbs/03_nn.ipynb b/nbs/03_nn.ipynb index cfebcff..c9e16c4 100644 --- a/nbs/03_nn.ipynb +++ b/nbs/03_nn.ipynb @@ -358,7 +358,17 @@ " \"\"\"\n", " for module in self.modules:\n", " x = module(x)\n", - " return x\n" + " return x\n", + "\n", + " def __iter__(self):\n", + " self._iter_idx = 0;\n", + " return self\n", + " def __next__(self):\n", + " if self._iter_idx < len(self.modules):\n", + " res = self.modules[self._iter_idx]\n", + " self._iter_idx += 1\n", + " return res\n", + " raise StopIteration()\n" ] }, { @@ -614,7 +624,7 @@ "CE = -z_c + \\log\\left(\\sum_{j=1}^{C} e^{z_j}\\right)\n", "$$\n", "\n", - "First, `log_sum_exp_logits = ops.logsumexp(logits, axes=(1, )).sum()` computes the term $\\log\\left(\\sum_{j=1}^{C} e^{z_j}\\right)$. The function `logsumexp` computes the logarithm of the sum of exponentials in a numerically stable way, and then these values are summed over all samples.\n", + "First, `log_sum_exp_logits = operators.logsumexp(logits, axes=(1, )).sum()` computes the term $\\log\\left(\\sum_{j=1}^{C} e^{z_j}\\right)$. The function `logsumexp` computes the logarithm of the sum of exponentials in a numerically stable way, and then these values are summed over all samples.\n", "\n", "Second, `true_class_logits_sum = (logits * init.one_hot(logits.shape[1], y)).sum()` computes the $-z_c$ term for each sample. The function `init.one_hot(logits.shape[1], y)` creates a one-hot encoding of the true labels, and this is then multiplied with the logits to pick out the logits for the correct classes. These values are then summed over all samples.\n", "\n", @@ -836,7 +846,7 @@ " Returns:\n", " Tensor: A single tensor that is the average cross-entropy loss.\n", " \"\"\"\n", - " log_sum_exp_logits = ops.logsumexp(input, axes=(1, )).sum()\n", + " log_sum_exp_logits = operators.logsumexp(input, axes=(1, )).sum()\n", " true_class_logits_sum = (input * init.one_hot(input.shape[1], target)).sum()\n", " return (log_sum_exp_logits - true_class_logits_sum) / input.shape[0]" ] @@ -872,7 +882,7 @@ " ```\n", " \"\"\"\n", "\n", - " def forward(self, input: Tensor, target: Tensor) -> Tensor:\n", + " def forward(self, input: Tensor) -> Tensor:\n", " \"\"\"\n", " Computes the Cross Entropy Loss between the input logits and the target class indices.\n", "\n", @@ -885,8 +895,8 @@ " \"\"\"\n", "\n", "\n", - " exps = ops.exp(X - mi.autograd.ARRAY_API.max(input))\n", - " return exps / ops.summation(exps)" + " exps = operators.exp(input - mi.autograd.ARRAY_API.max(input))\n", + " return exps / operators.summation(exps)" ] }, { @@ -937,11 +947,11 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[0.635511 0.622535 0.600604 0.845242 0.904673 0.620263 0.821988 0.767772 0.048989 0.805934]\n", - " [0.051952 0.311615 0.118656 0.707399 0.975698 0.100632 0.054584 0.661575 0.017908 0.314077]\n", - " [0.377436 0.97437 0.142414 0.080208 0.403369 0.323072 0.473484 0.06406 0.937268 0.994772]\n", - " [0.615802 0.052481 0.256974 0.033538 0.724013 0.796708 0.429111 0.319218 0.445589 0.394347]\n", - " [0.424856 0.08904 0.158487 0.628804 0.45154 0.794178 0.462012 0.206399 0.786458 0.149655]])" + "[[0.960428 0.07369 0.761998 0.741508 0.397817 0.150828 0.498489 0.868482 0.291344 0.785897]\n", + " [0.503009 0.45372 0.195389 0.208309 0.249612 0.176545 0.849755 0.306317 0.829817 0.354018]\n", + " [0.545867 0.400977 0.8664 0.036632 0.902998 0.975624 0.195215 0.306032 0.93682 0.824157]\n", + " [0.855338 0.32871 0.561241 0.963338 0.39172 0.173132 0.537805 0.437422 0.069504 0.652756]\n", + " [0.858377 0.417763 0.011918 0.379052 0.548395 0.866817 0.673907 0.53375 0.268325 0.160612]])" ] }, "execution_count": null, @@ -986,7 +996,7 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[0.667351 0.33141 0.477045 0.406778 0.415143])" + "[0.553048 0.412649 0.599072 0.497097 0.471892])" ] }, "execution_count": null, @@ -1009,11 +1019,11 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[0.667351]\n", - " [0.33141 ]\n", - " [0.477045]\n", - " [0.406778]\n", - " [0.415143]])" + "[[0.553048]\n", + " [0.412649]\n", + " [0.599072]\n", + " [0.497097]\n", + " [0.471892]])" ] }, "execution_count": null, @@ -1036,11 +1046,11 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[0.667351 0.667351 0.667351 0.667351 0.667351 0.667351 0.667351 0.667351 0.667351 0.667351]\n", - " [0.33141 0.33141 0.33141 0.33141 0.33141 0.33141 0.33141 0.33141 0.33141 0.33141 ]\n", - " [0.477045 0.477045 0.477045 0.477045 0.477045 0.477045 0.477045 0.477045 0.477045 0.477045]\n", - " [0.406778 0.406778 0.406778 0.406778 0.406778 0.406778 0.406778 0.406778 0.406778 0.406778]\n", - " [0.415143 0.415143 0.415143 0.415143 0.415143 0.415143 0.415143 0.415143 0.415143 0.415143]])" + "[[0.553048 0.553048 0.553048 0.553048 0.553048 0.553048 0.553048 0.553048 0.553048 0.553048]\n", + " [0.412649 0.412649 0.412649 0.412649 0.412649 0.412649 0.412649 0.412649 0.412649 0.412649]\n", + " [0.599072 0.599072 0.599072 0.599072 0.599072 0.599072 0.599072 0.599072 0.599072 0.599072]\n", + " [0.497097 0.497097 0.497097 0.497097 0.497097 0.497097 0.497097 0.497097 0.497097 0.497097]\n", + " [0.471892 0.471892 0.471892 0.471892 0.471892 0.471892 0.471892 0.471892 0.471892 0.471892]])" ] }, "execution_count": null, @@ -1063,11 +1073,11 @@ "data": { "text/plain": [ "minima.Tensor(\n", - "[[-0.03184 -0.044816 -0.066747 0.177891 0.237322 -0.047088 0.154637 0.100421 -0.618362 0.138583]\n", - " [-0.279457 -0.019795 -0.212753 0.375989 0.644288 -0.230778 -0.276825 0.330166 -0.313502 -0.017333]\n", - " [-0.09961 0.497325 -0.334632 -0.396837 -0.073677 -0.153973 -0.003561 -0.412985 0.460223 0.517727]\n", - " [ 0.209024 -0.354297 -0.149804 -0.37324 0.317235 0.38993 0.022333 -0.08756 0.038811 -0.012431]\n", - " [ 0.009713 -0.326103 -0.256656 0.213661 0.036397 0.379035 0.046869 -0.208744 0.371316 -0.265487]])" + "[[ 0.40738 -0.479359 0.20895 0.18846 -0.155231 -0.40222 -0.054559 0.315434 -0.261704 0.232849]\n", + " [ 0.090359 0.041071 -0.21726 -0.20434 -0.163037 -0.236105 0.437106 -0.106332 0.417168 -0.058632]\n", + " [-0.053205 -0.198095 0.267328 -0.562441 0.303926 0.376552 -0.403858 -0.293041 0.337748 0.225084]\n", + " [ 0.358241 -0.168386 0.064144 0.466241 -0.105377 -0.323964 0.040708 -0.059675 -0.427593 0.155659]\n", + " [ 0.386485 -0.054129 -0.459974 -0.09284 0.076503 0.394926 0.202016 0.061858 -0.203567 -0.31128 ]])" ] }, "execution_count": null, diff --git a/nbs/04_optim.ipynb b/nbs/04_optim.ipynb index 94fcc7d..b03b1a4 100644 --- a/nbs/04_optim.ipynb +++ b/nbs/04_optim.ipynb @@ -318,7 +318,7 @@ " def __init__(\n", " self,\n", " params, # The parameters of the model to be optimized.\n", - " lr=0.01, # The initial learning rate.\n", + " lr=0.001, # The initial learning rate.\n", " wd=0.0, # The weight decay (L2 regularization).\n", " eps=1e-7, # A small constant for numerical stability.\n", " ):\n", @@ -561,7 +561,7 @@ " def __init__(\n", " self,\n", " params, # `params` is the list of parameters\n", - " lr=0.01, # `lr` is the learning rate $\\alpha$\n", + " lr=1e-5, # `lr` is the learning rate $\\alpha$\n", " beta1=0.9, # The exponential decay rate for the first moment estimates. Default is 0.9.\n", " beta2=0.999, # The exponential decay rate for the second moment estimates. Default is 0.999.\n", " eps=1e-8, # `eps` is $\\hat{\\epsilon}$ or $\\epsilon$ based on `optimized_update`\n", @@ -643,14 +643,6 @@ "#| hide\n", "import nbdev; nbdev.nbdev_export()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eae20858-2009-470d-a863-ac0bd45ed6ab", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {