In [None]:
! pip install datasets transformers[torch] evaluate

In [None]:
from datasets import load_dataset
from transformers import( 
    AutoTokenizer,
    LongformerForSequenceClassification,
    Trainer,
    TrainingArguments
)

import numpy as np
import evaluate


In [None]:
dataset_name = "cw1521/ember2018-malware"
model_checkpoint = "allenai/longformer-base-4096"
model_name = "ma-ember-1"

In [None]:
dataset = load_dataset(
    dataset_name
)

Display the dataset

In [None]:
dataset

In [None]:
cols = [
    "subset", 
    "sha256",
    "appeared",
    "x",
    "label",
    "avclass"
]


dataset = dataset.remove_columns(cols)
dataset

In [None]:
dataset["train"][0]

In [None]:
dataset = dataset["train"].train_test_split(test_size=0.2)

train_ds = dataset["train"]
valid_ds = dataset["test"]

In [None]:

model = LongformerForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=1
)
# model.config

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Set max_input_length, max_output_length, and batch_size

In [None]:
max_input_length = 4096
max_output_length = 512
batch_size = 1

In [None]:
def process_data_to_model_inputs(batch):
    inputs = tokenizer(
        batch["input"],
        truncation=True,
        padding="max_length",
        max_length=max_input_length
    )
    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]
    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = batch["y"]
    return batch

Convert the dataset to torch

In [None]:
train = train_ds.map(
    process_data_to_model_inputs,
    batch_size=batch_size,
    batched=True,
    remove_columns=["input", "y"]
)

In [None]:
valid = valid_ds.map(
    process_data_to_model_inputs,
    batch_size=batch_size,
    batched=True,
    remove_columns=["input", "y"]
)

In [None]:
train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"]
)
valid.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"]
)

Metrics

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

Train the Model

In [None]:
training_args = TrainingArguments(
    model_name,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4,
    weight_decay=0.001,
    fp16=True,
    logging_dir='./logs',
    save_steps=100,
    save_total_limit=3,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    num_train_epochs=1
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train,
    eval_dataset=valid,
)

In [None]:
trainer.train()
trainer.save_model()
trainer.save_state()