In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import glob
import re
import seaborn as sns
from tqdm.notebook import tqdm
from collections import Counter
from functools import reduce

import sys
sys.path.append("..")
from data_preparation.data_preparation_pos import read_conll
import utils.utils as utils
import utils.pos_utils as pos_utils

In [2]:
def multi_merge(left, right, r):
    left_cols = ["Tag"] + list(filter(r.match, left.columns))
    right_cols = ["Tag"] + list(filter(r.match, right.columns))
    return pd.merge(left[left_cols], right[right_cols], on="Tag", suffixes=(None, "_{}".format(len(left_cols)-1)))

In [3]:
def calculate_total_tag_table(dfs):
    total = reduce(lambda left, right: multi_merge(left, right, re.compile("^Count($|_)")), dfs)
    total["Count"] = total.iloc[:,1:].apply(np.sum, axis=1)
    total = total.sort_values("Count", ascending=False).reset_index(drop=True)
    total = total.loc[:, ["Tag", "Count"]]
    total["Count(%)"] = total["Count"] / total["Count"].sum() * 100
    total["Cumulative(%)"] = total["Count(%)"].cumsum()
    
    return total

In [4]:
def tag_freq(info, output, lang_to_group):
    lang_name = info["lang_name"]
    file_path = info["file_path"]
    dataset = info["dataset"]
    group = lang_to_group[lang_name]
    
    conll_data = read_conll(file_path)
    tags = np.array(conll_data[2]).flatten().sum()
    df = pd.DataFrame(list(Counter(tags).items()), columns=["Tag", "Count"])
    df = df.sort_values("Count", ascending=False).reset_index(drop=True)
    df["Count(%)"] = df["Count"] / df["Count"].sum() * 100
    df["Cumulative(%)"] = df["Count(%)"].cumsum()
    
    # Add missing tags
    tagset = pos_utils.get_ud_tags()
    missing_tags = set(tagset) ^ set(df["Tag"])
    missing_data = {col: [100]*len(missing_tags) if "Cumulative" in col else [0]*len(missing_tags) for col in df.columns[1:]}
    missing_rows = pd.DataFrame({"Tag": list(missing_tags), **missing_data},
                                index=range(df.shape[0], df.shape[0] + len(missing_tags)))
    df = pd.concat([df, missing_rows])
    
    if lang_name not in output[group].keys():
        output[group][lang_name] = {}
    output[group][lang_name][dataset] = df
    
    # Calculate totals if all datasets are done
    if len(output[group][lang_name].keys()) == 3:
        total = calculate_total_tag_table(output[group][lang_name].values())
        output[group][lang_name]["total"] = total
    
    return output

### Tag stats table

In [5]:
tag_tables = utils.run_through_data("../data/ud/", tag_freq, lang_to_group=utils.make_lang_group_dict(),
                                    table={x: {} for x in ["Fusional", "Isolating", "Agglutinative", "Introflexive"]})

HBox(children=(FloatProgress(value=0.0, max=43.0), HTML(value='')))




In [6]:
def calculate_group_avg_tables(group_tables):
    group_avgs = {}
    
    for dataset in ["train", "dev", "test", "total"]:
        table = reduce(lambda left, right: multi_merge(left, right, re.compile("^Count\(%\)")), 
                       [group_tables[lang][dataset] for lang in group_tables.keys() if dataset in group_tables[lang].keys()])
        table = pd.DataFrame({"Tag": table["Tag"],
                              "MeanCount(%)": table.iloc[:,1:].apply(np.mean, axis=1)})
        table = table.sort_values("MeanCount(%)", ascending=False).reset_index(drop=True)
        table["Cumulative(%)"] = table["MeanCount(%)"].cumsum()
        group_avgs[dataset] = table
        
    return group_avgs

### Export excel

In [7]:
group_to_color = {
    "Fusional": "#95c78f",
    "Isolating": "#f79d97",
    "Agglutinative": "#abaff5",
    "Introflexive": "#fffecc"
}

