In [25]:
from transformers import BartModel, BartTokenizer

MODEL_NAME = "lucadiliello/bart-small"
DATASET_NAME = "evgenesh/java-obfuscation"
TRAIN_COEF = 0.2

In [26]:
tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)

model = BartModel.from_pretrained(MODEL_NAME)

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1'}. The number of labels wil be overwritten to 2.
Some weights of the model checkpoint at lucadiliello/bart-small were not used when initializing BartModel: ['final_logits_bias', 'lm_head.weight']
- This IS expected if you are initializing BartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [27]:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME)["train"]

Found cached dataset json (/Users/eshevlyakov/.cache/huggingface/datasets/evgenesh___json/evgenesh--java-obfuscation-4846253e475242bb/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/1 [00:00<?, ?it/s]

In [28]:
def create_inputs(dataset, tokenizer, test_size=TRAIN_COEF):
    train_dataset, test_dataset = dataset.train_test_split(test_size).values()

    tokenize = lambda x: tokenizer(
        text=x["decompiled"],
        text_target=x["source"],
        padding="max_length",
        truncation=True
    )

    tokenized_train_dataset = train_dataset.map(tokenize)
    tokenized_test_dataset = test_dataset.map(tokenize)

    return tokenized_train_dataset, tokenized_test_dataset

train_dataset, test_dataset = create_inputs(dataset, tokenizer)
print(len(train_dataset), len(test_dataset))

Map:   0%|          | 0/324 [00:00<?, ? examples/s]

Map:   0%|          | 0/81 [00:00<?, ? examples/s]

324 81


## Специфика обучения

In [29]:
OUTPUT_DIR = "bart-small"
LOG_DIR = "log-bart-small"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EPOCHS = 5

In [30]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    logging_dir=LOG_DIR,
    logging_steps=100,
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    save_total_limit=1,
)

In [31]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [32]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BartModel.forward` and have been ignored: decompiled, source, labels. If decompiled, source, labels are not expected by `BartModel.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 324
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 205
  Number of trainable parameters = 70402560


ValueError: The model did not return a loss from the inputs, only the following keys: last_hidden_state,past_key_values,encoder_last_hidden_state. For reference, the inputs it received are input_ids,attention_mask.

## Оценка модели

In [None]:
trainer.save_model(OUTPUT_DIR)

In [None]:
result = trainer.evaluate()
print(result)

## Картинки

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
training_logs = pd.read_csv(training_args.logging_dir + "/train_log.txt", sep="\t")
training_logs.head(5)

In [None]:
plt.plot(training_logs["epoch"], training_logs["eval_loss"], label="Validation Loss")
plt.plot(training_logs["epoch"], training_logs["eval_rouge2"], label="ROUGE-2 Score")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.legend()
plt.show()