In [1]:
from typing import Optional

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [3]:
import pytorch_lightning as pl

In [4]:
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.embeddings.base import Embeddings
from flair.data import Dictionary

In [5]:
class FlairData(pl.LightningDataModule):
    def __init__(
        self,
        corpus: Corpus,
        document_embeddings: Embeddings,
        label_dictionary: Dictionary,
        batch_size: int = 128,
    ):
        super().__init__()
        self.corpus = corpus
        self.document_embeddings = document_embeddings
        self.label_dictionary = label_dictionary
        self.batch_size = batch_size
        

    def transfer_batch_to_device(self, batch, device):  
        self.document_embeddings.embed(batch)
        embedding_names = self.document_embeddings.get_names()
            
        text_embedding_list = [
            sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in batch
        ]
        
        indices = [
            torch.LongTensor(
                [
                    self.label_dictionary.get_idx_for_item(label.value)
                    for label in sentence.get_labels()
                ]
            )
            for sentence in batch
        ]
        
        return (torch.cat(text_embedding_list, 0), torch.cat(indices, 0).to(device))
    
    def train_dataloader(self):
        return DataLoader(self.corpus.train.dataset, 
                          collate_fn=list,
                          batch_size=self.batch_size, 
                          shuffle=True,
                          num_workers=8,
                          pin_memory=True,
                         )
    
    def val_dataloader(self):
        return DataLoader(self.corpus.dev.dataset, 
                          collate_fn=list,
                          batch_size=self.batch_size, 
                          shuffle=False,
                          pin_memory=True,
                          num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(self.corpus.test.dataset, 
                          collate_fn=list,
                          batch_size=self.batch_size, 
                          shuffle=False,
                          pin_memory=True,
                          num_workers=8)

In [28]:
class FlairModel(pl.LightningModule):
    def __init__(self, document_embeddings, label_dictionary, learning_rate):
        super().__init__()
        
        self.learning_rate = learning_rate
        self.label_dictionary = label_dictionary
        self.document_embeddings = document_embeddings
        self.decoder = nn.Linear(
            self.document_embeddings.embedding_length, len(self.label_dictionary)
        )
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, x):
        return self.decoder(x)
    
    def training_step(self, batch, batch_idx):
        embeddings, labels = batch
        scores = self(embeddings)
        
        return self.loss_function(scores, labels)
    
    def validation_step(self, batch, batch_idx):
        embeddings, labels = batch
        scores = self(embeddings)
        loss = self.loss_function(scores, labels)
        self.log('val_loss', loss)
        return {'val_loss': loss}
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [29]:
corpus = TREC_6()

2021-04-21 10:54:23,624 Reading data from /home/krasul/.flair/datasets/trec_6
2021-04-21 10:54:23,625 Train: /home/krasul/.flair/datasets/trec_6/train.txt
2021-04-21 10:54:23,626 Dev: None
2021-04-21 10:54:23,627 Test: /home/krasul/.flair/datasets/trec_6/test.txt


In [30]:
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', 
                                                    fine_tune=True, 
                                                    batch_size=128)

In [31]:
label_dictionary = corpus.make_label_dictionary()

2021-04-21 10:54:27,955 Computing label dictionary. Progress:


100%|██████████| 5407/5407 [00:00<00:00, 64961.01it/s]

2021-04-21 10:54:28,043 [b'DESC', b'ENTY', b'ABBR', b'HUM', b'NUM', b'LOC']





In [32]:
trainer = pl.Trainer(gpus="0")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [33]:
model = FlairModel(document_embeddings=document_embeddings,
                   label_dictionary=label_dictionary, 
                   learning_rate=1e-3)

In [34]:
dm = FlairData(corpus=corpus,
               document_embeddings=model.document_embeddings,
               label_dictionary=model.label_dictionary)

In [35]:
trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type                          | Params
----------------------------------------------------------------------
0 | document_embeddings | TransformerDocumentEmbeddings | 66.4 M
1 | decoder             | Linear                        | 4.6 K 
2 | loss_function       | CrossEntropyLoss              | 0     
----------------------------------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.470   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1