In [6]:
%%capture
!git clone https://github.com/hudeven/text
!pip install --upgrade -e ./text;

In [1]:
import os

import torch
from torch.optim import AdamW
from pytorch_lightning import Trainer
from stl_text.ops.utils.arrow import convert_csv_to_arrow
from stl_text.datamodule import DocClassificationDataModule
from stl_text.models import RobertaModel
from task import DocClassificationTask

In [2]:
# convert csv to arrow format (only required for the first time)
data_path = "./glue_sst2_tiny"
for split in ("train.tsv", "valid.tsv", "test.tsv"):
    split_path = os.path.join(data_path, split)
    convert_csv_to_arrow(split_path)

# setup datamodule
datamodule = DocClassificationDataModule(data_path=data_path, batch_size=8, drop_last=True)
datamodule.setup("fit")

# build task
model = RobertaModel(
    vocab_size=1000,
    embedding_dim=1000,
    num_attention_heads=1,
    num_encoder_layers=1,
    output_dropout=0.4,
    out_dim=2,
)
optimizer = AdamW(model.parameters(), lr=0.01)
task = DocClassificationTask(
    datamodule=datamodule,
    model=model,
    optimizer=optimizer,
)

# train model
trainer = Trainer(max_epochs=5, fast_dev_run=True)
trainer.fit(task, datamodule=datamodule)

# test model
trainer.test(task, datamodule=datamodule)

# export task(transform + model) to TorchScript
export_path = "/tmp/doc_classification_task.pt1"
task.to_torchscript(export_path)

# deploy task to server and inference
with open(export_path, "rb") as f:
    ts_module = torch.load(f)
    print(ts_module(text_batch=["hello world", "attention is all your need!"]))

converted to arrow and saved to ./glue_sst2_tiny/train
converted to arrow and saved to ./glue_sst2_tiny/valid
converted to arrow and saved to ./glue_sst2_tiny/test
Loading cached processed dataset at ./glue_sst2_tiny/train/cache-f1e83dde37d3c060.arrow
Loading cached processed dataset at ./glue_sst2_tiny/train/cache-e464112f4fa64676.arrow
Loading cached processed dataset at ./glue_sst2_tiny/train/cache-3301217692ea7b7c.arrow
Loading cached processed dataset at ./glue_sst2_tiny/valid/cache-d7e81757a0546a02.arrow
Loading cached processed dataset at ./glue_sst2_tiny/valid/cache-daf7dc6f6a37bdc1.arrow
Loading cached processed dataset at ./glue_sst2_tiny/valid/cache-a1235ce082fa5b2a.arrow
Loading cached processed dataset at ./glue_sst2_tiny/test/cache-4a8ddec11e8e0938.arrow
Loading cached processed dataset at ./glue_sst2_tiny/test/cache-a6e5e3e4a7b8c81f.arrow
Loading cached processed dataset at ./glue_sst2_tiny/test/cache-c33ac2d9284c0aa3.arrow
GPU available: False, used: False
TPU available

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…












--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.5000), 'test_loss': tensor(9.0768)}
--------------------------------------------------------------------------------
tensor([[-9.4245,  8.7131],
        [-9.6996,  9.0748]], grad_fn=<AddBackward0>)


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

In [3]:
print(ts_module)

RecursiveScriptModule(
  original_name=DocClassificationTask
  (text_transform): RecursiveScriptModule(original_name=WhitespaceTokenizer)
  (model): RecursiveScriptModule(
    original_name=RobertaModel
    (encoder): RecursiveScriptModule(
      original_name=RobertaEncoder
      (transformer): RecursiveScriptModule(
        original_name=Transformer
        (token_embedding): RecursiveScriptModule(original_name=Embedding)
        (layers): RecursiveScriptModule(
          original_name=ModuleList
          (0): RecursiveScriptModule(
            original_name=TransformerLayer
            (dropout): RecursiveScriptModule(original_name=Dropout)
            (attention): RecursiveScriptModule(
              original_name=MultiheadSelfAttention
              (dropout): RecursiveScriptModule(original_name=Dropout)
              (input_projection): RecursiveScriptModule(original_name=Linear)
              (output_projection): RecursiveScriptModule(original_name=Linear)
            )
       