In [1]:
import numpy as np
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer, EarlyStoppingCallback
from datasets import load_metric, load_dataset, load_from_disk
import torch
from transformers import RobertaForSequenceClassification, AutoTokenizer

In [2]:
train_dataset = load_from_disk('./data/train_dataset_lv3')
valid_dataset = load_from_disk('./data/valid_dataset_lv3')

In [3]:
MODEL = "microsoft/graphcodebert-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RobertaForSequenceClassification.from_pretrained("./models/lv1/checkpoint-8000")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.truncation_side = 'left'

In [4]:
_collator = DataCollatorWithPadding(tokenizer=tokenizer)
_metric = load_metric("glue", "sst2")

In [5]:
def metric_fn(p):
    preds, labels = p
    output =  _metric.compute(references=labels, predictions=np.argmax(preds, axis=-1))
    return output

In [6]:
args = TrainingArguments(
    output_dir='./models/lv1-lv3/',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    disable_tqdm = False,
    do_train=True,
    do_eval=True,
    save_strategy="steps",
    logging_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=500,
    learning_rate=1e-5,
    optim='adamw_torch',
    # metric_for_best_model= "f1",
    save_total_limit=5,
    load_best_model_at_end=True,
)

trainer = Trainer(
        model=model,
        args=args,
        data_collator=_collator,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        tokenizer=tokenizer,
        compute_metrics= metric_fn,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
)

In [7]:
import gc
from knockknock import discord_sender

webhook_url='https://discord.com/api/webhooks/981021972697858078/cKpZXsyxyFGptLsMiFfWdEbjwavkO0qgkgWGW3fyYeBxMkJFebDq9U5M4vgDibgM3Ew6'

@discord_sender(webhook_url=webhook_url)
def do_train():
    gc.collect()
    torch.cuda.empty_cache()
    trainer.train()

In [8]:
do_train()

***** Running training *****
  Num examples = 90000
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 8436


Step,Training Loss,Validation Loss,Accuracy
500,0.4774,2.855442,0.502778
1000,0.0974,4.111362,0.519111


***** Running Evaluation *****
  Num examples = 9000
  Batch size = 64
Saving model checkpoint to ./models/lv1-lv3/checkpoint-500
Configuration saved in ./models/lv1-lv3/checkpoint-500/config.json
Model weights saved in ./models/lv1-lv3/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ./models/lv1-lv3/checkpoint-500/tokenizer_config.json
Special tokens file saved in ./models/lv1-lv3/checkpoint-500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 9000
  Batch size = 64
Saving model checkpoint to ./models/lv1-lv3/checkpoint-1000
Configuration saved in ./models/lv1-lv3/checkpoint-1000/config.json
Model weights saved in ./models/lv1-lv3/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in ./models/lv1-lv3/checkpoint-1000/tokenizer_config.json
Special tokens file saved in ./models/lv1-lv3/checkpoint-1000/special_tokens_map.json


KeyboardInterrupt: 

In [None]:
MAX_LEN = 512
def example_fn(examples):
    outputs = tokenizer(examples['code1'], examples['code2'], padding=True, max_length=MAX_LEN,truncation=True)
    if 'similar' in examples:
        outputs["labels"] = examples["similar"]
    return outputs

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

import pandas as pd

TEST = "./data/test.csv"
SUB = "./data/sample_submission.csv"

test_dataset = load_dataset("csv", data_files=TEST)["train"]
test_dataset = test_dataset.map(example_fn, remove_columns=["code1", "code2"])

predictions = trainer.predict(test_dataset)

df = pd.read_csv(SUB)
df["similar"] = np.argmax(predictions.predictions, axis=-1)
df.to_csv("./submissions/submission_lv3.csv", index=False)