<a href="https://colab.research.google.com/github/GAEGAE2675/01SW/blob/main/base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q datasets

In [None]:
from datasets import load_dataset

# case corpus
#data_corpus = load_dataset("lbox/lbox_open", "case_corpus")

# casename classficiation task
data_cn = load_dataset("lbox/lbox_open", "casename_classification")

# statutes classification task
data_st = load_dataset("lbox/lbox_open", "statute_classification")

# case summarization task
data_summ = load_dataset("lbox/lbox_open", "summarization")

In [None]:
!pip install --quiet transformers
!pip install --quiet sentencepiece
!pip install --quiet datasets
!pip install --quiet rouge_score
!pip install --quiet pytorch-lightning

In [None]:
!git clone https://github.com/lbox-kr/lbox_open.git --branch v0.1
%cd lbox_open

In [None]:
import os
from argparse import Namespace

import torch
import transformers
import pytorch_lightning as pl

from lcube.data_module.data_lbox_open import LBoxOpenDataModule
from lcube.model.model_baseline import SeqToSeqBaseline

device = 'cuda' if torch.cuda.is_available() else "cpu"

In [None]:
args = Namespace()
# dataset
args.dataset_card = "lbox/lbox_open"
args.task = "casename_classification"  # comment and uncomment following lines depending on the task
# args.task = "statute_classification"
# args.task = "summarization"


if args.task in ["casename_classification", "statute_classification"]:
    args.input_key = "facts"

    # model
    args.model_card = "google/mt5-small"
    args.max_input_len = 512
    args.max_target_len = 64

    # train
    args.max_epochs = 10
    args.learning_rate = 2e-4
    args.batch_size = 8
    args.batch_size_eval = 2 * args.batch_size
    args.accumulate_grad_batches = 1
    args.validation_metric = "exact_match"

elif args.task == "summarization":
    args.input_key = "precedent"

    # model
    args.model_card = "google/mt5-small"
    args.max_input_len = 768
    args.max_target_len = 512

    # train
    args.max_epochs = 10
    args.learning_rate = 2e-4
    args.batch_size = 1
    args.batch_size_eval = 2 * args.batch_size
    args.accumulate_grad_batches = 8
    args.validation_metric = "rougeL"

else:
    raise ValueError


args.tokenizer = transformers.MT5TokenizerFast.from_pretrained(args.model_card)
pl.utilities.seed.seed_everything(seed=1, workers=False)

In [None]:
data_module = LBoxOpenDataModule(
    args.dataset_card,
    args.task,
    args.tokenizer,
    args.max_input_len,
    args.max_target_len,
    args.batch_size,
    args.batch_size_eval,
)

In [None]:
backbone = transformers.MT5ForConditionalGeneration.from_pretrained(args.model_card)
model = SeqToSeqBaseline(
    args.task,
    backbone,
    args.tokenizer,
    args.learning_rate,
    args.max_target_len,
    args.validation_metric

)

In [None]:
callbacks = pl.callbacks.ModelCheckpoint(
    monitor=args.validation_metric,
    dirpath=f"./saved/0/{args.task}",
    save_top_k=1,
    mode="max",
)
trainer = pl.Trainer(
    max_epochs = args.max_epochs,
    gpus=torch.cuda.device_count(),
    accumulate_grad_batches=args.accumulate_grad_batches,
    fast_dev_run=not True,
    callbacks=callbacks,
)

In [None]:
trainer.fit(model, data_module)

In [None]:
model.model = model.model.to(device)
pr_seqs = model.model.generate(model_inputs["input_ids"],
                               max_length=args.max_target_len)
prs = args.tokenizer.batch_decode(pr_seqs, skip_special_tokens=True)
print(f"판례를 입력해주세요\n {input_text}\n\n")
print(f"예상되는 위반법\n {prs}")