# Argument Quality Scorer

## Import packages

In [1]:
import pandas as pd
import torch
from argsum.tools import train_quality_scorer, eval_quality_scorer

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.
loading configuration file config.json from cache at /Users/moritz/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be/config.json
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_attentions": true,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.33.3",
  "voc

## Load datasets

In [2]:
# Load ArgKP-2021 
ArgKP21 = pd.read_csv('data/ArgKP-2021/dataset_splits_scores.csv')
ArgKP21_train_topics = ArgKP21[ArgKP21['set'] == 'train']['topic'].unique()
ArgKP21_dev_topics = ArgKP21[ArgKP21['set'] == 'dev']['topic'].unique()
ArgKP21_test_topics = ArgKP21[ArgKP21['set'] == 'test']['topic'].unique()

# Load ArgQ dataset
ArgQ = pd.read_csv('data/IBM-ArgQ-Rank-30kArgs/dataset_scores.csv')
ArgQ_train = ArgQ[ArgQ['topic'].isin(ArgKP21_train_topics)]
ArgQ_dev = ArgQ[ArgQ['topic'].isin(ArgKP21_dev_topics)]

test_set_mask = (ArgQ['set'] == 'test') & (~ArgQ['topic'].isin(ArgKP21_train_topics.tolist() + ArgKP21_dev_topics.tolist()))
ArgQ_test = ArgQ[test_set_mask]

## Train, development and test topics

In [6]:
print('Train Topics:\n')

for topic in ArgQ_train['topic'].unique():
    print(topic)

Train Topics:

We should abandon marriage
Assisted suicide should be a criminal offence
We should prohibit women in combat
We should abolish capital punishment
We should legalize sex selection
We should ban human cloning
We should fight urbanization
We should introduce compulsory voting
We should fight for the abolition of nuclear weapons
We should adopt libertarianism
We should prohibit flag burning
We should legalize cannabis
We should adopt atheism
We should ban private military companies
Homeschooling should be banned
We should legalize prostitution
We should end mandatory retirement
We should abolish intellectual property rights
The vow of celibacy should be abandoned
We should subsidize vocational education
We should close Guantanamo Bay detention camp
We should ban the use of child actors
We should subsidize journalism
We should subsidize space exploration


In [7]:
print('Train Topics:\n')

for topic in ArgQ_dev['topic'].unique():
    print(topic)

Train Topics:

We should abolish the right to keep and bear arms
We should abandon the use of school uniform
We should adopt an austerity regime
We should end affirmative action


In [8]:
print('Train Topics:\n')

for topic in ArgQ_test['topic'].unique():
    print(topic)

Train Topics:

Holocaust denial should be a criminal offence
The use of public defenders should be mandatory
We should ban factory farming
We should abolish the three-strikes laws
We should prohibit school prayer
We should ban targeted killing
Social media brings more harm than good
We should ban algorithmic trading
We should ban missionary work
We should abolish the Olympic Games


In [10]:
set(ArgQ_train['topic'].unique()).intersection(set(ArgQ_dev['topic'].unique())).intersection(set(ArgQ_test['topic'].unique()))

set()

## Train model for different settings

In [5]:
for p in [True, False]:
    for t in ['MACE-P', 'WA']:
        train_quality_scorer(ArgQ_train, ArgQ_dev, target_name = t, pooling = p)

Map:   0%|          | 0/10324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1775 [00:00<?, ? examples/s]

