In [1]:
from Transformer_Models import MashableBertModel
import torch
import torch.nn as nn

import numpy as np
import pandas as pd

In [2]:
import datasets
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [3]:
articles = datasets.load_dataset('online_news_popularity_data')

Found cached dataset online_news_popularity_data (/home/leeparkuky/.cache/huggingface/datasets/online_news_popularity_data/online_news_popularity_data/1.0.0/63eb244b62e86df6ad3ae3034fcbddd6ed2840885e607a97d5e8f49afab926e0)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
def concatenate_fernandes_variables(examples):
    fernandes = [val for key, val in examples.items() if key not in ['title','content','shares','shares_class']]
    fernandes = np.array(fernandes).T.tolist()
    return {'fernandes': fernandes}

In [5]:
articles_concat = articles.map(concatenate_fernandes_variables, batched = True, batch_size = 64, num_proc = 16,
                              remove_columns = [x for x in articles.column_names['train'] if x not in ['title','content','shares','shares_class']] )

Loading cached processed dataset at /home/leeparkuky/.cache/huggingface/datasets/online_news_popularity_data/online_news_popularity_data/1.0.0/63eb244b62e86df6ad3ae3034fcbddd6ed2840885e607a97d5e8f49afab926e0/cache-d040c91994313900_*_of_00016.arrow


In [6]:
def tokenize(examples):
    text = [title + content for title, content in zip(examples['title'], examples['content'])]
    return tokenizer(text, max_length = 512, truncation = True, padding = True)

In [7]:
articles_tokenized = articles_concat.map(tokenize, batched = True, batch_size = 64, num_proc = 16,
                   remove_columns = ['shares','title','content'])

Loading cached processed dataset at /home/leeparkuky/.cache/huggingface/datasets/online_news_popularity_data/online_news_popularity_data/1.0.0/63eb244b62e86df6ad3ae3034fcbddd6ed2840885e607a97d5e8f49afab926e0/cache-2a956d417e53e673_*_of_00016.arrow


In [8]:
articles_tokenized.set_format('pt')

In [9]:
articles_tokenized = articles_tokenized['train'].train_test_split(.2)

In [10]:
articles_tokenized

DatasetDict({
    train: Dataset({
        features: ['shares_class', 'fernandes', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 31686
    })
    test: Dataset({
        features: ['shares_class', 'fernandes', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 7922
    })
})

## Testing Base Model

In [11]:
from Transformer_Models import MashableBertForClassification

model = MashableBertForClassification('bert-base-uncased', 2)
model.base_model_load_weight('Model Weights/MashableBertModel_Pretrained.pth')

In [16]:
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, recall_score, precision_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds)
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    
    return {"accuracy": accuracy,
           "precision": precision,
           "recall" : recall,
           "f1" : f1,
           "auc" : auc}

In [17]:
from transformers import TrainingArguments, Trainer
import torch

training_args = TrainingArguments(
    output_dir="finetuning-mashablebert",
    overwrite_output_dir = True,
    evaluation_strategy="epoch",
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    fp16 = True,
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    push_to_hub=False,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim = 'adafactor'
)


trainer = Trainer(
    model=model.to(torch.device('cuda')),
    args=training_args,
    train_dataset=articles_tokenized["train"].shuffle(),
    eval_dataset=articles_tokenized["test"].shuffle(),
#     compute_metrics = compute_metrics
)



In [24]:
trainer.train()

Epoch,Training Loss,Validation Loss
0,0.7039,No log
1,0.6992,No log
2,0.6962,No log
4,0.6934,No log



KeyboardInterrupt

