# Task 1 Normalization Analysis

In [None]:
import re
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils import load_data

matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 28

In [None]:
cfg = {
    "data_dir": "../data/PubMed_200k_RCT",
}
train_df = load_data(cfg=cfg, split_type='train', cached=True)
dev_df = load_data(cfg=cfg, split_type='dev', cached=True)
test_df = load_data(cfg=cfg, split_type='test', cached=True)

In [None]:
matcher_num = re.compile(r"\b[0-9]+\b")
def count_nums_binary(s):
    return 1 if matcher_num.match(s) is not None else 0

matcher_nct = re.compile(r"nct[0-9]+")
matcher_isrctn = re.compile(r"isrctn[0-9]+")
matcher_ntr = re.compile(r"ntr[0-9]+")
def count_paper_ids_binary(s):
    matches = [
       matcher_nct.match(s),
       matcher_isrctn.match(s),
       matcher_ntr.match(s),
    ]
    return int(any([match is not None for match in matches]))

matcher_units = re.compile(r"[0-9]+(mg|ml|mw|ms|mmhg|min|months|years)")
def count_units_binary(s):
    return 1 if matcher_units.match(s) is not None else 0

matcher_technical_abbr = re.compile(r"[0-9]*[a-z]+[0-9]+[a-z]*[0-9]*")
def count_technical_abbr_binary(s):
    return 1 if matcher_technical_abbr.match(s) is not None else 0

train_df["num_count"] = train_df["cleaned_text"].apply(count_nums_binary)
train_df["paper_id_count"] = train_df["cleaned_text"].apply(count_paper_ids_binary)
train_df["units_count"] = train_df["cleaned_text"].apply(count_units_binary)
train_df["technical_abbr_count"] = train_df["cleaned_text"].apply(count_technical_abbr_binary)

In [None]:
def by_class_dist(df, column):
    absolute_dist = df[["label", column]].groupby("label").sum().to_dict()[column]
    total_count = sum(absolute_dist.values())
    return {label: absolute_count / total_count for label, absolute_count in absolute_dist.items()}

In [None]:
fig, ax = plt.subplots(figsize=(10,6))
fig.set_facecolor("white")

classes_caps = train_df["label"].unique().tolist()
classes = [s[0] for s in classes_caps]
dist_labels= ["Baseline", "Numbers", "Paper IDs", "Units", "Tech. Abbr."][::-1]
bar_width = 0.6

class_dist = train_df["label"].value_counts(normalize=True).to_dict()
num_class_dist = by_class_dist(train_df, "num_count")
paper_id_class_dist = by_class_dist(train_df, "paper_id_count")
units_class_dist = by_class_dist(train_df, "units_count")
technical_abbr_class_dist = by_class_dist(train_df, "technical_abbr_count")

class_dist_left = 0.0
num_class_dist_left = 0.0
paper_id_class_dist_left = 0.0
units_class_dist_left = 0.0
technical_abbr_class_dist_left = 0.0
for i in range(len(classes)):
    ax.barh(
        dist_labels,
        [
            class_dist[classes_caps[i]],
            num_class_dist[classes_caps[i]],
            paper_id_class_dist[classes_caps[i]],
            units_class_dist[classes_caps[i]],
            technical_abbr_class_dist[classes_caps[i]]
        ][::-1],
        bar_width,
        left=[
            class_dist_left,
            num_class_dist_left,
            paper_id_class_dist_left,
            units_class_dist_left,
            technical_abbr_class_dist_left,
        ][::-1],
        label=classes[i]
    )
    class_dist_left += class_dist[classes_caps[i]]
    num_class_dist_left += num_class_dist[classes_caps[i]]
    paper_id_class_dist_left += paper_id_class_dist[classes_caps[i]]
    units_class_dist_left += units_class_dist[classes_caps[i]]
    technical_abbr_class_dist_left += technical_abbr_class_dist[classes_caps[i]]

ax.set_xlabel('Relative Document Frequency')
ax.set_xticks(np.linspace(0.0, 1.0, 11).tolist()[::2])
ax.set_yticklabels(dist_labels, rotation=45)
ax.set_title('Token Group Frequencies')
ax.legend(loc="lower right", ncol=len(classes), bbox_to_anchor=(-0.15,-0.45,1.2,0.2), mode="expand")

plt.show()