### 作業目的: 熟練自定義collate_fn與sampler進行資料讀取

本此作業主要會使用[IMDB](http://ai.stanford.edu/~amaas/data/sentiment/)資料集利用Pytorch的Dataset與DataLoader進行
客製化資料讀取。
下載後的資料有分成train與test，因為這份作業目的在讀取資料，所以我們取用train部分來進行練習。
(請同學先行至IMDB下載資料)

### 載入套件

In [11]:
# Import torch and other required modules
import glob
import torch
import re
import nltk
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence
from sklearn.datasets import load_svmlight_file
from nltk.corpus import stopwords

nltk.download('stopwords') #下載stopwords
nltk.download('punkt') #下載word_tokenize需要的corpus

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/katnyeung/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/katnyeung/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### 探索資料與資料前處理
這份作業我們使用test資料中的pos與neg


In [4]:
# 讀取字典，這份字典為review內所有出現的字詞
with open(os.path.join('aclImdb', 'imdb.vocab'), encoding='utf-8') as f:
    vocab = [line.strip() for line in f.readlines()]

# 以nltk stopwords移除贅字，過多的贅字無法提供有用的訊息，也可能影響模型的訓練
print(f"vocab length before removing stopwords: {len(vocab)}")
en_stopwords = set(stopwords.words('english'))
vocab = [word for word in vocab if word not in en_stopwords]
print(f"vocab length after removing stopwords: {len(vocab)}")

# 將字典轉換成dictionary
vocab_dic = {word: idx for idx, word in enumerate(vocab)}

vocab length before removing stopwords: 89527
vocab length after removing stopwords: 89356


In [5]:
# 將資料打包成(x, y)配對，其中x為review的檔案路徑，y為正評(1)或負評(0)
# 這裡將x以檔案路徑代表的原因是讓同學練習不一次將資料全讀取進來，若電腦記憶體夠大(所有資料檔案沒有很大)
# 可以將資料全一次讀取，可以減少在訓練時I/O時間，增加訓練速度

review_pairs = []
for folder, label in [('pos', 1), ('neg', 0)]:
    filepaths = glob.glob(os.path.join('aclImdb', 'train', folder, '*'))
    for filepath in filepaths:
        review_pairs.append((filepath, label))

print(review_pairs[:2])
print(f"Total reviews: {len(review_pairs)}")

[('aclImdb/train/pos/190_10.txt', 1), ('aclImdb/train/pos/3395_9.txt', 1)]
Total reviews: 25000


### 建立Dataset, DataLoader, Sampler與Collate_fn讀取資料
這裡我們會需要兩個helper functions，其中一個是讀取資料與清洗資料的函式(load_review)，另外一個是生成詞向量函式
(generate_vec)，注意這裡我們用來產生詞向量的方法是單純將文字tokenize(為了使產生的文本長度不同，而不使用BoW)

In [6]:
def load_review(review_path):
    with open(review_path, encoding='utf-8') as f:
        review = f.read()

    # 移除non-alphabet符號、贅字與tokenize
    review = re.sub(r'\W', ' ', review)
    review = nltk.word_tokenize(review)
    
    return review


def generate_vec(review, vocab_dic):
    idx_vec = [vocab_dic[word] for word in review if vocab_dic.get(word)]

    return idx_vec

In [7]:
#建立客製化dataset

class dataset(Dataset):
    '''custom dataset to load reviews and labels
    Parameters
    ----------
    data_pairs: list
        directory of all review-label pairs
    vocab: list
        list of vocabularies
    '''
    def __init__(self, data_dirs, vocab):
        self.data_dirs = data_dirs
        self.vocab = vocab

    def __len__(self):
        return len(self.data_dirs)

    def __getitem__(self, idx):
        review_path, label = self.data_dirs[idx]
        review = load_review(review_path)
        idx_vector = generate_vec(review, self.vocab)

        return idx_vector, label
    
    

#建立客製化collate_fn，將長度不一的文本pad 0 變成相同長度
def collate_fn(batch):
    reviews, labels = zip(*batch)
    lengths = torch.LongTensor([len(review) for review in reviews])
    labels = torch.LongTensor(labels)
    reviews = pad_sequence([
        torch.LongTensor(review) for review in reviews
    ], batch_first=True, padding_value=0)

    return reviews, labels, lengths

In [12]:
# 使用Pytorch的RandomSampler來進行indice讀取並建立dataloader
custom_dataset = dataset(review_pairs, vocab_dic)
custom_dataloader = DataLoader(custom_dataset, 
                               batch_size=4, 
                               sampler=RandomSampler(custom_dataset), 
                               collate_fn=collate_fn
)
next(iter(custom_dataloader))

(tensor([[  486,     4,   701,    25,     2,   884,    17, 75154, 75154,     6,
            392,  4207,  3993, 32461, 18275,  3993,  7926, 33012,  9537,   292,
           1953, 75154, 75154,    17,    74,    14,   268,  1783,    14,   466,
           1288,  3220,   106,    92,   404,  3046,    17,    34,    10,   277,
            599,  1266,   194,   101,  3993,  4207,    49,   787,   312,     8,
             94,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0],
         [ 4969,   101, 10457,  1069,   233,   295,   645,  1306,  1190, 40274,
           1306,   536,    43,  2779,   386,   128,  