In [1]:
!CUDA_LAUNCH_BLOCKING=1

In [2]:
from datasets import load_dataset
data_dir = "/home/mfclinton/Documents/Repos/iclr-discourse-dataset/review_rebuttal_pair_dataset/"
dataset = load_dataset("json", field="review_rebuttal_pairs", data_files={"train": data_dir + "traindev_train.json","validation": data_dir + "traindev_dev.json"})

Using custom data configuration default
Reusing dataset json (/home/mfclinton/.cache/huggingface/datasets/json/default-77beb580bf249ad2/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514)


In [3]:
dataset.shape

{'train': (2148, 9), 'validation': (727, 9)}

In [4]:
dataset.num_columns

{'train': 9, 'validation': 9}

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['index', 'review_sid', 'rebuttal_sid', 'review_text', 'rebuttal_text', 'title', 'review_author', 'forum', 'labels'],
        num_rows: 2148
    })
    validation: Dataset({
        features: ['index', 'review_sid', 'rebuttal_sid', 'review_text', 'rebuttal_text', 'title', 'review_author', 'forum', 'labels'],
        num_rows: 727
    })
})

In [6]:
# dataset["train"]["review_text"][0]
dataset["train"].format

{'type': None,
 'format_kwargs': {},
 'columns': ['index',
  'review_sid',
  'rebuttal_sid',
  'review_text',
  'rebuttal_text',
  'title',
  'review_author',
  'forum',
  'labels'],
 'output_all_columns': False}

In [7]:
meme = {}
def flatten_text(examples):
    data = []
    for block in examples["review_text"]:
        for paragraph in block:
            data += paragraph
    meme[examples["labels"]["rating"]] = None        
    if(not (1 <= examples["labels"]["rating"] <= 10)):
        print(examples["labels"]["rating"])
        1/0
    return {"review_text": data, "label": examples["labels"]["rating"]}

In [8]:
column_names = dataset["train"].column_names
column_names.remove("review_text")

updated_dataset = dataset.map(flatten_text, remove_columns=column_names)
print(meme)

Loading cached processed dataset at /home/mfclinton/.cache/huggingface/datasets/json/default-77beb580bf249ad2/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-98645dcb4474d212.arrow
Loading cached processed dataset at /home/mfclinton/.cache/huggingface/datasets/json/default-77beb580bf249ad2/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-6aaeb1c5aa0eeba3.arrow


{7: None, 6: None}


In [9]:
dataset["train"][0]["labels"]

{'rating': 7, 'confidence': 4}

In [10]:
model_checkpoint = 'bert-base-uncased'
num_labels = 11

In [11]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [12]:
# from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [13]:
def preprocess_function(examples):
    tokens = tokenizer(examples["review_text"], is_split_into_words=True, truncation=True)
    return tokens

In [14]:
# preprocess_function(updated_dataset['train'][:5])
encoded_dataset = updated_dataset.map(preprocess_function)

Loading cached processed dataset at /home/mfclinton/.cache/huggingface/datasets/json/default-77beb580bf249ad2/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-24e02996c516659f.arrow
Loading cached processed dataset at /home/mfclinton/.cache/huggingface/datasets/json/default-77beb580bf249ad2/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-45f7b225c6f3ed9f.arrow


In [15]:
len(encoded_dataset["train"]["review_text"][3])

442

In [16]:
encoded_dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'review_text', 'token_type_ids'],
        num_rows: 2148
    })
    validation: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'review_text', 'token_type_ids'],
        num_rows: 727
    })
})

In [17]:
from transformers import TrainingArguments, Trainer

batch_size = 3
args = TrainingArguments(
    "test-dir",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    metric_for_best_model="accuracy"
)

In [18]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [19]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


In [None]:
trainer.evaluate()