In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

In [3]:
import os
from argparse import Namespace

import torch
import transformers
import pytorch_lightning as pl

from lcube.data_module.data_lbox_open import LBoxOpenDataModule
from lcube.model.model_baseline import SeqToSeqBaseline

os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = 'cuda' if torch.cuda.is_available() else "cpu"

# 0. Set parameters

In [4]:
args = Namespace()
# dataset
args.dataset_card = "lbox/lbox_open"
# args.task = "casename_classification"
args.task = "statute_classification"
# args.task = "summarization"


if args.task in ["casename_classification", "statute_classification"]:
    args.input_key = "facts"

    # model
    args.model_card = "google/mt5-small"
    args.max_input_len = 512
    args.max_target_len = 64

    # train
    args.max_epochs = 10
    args.learning_rate = 2e-4
    args.batch_size = 8
    args.batch_size_eval = 2 * args.batch_size
    args.accumulate_grad_batches = 1
    args.validation_metric = "exact_match"
    
elif args.task == "summarization":
    args.input_key = "precedent"

    # model
    args.model_card = "google/mt5-small"
    args.max_input_len = 768
    args.max_target_len = 512

    # train
    args.max_epochs = 10
    args.learning_rate = 2e-4
    args.batch_size = 1
    args.batch_size_eval = 2 * args.batch_size
    args.accumulate_grad_batches = 8
    args.validation_metric = "rougeL"

else:
    raise ValueError


args.tokenizer = transformers.MT5TokenizerFast.from_pretrained(args.model_card)

# 1. Load dataset

In [5]:
data_module = LBoxOpenDataModule(
    args.dataset_card,
    args.task,
    args.tokenizer,
    args.max_input_len,
    args.max_target_len,
    args.batch_size,
    args.batch_size_eval,
)

# 2. Load model

In [None]:
backbone = transformers.MT5ForConditionalGeneration.from_pretrained(args.model_card)

In [None]:
model = SeqToSeqBaseline(
    args.task,
    backbone,
    args.tokenizer,
    args.learning_rate,
    args.max_target_len,
    args.validation_metric

)

# 3. Trainer

In [None]:
callbacks = pl.callbacks.ModelCheckpoint(
    monitor=args.validation_metric,
    dirpath=f"./saved/0/{args.task}",
    save_top_k=1,
    mode="max",
)
trainer = pl.Trainer(
    max_epochs = args.max_epochs,
    gpus=torch.cuda.device_count(),
    accumulate_grad_batches=args.accumulate_grad_batches,
    fast_dev_run=not True,
    callbacks=callbacks,
)

# 4. Train

In [None]:
trainer.fit(model, data_module)

# 4. Train the model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_text = data_module.dataset["test"][40][args.input_key]
model_inputs = args.tokenizer(
            input_text,
            max_length=args.max_input_len,
            padding=True,
            truncation=True,
            return_tensors='pt',
        )
model_inputs = {k: v.to(device) for k,v in model_inputs.items()}
pr_seqs = model.model.generate(model_inputs["input_ids"], max_length=args.max_target_len)
prs = args.tokenizer.batch_decode(pr_seqs, skip_special_tokens=True)
print(f"Input\n {input_text}\n\n")
print(f"Prediction\n {prs}")

In [None]:
args.tokenizer.batch_decode(args.tokenizer(data_module.dataset["test"][40]["statutes"], 
               padding=True, truncation=True, return_tensors="pt"
              )["input_ids"])

In [None]:
", ".join(data_module.dataset["test"][40]["statutes"])