In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/tolluset/hh-ai-1/blob/w3/%08week3/3-basic.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# [3주차 기본과제] DistilBERT로 뉴스 기사 분류 모델 학습하기\n"
      ],
      "metadata": {
        "id": "sbgz49PvHhLt"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 89,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1LqgujQUbv6X",
        "outputId": "f8c15a7a-da76-4eb5-cecd-eca412ed759a",
        "collapsed": true
      },
      "outputs": [
        
      ],
      "source": [
        "!pip install tqdm boto3 requests regex sentencepiece sacremoses datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 90,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6lGiZUoPby6e",
        "outputId": "7c7c010f-3741-441d-b2e9-a67f3ae1e2ae"
      },
      "outputs": [

      ],
      "source": [
        "import torch\n",
        "from datasets import load_dataset\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'distilbert-base-uncased')"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## ✅  AG_News dataset 준비\n",
        "\n",
        "- ✅ Huggingface dataset의 `fancyzhx/ag_news`를 load합니다.\n",
        "- ✅ `collate_fn` 함수에 다음 수정사항들을 반영하면 됩니다.\n",
        "    - ✅ Truncation과 관련된 부분들을 지웁니다."
      ],
      "metadata": {
        "id": "r1oVPoix29Yj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "ds = load_dataset(\"fancyzhx/ag_news\")\n",
        "\n",
        "def collate_fn(batch):\n",
        "  max_len = 400\n",
        "  texts, labels = [], []\n",
        "  for row in batch:\n",
        "    labels.append(row['label'])\n",
        "    texts.append(row['text'])\n",
        "\n",
        "  encoding = tokenizer(texts, padding=True, return_tensors='pt')\n",
        "\n",
        "  texts, attention_mask = encoding['input_ids'], encoding['attention_mask']\n",
        "  labels = torch.tensor(labels)\n",
        "\n",
        "  return texts, attention_mask, labels\n",
        "\n",
        "\n",
        "train_loader = DataLoader(\n",
        "    ds['train'], batch_size=64, shuffle=True, collate_fn=collate_fn\n",
        ")\n",
        "test_loader = DataLoader(\n",
        "    ds['test'], batch_size=64, shuffle=False, collate_fn=collate_fn\n",
        ")"
      ],
      "metadata": {
        "id": "rE-y8sY9HuwP"
      },
      "execution_count": 91,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 92,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HJaUp2Vob0U-",
        "outputId": "84c7ec28-9be2-4447-92b4-d558a9cb5a35",
        "collapsed": true
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_main\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DistilBertModel(\n",
              "  (embeddings): Embeddings(\n",
              "    (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
              "    (position_embeddings): Embedding(512, 768)\n",
              "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "    (dropout): Dropout(p=0.1, inplace=False)\n",
              "  )\n",
              "  (transformer): Transformer(\n",
              "    (layer): ModuleList(\n",
              "      (0-5): 6 x TransformerBlock(\n",
              "        (attention): MultiHeadSelfAttention(\n",
              "          (dropout): Dropout(p=0.1, inplace=False)\n",
              "          (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
              "          (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
              "          (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
              "          (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
              "        )\n",
              "        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "        (ffn): FFN(\n",
              "          (dropout): Dropout(p=0.1, inplace=False)\n",
              "          (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
              "          (activation): GELUActivation()\n",
              "        )\n",
              "        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      )\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 92
        }
      ],
      "source": [
        "model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'distilbert-base-uncased')\n",
        "model"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# ✅ Classifier output, loss function, accuracy function 변경\n",
        "- ✅  뉴스 기사 분류 문제는 binary classification이 아닌 일반적인 classification 문제입니다. MNIST 과제에서 했던 것 처럼 `nn.CrossEntropyLoss` 를 추가하고 `TextClassifier`의 출력 차원을 잘 조정하여 task를 풀 수 있도록 수정하시면 됩니다.\n",
        "- 그리고 정확도를 재는 `accuracy` 함수도 classification에 맞춰 수정하시면 됩니다."
      ],
      "metadata": {
        "id": "UIcTzq726b14"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 93,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xW7ETZQzzNp2",
        "outputId": "9a2b1fca-77ab-4f53-e089-6866fe507478"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_main\n"
          ]
        }
      ],
      "source": [
        "from torch import nn\n",
        "\n",
        "\n",
        "class TextClassifier(nn.Module):\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "\n",
        "    self.encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'distilbert-base-uncased')\n",
        "    self.classifier = nn.Linear(768, 4)\n",
        "\n",
        "  def forward(self, input_ids, attention_mask):\n",
        "    x = self.encoder(input_ids, attention_mask)['last_hidden_state']\n",
        "    x = self.classifier(x[:, 0])\n",
        "\n",
        "    return x\n",
        "\n",
        "\n",
        "model = TextClassifier()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 94,
      "metadata": {
        "id": "uyTciaPZ0KYo"
      },
      "outputs": [],
      "source": [
        "for param in model.encoder.parameters():\n",
        "  param.requires_grad = False"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# ✅ 학습 결과 report\n",
        "- ✅  DistilBERT 실습과 같이 매 epoch 마다의 train loss를 출력하고 최종 모델의 test accuracy를 report합니다."
      ],
      "metadata": {
        "id": "TPH7Fa9n_p6M"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 95,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XvvaAEwCznt-",
        "outputId": "19ace1f8-69ad-4648-da39-50ae68986ed3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch   0 | Train Loss: 73.79941156506538\n",
            "Epoch   1 | Train Loss: 44.467991292476654\n",
            "Epoch   2 | Train Loss: 37.692022517323494\n",
            "Epoch   3 | Train Loss: 36.21875162422657\n",
            "Epoch   4 | Train Loss: 35.33278389275074\n",
            "Epoch   5 | Train Loss: 33.67221947014332\n",
            "Epoch   6 | Train Loss: 33.545470505952835\n",
            "Epoch   7 | Train Loss: 34.18277934193611\n",
            "Epoch   8 | Train Loss: 31.80599258840084\n",
            "Epoch   9 | Train Loss: 32.547775372862816\n"
          ]
        }
      ],
      "source": [
        "from torch.optim import Adam\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "lr = 0.001\n",
        "model = model.to('cuda')\n",
        "loss_fn = nn.CrossEntropyLoss()\n",
        "\n",
        "optimizer = Adam(model.parameters(), lr=lr)\n",
        "n_epochs = 10\n",
        "\n",
        "for epoch in range(n_epochs):\n",
        "  total_loss = 0.\n",
        "  model.train()\n",
        "  for data in itertools.islice(train_loader, 100):\n",
        "    model.zero_grad()\n",
        "    inputs, attention_mask, labels = data\n",
        "    inputs, attention_mask, labels = inputs.to('cuda'), attention_mask.to('cuda'), labels.to('cuda')\n",
        "\n",
        "    preds = model(inputs, attention_mask)\n",
        "    loss = loss_fn(preds, labels)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    total_loss += loss.item()\n",
        "\n",
        "  print(f\"Epoch {epoch:3d} | Train Loss: {total_loss}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 96,
      "metadata": {
        "id": "DjphVwXL00E2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3d9bff1f-7b26-4796-80af-26d8dfaec442"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "=========> Train acc: 0.897 | Test acc: 0.890\n"
          ]
        }
      ],
      "source": [
        "import itertools\n",
        "\n",
        "def accuracy(model, dataloader):\n",
        "  cnt = 0\n",
        "  acc = 0\n",
        "\n",
        "  for data in itertools.islice(dataloader, 100):\n",
        "    inputs, attention_mask, labels = data\n",
        "    inputs, attention_mask, labels = inputs.to('cuda'), attention_mask.to('cuda'), labels.to('cuda')\n",
        "\n",
        "    preds = model(inputs, attention_mask)\n",
        "    preds = torch.argmax(preds, dim=-1)\n",
        "    # preds = (preds > 0).long()[..., 0]\n",
        "\n",
        "    cnt += labels.shape[0]\n",
        "    acc += (labels == preds).sum().item()\n",
        "\n",
        "  return acc / cnt\n",
        "\n",
        "\n",
        "with torch.no_grad():\n",
        "  model.eval()\n",
        "  train_acc = accuracy(model, train_loader)\n",
        "  test_acc = accuracy(model, test_loader)\n",
        "  print(f\"=========> Train acc: {train_acc:.3f} | Test acc: {test_acc:.3f}\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}