In [1]:
from pathlib import Path
from collections import Counter

import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from joblib import Parallel, delayed

from src.gen.util import read_gzip_data, write_gzip_data

datap = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/data/level/2")
dumpp = Path("/users/k21190024/study/fact-check-transfer-learning/scratch/dumps/explore/level/2")
if not dumpp.exists():
    dumpp.mkdir(parents=True)

# Load data

In [2]:
cla = "claims_lemma.pkl.gz"
cor = "corpus_lemma.pkl.gz"

# Explore

In [3]:
fls = [
    datap.joinpath("scifact", cla), datap.joinpath("scifact", cor)
    , datap.joinpath("fever", cla), datap.joinpath("fever", cor)
    , datap.joinpath("climatefever", cla), datap.joinpath("climatefever", cor)
]

## Lemma per doc

In [4]:
def lemma_per_doc(fp):
    ls = read_gzip_data(fp)
    res = []
    
    if isinstance(ls, list):
        res.append((f"{fp.parent.name}-{fp.name.split('_')[0]}", [len(l) for l in ls]))
    elif isinstance(ls, dict):
        for k, v in ls.items():
            if k == "ner":
                continue
            res.append((f"{fp.parent.name}-{fp.name.split('_')[0]}-{k}", [len(l) for l in v]))
    return res

res_ls = Parallel(n_jobs=len(fls))(delayed(lemma_per_doc)(fp) for fp in fls)

lemmapd = []
for r in res_ls:
    tmp = {}
    for doc, lemmas in r:
        lemmapd.append({
            "dataset": doc,
            "mean": np.mean(lemmas),
            "median": np.median(lemmas),
            "min": np.min(lemmas),
            "max": np.max(lemmas),
        })
df_lempd = pd.DataFrame(lemmapd)
df_lempd

Unnamed: 0,dataset,mean,median,min,max
0,scifact-claims-claims,1409.0,1409.0,1409,1409
1,scifact-corpus-title,5183.0,5183.0,5183,5183
2,scifact-corpus-evidence,5183.0,5183.0,5183,5183
3,fever-claims-claims,185445.0,185445.0,185445,185445
4,fever-corpus,48.791444,32.0,0,93907
5,climatefever-claims-claims,1535.0,1535.0,1535,1535
6,climatefever-corpus-title,1535.0,1535.0,1535,1535
7,climatefever-corpus-evidence,7675.0,7675.0,7675,7675


## Lemma histogram

In [5]:
def count_words(fp):
    ls = read_gzip_data(fp)
    res = []
    
    if isinstance(ls, list):
        flatls = []
        for i in ls:
            flatls += i
        res.append((f"{fp.parent.name}-{fp.name.split('_')[0]}", Counter(flatls)))
    elif isinstance(ls, dict):
        for k, v in ls.items():
            if k == "ner":
                continue
            flatls = []
            for i in v:
                flatls += i
            res.append((f"{fp.parent.name}-{fp.name.split('_')[0]}-{k}", Counter(flatls)))
    return res

if dumpp.joinpath("count_lemma.pkl.gz").exists():
    countwords = read_gzip_data(dumpp.joinpath("count_lemma.pkl.gz"))
else:
    res_ls = Parallel(n_jobs=len(fls))(delayed(count_words)(fp) for fp in fls)
    countwords = []
    for r in res_ls:
        countwords += r

    write_gzip_data(dumpp.joinpath("count_lemma.pkl.gz"), countwords)

TypeError: unhashable type: 'list'

In [None]:
df_words = pd.DataFrame(columns=["dataset", "lemma", "count"])
for k, v in countwords:
    tv = dict(v)
    tmpdf = pd.DataFrame(list(tv.items()), columns=["lemma", "count"])
    tmpdf["dataset"] = k
    df_words = pd.concat([df_words, tmpdf], axis=0)
df_words = df_words.reset_index(drop=True)
df_words = df_words.merge(df_words.groupby("dataset", as_index=False)["count"].sum().rename({"count": "total_count"}, axis=1), on="dataset")
df_words["normalised_count"] = df_words["count"] / df_words["total_count"] * 100
df_words

In [None]:
df_topwords = (
    df_words
    .sort_values(["dataset", "count"], ascending=[True, False])
    .groupby("dataset", as_index=False)
    .head(10)
)
df_topwords

In [None]:
g = sns.catplot(x="lemma", y="normalised_count", data=df_topwords[df_topwords["dataset"].str.contains("climate")], col="dataset", kind="bar", facet_kws={'sharey': False, 'sharex': False})
g.set_xticklabels(rotation=45)

In [None]:
g = sns.catplot(x="lemma", y="normalised_count", data=df_topwords.loc[(df_topwords["dataset"].str.contains("fever") & (~df_topwords["dataset"].str.contains("climate")))], col="dataset", kind="bar", facet_kws={'sharey': False, 'sharex': False})
g.set_xticklabels(rotation=45)

In [None]:
g = sns.catplot(x="lemma", y="normalised_count", data=df_topwords[df_topwords["dataset"].str.contains("scifact")], col="dataset", kind="bar", facet_kws={'sharey': False, 'sharex': False})
g.set_xticklabels(rotation=45)