Skip to content

Commit

Permalink
added Softmax, LSUV and fixed some typos :33
Browse files Browse the repository at this point in the history
  • Loading branch information
m0saan committed Jun 22, 2023
1 parent 5133e76 commit 43569ca
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 79 deletions.
12 changes: 6 additions & 6 deletions nbs/01_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@
" Returns:\n",
" The result of applying ReLU to a.\n",
" \"\"\"\n",
" self.out = ARRAY_API.clip(a, a_min=0)\n",
" self.out = ARRAY_API.clip(a, a_min=0, a_max=None)\n",
" return self.out\n",
"\n",
" def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor,]:\n",
Expand Down Expand Up @@ -1891,11 +1891,11 @@
{
"data": {
"text/plain": [
"tensor([[ 5.9966, -2.2463, -3.3931, 0.4814, -0.6434],\n",
" [ 3.0877, 2.2520, 4.4612, -0.9591, -2.2344],\n",
" [ 4.4212, 8.7814, 0.1132, -8.3712, -3.2276],\n",
" [-0.6944, -2.7023, -3.3642, 2.9205, 0.5669],\n",
" [-4.0364, 2.4787, -5.7720, -5.1817, 1.5952]])"
"tensor([[ 3.0506e+00, 4.4419e+00, -7.5836e-01, 7.2981e-01, -2.4619e-03],\n",
" [-3.8709e-01, 1.4354e+00, -2.3635e+00, -1.3456e+00, 2.4888e+00],\n",
" [-7.4363e-01, 8.4173e+00, -1.7585e+00, 6.0386e+00, -6.7308e-01],\n",
" [ 3.4582e+00, 4.3952e+00, -3.9762e+00, 4.0534e-01, 3.8527e+00],\n",
" [ 3.3561e+00, 7.8316e-01, 5.0760e+00, -1.6189e+00, -4.0100e+00]])"
]
},
"execution_count": null,
Expand Down
228 changes: 192 additions & 36 deletions nbs/02_init.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"source": [
"#| export\n",
"import math\n",
"import minima as mi"
"import minima as mi\n",
"from functools import partial"
]
},
{
Expand Down Expand Up @@ -92,16 +93,16 @@
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[0.724743 0.236047 0.065364 0.855749 0.91864 ]\n",
" [0.027286 0.800982 0.64705 0.968984 0.489134]\n",
" [0.931952 0.758258 0.887548 0.867378 0.108525]\n",
" [0.03819 0.143269 0.210976 0.892318 0.069396]\n",
" [0.257424 0.554249 0.235325 0.064803 0.843057]\n",
" [0.696038 0.812699 0.54037 0.754445 0.385663]\n",
" [0.461943 0.538387 0.582451 0.802216 0.6077 ]\n",
" [0.045212 0.726626 0.886866 0.190699 0.00549 ]\n",
" [0.685753 0.342417 0.554111 0.813416 0.375196]\n",
" [0.170601 0.631679 0.474656 0.363225 0.162466]])"
"[[0.423019 0.831303 0.593536 0.464066 0.622164]\n",
" [0.519762 0.698 0.364592 0.593321 0.299263]\n",
" [0.330883 0.566039 0.327606 0.069224 0.077561]\n",
" [0.591434 0.092411 0.049555 0.729441 0.001867]\n",
" [0.60242 0.36611 0.162999 0.602054 0.684817]\n",
" [0.545608 0.415636 0.746867 0.923219 0.67769 ]\n",
" [0.809501 0.496377 0.527514 0.333276 0.479529]\n",
" [0.080732 0.63581 0.950788 0.387371 0.570476]\n",
" [0.677467 0.620451 0.702335 0.071747 0.067357]\n",
" [0.66082 0.372642 0.226082 0.687941 0.761832]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -210,11 +211,11 @@
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[ 0.439627 -0.765065 -0.323372 -1.518769 -0.563107]\n",
" [-0.255756 0.96025 1.512503 -0.662302 -1.201184]\n",
" [ 0.650412 0.263193 1.310423 1.383127 1.237785]\n",
" [-0.008076 0.028429 1.874965 0.977454 -0.068408]\n",
" [-1.75604 0.546302 0.359429 0.864159 1.347796]])"
"[[ 0.934699 -1.883731 1.56695 0.929079 0.73024 ]\n",
" [ 2.066303 0.109121 1.161415 -1.184726 -1.753147]\n",
" [ 0.339952 -1.125624 -0.740886 -0.808628 0.024874]\n",
" [-0.307566 1.072183 0.013086 0.407447 -0.705648]\n",
" [-0.956348 -0.291481 -0.1 0.70653 0.500862]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -605,16 +606,16 @@
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[-0.362476 0.425931 -0.604767 -0.585322 -0.469232]\n",
" [-0.207212 -0.106902 -0.349397 -0.082876 0.145293]\n",
" [ 0.274891 -0.136795 -0.210589 0.275138 0.481523]\n",
" [ 0.258619 0.307068 0.458597 0.236569 0.462849]\n",
" [ 0.029899 -0.559897 0.514509 0.32062 0.208706]\n",
" [-0.488839 -0.433476 0.089545 0.466521 -0.407354]\n",
" [-0.243207 0.266691 0.27616 0.263078 -0.267017]\n",
" [-0.61643 -0.143201 0.083898 0.366265 0.022065]\n",
" [ 0.149755 -0.155406 -0.494278 0.481983 -0.509169]\n",
" [ 0.159343 0.597055 -0.36376 0.376093 -0.399417]])"
"[[-0.221499 -0.06155 -0.077118 0.56846 -0.418471]\n",
" [-0.149945 -0.553442 0.581115 -0.460948 -0.420142]\n",
" [-0.355427 -0.066154 0.355814 0.082557 -0.556673]\n",
" [-0.497098 -0.087087 -0.051234 -0.238323 -0.290452]\n",
" [-0.280464 0.334714 0.116377 -0.481387 -0.388613]\n",
" [ 0.110984 0.625096 -0.228138 -0.500467 0.502594]\n",
" [-0.400704 0.197745 0.166157 -0.479262 0.577242]\n",
" [-0.196405 -0.577416 -0.605291 -0.294985 -0.606795]\n",
" [ 0.321031 -0.098246 0.278399 0.047973 0.295106]\n",
" [-0.385705 0.34554 -0.519177 0.389492 0.040751]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -646,16 +647,16 @@
"data": {
"text/plain": [
"minima.Tensor(\n",
"[[ 0.176985 0.475322 -0.472049 -0.197204 -0.124735]\n",
" [-0.021712 0.164783 -0.76619 -0.109438 0.25406 ]\n",
" [ 0.747072 0.355753 -0.609391 0.107684 -0.27346 ]\n",
" [ 0.418525 -0.260188 -0.422199 -0.179245 -0.042915]\n",
" [ 0.167997 -0.004962 0.17853 -0.520827 0.349572]\n",
" [-0.019444 -0.406027 0.332068 -0.591041 -0.408733]\n",
" [-0.535821 -0.790127 -0.098206 0.25483 0.509668]\n",
" [-0.401008 -0.051625 0.376056 0.1105 -1.083598]\n",
" [ 0.06975 0.191631 0.233941 0.101632 0.235047]\n",
" [-0.034112 -0.465086 0.053479 0.444051 -0.654049]])"
"[[-0.063341 -0.374345 -0.352547 0.789786 -0.782333]\n",
" [ 0.195202 -0.227442 0.252176 0.225021 0.105454]\n",
" [-0.174635 -0.043868 0.195862 -0.917395 0.502817]\n",
" [ 0.177671 0.373282 0.601478 0.593381 -0.1945 ]\n",
" [ 0.056075 0.224731 -0.458342 -0.133477 -0.138557]\n",
" [ 0.393741 0.096873 0.514728 0.195166 -0.260037]\n",
" [-0.161321 -0.056462 0.609632 -0.470343 0.118147]\n",
" [-0.359394 -0.297816 0.383475 0.310443 0.510362]\n",
" [-0.433323 -0.49009 -0.362796 0.263523 -0.023001]\n",
" [-0.156537 0.312429 -0.113006 -0.195525 0.197912]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -814,6 +815,161 @@
"2. It performs better with ReLU and its variants because it accounts for the fact that the variance of the output of a neuron with a ReLU activation function is half the variance of its input."
]
},
{
"cell_type": "markdown",
"id": "b9febe08-dcf9-4421-9a1f-4e38a193646b",
"metadata": {},
"source": [
"### LSUV Initialization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb0c6a3a-78e7-4b96-b32a-1c311bf9d70a",
"metadata": {},
"outputs": [],
"source": [
"class Hook():\n",
" def __init__(self, layer, fn): self.hook = layer.register_forward_hook(partial(fn, self))\n",
" # def remove(self): self.hook.remove()\n",
" # def __del__(self): self.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5f46d88-3ab2-4bde-bb2e-1026351fb79e",
"metadata": {},
"outputs": [],
"source": [
"def append_stats(hook, mod, inp, outp):\n",
" if not hasattr(hook,'stats'): hook.stats = ([],[])\n",
" acts = outp # TODO: move outp to cpu when USING ACCELERAOR!! :3\n",
" hook.stats[0].append(acts.numpy().mean())\n",
" hook.stats[1].append(acts.numpy().std())\n",
"\n",
"#| export\n",
"def _lsuv_stats(hook, mod, inp, outp):\n",
" acts = outp\n",
" hook.mean = acts.numpy().mean()\n",
" hook.std = acts.numpy().std()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb14b7e1-9c0a-48fe-957a-d6f297059c28",
"metadata": {},
"outputs": [],
"source": [
"class LSUV:\n",
"\n",
" def __init__(self, model, batch) -> None:\n",
" self.model = model\n",
" self.batch = batch\n",
" self.params_layers = [m for m in model if hasattr(m, 'weight') and not isinstance(m, mi.nn.BatchNorm1d)]\n",
" self.act_fns = [m for m in model if isinstance(m, mi.nn.ReLU)]\n",
" \n",
" # Constants\n",
" self.TOLERANCE = 1e-3\n",
" \n",
"\n",
" def lsuv_init(self):\n",
" \"\"\"\n",
" Layer-wise Sequential Unit Variance Initialization (LSUV).\n",
" A method to help neural nets converge faster.\n",
" \n",
" Args:\n",
" model : the model on which to perform LSUV initialization\n",
" param_module : the module with trainable parameters to which the Hook is to be registered\n",
" activation_module : the activation module to be initialized (ReLU, Sigmoid, etc.)\n",
" input_data : input data to be passed through the model\n",
" \"\"\"\n",
" for params_layer, acts_layer in zip(self.params_layers, self.act_fns):\n",
" hook = Hook(acts_layer, _lsuv_stats)\n",
" while self.model(self.batch) is not None and (abs(hook.std-1) > self.TOLERANCE or abs(hook.mean) > self.TOLERANCE):\n",
" print(f'---> before: {hook.mean} -- {hook.std}')\n",
" if params_layer.bias is not None: params_layer.bias -= mi.Tensor(hook.mean)\n",
" params_layer.weight.data /= mi.Tensor(hook.std)\n",
" print(f'-------------> after: {hook.mean} -- {hook.std}')\n",
" hook.remove()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "354013dd-f829-4860-8bd3-e409639fe8de",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import minima as mi\n",
"\n",
"# Number of samples\n",
"n_samples = 1000\n",
"\n",
"# Number of features (28x28 pixels for a grayscale image)\n",
"n_features = 784\n",
"\n",
"# Number of classes\n",
"n_classes = 10\n",
"\n",
"# Generate random inputs from a standard normal distribution\n",
"X = mi.init.randn(n_samples, n_features)\n",
"\n",
"# Generate random target classes\n",
"y = mi.Tensor(np.random.randint(0, n_classes, size=n_samples))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3f22936-017b-4ebc-977d-f969041f883b",
"metadata": {},
"outputs": [],
"source": [
"# Define the neural network architecture\n",
"model = mi.nn.Sequential(\n",
" mi.nn.Linear(784, 128),\n",
" mi.nn.ReLU(),\n",
" mi.nn.Linear(128, 10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3db16ba0-a9ac-4762-bea2-b2811a19da6a",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'Sequential' object is not iterable",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m lsuv \u001b[38;5;241m=\u001b[39m \u001b[43mLSUV\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[26], line 6\u001b[0m, in \u001b[0;36mLSUV.__init__\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m model\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch \u001b[38;5;241m=\u001b[39m batch\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams_layers \u001b[38;5;241m=\u001b[39m [m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m model \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mweight\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(m, mi\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mBatchNorm1d)]\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fns \u001b[38;5;241m=\u001b[39m [m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m model \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(m, mi\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mReLU)]\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Constants\u001b[39;00m\n",
"\u001b[0;31mTypeError\u001b[0m: 'Sequential' object is not iterable"
]
}
],
"source": [
"lsuv = LSUV(model, X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eaabe69c-6afc-4203-8dde-89edd339ba1f",
"metadata": {},
"outputs": [],
"source": [
"# lsuv.lsuv_init()"
]
},
{
"cell_type": "markdown",
"id": "1db18cd9-190f-4103-9866-6e7fe7f4863a",
Expand Down
Loading

0 comments on commit 43569ca

Please sign in to comment.