In [2]:
import numpy as np
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer, EarlyStoppingCallback, RobertaConfig
from datasets import load_metric, load_dataset, load_from_disk
import torch
from transformers import RobertaForSequenceClassification, AutoTokenizer
from graphcodebert import GraphCodeBert
import gc
from knockknock import discord_sender

# LOAD DATASET DICT

In [3]:
train_dataset = load_from_disk('./data/train_dataset_lv1')
valid_dataset = load_from_disk('./data/valid_dataset_lv1')

# SET ARGS

In [None]:
MODEL = "microsoft/graphcodebert-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.truncation_side = 'left'

config = RobertaConfig.from_pretrained(MODEL)
config.num_labels=1

model = RobertaForSequenceClassification.from_pretrained(MODEL)
model = GraphCodeBert(model, config=config , tokenizer=tokenizer)
model.to(device)

In [None]:
_collator = DataCollatorWithPadding(tokenizer=tokenizer)
_metric = load_metric("glue", "sst2")

In [None]:
def metric_fn(p):
    preds, labels = p
    output =  _metric.compute(references=labels, predictions=np.argmax(preds, axis=-1))
    return output

In [None]:
args = TrainingArguments(
    output_dir='./models/',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    disable_tqdm = False,
    do_train=True,
    do_eval=True,
    save_strategy="steps",
    logging_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=500,
    learning_rate=1e-5,
    optim='adamw_torch',
    # metric_for_best_model= "f1",
    save_total_limit=5,
    load_best_model_at_end=True,
)

trainer = Trainer(
        model=model,
        args=args,
        data_collator=_collator,
        train_dataset=dataset_dict['train'],
        eval_dataset=dataset_dict['valid'],
        tokenizer=tokenizer,
        compute_metrics= metric_fn,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]
)

# DO TRAIN

In [None]:
webhook_url='https://discord.com/api/webhooks/981021972697858078/cKpZXsyxyFGptLsMiFfWdEbjwavkO0qgkgWGW3fyYeBxMkJFebDq9U5M4vgDibgM3Ew6'

@discord_sender(webhook_url=webhook_url)
def do_train():
    gc.collect()
    torch.cuda.empty_cache()
    trainer.train()

In [None]:
do_train()