# AG News Classification

### classes
- world
- sports
- business
- science

## imports

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from pytorch_lightning import Trainer, LightningDataModule, LightningModule
from pytorch_lightning.callbacks import EarlyStopping
import argparse

## arguments

In [2]:
args = argparse.Namespace(
    data_path = "../dataset/ag-news",
    
    lr = 0.0001,
    max_epochs = 200,
    batch_size = 128,
)

## 데이터 불러오기

In [3]:
train_df = pd.read_csv(os.path.join(args.data_path, "train.csv"))
test_df = pd.read_csv(os.path.join(args.data_path, "test.csv"))

print(f"train_df.shape : {train_df.shape}")
print(f"test_df.shape : {test_df.shape}")

train_df.head()

train_df.shape : (120000, 3)
test_df.shape : (7600, 3)


Unnamed: 0,Class Index,Title,Description
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."


### Title + Description와 Class Index로 분리

In [4]:
train_sentences = train_df["Title"].str.cat(train_df["Description"], sep=" ")
train_labels = train_df["Class Index"]

test_sentences = test_df["Title"].str.cat(test_df["Description"], sep=" ")
test_labels = test_df["Class Index"]

train_sentences, train_labels

(0         Wall St. Bears Claw Back Into the Black (Reute...
 1         Carlyle Looks Toward Commercial Aerospace (Reu...
 2         Oil and Economy Cloud Stocks' Outlook (Reuters...
 3         Iraq Halts Oil Exports from Main Southern Pipe...
 4         Oil prices soar to all-time record, posing new...
                                 ...                        
 119995    Pakistan's Musharraf Says Won't Quit as Army C...
 119996    Renteria signing a top-shelf deal Red Sox gene...
 119997    Saban not going to Dolphins yet The Miami Dolp...
 119998    Today's NFL games PITTSBURGH at NY GIANTS Time...
 119999    Nets get Carter from Raptors INDIANAPOLIS -- A...
 Name: Title, Length: 120000, dtype: object,
 0         3
 1         3
 2         3
 3         3
 4         3
          ..
 119995    1
 119996    2
 119997    2
 119998    2
 119999    2
 Name: Class Index, Length: 120000, dtype: int64)

## vocabulary 구성

- train data 사용
- 공백 단위로 word 추출
- 소문자만 이용 : lower()

word_to_id, id_to_word

In [5]:
vocabulary = set()

for sentence in train_sentences:
    for word in sentence.split():
        word = word.lower()
        vocabulary.add(word)

vocabulary = sorted(list(vocabulary))

word_to_id = {"[PAD]": 0, "[UNK]": 1}
id_to_word = ["[PAD]", "[UNK]"]

for word in vocabulary:
    word_to_id[word] = len(word_to_id)
    id_to_word.append(word)

In [6]:
print(f"len(vocabulary) : {len(vocabulary)}")

len(vocabulary) : 158715


## Dataset

In [7]:
class SequenceDataset(Dataset):
    def __init__(self, word_to_id: dict, sentences: pd.Series, labels: pd.Series=None):
        super().__init__()
        self.sentences = sentences
        self.labels = labels
        self.word_to_id = word_to_id
    
    def __getitem__(self, index):
        sequence = []
        sentence = self.sentences[index]
        label = self.labels[index]
        for word in sentence.split():
            word = word.lower()
            id = 1 if word not in self.word_to_id else self.word_to_id[word]
            sequence.append(id)
        
        sequence = torch.tensor(sequence, dtype=torch.long)
        label = torch.tensor(label, dtype=torch.long)
        
        return sequence, label
    
    def __len__(self):
        return len(self.labels)

## DataModule

In [8]:
class SequenceDataModule(LightningDataModule):
    def __init__(self, train_sentences, train_labels, test_sentences, test_labels, word_to_id, batch_size):
        super().__init__()
        self.train_sentences = train_sentences
        self.train_labels = train_labels
        self.test_sentences = test_sentences
        self.test_labels = test_labels
        self.word_to_id = word_to_id
        self.batch_size = batch_size
    
    def setup(self, stage):
        if stage == "fit":
            self.train_dataset, self.val_dataset = random_split(SequenceDataset(self.word_to_id, self.train_sentences, self.train_labels), [0.8, 0.2])
        if stage == "test":
            self.test_dataset = SequenceDataset(self.word_to_id, self.test_sentences, self.test_labels)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, shuffle=True, collate_fn=self.collate_fn)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, collate_fn=self.collate_fn)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, self.batch_size, collate_fn=self.collate_fn)
    
    def collate_fn(self, batch):
        sequences, labels = list(zip(*batch))
        sequences = nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
        labels = torch.stack(labels)
        
        return [sequences, labels]

## Model

### Embedding + FC

In [9]:
class SequenceModel(LightningModule):
    def __init__(self, n_vocab, lr):
        super().__init__()
        self.embed = nn.Embedding(n_vocab, 1000)
        self.fc1 = nn.Linear(1000, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 4)
        self.lr = lr
    
    def forward(self, x):
        x = self.embed(x)
        x, _ = torch.max(x, dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y -= 1  # 1~4 => 0~3
        
        loss = F.cross_entropy(y_hat, y)
        self.log("training_loss", loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y -= 1
        
        loss = F.cross_entropy(y_hat, y)
        acc = np.mean(list(map(int, torch.argmax(y_hat, dim=1)==y)))
        metrics = {"val_loss": loss, "val_acc": acc}
        self.log_dict(metrics)
        
        return metrics
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y -= 1
        
        loss = F.cross_entropy(y_hat, y)
        acc = np.mean(list(map(int, torch.argmax(y_hat, dim=1)==y)))
        metrics = {"test_loss": loss, "test_acc": acc}
        self.log_dict(metrics)
        
        return metrics
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## 학습

In [10]:
ag_news_data = SequenceDataModule(train_sentences, train_labels, test_sentences, test_labels, word_to_id, args.batch_size)
model = SequenceModel(len(word_to_id), args.lr)
trainer = Trainer(max_epochs=args.max_epochs, callbacks=[EarlyStopping(monitor="val_acc", mode="max", patience=3, verbose=True)])

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### 초기 상태 25% 정확도 (클래스 4개 중 1개)

In [11]:
trainer.test(model, ag_news_data)

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.24947368421052632
        test_loss           1.4024487733840942
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.4024487733840942, 'test_acc': 0.24947368421052632}]

### gpu 없이 학습하려니 너무 오래 걸려서 서버에서 학습하고 weigts만 이용

In [12]:
# trainer.fit(model, ag_news_data)

In [13]:
model = SequenceModel(len(word_to_id), args.lr)
model.load_state_dict(torch.load("2_weights.pth"))
trainer.test(model, ag_news_data)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8735526315789474
        test_loss            0.404371052980423
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.404371052980423, 'test_acc': 0.8735526315789474}]