In [41]:
import torch
from torch import nn
import torchtext
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchtext.data.utils import get_tokenizer, ngrams_iterator
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import torchtext.transforms as T


In [65]:
torch.__version__

'2.1.0+cu118'

In [64]:
torchtext.__version__

'0.16.0+cpu'

In [32]:
unk_idex = 0
bos_idx = 1
eos_idx = 2
padding_idx = 3
min_seq_len = 5
max_seq_len = 256

In [13]:
batch_size = 128
min_frequency = 5

In [10]:
tokenizer = get_tokenizer('basic_english')

In [161]:
def datapipe_factory(datapipe, transform = None):
    datapipe = (
        datapipe
        .map(lambda item: (item[0], item[1].strip()) )
        .map(lambda item: (item[0], item[1].lower()) )
        .filter(lambda item: len(item[1]) > min_seq_len)
        .map(lambda item: (item[0], tokenizer(item[1])) )
        # .map(lambda text: drop_keywords(text, stopwords.words('english')))
        # .filter(lambda tokens: drop_short_text(tokens, context_size))
    )

    if transform:
      datapipe = datapipe.map(lambda item: (item[0], transform(item[1])))

    return datapipe

In [285]:
train_datapipe, test_datapipe = torchtext.datasets.IMDB("./data")

In [286]:
vocab = build_vocab_from_iterator(map(lambda x: x[1], datapipe_factory(test_datapipe)), specials=["<unk>", "<bos>", "<eos>", "<pad>"], min_freq=min_frequency,)
vocab.set_default_index(vocab["<unk>"])



In [287]:
text_transform = T.Sequential(
    T.VocabTransform(vocab),
    T.Truncate(max_seq_len - 2),
    T.AddToken(token=bos_idx, begin=True),
    T.AddToken(token=eos_idx, begin=False),
    T.ToTensor(),
    T.PadTransform(max_seq_len, padding_idx),
)

In [288]:
train_datapipe = datapipe_factory(
    train_datapipe,
    text_transform,
)

test_datapipe = datapipe_factory(
    test_datapipe,
    text_transform,
)

In [289]:
def collate_fn(batch: list):
    text_placeholder, label_placeholder = [], []
    for (label, text) in batch:
        text_placeholder.append(text)
        label_placeholder.append(label)

    return torch.stack(text_placeholder), torch.Tensor(label_placeholder)

In [290]:
train_dataloader = DataLoader(
    train_datapipe,
    batch_size=batch_size,
    collate_fn=collate_fn,
    shuffle=True,
)

test_dataloader = DataLoader(
    test_datapipe,
    batch_size=batch_size,
    collate_fn=collate_fn,
)

In [303]:
for text, label in train_dataloader:
    print(text.shape, label.shape)
    break

torch.Size([128, 256]) torch.Size([128])


In [304]:
vocab_size = len(vocab)
vocab_size

29575

In [331]:
class LSTM_Classifier(nn.Module):
    
    def __init__(self, embed_dim:int, hidden_dim:int):
        super(LSTM_Classifier,self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, 1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        output = self.embed(x) # batch_size, seq_len, embed_dim 
        # ot: batch_size, seq_len, bidirectional*hidden_dim
        # ht: bidirectional * layer_size, batch_size, hidden_dim 
        _, (ht, _) = self.rnn(output)
        output = self.fc(ht[-1]) # the last layer's hidden state represents the paragraph
        
        return output

In [332]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

In [333]:
embed_dim = 128
hidden_dim = 128
model = LSTM_Classifier(embed_dim, hidden_dim).to(device)

In [334]:
## 測試 forward
for batch in train_dataloader:
    text, label = batch
    text = text.to(device)
    output = model(text)
    print(output.shape)
    break

torch.Size([128, 1])


In [337]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())

In [338]:
epoches = 20

In [None]:
model.train()
training_loss = []
for epoch in range(epoches):
    running_loss = 0.0
    batch_idx = 0
    for batch in tqdm(test_dataloader):    
        optimizer.zero_grad()

        text, label = batch
        text = text.to(device)
        label = label.to(device)
        output = model(text).squeeze(1)
        loss = criterion(output, label-1)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch_idx += 1

    mean_running_loss = running_loss/batch_idx
    training_loss.append(mean_running_loss)
    print(f'epoch {epoch+1} : {mean_running_loss}')

In [None]:
plt.plot(training_loss)

In [None]:
model.eval()
with torch.no_grad():
    running_hit = 0.0
    data_size = 0
    for batch in tqdm(test_dataloader):
        text, label = batch
        text = text.to(device)
        label = label.to(device)
        
        output = model(text).squeeze(1)
        pred = output.sigmoid() > 0.5
        running_hit += (pred == (label-1)).sum().item()
        data_size += text.size(0)

    print(running_hit/data_size)

In [278]:
torch.save(model, 'lstm.pth')

In [329]:
model(text_transform(tokenizer("It was such a good movie I have ever seen")).unsqueeze(0).to(device))

tensor([[-0.3088]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [330]:
model(text_transform(tokenizer("So bad a movie can be, waste my money")).unsqueeze(0).to(device))

tensor([[-0.3088]], device='cuda:0', grad_fn=<AddmmBackward0>)