{'loss': 0.1541, 'learning_rate': 1.7523219814241487e-05, 'epoch': 0.62}
{'eval_loss': 0.1050514429807663, 'eval_rmse': 0.34775906801223755, 'eval_p_corr': 0.47317079134091194, 'eval_s_corr': 0.44035769803739805, 'eval_runtime': 6.8048, 'eval_samples_per_second': 260.847, 'eval_steps_per_second': 8.23, 'epoch': 0.62}
{'loss': 0.0996, 'learning_rate': 1.5046439628482974e-05, 'epoch': 1.24}
{'eval_loss': 0.09904874861240387, 'eval_rmse': 0.3462628424167633, 'eval_p_corr': 0.49913317661069634, 'eval_s_corr': 0.4594649161342128, 'eval_runtime': 6.4327, 'eval_samples_per_second': 275.934, 'eval_steps_per_second': 8.706, 'epoch': 1.24}
{'loss': 0.089, 'learning_rate': 1.256965944272446e-05, 'epoch': 1.86}
{'eval_loss': 0.10298730432987213, 'eval_rmse': 0.3452208936214447, 'eval_p_corr': 0.4911572815096422, 'eval_s_corr': 0.4573436804802352, 'eval_runtime': 8.4149, 'eval_samples_per_second': 210.936, 'eval_steps_per_second': 6.655, 'epoch': 1.86}
{'loss': 0.0702, 'learning_rate': 1.0092879256

Map:   0%|          | 0/10324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1775 [00:00<?, ? examples/s]

{'loss': 0.0847, 'learning_rate': 1.7523219814241487e-05, 'epoch': 0.62}
{'eval_loss': 0.028995277360081673, 'eval_rmse': 0.21972905099391937, 'eval_p_corr': 0.49464174560467067, 'eval_s_corr': 0.42978126113975124, 'eval_runtime': 8.0713, 'eval_samples_per_second': 219.914, 'eval_steps_per_second': 6.938, 'epoch': 0.62}
{'loss': 0.0291, 'learning_rate': 1.5046439628482974e-05, 'epoch': 1.24}
{'eval_loss': 0.02699718251824379, 'eval_rmse': 0.21290089190006256, 'eval_p_corr': 0.5137796709638508, 'eval_s_corr': 0.44535175357629303, 'eval_runtime': 8.2756, 'eval_samples_per_second': 214.486, 'eval_steps_per_second': 6.767, 'epoch': 1.24}
{'loss': 0.0268, 'learning_rate': 1.256965944272446e-05, 'epoch': 1.86}
{'eval_loss': 0.026930633932352066, 'eval_rmse': 0.2128283679485321, 'eval_p_corr': 0.5117600348155743, 'eval_s_corr': 0.44420088938984204, 'eval_runtime': 7.5506, 'eval_samples_per_second': 235.081, 'eval_steps_per_second': 7.417, 'epoch': 1.86}
{'loss': 0.0216, 'learning_rate': 1.009

Map:   0%|          | 0/10324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1775 [00:00<?, ? examples/s]

{'loss': 0.1135, 'learning_rate': 1.7523219814241487e-05, 'epoch': 0.62}
{'eval_loss': 0.10458976030349731, 'eval_rmse': 0.32340338826179504, 'eval_p_corr': 0.46305680101238283, 'eval_s_corr': 0.42896611670875895, 'eval_runtime': 8.1778, 'eval_samples_per_second': 217.05, 'eval_steps_per_second': 6.848, 'epoch': 0.62}
{'loss': 0.0941, 'learning_rate': 1.5046439628482974e-05, 'epoch': 1.24}
{'eval_loss': 0.10314378142356873, 'eval_rmse': 0.32116004824638367, 'eval_p_corr': 0.49076064534666997, 'eval_s_corr': 0.44830432480488647, 'eval_runtime': 7.4569, 'eval_samples_per_second': 238.035, 'eval_steps_per_second': 7.51, 'epoch': 1.24}
{'loss': 0.0808, 'learning_rate': 1.256965944272446e-05, 'epoch': 1.86}
{'eval_loss': 0.10578076541423798, 'eval_rmse': 0.32523953914642334, 'eval_p_corr': 0.487155324613833, 'eval_s_corr': 0.44520598982936155, 'eval_runtime': 7.9739, 'eval_samples_per_second': 222.6, 'eval_steps_per_second': 7.023, 'epoch': 1.86}
{'loss': 0.0637, 'learning_rate': 1.00928792

Map:   0%|          | 0/10324 [00:00<?, ? examples/s]

Map:   0%|          | 0/1775 [00:00<?, ? examples/s]

{'loss': 0.0319, 'learning_rate': 1.7523219814241487e-05, 'epoch': 0.62}
{'eval_loss': 0.027680937200784683, 'eval_rmse': 0.16637589037418365, 'eval_p_corr': 0.48942309773966325, 'eval_s_corr': 0.4316377085346035, 'eval_runtime': 8.1486, 'eval_samples_per_second': 217.828, 'eval_steps_per_second': 6.872, 'epoch': 0.62}
{'loss': 0.0252, 'learning_rate': 1.5046439628482974e-05, 'epoch': 1.24}
{'eval_loss': 0.02707560360431671, 'eval_rmse': 0.1645466536283493, 'eval_p_corr': 0.5041953715495292, 'eval_s_corr': 0.42840610164722454, 'eval_runtime': 7.5742, 'eval_samples_per_second': 234.347, 'eval_steps_per_second': 7.393, 'epoch': 1.24}
{'loss': 0.0229, 'learning_rate': 1.256965944272446e-05, 'epoch': 1.86}
{'eval_loss': 0.027422746643424034, 'eval_rmse': 0.16559815406799316, 'eval_p_corr': 0.5000947781225905, 'eval_s_corr': 0.42685444864246236, 'eval_runtime': 7.967, 'eval_samples_per_second': 222.795, 'eval_steps_per_second': 7.029, 'epoch': 1.86}
{'loss': 0.0181, 'learning_rate': 1.00928

## Compare model performances

In [3]:
# Set model directories 
model_dirs = ['models/quality_scorer/bert_ft_topic_np_mace-p/2024-May-29_15-58-03',
              'models/quality_scorer/bert_ft_topic_np_wa/2024-May-29_16-10-44',
              'models/quality_scorer/bert_ft_topic_p_mace-p/2024-May-29_15-30-54',
              'models/quality_scorer/bert_ft_topic_p_wa/2024-May-29_15-43-23']
# Evaluate models 
for dir in model_dirs:
    model = torch.load(dir + '/best_model.pt')
    eval_quality_scorer(model, ArgQ_test, output_dir = dir)

                                                                     4.54it/s]

In [4]:
# Evaluate debater api
eval_quality_scorer('debater_api', ArgQ_test, output_dir = 'models/quality_scorer/debater_api')

ArgumentQualityClient: 100%|██████████| 4216/4216 [00:48<00:00, 86.23it/s] 


In [11]:
model_dirs = ['models/quality_scorer/bert_ft_topic_np_mace-p/2024-May-29_15-58-03',
              'models/quality_scorer/bert_ft_topic_np_wa/2024-May-29_16-10-44',
              'models/quality_scorer/bert_ft_topic_p_mace-p/2024-May-29_15-30-54',
              'models/quality_scorer/bert_ft_topic_p_wa/2024-May-29_15-43-23']

# Load evaluation reports
np_macep = pd.read_csv(model_dirs[0] + '/eval_report.csv')
np_wa = pd.read_csv(model_dirs[1] + '/eval_report.csv')
p_macep = pd.read_csv(model_dirs[2] + '/eval_report.csv')
p_wa = pd.read_csv(model_dirs[3] + '/eval_report.csv')
debater_api = pd.read_csv('models/quality_scorer/debater_api/eval_report.csv')
# Set model names
model_names = ['np_mace-p', 'np_wa', 'p_mace-p', 'p_wa', 'debater_api']
# Print concatenated evaluation reports
results = pd.concat([np_macep, np_wa, p_macep, p_wa, debater_api])
results.index = model_names
results.round(3)

Unnamed: 0,macep_rmse_eval,macep_p_corr_eval,macep_s_corr_eval,wa_rmse_eval,wa_p_corr_eval,wa_s_corr_eval,runtime
np_mace-p,0.33,0.494,0.467,0.242,0.479,0.427,31.761
np_wa,0.4,0.471,0.442,0.169,0.474,0.409,30.719
p_mace-p,0.355,0.482,0.457,0.23,0.469,0.418,30.874
p_wa,0.374,0.459,0.434,0.212,0.455,0.396,31.284
debater_api,0.427,0.502,0.492,0.176,0.518,0.461,48.896
