In [None]:
from torchtext.data import Field 
from torchtext.datasets import IMDB

text_field = Field(sequential=True, include_lengths=True, fix_length=200)
label_field = Field(sequential=False)


train, test = IMDB.splits(text_field, label_field)

In [None]:
from torchtext.vocab import FastText

text_field.build_vocab(train, vectors=FastText('simple'))
label_field.build_vocab(train)

In [None]:
from torchtext.data import BucketIterator

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32

train_iter, test_iter = BucketIterator.splits(
    (train, test), 
    batch_size=batch_size, 
    device=device
)

In [None]:
from pytorch_lightning.core.lightning import LightningModule

class classifier(LightningModule):
    def __init__(self, embedding, lstm_input_size=300, lstm_hidden_size=100, output_size=3):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(lstm_input_size, lstm_hidden_size)
        self.lin = nn.Linear(lstm_hidden_size, output_size)
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, X: torch.Tensor):
        # need to be permuted because by default X is batch first
        x = self.embedding[X].to(self.device).permute(1, 0, 2)
        x, _ = self.lstm(x)
        x = F.elu(x.permute(1, 0, 2))
        x = self.lin(x)
        x = x.sum(dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            loss=loss,
            log=dict(
                train_loss=loss
            )
        )
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.01)
    
    def train_dataloader(self):
        return train_iter
    
    def test_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            test_loss=loss,
            log=dict(
                test_loss=loss
            )
        )
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = dict(
            test_loss=avg_loss
        )
        return dict(
            avg_test_loss=avg_loss, 
            log=tensorboard_logs
        )
    
    def test_dataloader(self):
        return test_iter

In [None]:
sample_batch = next(iter(train_iter))
model(sample_batch.text[0].T)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
model = MyModel(text_field.vocab.vectors)
logger = TensorBoardLogger('tb_logs', name='my_model')
trainer = Trainer(
    gpus=1, 
    logger=logger,
    max_epochs=10
)
trainer.fit(model)

In [None]:
trainer.test()
