In [1]:
import pandas as pd
import os
import re
import numpy as np
from params.paths import ROOT_DIR
from api_requests.meeting_convo_collector import MeetingConvoCollector
from file_handling.file_read_writer import read_json, write_json, create_dir, write_file

In [2]:
#params
DATA_DIR = os.path.join(ROOT_DIR, 'data')
RESOURCE_DIR = os.path.join(ROOT_DIR, 'resource')

In [3]:
#loading the labels corresponding to the label ids
label2id = read_json(os.path.join(RESOURCE_DIR, "labels.json"))
id2label = {v:k for k, v in label2id.items()}
label2id = {k: int(v) for k, v in label2id.items()}
for k, v in label2id.items():
    print(f"{k}=>{v}")
print("___________________________________________________________")
for k, v in id2label.items():
	print(f"{k}=>{v}")

事実文=>0
質問文=>1
説明文=>2
意見文=>3
その他=>4
___________________________________________________________
0=>事実文
1=>質問文
2=>説明文
3=>意見文
4=>その他


# Preparing dataset

In [4]:
from datasets import load_dataset
PATH_TO_DATA_FILE = os.path.join(DATA_DIR, "labelled_data_77.csv")
dataset = load_dataset('csv', data_files=PATH_TO_DATA_FILE, split="train").train_test_split(test_size=0.2)
dataset = dataset.filter(lambda example: bool(example["speech"])).filter(lambda example: len(example['speech'])<1000)
dataset.rename_column("label", "label_name")
dataset = dataset.rename_columns({"label":"label_name", "label_id":"label"})
print(dataset)
print(dataset["train"][199])

Found cached dataset csv (/home/ubuntu/.cache/huggingface/datasets/csv/default-fd254dbca646b115/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


Filter:   0%|          | 0/1150 [00:00<?, ? examples/s]

Filter:   0%|          | 0/288 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1142 [00:00<?, ? examples/s]

Filter:   0%|          | 0/285 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['speech', 'label_name', 'label', 'record_position'],
        num_rows: 1137
    })
    test: Dataset({
        features: ['speech', 'label_name', 'label', 'record_position'],
        num_rows: 285
    })
})
{'speech': 'こうした国民の皆様の悲痛な声に胸が痛みませんか', 'label_name': 'その他', 'label': 4, 'record_position': 27}


In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v3")

def tokenize_function(examples):
    result = tokenizer(examples["speech"], truncation=True, padding=True)
    return result

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/1137 [00:00<?, ? examples/s]

Map:   0%|          | 0/285 [00:00<?, ? examples/s]

# Training

In [6]:
import evaluate

accuracy = evaluate.load("accuracy")

In [7]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [8]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
    "cl-tohoku/bert-base-japanese-v3", num_labels=5, id2label=id2label, label2id=label2id
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
training_args = TrainingArguments(
    output_dir="jp-speech-classifier",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    label_names=['labels'],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.104806,0.677193
2,No log,1.189505,0.705263


TrainOutput(global_step=144, training_loss=0.10924084319008721, metrics={'train_runtime': 64.1028, 'train_samples_per_second': 35.474, 'train_steps_per_second': 2.246, 'total_flos': 458096908471776.0, 'train_loss': 0.10924084319008721, 'epoch': 2.0})

In [30]:
trainer.push_to_hub()

Cloning https://huggingface.co/kkatodus/jp-speech-classifier into local empty directory.


Upload file pytorch_model.bin:   0%|          | 1.00/424M [00:00<?, ?B/s]

Upload file training_args.bin:   0%|          | 1.00/3.87k [00:00<?, ?B/s]

To https://huggingface.co/kkatodus/jp-speech-classifier
   7f8d406..8e83e6c  main -> main

To https://huggingface.co/kkatodus/jp-speech-classifier
   8e83e6c..14060e9  main -> main



'https://huggingface.co/kkatodus/jp-speech-classifier/commit/8e83e6c518f52feacec869c6a8b4a562775aa410'

# Testing on single samples

In [21]:
from datasets import Dataset
"""
事実文=>0
質問文=>1
説明文=>2
意見文=>3
その他=>4
___________________________________________________________
0=>事実文
1=>質問文
2=>説明文
3=>意見文
4=>その他
"""
['speech', 'label_name', 'label', 'record_position'],
ds = Dataset.from_dict({'speech':['バナナはおやつに入りますか', "私は防衛費を増額するべきだと思います", "日本の人口はここ10年右肩下がりです", "暇暇暇暇暇なんだよ", "私は鎌田の最も適切な移籍先はバルセロナだと思っています"], 'label':[1, 3, 0, 4, 3]})
tokenized_ds = ds.map(tokenize_function, batched=True)
predictions = trainer.predict(tokenized_ds)


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Dataset({
    features: ['speech', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 5
})


PredictionOutput(predictions=array([[-1.0185171 ,  2.1610246 , -0.7538159 , -0.53144586, -0.10285759],
       [-1.98192   , -0.6809551 , -0.59336185,  3.3562343 , -1.572317  ],
       [ 1.3550638 , -2.3506474 ,  1.9632782 , -0.73973805, -0.94014865],
       [-0.5290267 , -0.9143827 , -0.14919591, -0.92953813,  2.5128245 ],
       [-0.4566768 , -1.661125  ,  0.0658606 ,  0.6682328 ,  0.3735632 ]],
      dtype=float32), label_ids=array([1, 3, 0, 4, 3]), metrics={'test_loss': 0.515464723110199, 'test_accuracy': 0.8, 'test_runtime': 0.0141, 'test_samples_per_second': 354.5, 'test_steps_per_second': 70.9})