From b9e3e58471f2bdde80709230bd7340a1fc5014ca Mon Sep 17 00:00:00 2001 From: Mario Sessa <76259752+kode-git@users.noreply.github.com> Date: Mon, 9 May 2022 16:06:20 +0200 Subject: [PATCH] Debug closure function for SAM on loading mode in train_model --- ViT_Face_Emotion_Recognition.ipynb | 186 +++++++++++------------------ 1 file changed, 68 insertions(+), 118 deletions(-) diff --git a/ViT_Face_Emotion_Recognition.ipynb b/ViT_Face_Emotion_Recognition.ipynb index 4897383..b181fcc 100644 --- a/ViT_Face_Emotion_Recognition.ipynb +++ b/ViT_Face_Emotion_Recognition.ipynb @@ -72,7 +72,7 @@ "base_uri": "https://localhost:8080/" }, "id": "TBLS3rKGcIUm", - "outputId": "2aa30f9d-9da2-4666-d6b9-b0e6155df75d" + "outputId": "6d4ba545-2fdd-4aee-9b42-40b10486e043" }, "outputs": [ { @@ -81,9 +81,9 @@ "text": [ "Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (7.1.2)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (1.3.5)\n", - "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2022.1)\n", - "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (1.21.6)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2022.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)\n" ] } @@ -140,13 +140,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0WafVw77sp2v", - "outputId": "b4999e84-35f9-4f94-ec0a-edc7108088f0" + "outputId": "4ca677a2-57c2-462b-86bf-c8dd674fd681" }, "outputs": [ { @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "id": "mMiGrpYXLFNP" }, @@ -261,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "id": "NfjxfLlujLZi" }, @@ -3811,13 +3811,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BItVISfStbBL", - "outputId": "bd87bb1a-7c63-4560-e390-5c6219c93aa5" + "outputId": "60863feb-4e78-4642-bce6-2279f1e2f1fc" }, "outputs": [ { @@ -3826,17 +3826,17 @@ "text": [ "Collecting timm\n", " Downloading timm-0.5.4-py3-none-any.whl (431 kB)\n", - "\u001b[?25l\r\u001b[K |▊ | 10 kB 34.6 MB/s eta 0:00:01\r\u001b[K |█▌ | 20 kB 42.8 MB/s eta 0:00:01\r\u001b[K |██▎ | 30 kB 49.4 MB/s eta 0:00:01\r\u001b[K |███ | 40 kB 34.8 MB/s eta 0:00:01\r\u001b[K |███▉ | 51 kB 38.4 MB/s eta 0:00:01\r\u001b[K |████▋ | 61 kB 43.6 MB/s eta 0:00:01\r\u001b[K |█████▎ | 71 kB 29.8 MB/s eta 0:00:01\r\u001b[K |██████ | 81 kB 31.5 MB/s eta 0:00:01\r\u001b[K |██████▉ | 92 kB 34.2 MB/s eta 0:00:01\r\u001b[K |███████▋ | 102 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████▍ | 112 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████▏ | 122 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████▉ | 133 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████▋ | 143 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████▍ | 153 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 163 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████ | 174 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████▊ | 184 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████▍ | 194 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 204 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████ | 215 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 225 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████████▌ | 235 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 245 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████████ | 256 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 266 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 276 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 286 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 296 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████████████▉ | 307 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████████████▌ | 317 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 327 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 337 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 348 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▋ | 358 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 368 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 378 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▉ | 389 kB 36.6 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 399 kB 36.6 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 409 kB 36.6 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▏| 419 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 430 kB 36.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 431 kB 36.6 MB/s \n", - "\u001b[?25hRequirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm) (0.12.0+cu113)\n", - "Requirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from timm) (1.11.0+cu113)\n", + "\u001b[?25l\r\u001b[K |▊ | 10 kB 33.4 MB/s eta 0:00:01\r\u001b[K |█▌ | 20 kB 20.5 MB/s eta 0:00:01\r\u001b[K |██▎ | 30 kB 10.7 MB/s eta 0:00:01\r\u001b[K |███ | 40 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███▉ | 51 kB 4.2 MB/s eta 0:00:01\r\u001b[K |████▋ | 61 kB 5.0 MB/s eta 0:00:01\r\u001b[K |█████▎ | 71 kB 5.5 MB/s eta 0:00:01\r\u001b[K |██████ | 81 kB 5.4 MB/s eta 0:00:01\r\u001b[K |██████▉ | 92 kB 6.1 MB/s eta 0:00:01\r\u001b[K |███████▋ | 102 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████▍ | 112 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████▏ | 122 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████▉ | 133 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████▋ | 143 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████▍ | 153 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████▏ | 163 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████ | 174 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████▊ | 184 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████▍ | 194 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 204 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████ | 215 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 225 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████████▌ | 235 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 245 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████████ | 256 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 266 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 276 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 286 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 296 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████████████▉ | 307 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████████████▌ | 317 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 327 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 337 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 348 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▋ | 358 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 368 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 378 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▉ | 389 kB 5.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 399 kB 5.1 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 409 kB 5.1 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▏| 419 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 430 kB 5.1 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 431 kB 5.1 MB/s \n", + "\u001b[?25hRequirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from timm) (1.11.0+cu113)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm) (0.12.0+cu113)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4->timm) (4.2.0)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (7.1.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (2.23.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (1.21.6)\n", - "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (7.1.2)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2021.10.8)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2.10)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (3.0.4)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (1.24.3)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2.10)\n", "Installing collected packages: timm\n", "Successfully installed timm-0.5.4\n" ] @@ -3848,7 +3848,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "id": "rZYiTth-y7yy" }, @@ -3865,7 +3865,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "id": "yk-iKrzguAQR" }, @@ -3878,13 +3878,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MW0qScRotgal", - "outputId": "0f23c317-7cc6-4f14-909e-7e8549feb43c" + "outputId": "201c02c2-8896-4aad-ccdb-77d4e54e828d" }, "outputs": [ { @@ -3932,13 +3932,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AkRMzh5uyj6n", - "outputId": "9183e034-941e-4da0-b3c6-631dcda7aff4" + "outputId": "35ee222f-e70e-4e6f-902d-4e43d6f59a6d" }, "outputs": [ { @@ -3966,7 +3966,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": { "id": "OY4wmSSvyR2r" }, @@ -3978,13 +3978,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kEh7xW8xycSX", - "outputId": "f5225b8e-f565-41a4-871b-4c87c259d397" + "outputId": "f54c349a-7465-4701-d0c8-fcbf63b8e9dc" }, "outputs": [ { @@ -4222,7 +4222,7 @@ ] }, "metadata": {}, - "execution_count": 16 + "execution_count": 13 } ], "source": [ @@ -4231,13 +4231,13 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "co_WIrPEyeGx", - "outputId": "766c17f0-7212-433a-986f-889786f67537" + "outputId": "0452f6c8-5738-4059-f1bb-ee85cee39d13" }, "outputs": [ { @@ -4475,7 +4475,7 @@ ] }, "metadata": {}, - "execution_count": 17 + "execution_count": 14 } ], "source": [ @@ -4485,7 +4485,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 15, "metadata": { "id": "MfPgq1CyyuRW" }, @@ -4513,6 +4513,7 @@ " is_loaded = False, load_state_ws=None, history_file_acc=\"history_accuracy\",\n", " history_file_loss=\"history_loss\", n_partial=0, model_folder=\"\", ):\n", " \n", + " \n", " history = {'val' : [], 'train' : []}\n", " loss_history = {'val' : [], 'train' : []}\n", "\n", @@ -4569,6 +4570,12 @@ " loss = criterion(outputs, labels)\n", "\n", " _, preds = torch.max(outputs, 1)\n", + " def closure():\n", + " outputs = model(inputs)\n", + " _, preds = torch.max(outputs, 1)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " return loss\n", "\n", " # backward + optimize only if in training phase\n", " if phase == 'train':\n", @@ -4576,14 +4583,7 @@ " if type(optimizer) != SAM:\n", " optimizer.step()\n", " else:\n", - " def closure():\n", - " outputs = model(inputs)\n", - " _, preds = torch.max(outputs, 1)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " return loss\n", - " \n", - " optimizer.step(closure)\n", + " optimizer.step(closure)\n", "\n", " \n", "\n", @@ -4644,9 +4644,9 @@ "base_uri": "https://localhost:8080/" }, "id": "BfLdUVTVLFav", - "outputId": "b0f6911e-8227-427d-8551-e69004c3052b" + "outputId": "7f289f9f-34bb-421b-afd8-83648fbcf662" }, - "execution_count": 19, + "execution_count": 16, "outputs": [ { "output_type": "stream", @@ -4657,7 +4657,7 @@ "remote: Counting objects: 100% (75/75), done.\u001b[K\n", "remote: Compressing objects: 100% (22/22), done.\u001b[K\n", "remote: Total 179 (delta 62), reused 53 (delta 53), pack-reused 104\u001b[K\n", - "Receiving objects: 100% (179/179), 650.16 KiB | 25.01 MiB/s, done.\n", + "Receiving objects: 100% (179/179), 650.16 KiB | 12.75 MiB/s, done.\n", "Resolving deltas: 100% (84/84), done.\n" ] } @@ -4671,18 +4671,18 @@ "metadata": { "id": "DmRq96ZqLHhy" }, - "execution_count": 20, + "execution_count": 17, "outputs": [] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ug7mEpCZzD6L", - "outputId": "11d59bdb-f130-4012-d05f-369330f7aa90" + "outputId": "f003d1fc-9632-4eed-95fe-12404ed72a5e" }, "outputs": [ { @@ -4887,9 +4887,9 @@ "base_uri": "https://localhost:8080/" }, "id": "TxoRhdLgLgjd", - "outputId": "4edbec2d-465d-4662-ef03-b474703fbd49" + "outputId": "6625fd69-bd68-442b-f893-4919570cde8a" }, - "execution_count": 22, + "execution_count": 19, "outputs": [ { "output_type": "execute_result", @@ -4909,13 +4909,13 @@ ] }, "metadata": {}, - "execution_count": 22 + "execution_count": 19 } ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 20, "metadata": { "id": "76ScGC4bVS1m" }, @@ -4958,73 +4958,9 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "vSN5a2TqzHMc", - "outputId": "3180803b-9aaf-43d6-b5b6-f8859d1ccb55" + "id": "vSN5a2TqzHMc" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Starting Training\n", - "------------\n", - "Epoch 1/10\n", - "------------\n", - "1/10 - train step : 160020/160020 - train_accuracy : 0.546881 - train_loss : 1.190360\n", - "1/10 - val step : 5460/5460 - val_accuracy : 0.495588 - val_loss : 1.354128\n", - "\n", - "Epoch 1 complete in. 220m 6s with best local accuracy and with a learning rate of 0.01\n", - "------------\n", - "Epoch 2/10\n", - "------------\n", - "2/10 - train step : 160020/160020 - train_accuracy : 0.635938 - train_loss : 0.967213\n", - "2/10 - val step : 5460/5460 - val_accuracy : 0.520588 - val_loss : 1.271463\n", - "\n", - "Epoch 2 complete in. 187m 8s with best local accuracy and with a learning rate of 0.01\n", - "------------\n", - "Epoch 3/10\n", - "------------\n", - "3/10 - train step : 160020/160020 - train_accuracy : 0.669750 - train_loss : 0.876207\n", - "3/10 - val step : 5460/5460 - val_accuracy : 0.541544 - val_loss : 1.234829\n", - "\n", - "Epoch 3 complete in. 187m 20s with best local accuracy and with a learning rate of 0.01\n", - "------------\n", - "Epoch 4/10\n", - "------------\n", - "4/10 - train step : 160020/160020 - train_accuracy : 0.700075 - train_loss : 0.802861\n", - "4/10 - val step : 5460/5460 - val_accuracy : 0.550919 - val_loss : 1.214435\n", - "\n", - "Epoch 4 complete in. 187m 0s with best local accuracy and with a learning rate of 0.01\n", - "------------\n", - "Epoch 5/10\n", - "------------\n", - "5/10 - train step : 160020/160020 - train_accuracy : 0.726281 - train_loss : 0.735707\n", - "5/10 - val step : 5460/5460 - val_accuracy : 0.531985 - val_loss : 1.331294\n", - "\n", - "Epoch 5 complete in. 187m 8s and with a learning rate of 0.01\n", - "------------\n", - "Epoch 6/10\n", - "------------\n", - "6/10 - train step : 54300/160020 - train_accuracy : 0.747716 - train_loss : 0.680886" - ] - }, - { - "output_type": "error", - "ename": "KeyboardInterrupt", - "evalue": "ignored", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Train and evaluate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m model, train_hist, val_hist = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, \n\u001b[0;32m----> 3\u001b[0;31m is_inception=False)\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;31m#Saving the updated model for the inference phase\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, dataloaders, criterion, optimizer, lr_scheduler, num_epochs, is_inception, is_loaded, load_state_ws, history_file_acc, history_file_loss, n_partial, model_folder)\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;31m# statistics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m \u001b[0mrunning_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0mrunning_corrects\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mepoch_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrunning_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mtotalIm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "# Train and evaluate\n", "model, train_hist, val_hist = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, \n", @@ -5045,7 +4981,7 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "33dc1a7f-0efa-40a7-9cf4-dbf5d5d8dfc6" + "outputId": "5fdc518c-8bbc-46b1-de35-9caff64c835e" }, "outputs": [ { @@ -5056,7 +4992,8 @@ "Starting Training\n", "------------\n", "Epoch 1/5\n", - "------------\n" + "------------\n", + "1/5 - train step : 37560/160020 - train_accuracy : 0.734878 - train_loss : 0.710255" ] } ], @@ -5073,14 +5010,14 @@ "val_history = model_folder + name_model + \"_\" + \"history_val\"\n", "\n", "# changing starting lr\n", - "lr_in = 0.01\n", + "lr_in = 0.001\n", "optimizer_ft = optim.SGD(model.parameters(), lr=lr_in, momentum=momentum_in)\n", "scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)\n", "\n", "# Train and evaluate\n", "model, train_hist, val_hist = train_model(model, dataloaders_dict, criterion, optimizer_ft,scheduler, num_epochs=num_epochs, \n", " is_inception=False, is_loaded=True, model_folder= model_folder,\n", - " load_state_ws=\"/content/drive/MyDrive/Models/vfer_sam_5/vfer_sam_5\" )\n", + " load_state_ws=\"/content/drive/MyDrive/Models/vfer_sam_5/vfer_sam_10_4\" )\n", "\n", "\n", "#Saving the updated model for the inference phase\n", @@ -5417,7 +5354,20 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [ - "A4NdJRd3L0dY" + "A4NdJRd3L0dY", + "-s6FVgkOL4BI", + "h_g6Johu3PIz", + "vAgvaOcbJeBj", + "ClgI0sjPvH5j", + "fp75RuDu_GVY", + "uMLuH1Ng4GxX", + "fiKUpawZS342", + "LPMuS0gLW0jc", + "KsT90RRpoB3Y", + "MlUAtTmmoI-M", + "bD8RQf6DLk8T", + "Yrza5UhRTQvn", + "nin-hhWX6wX_" ], "machine_shape": "hm", "name": "ViT_Face_Emotion_Recognition.ipynb",