In [1]:
from torchtext.data import Field 
from torchtext.datasets import IMDB

import torch

text_field = Field(sequential=True, include_lengths=True, fix_length=200)
label_field = Field(sequential=False)


train, test = IMDB.splits(text_field, label_field)

In [2]:
from torchtext.vocab import FastText

text_field.build_vocab(train, vectors=FastText('simple'))
label_field.build_vocab(train)

In [3]:
from torchtext.data import BucketIterator

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32

train_iter, test_iter = BucketIterator.splits(
    (train, test), 
    batch_size=batch_size, 
    device=device
)

In [4]:
from pytorch_lightning.core.lightning import LightningModule
from torch import nn

class classifier(LightningModule):
    def __init__(self, embedding, lstm_input_size=300, lstm_hidden_size=100, output_size=3):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(lstm_input_size, lstm_hidden_size)
        self.lin = nn.Linear(lstm_hidden_size, output_size)
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, X: torch.Tensor):
        # need to be permuted because by default X is batch first
        x = self.embedding[X].to(self.device).permute(1, 0, 2)
        x, _ = self.lstm(x)
        x = torch.nn.functional.elu(x.permute(1, 0, 2))
        x = self.lin(x)
        x = x.sum(dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            loss=loss,
            log=dict(
                train_loss=loss
            )
        )
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)
    
    def train_dataloader(self):
        return train_iter
    
    def test_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            test_loss=loss,
            log=dict(
                test_loss=loss
            )
        )
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = dict(
            test_loss=avg_loss
        )
        return dict(
            avg_test_loss=avg_loss, 
            log=tensorboard_logs
        )
    
    def test_dataloader(self):
        return test_iter

In [8]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
model = classifier(text_field.vocab.vectors)
logger = TensorBoardLogger('tb_logs', name='text_classification')
trainer = Trainer(
    gpus=1, 
    logger=logger,
    max_epochs=10
)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | lstm          | LSTM             | 160 K 
1 | lin           | Linear           | 303   
2 | loss_function | CrossEntropyLoss | 0     
Epoch 9: 100%|██████████| 782/782 [00:49<00:00, 15.65it/s, loss=0.347, v_num=0]


1

In [9]:
trainer.test()


Testing: 100%|█████████▉| 779/782 [00:24<00:00, 30.33it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': tensor(0.5879, device='cuda:0'),
 'test_loss': tensor(0.5879, device='cuda:0')}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 782/782 [00:25<00:00, 31.24it/s]


[{'test_loss': 0.5878623127937317, 'avg_test_loss': 0.5878623127937317}]