In [16]:
for group in tag_tables.keys():
    writer = pd.ExcelWriter("tag_stats/tag_stats_{}.xlsx".format(group.lower()))
    workbook  = writer.book
    
    # Formats
    percentage_format = workbook.add_format({"num_format": "0.00\%"})
    merge_format = workbook.add_format({
        "bold": 1,
        "border": 1,
        "align": "center",
        "fg_color": group_to_color[group],
        "font_size": 14
    })
    
    # Sheet for every language
    langs = utils.order_table(pd.DataFrame(tag_tables[group].keys(), columns=["Language"])).iloc[:,0].values # Order as usual
    
    for lang in langs:
        if len(tag_tables[group][lang].keys()) > 1:
            for i, dataset in enumerate(["train", "dev", "test", "total"]):
                dcol = tag_tables[group][lang][dataset].shape[1] + 1
                tag_tables[group][lang][dataset].to_excel(writer, index=False, sheet_name=lang, 
                                                          startcol=i * dcol, startrow=1)
                worksheet = writer.sheets[lang]
                worksheet.merge_range(0, i * dcol, 0, (i + 1) * dcol - 2, dataset.upper(), merge_format)
                worksheet.set_column((i * dcol) + 2, (i * dcol) + 3, 15, percentage_format)
        else:
            i = 0
            dataset = list(tag_tables[group][lang].keys())[0]
            dcol = tag_tables[group][lang][dataset].shape[1] + 1
            tag_tables[group][lang][dataset].to_excel(writer, index=False, sheet_name=lang, startrow=1)
            worksheet = writer.sheets[lang]
            worksheet.merge_range(0, i * dcol, 0, (i + 1) * dcol - 2, dataset.upper(), merge_format)
            worksheet.set_column((i * dcol) + 2, (i * dcol) + 3, 15, percentage_format)
            
        # Insert plot
        worksheet.insert_image(tag_tables[group][lang][dataset].shape[0] + 2, 0,
                               "tag_stats/plots/langs/pos_tags_plot_{}.png".format(lang.lower()),
                               {"x_scale": 0.75, "y_scale": 0.75})
            
    # Group average sheet
    group_avgs = calculate_group_avg_tables(tag_tables[group])
    
    for i, dataset in enumerate(["train", "dev", "test", "total"]):
        dcol = group_avgs[dataset].shape[1] + 1
        group_avgs[dataset].to_excel(writer, index=False, sheet_name=group, 
                                     startcol=i * dcol, startrow=1)
        worksheet = writer.sheets[group]
        worksheet.merge_range(0, i * dcol, 0, (i + 1) * dcol - 2, dataset.upper(), merge_format)
        worksheet.set_column((i * dcol) + 1, (i * dcol) + 2, 15, percentage_format)
        
    worksheet.set_tab_color(group_to_color[group]) # Special color for group sheet
    # Insert plot
    worksheet.insert_image(group_avgs[dataset].shape[0] + 2, 0, 
                           "tag_stats/plots/groups/pos_tags_plot_{}.png".format(group.lower()),
                           {"x_scale": 0.75, "y_scale": 0.75})
            
    writer.close()

### Make plots

In [16]:
def make_plot_table(lang_tables, n):
    def get_tag_order(lang_tables, n):
        tag_order = []
        for row_tags in zip(*[df.loc[:n, "Tag"] for df in lang_tables.values()]):
            for tag in np.array(Counter(row_tags).most_common(None))[:,0].tolist():
                if tag not in tag_order:
                    tag_order.append(tag)     
        return tag_order[:n]
    
    def find_freq_col(table):
        r = re.compile(".*Count\(%\)")
        return list(filter(r.match, table.columns))[0]
                    
    tag_order = get_tag_order(lang_tables, n)
    tag_order.append("OTHERS")
    plot_table = pd.DataFrame({"Tag": tag_order})
    freq_col = find_freq_col(list(lang_tables.values())[0])
    
    if len(lang_tables.keys()) > 1:
        datasets = ["total", "test", "dev", "train"]
    else:
        datasets = lang_tables.keys()
    for dataset in datasets:
        df = lang_tables[dataset]
        selected = df.set_index("Tag").loc[tag_order[:-1]]
        plot_table[dataset.capitalize()] = [*selected[freq_col], 100 - selected[freq_col].sum()]
    
    plot_table = plot_table.replace("_", "MULTI")
    return plot_table

In [42]:
def make_plot(table_plot, lang, path):
    sns.set()
    sns.set_style("ticks")
    plt.rc('xtick', labelsize=16)
    plt.rc('ytick', labelsize=16)
    plt.rc("axes", labelsize=16)
    
    g = table_plot.set_index("Tag").rename(columns={"Total": "TOTAL"}).T.plot(kind="barh", stacked=True, colormap="crest_r", 
                                                                              figsize=(18, table_plot.shape[1]), xlim=(0, 100))
    for i, row in table_plot.iterrows():
        cumulative = table_plot.iloc[:i, 1:].sum().values
        current = table_plot.iloc[i, 1:].values
        for p, y in zip(row.iloc[1:], range(len(current))):
            if row[y+1] != 0:
                x = cumulative[y] + current[y] / 2
                g.text(x=x, y=y+0.375, s=row["Tag"], fontsize=16, fontstretch="condensed", 
                       horizontalalignment="center", verticalalignment="center")
                g.text(x=x, y=y, s="{:.1f}%".format(current[y]), fontsize=16, fontstretch="condensed", 
                       horizontalalignment="center", verticalalignment="center",
                       bbox=dict(boxstyle="round, pad=0.15",
                                 fc=(1, 1, 1, 0.5),
                                 ec="none"))
    g.set_xticks(range(0, 101, 10))
    g.set(xlabel = "Cumulative Frequency (%)")
    g.legend().remove()
    g.set_title(lang, fontsize=28, pad=10, color="grey")
    sns.despine(ax=g)
    g.figure.savefig(path + "pos_tags_plot_{}.png".format(lang.lower()), dpi=400)
    plt.close(g.figure)

In [44]:
for group in tqdm(tag_tables.keys()):
    for lang in tag_tables[group].keys():
        table_plot = make_plot_table(tag_tables[group][lang], 5)
        make_plot(table_plot, lang, "tag_stats/plots/langs/")
        
    group_tables = calculate_group_avg_tables(tag_tables[group])
    table_plot = make_plot_table(group_tables, 5)
    make_plot(table_plot, group, "tag_stats/plots/groups/")

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


