출처 : https://tutorials.pytorch.kr/beginner/text_sentiment_ngrams_tutorial.html

In [1]:
import torch
import torchtext
from torchtext.datasets import text_classification # 데이터는 이곳에 내장됨
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
    os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ag_news_csv.tar.gz: 11.8MB [00:02, 5.28MB/s]
120000lines [00:15, 7621.54lines/s]
120000lines [00:38, 3133.77lines/s]
7600lines [00:02, 3725.88lines/s]


## 모델 정의 하기 

- nn.EmbeddingBag 기능으로 임베딩 된 것들을 bag으로 둔다
- 텍스트 별 길이가 다를 수 있으나, offset 를 통해서 저장되어있음

In [7]:
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse = True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
    
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

## 인스턴스 생성하기

In [8]:
VOCAB_SIZE = len(train_dataset.get_vocab())
# train_dataset 은 텐서로 구성되어있어서 .get_vocab() 기능으로 단어 길이 확인할 수있음

EMBED_DIM = 32ㅍ
NUM_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)

## 배치 생성을 위한 함수들 
- torch.utils.data.DataLoder 의 collate_fn 인자로 넘겨줌

In [10]:
dir(train_dataset)

['__add__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_data',
 '_labels',
 '_vocab',
 'get_labels',
 'get_vocab']

In [11]:
train_dataset._data 
# train_dataset 은 카테고리 넘버 , text 의 인덱스로 구성된 텐서로 저장되어있다

[(2,
  tensor([    572,     564,       2,    2326,   49106,     150,      88,       3,
             1143,      14,      32,      15,      32,      16,  443749,       4,
              572,     499,      17,      10,  741769,       7,  468770,       4,
               52,    7019,    1050,     442,       2,   14341,     673,  141447,
           326092,   55044,    7887,     411,    9870,  628642,      43,      44,
              144,     145,  299709,  443750,   51274,     703,   14312,      23,
          1111134,  741770,  411508,  468771,    3779,   86384,  135944,  371666,
             4052])),
 (2,
  tensor([  55003,    1474,    1150,    1832,    7559,      14,      32,      15,
               32,      16,    1262,    1072,     436,   55003,     131,       4,
           142576,      33,       6,    8062,      12,     756,  475640,       9,
           991346,    3186,       8,       3,     698,     329,       4,      33,
             6764, 1040465,   13979,      11,     278,     483,   

In [14]:
train_dataset.get_labels() # 카테고리 라벨 4개

{0, 1, 2, 3}

In [16]:
train_dataset._vocab

<torchtext.vocab.Vocab object at 0x7f935eb122b0>


In [17]:
dir(train_dataset._vocab)

['UNK',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_default_unk_index',
 'extend',
 'freqs',
 'itos',
 'load_vectors',
 'set_vectors',
 'stoi',
 'unk_index',
 'vectors']

In [20]:
train_dataset._vocab.freqs # 발생한 단어들의 프리퀀시는 여기 저장되어있음

Counter({'laid his': 1,
         'music illegally': 6,
         'in bitdefender': 1,
         'trialing the': 1,
         'snow\\repeated on': 1,
         '41-20 ,': 1,
         'catastrophe .': 5,
         'school dormitory': 8,
         'still crunching': 1,
         'bankrupt\\united': 1,
         'irs pledges': 1,
         'witty': 4,
         'mcdonalds says': 1,
         'rename cnet': 1,
         'acquisition .': 15,
         'whole story': 2,
         'the mick': 1,
         'attendees say': 1,
         'postal rider': 1,
         'practice 3': 2,
         'rename companies': 1,
         'company 3com': 1,
         'software programmes': 1,
         'recalls cell': 1,
         '55-yard field': 4,
         'addresses rally': 1,
         're-draw the': 1,
         'the\\brink of': 2,
         'canada to': 55,
         'estimates that': 14,
         'posts bent': 1,
         "novak '": 1,
         'official sana': 3,
         'eyes london': 1,
         'rebels turned': 1,
        

In [None]:
vectors