In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# sns.set(rc = {'figure.figsize':(8,8)})

# DFs

In [None]:
old = False

if old:
    path_dijbytes = "dataset_sizes_bytes_OLD.csv"
    path_dijtokens = "dataset_sizes_tokens_OLD.csv"
else:
    path_dijbytes = "dataset_sizes_bytes.csv"
    path_dijtokens = "dataset_sizes_tokens.csv"

In [None]:
paths = {
    "dijbytes": path_dijbytes,
    "Wij": "../SAMPLING_WEIGHTS/SAMPLING_WEIGHTS_real.csv",
    "dijtokens": path_dijtokens,
}

In [None]:
dfs = {k: pd.read_csv(paths[k], index_col=0, comment='#') for k in paths.keys()}

dfs["dijbytes"] /= 10**9

dfs["dijbytes*Wij"] = pd.DataFrame(dfs["dijbytes"].values*dfs["Wij"].values, 
                                   columns=dfs["dijbytes"].columns, 
                                   index=dfs["dijbytes"].index)
dfs["dijtokens*Wij"] = pd.DataFrame(dfs["dijtokens"].values*dfs["Wij"].values, 
                                   columns=dfs["dijtokens"].columns, 
                                   index=dfs["dijtokens"].index)

dfs["fijbytes"] = dfs["dijbytes"].div(dfs["dijbytes"].to_numpy().sum()).multiply(100)
dfs["Fijbytes"] = dfs["dijbytes*Wij"].div(dfs["dijbytes*Wij"].to_numpy().sum()).multiply(100)
dfs["fijtokens"] = dfs["dijtokens"].div(dfs["dijtokens"].to_numpy().sum()).multiply(100)
dfs["Fijtokens"] = dfs["dijtokens*Wij"].div(dfs["dijtokens*Wij"].to_numpy().sum()).multiply(100)
    
for key in dfs.keys():
    if key == "Wij":
        dfs[key]["total"] = dfs[key].apply(lambda x: 0, axis=1)
        dfs[key].loc["total"] = dfs[key].apply(lambda x: 0, axis=0)
    else:
        dfs[key]["total"] = dfs[key].apply(lambda x: sum(x), axis=1)
        dfs[key].loc["total"] = dfs[key].apply(lambda x: sum(x), axis=0)

In [None]:
def zero_total(_df):
    drop_total(_df)
    _df["total"] = 0
    _df.loc["total"] = 0

def drop_total(_df):
    _df.drop("total", axis=0, inplace=True)
    _df.drop("total", axis=1, inplace=True)
    
def add_total(_df):
    _df["total"] = _df.apply(lambda x: sum(x), axis=1)
    _df.loc["total"] = _df.apply(lambda x: sum(x), axis=0)
    
def totex(_df, _name, header, tail="\end{tabular}}"):
    t = "\\scalebox{\\tabscale}{"
    t += header + " \n"
    t += _df.to_csv().replace(",", " & ").replace("commoncrawl", "cc").replace("conversational", "conv").replace("\n", " \\\\ \n").replace("_", "\_").replace("\\\\", "\\\\ \\hline", 1)
    if t.endswith(" \\\\ \n"):
        t = t[:-len(" \\\\ \n")]
        t += " \n"
    t = "\\\\ \\hline".join(t.rsplit("\\\\", 1))  # replace last "\\" by "\\ \hline"
    t += tail
    
    path = f"tables/{_name}.tex"
    with open(path, "w") as f:
        f.write(t)

### Step 0: Plain dataset sizes

