<a href="https://colab.research.google.com/github/hanghae-plus-AI/AI-1-ssungz789/blob/main/w5/Text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 실습: Zero-shot Classification

이번 실습에서는 open LLM을 가지고 zero-shot classification을 해봅니다. 먼저 필요한 library들을 설치합시다.

In [1]:
!pip -q install datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/471.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m297.0/471.6 kB[0m [31m10.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/194.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━

그 다음 Gemma-2B를 사용하기 위해 다음과 같은 작업을 진행합니다:
1. huggingface.co 계정 만들고 로그인하기
2. https://www.kaggle.com/models/google/gemma/license/consent 에서 Gemma license 동의하기
3. 홈 화면으로 돌아와, `Profile > Settings > Access Tokens` 메뉴로 들어와 "Write" type의 token 생성하기
4. 생성한 토큰을 아래 "HF TOKEN"에 불여넣고 셀을 실행하기.

In [2]:
from huggingface_hub import login
from google.colab import userdata


login(userdata.get('HF_TOKEN'))

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


정상적으로 token을 생성하고 Gemma license에 동의했다면 아래 코드로 tokenizer와 Gemma-2B 모델을 불러올 수 있습니다.

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [4]:
import torch

def zero_shot_classification(text, task_description, labels):  # text는 주어진 입력, task_description은 task에 대한 설명, labels은 class들을 text로 변환한 결과입니다.
    text_ids = tokenizer(task_description + text, return_tensors="pt").to("cuda")  # 먼저 task_description과 text를 이어붙인 후, tokenize합니다.
    probs = []
    for label in labels:  # 그 다음 각 text화된 label들을 tokenize하고 입력에 이어붙인 후, Gemma-2B에 넣어줍니다.
        label_ids = tokenizer(label, return_tensors="pt").to("cuda")
        n_label_tokens = label_ids['input_ids'].shape[-1] - 1  # text로 변환한 label의 token 수를 계산합니다.
        input_ids = {
            'input_ids': torch.concatenate([text_ids['input_ids'], label_ids['input_ids'][:, 1:]], axis=-1),  # concatenate 명령어를 통해 이어붙이는 모습입니다.
            'attention_mask': torch.concatenate([text_ids['attention_mask'], label_ids['attention_mask'][:, 1:]], axis=-1)
        }

        logits = model(**input_ids).logits  # Logit을 계산한 모습입니다.
        prob = 0
        n_total = input_ids['input_ids'].shape[-1]
        for i in range(n_label_tokens, 0, -1):  # 일반적으로 text로 변환한 label은 여러 token으로 이루어져있습니다. 이러한 label에 대한 logit은 구성하는 모든 token들의 logit들의 합으로 정의합니다.
            token = label_ids['input_ids'][0, i].item()
            prob += logits[0, n_total - i, token].item()
        probs.append(prob)

        del input_ids
        del logits
        torch.cuda.empty_cache()  # 위의 del과 empty_cache() 명령어를 통해 GPU를 제때 할당해제 해줍니다. 만약 GPU가 여유롭다면 지워주시는게 속도적으로 이득입니다.

    return probs

아래는 실제로 zero-shot classification을 해본 결과입니다.

In [5]:
probs = zero_shot_classification("I am happy!", "Is the sentence positive or negative?: ", ["positive", "negative"])
print(probs)

[-4.5151824951171875, -9.59005069732666]


보시다시피 우리는 Gemma를 별도로 학습하지 않았음에도 불구하고 주어진 문장이 긍정적이라는 것을 정확하게 예측하고 있습니다.

다음은 영화 리뷰 감정 분석 task에 적용해봅시다.
먼저 data를 불러옵니다.

In [41]:
from datasets import load_dataset
import random


#imdb = load_dataset("fancyzhx/ag_news", download_mode="force_redownload", split="test[:50]") ## ag_news dataset load

imdb = load_dataset("fancyzhx/ag_news")
shuffle_test = imdb["test"].shuffle(seed=42)
imdb = shuffle_test.select(range(50))  # 궁금해서 셔플을 추가했습니다.

#print(imdb)

#print("Dataset splits:", imdb.keys())
label_names = dataset['train'].features['label'].names
print("Label names:", label_names)
#label_names = imdb['train'].features["label"].names
#label_names = ["World", "Sports", "Business", "Science/Technology"]

def preprocess_function(examples):
    return tokenizer(examples["text"], max_length=200, truncation=True)

tokenized_imdb = imdb.map(preprocess_function, batched=True)

Label names: ['World', 'Sports', 'Business', 'Sci/Tech']


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

그리고 `test` data에서 50개의 영화 리뷰에 대해 예측하는 코드는 다음과 같습니다.

In [44]:
import numpy as np
from tqdm import tqdm


n_corrects = 0
#task_description = f"Classify the following news article into one of these categories: {', '.join(label_names)}. "
label_descriptions = {
    "World": "News about international events, politics, and global affairs",
    "Sports": "News about various sports, athletes, and sporting events",
    "Business": "News about economics, companies, markets, and financial affairs",
    "Science/Technology": "News about scientific discoveries, technological advancements, and innovations"
}

task_description = f"""Classify the following news article into one of these categories:
{' | '.join([f"{k}: {v}" for k, v in label_descriptions.items()])}
Article: """

for i in tqdm(range(50)):
    text = imdb[i]['text']
    label = imdb[i]['label']
    probs = zero_shot_classification(
        text,
        task_description,
        label_names
        #labels = [f"Answer: {label_names[0]}", f"Answer: {label_names[1]}", f"Answer: {label_names[2]}", f"Answer: {label_names[3]}"]
    )

    pred = np.argmax(np.array(probs))
    if pred == label:
        n_corrects += 1

    print(f"Example {i+1}:")
    print(f"True label: {list(label_descriptions.keys())[label]}")
    print(f"Predicted label: {label_names[pred]}")
    print(f"Correct: {pred == label}\n")

print(n_corrects)
print(f"Accuracy: {n_corrects / 50:.2f}")

  2%|▏         | 1/50 [00:00<00:45,  1.08it/s]

Example 1:
True label: Sports
Predicted label: Sports
Correct: True



  4%|▍         | 2/50 [00:01<00:41,  1.17it/s]

Example 2:
True label: Business
Predicted label: Sports
Correct: False



  6%|▌         | 3/50 [00:02<00:40,  1.16it/s]

Example 3:
True label: Sports
Predicted label: World
Correct: False



  8%|▊         | 4/50 [00:03<00:38,  1.19it/s]

Example 4:
True label: Business
Predicted label: Sports
Correct: False



 10%|█         | 5/50 [00:04<00:43,  1.03it/s]

Example 5:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 12%|█▏        | 6/50 [00:05<00:42,  1.04it/s]

Example 6:
True label: World
Predicted label: Sci/Tech
Correct: False



 14%|█▍        | 7/50 [00:06<00:44,  1.03s/it]

Example 7:
True label: Business
Predicted label: World
Correct: False



 16%|█▌        | 8/50 [00:07<00:45,  1.07s/it]

Example 8:
True label: Business
Predicted label: Sci/Tech
Correct: False



 18%|█▊        | 9/50 [00:09<00:44,  1.09s/it]

Example 9:
True label: Science/Technology
Predicted label: Sports
Correct: False



 20%|██        | 10/50 [00:10<00:44,  1.11s/it]

Example 10:
True label: Business
Predicted label: Sports
Correct: False



 22%|██▏       | 11/50 [00:11<00:40,  1.03s/it]

Example 11:
True label: World
Predicted label: Sports
Correct: False



 24%|██▍       | 12/50 [00:12<00:40,  1.07s/it]

Example 12:
True label: Sports
Predicted label: Sci/Tech
Correct: False



 26%|██▌       | 13/50 [00:13<00:40,  1.10s/it]

Example 13:
True label: Sports
Predicted label: Sports
Correct: True



 28%|██▊       | 14/50 [00:14<00:40,  1.11s/it]

Example 14:
True label: World
Predicted label: Sci/Tech
Correct: False



 30%|███       | 15/50 [00:15<00:36,  1.04s/it]

Example 15:
True label: Business
Predicted label: Sci/Tech
Correct: False



 32%|███▏      | 16/50 [00:16<00:37,  1.09s/it]

Example 16:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 34%|███▍      | 17/50 [00:17<00:34,  1.04s/it]

Example 17:
True label: World
Predicted label: World
Correct: True



 36%|███▌      | 18/50 [00:18<00:31,  1.01it/s]

Example 18:
True label: Sports
Predicted label: Sci/Tech
Correct: False



 38%|███▊      | 19/50 [00:19<00:29,  1.04it/s]

Example 19:
True label: World
Predicted label: Sci/Tech
Correct: False



 40%|████      | 20/50 [00:20<00:30,  1.03s/it]

Example 20:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 42%|████▏     | 21/50 [00:21<00:31,  1.08s/it]

Example 21:
True label: World
Predicted label: Sports
Correct: False



 44%|████▍     | 22/50 [00:22<00:28,  1.01s/it]

Example 22:
True label: Science/Technology
Predicted label: World
Correct: False



 46%|████▌     | 23/50 [00:23<00:26,  1.02it/s]

Example 23:
True label: Sports
Predicted label: Sports
Correct: True



 48%|████▊     | 24/50 [00:24<00:27,  1.05s/it]

Example 24:
True label: Sports
Predicted label: World
Correct: False



 50%|█████     | 25/50 [00:25<00:25,  1.01s/it]

Example 25:
True label: World
Predicted label: Sports
Correct: False



 52%|█████▏    | 26/50 [00:26<00:23,  1.02it/s]

Example 26:
True label: Business
Predicted label: World
Correct: False



 54%|█████▍    | 27/50 [00:27<00:24,  1.04s/it]

Example 27:
True label: Business
Predicted label: Sci/Tech
Correct: False



 56%|█████▌    | 28/50 [00:28<00:21,  1.00it/s]

Example 28:
True label: Business
Predicted label: World
Correct: False



 58%|█████▊    | 29/50 [00:29<00:22,  1.06s/it]

Example 29:
True label: Business
Predicted label: Sci/Tech
Correct: False



 60%|██████    | 30/50 [00:30<00:20,  1.04s/it]

Example 30:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 62%|██████▏   | 31/50 [00:31<00:19,  1.00s/it]

Example 31:
True label: Business
Predicted label: Sports
Correct: False



 64%|██████▍   | 32/50 [00:32<00:17,  1.05it/s]

Example 32:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 66%|██████▌   | 33/50 [00:33<00:15,  1.08it/s]

Example 33:
True label: Business
Predicted label: Sci/Tech
Correct: False



 68%|██████▊   | 34/50 [00:34<00:16,  1.00s/it]

Example 34:
True label: Business
Predicted label: Sports
Correct: False



 70%|███████   | 35/50 [00:35<00:14,  1.02it/s]

Example 35:
True label: Science/Technology
Predicted label: Sports
Correct: False



 72%|███████▏  | 36/50 [00:36<00:13,  1.04it/s]

Example 36:
True label: World
Predicted label: Sports
Correct: False



 74%|███████▍  | 37/50 [00:37<00:12,  1.06it/s]

Example 37:
True label: World
Predicted label: Sci/Tech
Correct: False



 76%|███████▌  | 38/50 [00:38<00:11,  1.08it/s]

Example 38:
True label: Business
Predicted label: Sci/Tech
Correct: False



 78%|███████▊  | 39/50 [00:39<00:10,  1.08it/s]

Example 39:
True label: World
Predicted label: Sports
Correct: False



 80%|████████  | 40/50 [00:40<00:10,  1.03s/it]

Example 40:
True label: Sports
Predicted label: Sci/Tech
Correct: False



 82%|████████▏ | 41/50 [00:41<00:08,  1.01it/s]

Example 41:
True label: Sports
Predicted label: Sports
Correct: True



 84%|████████▍ | 42/50 [00:42<00:07,  1.03it/s]

Example 42:
True label: World
Predicted label: Sports
Correct: False



 86%|████████▌ | 43/50 [00:43<00:06,  1.04it/s]

Example 43:
True label: World
Predicted label: Sci/Tech
Correct: False



 88%|████████▊ | 44/50 [00:44<00:06,  1.07s/it]

Example 44:
True label: Science/Technology
Predicted label: Sci/Tech
Correct: True



 90%|█████████ | 45/50 [00:45<00:05,  1.02s/it]

Example 45:
True label: World
Predicted label: Sci/Tech
Correct: False



 92%|█████████▏| 46/50 [00:46<00:04,  1.10s/it]

Example 46:
True label: Sports
Predicted label: Sports
Correct: True



 94%|█████████▍| 47/50 [00:47<00:03,  1.05s/it]

Example 47:
True label: Business
Predicted label: World
Correct: False



 96%|█████████▌| 48/50 [00:48<00:02,  1.12s/it]

Example 48:
True label: Sports
Predicted label: World
Correct: False



 98%|█████████▊| 49/50 [00:50<00:01,  1.16s/it]

Example 49:
True label: Business
Predicted label: Sports
Correct: False



100%|██████████| 50/50 [00:51<00:00,  1.02s/it]

Example 50:
True label: Sports
Predicted label: Sports
Correct: True

13
Accuracy: 0.26



