<a href="https://colab.research.google.com/github/imstaHub/hanghae99/blob/master/week3_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Week3 Basic Homework

In [None]:
!pip install tqdm boto3 requests regex sentencepiece sacremoses datasets

In [None]:
import torch

#device setting
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.backends.cuda.is_built():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

In [None]:
def accuracy(model, dataloader, device):
  cnt = 0
  acc = 0

  for data in dataloader:
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    preds = model(inputs)
    preds = torch.argmax(preds, dim=-1)
    #preds = (preds > 0).long()[..., 0]

    cnt += labels.shape[0]
    acc += (labels == preds).sum().item()

  return acc / cnt

## [MY CODE] tokenizer, dataset, model 로드
- tokenizer: distilbert
- dataset: fancyzhx/ag_news
- model: distilbert

In [None]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'distilbert-base-uncased')

In [None]:
ds = load_dataset("fancyzhx/ag_news")


def collate_fn(batch):
  texts, labels = [], []
  for row in batch:
    labels.append(row['label'])
    texts.append(row['text'])

  # truncation 옵션 제거
  # -> 학습 중 warning과 학습이 잘 되지 않아서 positional encoding layer의 input에 맞춤. 다행히 토큰 최대 길이가 512를 넘지 않았음
  texts = torch.LongTensor(tokenizer(texts, padding=True, truncation=True, max_length=512).input_ids)
  labels = torch.LongTensor(labels)

  return texts, labels


train_loader = DataLoader(
    ds['train'], batch_size=64, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    ds['test'], batch_size=64, shuffle=False, collate_fn=collate_fn
)

In [None]:
# load model and check params
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'distilbert-base-uncased')
model

## [MY CODE] dataset 확인

In [None]:
# news 분류
ds['train']

In [None]:
# news 내용
ds['train']['text'][:10]

In [None]:
# news 유형
for i in list(dict.fromkeys(ds['train']['label'])):
    cnt =ds['train']['label'].count(i)
    print(f'{i} : {cnt}')

In [None]:
# 최대 토큰 길이 확인
tmp_list = []
for data in train_loader:
  inputs, labels = data
  for i in range(len(inputs)):
    tmp_list.append(list(inputs[i]!=0).count(True))

print(max(tmp_list))

## [MY CODE] model 정의, 다중분류

In [None]:
from torch import nn
from torch.optim import Adam
import numpy as np

class TextClassifier(nn.Module):
  def __init__(self, output_dims):
    super().__init__()

    self.encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'distilbert-base-uncased')
    self.classifier = nn.Linear(768, output_dims) # 마지막에 뉴스 유형의 수만큼 output으로 return

  def forward(self, x):
    x = self.encoder(x)['last_hidden_state']
    x = self.classifier(x[:, 0])

    return x

output_dims = len(list(dict.fromkeys(ds['train']['label'])))
model = TextClassifier(output_dims=output_dims)

# encoder(last hidden state까지) freeze
for param in model.encoder.parameters():
  param.requires_grad = False


model = model.to(device)
# loss function 변경, 다중분류
loss_fn = nn.CrossEntropyLoss()

lr = 0.001
optimizer = Adam(model.parameters(), lr=lr)

## [MY CODE] 학습 진행

In [None]:
import time

n_epochs = 10

for epoch in range(n_epochs):
  total_loss = 0.
  model.train()

  start = time.time()
  for data in train_loader:
    model.zero_grad()
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device).long()

    preds = model(inputs)
    loss = loss_fn(preds, labels)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  end = time.time() - start
  print(f"Epoch {epoch:3d} | Time : {end} | Train Loss: {total_loss}")


In [None]:
with torch.no_grad():
  model.eval()
  train_acc = accuracy(model, train_loader, device)
  test_acc = accuracy(model, test_loader, device)
  print(f"=========> Train acc: {train_acc:.3f} | Test acc: {test_acc:.3f}")

## [LOG] train, test 둘다 정확도 89%로, 학습이 잘 되었다.