In [1]:
import os

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger

from dataset import CustomDataModule
from factory import read_yaml
from lightning_module import CustomLitModule


In [11]:
cfg = read_yaml(fpath="./configs/sample.yaml")
output_path = f"../output"

seed_everything(cfg.General.seed)
debug = True
fold = cfg.Data.dataset.fold

logger = CSVLogger(save_dir=str(output_path), name=f"fold_{fold}")

early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.05, patience=3, mode="min"
)
# 学習済重みを保存するために必要
checkpoint_callback = ModelCheckpoint(
    dirpath=str(output_path),
    filename=f"sample_fold_{fold}",
    verbose=True,
    monitor="val_loss",
    mode="min",
)
trainer = Trainer(
    max_epochs=3 if debug else cfg.General.epoch,
    accelerator="gpu",
    devices=1,
    amp_backend="native",
    deterministic=True,
    auto_select_gpus=False,
    benchmark=False,
    default_root_dir=os.getcwd(),
    limit_train_batches=0.2 if debug else 1.0,
    limit_val_batches=0.2 if debug else 1.0,
    callbacks=[checkpoint_callback, early_stop_callback],
    # logger=[logger, wandb_logger],
    logger=[logger],
)

# Lightning module and start training
model = CustomLitModule(cfg)
datamodule = CustomDataModule(cfg)
trainer.fit(model, datamodule=datamodule)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-char-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassificati

Sanity Checking: 0it [00:00, ?it/s]

ValueError: Number of classes in y_true not equal to the number of columns in 'y_score'

In [9]:
datamodule.setup()
datamodule.get_dataframe("train")["resB_label"].unique()

array([3, 2, 1, 0, 4])

In [10]:
datamodule.get_dataframe("valid")["resB_label"].unique()

array([3, 1, 2, 0, 4])

In [None]:
result = trainer.test(ckpt_path=checkpoint_callback.best_model_path, datamodule=datamodule)

Restoring states from the checkpoint path at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from checkpoint at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt


Testing: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.6777777777777778
        test_loss           1.3027324676513672
      test_macro_f1         0.18964525407478425
     test_precision         0.23636363636363633
       test_recall          0.2133879781420765
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
logits = trainer.predict(ckpt_path=checkpoint_callback.best_model_path, dataloaders=datamodule.test_dataloader())

Restoring states from the checkpoint path at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from checkpoint at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt


Predicting: 9it [00:00, ?it/s]

In [None]:
import torch
import numpy as np
logits = trainer.predict(ckpt_path=checkpoint_callback.best_model_path, dataloaders=datamodule.test_dataloader())
logits = torch.cat(logits)
test_df = datamodule.get_dataframe("test")
labels = torch.from_numpy(test_df["resB_label"].to_numpy())
test_df["pred"] = logits.argmax(dim=1)
display(test_df)

Restoring states from the checkpoint path at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from checkpoint at /home/workspace/labo/defamation_detection/output/sample_fold_0-v11.ckpt


Predicting: 9it [00:00, ?it/s]

Unnamed: 0,textDisplay,resB_label,pred
0,ウクライナ側は、納得しないだろ。ウクライナの領土を勝手に武力で侵犯しといて、併合は、ないよ。,3,3
1,田口くんを陥れるために市役所職員がしかけた罠なんだろ、なんで田口くんが責められるん？？犯罪な...,0,3
2,同意の誤信ってこのレベルの話でそんなんだったら、罪犯した人みんなそういう発言する人が増えるん...,3,3
3,まともな人なら返すからな。嘘ついたり、俺は悪くないとかアホな発言しているところをを見るとネッ...,3,3
4,こういう犯罪者の証拠はしっかりと残して、後にさばくべき,3,3
...,...,...,...
85,いちばん悪いのは 役所だ｡あってはならないミスを犯し 問題を この方に すり替えている。そも...,3,3
86,何か胡散臭いですね。まだ持ってるやろ（笑）,3,3
87,ハンドルロックやタイヤにロックしても切断して盗むから自宅の防犯設備のガーレージで駐車しないと...,3,3
88,マスクするしないは自由だが、搭乗するしないも被告の自由。搭乗者は、機内の規則を遵守することを...,3,3
