In [1]:
import torch

print(torch.cuda.is_available())

True


## load_model

In [2]:
from peft import LoraConfig, TaskType
from peft.peft_model import PeftModelForSequenceClassification
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)

tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-1.8b")
model = AutoModelForSequenceClassification.from_pretrained("llm-jp/llm-jp-3-1.8b")

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at llm-jp/llm-jp-3-1.8b and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.config.pad_token_id = tokenizer.pad_token_id

In [4]:
class PatchedPeftModelForSequenceClassification(PeftModelForSequenceClassification):
    def add_adapter(self, adaper_name, peft_config, low_cpu_mem_usage: bool = False):
        super().add_adapter(adapter_name, peft_config)

In [5]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=1,
)

In [6]:
peft_model = PatchedPeftModelForSequenceClassification(model, peft_config)

## load_dataset

In [7]:
from datasets import load_dataset

In [8]:
ds = load_dataset(
    "csv",
    data_files={
        "train": "../data/train.tsv",
        "valid": "../data/valid.tsv",
        "test": "../data/test.tsv",
    },
    delimiter="\t",
).rename_column("label", "labels")

## training

In [9]:
import datetime

yyyymmddhhmmss = "{:%Y%m%d%H%M%S}".format(datetime.datetime.now())
yyyymmddhhmmss

'20241201051941'

In [10]:
class TokenizeCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, examples):
        encoding = self.tokenizer(
            [ex["poem"] for ex in examples],
            padding="longest",
            truncation=True,
            max_length=200,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding["attention_mask"],
            "labels": torch.tensor([ex["labels"] for ex in examples]),
        }

In [11]:
import evaluate

roc_auc_evaluate = evaluate.load("roc_auc")
acc_evaluate = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    logits, labels = map(torch.tensor, eval_pred)
    probs = torch.nn.functional.softmax(logits, dim=1)[:, 1]  # label=1の確率
    pred_labels = torch.argmax(logits, dim=1)  # 予測ラベル
    return {
        **roc_auc_evaluate.compute(prediction_scores=probs, references=labels),
        **acc_evaluate.compute(predictions=pred_labels, references=labels),
    }

In [12]:
training_args = TrainingArguments(
    output_dir=f"../results/{yyyymmddhhmmss}",
    num_train_epochs=10,
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=1.0,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    remove_unused_columns=False,
)



In [13]:
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["valid"],
    tokenizer=tokenizer,
    data_collator=TokenizeCollator(tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Roc Auc,Accuracy
1,1.1623,0.371995,0.949013,0.925926
2,0.4494,0.490962,0.962993,0.925926
3,0.4034,0.310664,0.967105,0.944444
4,0.3103,0.382418,0.972862,0.925926
5,0.2625,0.374812,0.972862,0.925926
6,0.2323,0.381211,0.978618,0.925926
7,0.2002,0.311052,0.978618,0.944444
8,0.1574,0.325294,0.978618,0.925926
9,0.1432,0.312975,0.978618,0.925926
10,0.1332,0.313705,0.978618,0.925926


TrainOutput(global_step=1620, training_loss=0.34542546684359327, metrics={'train_runtime': 99.7472, 'train_samples_per_second': 16.241, 'train_steps_per_second': 16.241, 'total_flos': 195863262167040.0, 'train_loss': 0.34542546684359327, 'epoch': 10.0})

In [21]:
compute_metrics((trainer.predict(ds["test"]).predictions, ds["test"]["labels"]))

{'roc_auc': np.float64(0.9835526315789473), 'accuracy': 0.9074074074074074}