#1. download IMDB dataset

In [None]:
!pip install datasets



In [None]:
from datasets import load_dataset
imdb = load_dataset('imdb')
train_data = imdb['train']
test_data = imdb['test']

README.md: 0.00B [00:00, ?B/s]

#2. Make vocab Dictionary and sentencepiece model

In [None]:
!pip install sentencepiece



In [None]:
"""
I chose SentencePiece as the tokenizer because it is a subword-based tokenizer.
This significantly reduces the occurrence of the <unk> token.
"""

import sentencepiece as spm
with open('/content/drive/MyDrive/github/imdb-sentiment-comparison-rnn-transformer/src/sentencepiece/imdb.txt', 'w', encoding='utf-8') as f:
  for item in train_data:
    f.write(item['text'] + '\n')

In [None]:
"""
The reason I chose 20,000 as the vocabulary size is that the IMDb dataset has long reviews and a wide variety of expressions.
"""
spm.SentencePieceTrainer.train(
    input = '/content/drive/MyDrive/github/imdb-sentiment-comparison-rnn-transformer/src/sentencepiece/imdb.txt',
    model_prefix = '/content/drive/MyDrive/github/imdb-sentiment-comparison-rnn-transformer/src/sentencepiece/imdb',
    vocab_size = 20000,
    unk_id=0,
    pad_id=1,
    bos_id=2,
    eos_id=3
)

In [None]:
sp = spm.SentencePieceProcessor()
sp.load('/content/drive/MyDrive/github/imdb-sentiment-comparison-rnn-transformer/src/sentencepiece/imdb.model')

True

#3. Make Dataset&Dataloader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class imdbDataset(Dataset):
  def __init__(self, data):
    self.data = data

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    text = torch.tensor(sp.encode(self.data[idx]['text']), dtype=torch.long)
    label = torch.tensor(self.data[idx]['label'], dtype=torch.float) #Use BCEWithLogitsLoss, so the target dtype should be float
    return text, label

def collate_fn(batch):
  texts, labels = zip(*batch)
  lengths = torch.tensor([len(text) for text in texts], dtype=torch.long)
  texts = pad_sequence(texts, batch_first=True, padding_value=sp.pad_id())
  labels = torch.stack(labels)
  return texts, lengths, labels

train_loader = DataLoader(imdbDataset(train_data), batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(imdbDataset(test_data), batch_size=32, collate_fn=collate_fn)