In [None]:
totex(dfs["dijbytes"].applymap(lambda x: f"{x:.1f}"), "dijbytes", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["dijbytes"]

In [None]:
totex(dfs["fijbytes"].applymap(lambda x: f"{x:.2f}"), "fijbytes", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["fijbytes"]

### Step 1: Tokenizer Training

In [None]:
dfs["Wij"]

In [None]:
totex(dfs["Wij"], "Wij", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")

In [None]:
dfs["dijbytes*Wij"] 

In [None]:
totex(dfs["Fijbytes"].applymap(lambda x: f"{x:.2f}"), "FFijbytes", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["Fijbytes"]

### Step 2: Tokenizer Application

In [None]:
totex(dfs["dijtokens"], "dijtokens_not_rounded", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
totex(dfs["dijtokens"].applymap(lambda x: f"{x/10**9:.2f}"), "dijtokens", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["dijtokens"]

In [None]:
T = dfs["dijtokens"].max().max()
T

In [None]:
totex(dfs["fijtokens"].applymap(lambda x: f"{x:.2f}"), "fijtokens", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["fijtokens"]

In [None]:
T = dfs["dijtokens"].loc["total"]["total"]
T

In [None]:
dfs["rij"] = dfs["dijtokens"]/dfs["dijbytes"]/10**9
dfs["rij"]

In [None]:
m1, m2 = min([elem for val in dfs["rij"].values for elem in val]), max([elem for val in dfs["rij"].values for elem in val])
1/m1, 1/m2

### Step 3: Model Training

In [None]:
totex(dfs["Fijtokens"].applymap(lambda x: f"{x:.2f}"), "FFijtokens", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["Fijtokens"]

In [None]:
dfs["Eij"] = T*dfs["Fijtokens"]/100/dfs["dijtokens"]
dfs["Eij"] = dfs["Eij"].fillna(0)
zero_total(dfs["Eij"])
totex(dfs["Eij"].applymap(lambda x: f"{x:.2f}"), "Eij", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["Eij"]

In [None]:
maxE = dfs["Eij"].max().max()
maxE

In [None]:
dfs["Eij_rounded"] = dfs["Eij"].applymap(lambda x: np.ceil(x))
drop_total(dfs["Eij_rounded"])
dfs["Eij_rounded"] = dfs["Eij_rounded"].fillna(0)
dfs["Eij_rounded"] = dfs["Eij_rounded"].astype(int)
# totex(dfs["Eij_rounded"].applymap(lambda x: f"{x:.0f}"), "Eij_rounded", header="\\begin{tabular}{c||c|c|c|c|c|c||c}")
dfs["Eij_rounded"]

In [None]:
dfs["Tijmax"] = dfs["Eij_rounded"]*dfs["dijtokens"]/(dfs["Fijtokens"]/100)
dfs["Tijmax"].drop("total", axis=1, inplace=True)
# dfs["Tijmax"] = dfs["Tijmax"].fillna(0)
dfs["Tijmax"]

In [None]:
_t = [value for array in dfs["Tijmax"].values for value in array if value > 0]

Tmax = min(_t)
Tmax

### Plot

In [None]:
I = [elem for elem in dfs["Eij"].index.to_list() if elem != "total"]
J = [elem for elem in dfs["Eij"].columns.to_list() if elem != "total"]
I, J

In [None]:
Tthr = T/maxE # 98.8*10**9
Tthr

In [None]:
def plot_data_overview(_dfs, _field, T_thr: bool = True):

    verbose = 0
    xlim = 500

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    if not isinstance(ax, list):
        ax = [ax, None]

    ax[0].set_xlim([0, xlim])
    ax[0].set_ylim([0, 1])
    ax[0].set_xlabel("t [10^9 tokens]", fontsize=14)
    _ = ax[0].plot()

    y = 0
    for i, category in enumerate(I):
        for j, language in enumerate(J):
            clr = "r" if (2*i+j)%2 == 0 else "green"
            dijtokens = _dfs["dijtokens"].iloc[i, j]/10**9
            Fijtokens = _dfs[_field].iloc[i, j]/100
            Tijmax = _dfs["Tijmax"].iloc[i, j]/10**9

            length = 2
            y1 = [y]*length
            y2 = [y + Fijtokens]*length

            #########
            if Fijtokens > 0:
                E_1 = dijtokens / Fijtokens
                if verbose:
                    print(E_1)
                    print(category, language, f"{Fijtokens:.2f}", clr)
                    
                _ = ax[0].plot([E_1, E_1], 
                               [y1[0], y2[0]], 
                               linestyle="-", 
                               color="k", 
                               label="E_ij = 1" if i == 0 and j == 0 else None)
            
                x_unique = np.linspace(0, E_1, length)
                _ = ax[0].fill_between(x_unique, y1, y2, color="green", alpha=0.5)
                
                if E_1 < T/10**9:
                    x_repeated = np.linspace(min(E_1, T/10**9), T/10**9, length) 
                    _ = ax[0].fill_between(x_repeated, y1, y2, color="orange", alpha=0.5)
                else:
                    x_discarded = np.linspace(T/10**9, E_1, length)                   
                    _ = ax[0].fill_between(x_discarded, y1, y2, color="red", alpha=0.5)

            
            #########
            if Fijtokens > 0.05:
                _ = ax[0].text(10, y + 0.02, f"{category}, {language}")
                _ = ax[0].text(T/10**9 + 10, y + 0.02, f"{T/E_1/10**9:.2f}")

            y += Fijtokens

    ax[0].plot([T/10**9, T/10**9], [0, 1], linestyle="--", color="k", label="T")
    if T_thr:
        ax[0].plot([Tthr/10**9, Tthr/10**9], [0, 1], linestyle=":", color="k", label="T_thr")
    _ = ax[0].legend(loc="upper right")
    return fig, ax

In [None]:
fig, ax = plot_data_overview(dfs, "fijtokens", T_thr=False)
plt.savefig("./figs/data_overview_0.png", facecolor='w')

In [None]:
fig, ax = plot_data_overview(dfs, "Fijtokens")
plt.savefig("./figs/data_overview.png", facecolor='w')

### Dataset Details

In [None]:
used = dfs["Eij"]*dfs["dijtokens"]
# used

In [None]:
existing = dfs["dijtokens"]

In [None]:
def get_unique_repeated_discarded(_used, _existing):
    _unique_lists = [[min(c,d) for c, d in zip(a, b)] for a, b in zip(_used.values, _existing.values)]
    _repeated_lists = [[max(0,c-d) for c, d in zip(a, b)] for a, b in zip(_used.values, _existing.values)]
    _discarded_lists = [[max(0,d-c) for c, d in zip(a, b)] for a, b in zip(_used.values, _existing.values)]
    return _unique_lists, _repeated_lists, _discarded_lists

In [None]:
unique_lists, repeated_lists, discarded_lists = get_unique_repeated_discarded(used, existing)

In [None]:
unique = pd.DataFrame(unique_lists, 
                      columns=used.columns, 
                      index=used.index)
drop_total(unique)
# unique

In [None]:
repeated = pd.DataFrame(repeated_lists, 
                        columns=used.columns, 
                        index=used.index)
drop_total(repeated)
# repeated

In [None]:
discarded = pd.DataFrame(discarded_lists, 
                         columns=used.columns, 
                         index=used.index)
drop_total(discarded)
# discarded

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 10))
if not isinstance(ax, list):
    ax = [ax, None]
    
width = 0.1

for i, category in enumerate(I):
    for j, language in enumerate(J):
        u = unique.iloc[i, j]/10**9
        r = repeated.iloc[i, j]/10**9
        d = discarded.iloc[i, j]/10**9
        ax[0].barh(f"{category}, {language}", u, width, color="green", label='unique' if i == 0 and j == 0 else None)
        ax[0].barh(f"{category}, {language}", r, width, left=u, color="orange", label='repeated' if i == 0 and j == 0 else None)
        ax[0].barh(f"{category}, {language}", d, width, left=u+r, color="red", label='discarded' if i == 0 and j == 0 else None)
ax[0].legend(loc="lower right")
ax[0].set_xlabel("tokens [10^9]")
plt.gca().invert_yaxis()

plt.tight_layout()
plt.savefig("./figs/data_overview_2.png", facecolor='w')

In [None]:
tokens_unique = unique.sum().sum()
tokens_repeated = repeated.sum().sum()
tokens_discarded = discarded.sum().sum()
tokens_all = tokens_unique + tokens_repeated

tokens_unique, tokens_repeated, tokens_discarded, tokens_all

In [None]:
fraction_unique = tokens_unique / tokens_all
fraction_exchanged = tokens_repeated / tokens_all

fraction_unique, fraction_exchanged

#### vary T

In [None]:
factors = np.linspace(0, 1, 100)

In [None]:
f_unique_total = {f: 0 for f in factors}
f_repeated_total = {f: 0 for f in factors}
f_discarded_total = {f: 0 for f in factors}
f_used_total ={f: 0 for f in factors}

for f in factors:
    f_used = f*dfs["Eij"]*dfs["dijtokens"]
    f_existing = dfs["dijtokens"]
    f_unique_lists, f_repeated_lists, f_discarded_lists = get_unique_repeated_discarded(f_used, f_existing)
    
    f_unique = pd.DataFrame(f_unique_lists, 
                            columns=used.columns, 
                            index=used.index)
    f_repeated = pd.DataFrame(f_repeated_lists, 
                            columns=used.columns, 
                            index=used.index)
    f_discarded = pd.DataFrame(f_discarded_lists, 
                               columns=used.columns, 
                               index=used.index)
    drop_total(f_unique)
    drop_total(f_repeated)
    drop_total(f_discarded)
    f_unique_total[f] += f_unique.sum().sum()/10**9
    f_repeated_total[f] += f_repeated.sum().sum()/10**9
    f_discarded_total[f] += f_discarded.sum().sum()/10**9
    f_used_total[f] += f_unique.sum().sum()/10**9 + f_repeated.sum().sum()/10**9

# f_unique_total, f_repeated_total, f_discarded_total, f_used_total

In [None]:
unique_total = list(f_unique_total.values())
repeated_total = list(f_repeated_total.values())
discarded_total = list(f_repeated_total.values())

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
if not isinstance(ax, list):
    ax = [ax, None]
  
factorsT = [elem*T/10**9 for elem in factors]
unique_percent = [u/(u+r)*100 for u, r in zip(unique_total, repeated_total)]
# _ = ax[0].plot(factorsT, unique_total, marker="o", linestyle="", color="green", label="unique")
# _ = ax[0].plot(factorsT, [a+b for a,b in zip(unique_total, repeated_total)], marker="o", linestyle="", color="orange", label="repeated")
_ = ax[0].fill_between(factorsT, 0, unique_total, color="green", alpha=0.5, label="unique")
_ = ax[0].fill_between(factorsT, unique_total, [a+b for a,b in zip(unique_total, repeated_total)], color="orange", alpha=0.5, label="repeated")
ax[0].set_xlim([0, T/10**9])
ax[0].set_ylim([0, T/10**9])
ax[0].set_xlabel("t [10^9 tokens]")
ax[0].set_ylabel("t [10^9 tokens]")
_ = ax[0].legend(loc="upper left")



ax2 = ax[0].twinx()
ax2.plot(factorsT, unique_percent, color="green", label="unique percentage")
ax2.set_ylim([0, 100])
ax2.plot([Tthr/10**9, Tthr/10**9], [0, 100], linestyle=":", color="k", label="T_thr")
_ = ax2.set_ylabel("unique percentage [%]")
_ = ax2.legend()

plt.tight_layout()
plt.savefig("./figs/data_overview_3.png", facecolor='w')

### Merged Datasets

In [None]:
def merge(_dfs, _name):
    _dfs[f"{_name}_MERGED"] = _dfs[_name].copy()
    
    # 1. merge sv, no, da, is -> nd
    _dfs[f"{_name}_MERGED"]['nd'] = _dfs[f"{_name}_MERGED"]['sv']
    for lang in ["sv", "no", "da", "is"]:
        if lang != "sv":
            _dfs[f"{_name}_MERGED"]['nd'] += _dfs[f"{_name}_MERGED"][lang] 
        _dfs[f"{_name}_MERGED"].drop(lang, axis=1, inplace=True)
        
    # 2. merge books_hq and conversational for lang=nd
    a1 = _dfs[f"{_name}_MERGED"].loc["books_hq", "nd"]
    a2 = _dfs[f"{_name}_MERGED"].loc["conversational", "nd"]
    _dfs[f"{_name}_MERGED"].loc["books_conv"] = {"nd": a1+a2, "en": 0, "cd": 0, "total": 0}
    _dfs[f"{_name}_MERGED"].loc["books_hq", "nd"] = 0
    _dfs[f"{_name}_MERGED"].loc["conversational", "nd"] = 0
    
    # 3. merge wiki languages
    a1 = _dfs[f"{_name}_MERGED"].loc["wiki", "nd"]
    a2 = _dfs[f"{_name}_MERGED"].loc["wiki", "en"]
    _categories = _dfs[f"{_name}_MERGED"].index.tolist()
    _dfs[f"{_name}_MERGED"]["all"] = [0.0 for idx in _categories]
    _dfs[f"{_name}_MERGED"].loc["wiki", "all"] = a1 + a2
    _dfs[f"{_name}_MERGED"].loc["wiki", "nd"] = 0
    _dfs[f"{_name}_MERGED"].loc["wiki", "en"] = 0
    
    # 4a. reorder columns
    _dfs[f"{_name}_MERGED"] = _dfs[f"{_name}_MERGED"][['nd', 'en', 'cd', 'all', 'total']]
    
    # 4b. reorder rows
    _categories_new = [_categories[0]] + [_categories[-1]] + _categories[1:-1]
    _dfs[f"{_name}_MERGED"] = _dfs[f"{_name}_MERGED"].reindex(_categories_new)
    
    # return
    return _dfs
    

In [None]:
dfs = merge(dfs, "dijtokens")
# dfs["dijtokens_MERGED"]

In [None]:
dfs = merge(dfs, "Fijtokens")
drop_total(dfs["Fijtokens_MERGED"])
add_total(dfs["Fijtokens_MERGED"])
totex(dfs["Fijtokens_MERGED"].applymap(lambda x: f"{x:.2f}"), "Fijtokens_MERGED", header="\\begin{tabular}{c||c|c|c|c||c}")
dfs["Fijtokens_MERGED"]

# NUMBERS SUMMARY

In [None]:
# number of tokens
T

In [None]:
# range of rij
1/m1, 1/m2

In [None]:
# maximum epochs for t=T
maxE

In [None]:
# maximum number of tokens possible with NeMo Megatron
Tmax

In [None]:
# number of tokens where data start to get repeated
Tthr

In [None]:
# tokens unique/repeated/discarded for t=T, absolute
tokens_unique, tokens_repeated, tokens_discarded

In [None]:
# tokens unique/repeated/discarded for t=T, relative
fraction_unique, fraction_exchanged

# OLD

### Minimum Hypothesis

# HEATMAPS

In [None]:
ax = sns.heatmap(dfs["dijbytes"], annot=True)
ax.set_title('dataset sizes [bytes]')
plt.show()

In [None]:
ax = plt.axes()
sns.heatmap(dfs["Wij"], annot=True)
ax.set_title('weights')
plt.show()

In [None]:
ax = plt.axes()
sns.heatmap(dfs["fijbytes"], annot=True)
ax.set_title('dataset_size [%]')
plt.show()

In [None]:
ax = plt.axes()
sns.heatmap(dfs["Fijbytes"], annot=True)
ax.set_title('dataset_size weighted [%]')
plt.show()