In [9]:
%matplotlib inline
import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
    os.mkdir('./.data')
# 此处为原教程中使用的代码，可以自动下载所需数据，国内由于网络原因会导致连接失败，后面会导入本地数据
# train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
#     root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [10]:
import logging
from torchtext.utils import extract_archive,unicode_csv_reader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator,_create_data_from_iterator

def _setup_datasets(dataset_tar='./.data/ag_news_csv.tar.gz',dataset_name="AG_NEWS", root='./.data', ngrams=NGRAMS, vocab=None, include_unk=False):
    extracted_files = extract_archive(dataset_tar)
    for fname in extracted_files:
        if fname.endswith("train.csv"):
            train_csv_path = fname
        if fname.endswith("test.csv"):
            test_csv_path =fname
    
    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path,ngrams))
    else:
        if not isinstance(vocab,Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(   #创建训练数据
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) 
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(   #创建测试数据
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),  #返回数据集实例
            TextClassificationDataset(vocab, test_data, test_labels))

In [11]:
train_dataset, test_dataset = _setup_datasets()

120000lines [00:06, 17840.67lines/s]
120000lines [00:12, 9830.40lines/s] 
7600lines [00:00, 10043.87lines/s]


In [12]:
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self,vocab_size,embed_dim,num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,embed_dim,sparse=True)
        self.fc = nn.Linear(embed_dim,num_class)
        self.init_weigths()
        
    def init_weigths(self):
        inirange = 0.5
        self.embedding.weight.data.uniform_(-inirange,inirange)
        self.fc.weight.data.uniform_(-inirange,inirange)
        self.fc.bias.data.zero_()
    
    def forward(self,x):
        embedded = self.embedding(text,offsets)
        return self.fc(embedded)

In [13]:
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

In [14]:
def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0]+[len(entry) for entry in text]
    
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text,offsets,label