In [None]:
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
from nlp import load_dataset
import nlp
import torch
import numpy as np
import random
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

In [None]:
RANDOM_SEED = 50
MAX_LEN = 512
EPOCHS=4
if torch.cuda.device_count() > 1:
    BATCH_SIZE = torch.cuda.device_count() * 8
else:
    BATCH_SIZE = 8

In [None]:
dataset = load_dataset('csv', data_files='./reviews.csv', split='train')
shuffled_ds = dataset.shuffle(RANDOM_SEED)
split_ds = shuffled_ds.train_test_split(test_size=0.2)
train_dataset = split_ds['train']
test_val_dataset = split_ds['test']
split_tv = test_val_dataset.train_test_split(test_size=0.5)
test_dataset = split_tv['train']
val_dataset = split_tv['test']


In [None]:
def sample(ds, ratio):
    return random.sample(range(0, len(ds)), int(len(ds) * ratio)) 
'''
ds_names = ['train', 'test', 'val']
for name in ds_names:
    ds_name = eval("{}_dataset".format(name))
    ds_name = ds_name.select(indices=sample(ds_name, 0.1))
    # train_dataset.select(indices=sample(train_dataset, 0.1))
'''
train_dataset = train_dataset.select(indices=sample(train_dataset, 0.1))
test_dataset = test_dataset.select(indices=sample(test_dataset, 0.1))
val_dataset = val_dataset.select(indices=sample(val_dataset, 0.1))


In [None]:
def tokenize(batch):
    return tokenizer(batch['content'], max_length=MAX_LEN, padding=True, truncation=True)

train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
val_dataset = val_dataset.map(tokenize, batched=True, batch_size=len(val_dataset))

train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=500,
    weight_decay=0.01,
    evaluate_during_training=True,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
%load_ext tensorboard
%tensorboard --logdir './logs'