In [54]:
import yaml
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
import numpy as np
import copy
from collections import Counter

In [55]:
with open('../config.yml') as file:
    train_config = yaml.safe_load(file)

In [56]:
csv_path = train_config['data_args']['csv_fi']
df = pd.read_csv(csv_path)

In [57]:
# avg_samples_per_gp = base.sum()*2/len(base)
# weights = np.round(avg_samples_per_gp - base)
# weights = weights/weights.index
# resample_weights = [weights[cefr] for cefr in df.cefr_mean]
# resample_weights

In [58]:
def resample(train_dataset,
             criterion: str = "rating"):
    """Resample data to balanced out the data based on the chosen rating (default = cefr_mean)"""
    
    train_copy = copy.deepcopy(train_dataset) # always create a copy 
    ratings = train_copy[criterion].tolist()

    # calculate samlping rate
    group_counts = Counter(ratings)
    n_group = len(group_counts)
    n_copy = 2 # double the dataset
    avg_n_samples_per_gp = len(train_copy)*n_copy/n_group
    n_samples = [(group, avg_n_samples_per_gp - count) for group, count in group_counts.items()]
    assert all([n > 0 for _, n in n_samples]), f"This calculation does not work. Might have to re-design."
    weights = {group: round(100*n/group_counts[group]) for group, n in n_samples}
    resample_weights = [weights[r] for r in ratings]

    # resample data based on weights
    n = random.choices(range(len(train_copy)), weights=resample_weights, k=len(train_copy))
    train_copy = train_copy.iloc[n]

    return train_copy

In [59]:
resampled = resample(df, criterion='cefr_mean')
resampled

Unnamed: 0.1,Unnamed: 0,sample,student,task_id,transcript,recording_path,accuracy_mean,range_mean,fluency_mean,cefr_mean,...,pronunciation_mean,split,transcript_normalized,ASR_transcript,cefr_mean_original,pronunciation_mean_original,fluency_mean_original,accuracy_mean_original,range_mean_original,task_completion_mean_original
1460,1460,1532,259,28,mikä sun nimi on kuinka vanha sä olet,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,2,3,1,...,3,1,mikä sun nimi on kuinka vanha sä olet,,1.5,2.5,3.0,4.0,1.5,2.5
2098,2098,870,87,18,hei maija<name> soittaa olimme eilen illalla *...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,6,...,4,0,hei maija soittaa olimme eilen illalla teiän k...,,5.5,4.0,4.0,4.0,3.0,3.0
70,70,1519,18,17,<garbage> hei juuri tällä hetkellä minul ei ol...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,2,3,4,...,4,0,hei juuri tällä hetkellä minul ei ole mahdolli...,,4.5,4.0,3.5,3.5,2.0,3.0
368,368,1970,58,19,hei <paral> tuhannet kiitokset *siittä* että l...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,3,7,...,4,2,hei tuhannet kiitokset siittä että löysit mun ...,,7.0,4.0,3.0,4.0,3.0,3.0
623,623,2365,92,19,hei *tääl* on matti meikäläinen<name> ja oho k...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,5,...,4,3,hei tääl on matti meikäläinen ja oho kikiitos ...,,5.5,3.5,4.0,4.0,3.0,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1064,1064,1242,155,3,kaikki kolme kuvaa ovat aika hienoja ainakin m...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,6,...,4,0,kaikki kolme kuvaa ovat aika hienoja ainakin m...,,6.0,4.0,3.5,4.0,3.0,3.0
128,128,1011,19,15,geenimuunneltu ruoka ei ollu sillee tuttu aihe...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,3,6,...,3,2,geenimuunneltu ruoka ei ollu sillee tuttu aihe...,,6.0,3.5,3.5,4.0,3.0,3.0
985,985,92,161,15,tämä teema on minulle erittäin aiheellinen sil...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,7,...,4,2,tämä teema on minulle erittäin aiheellinen sil...,,6.5,4.0,4.0,4.0,3.0,3.0
1482,1482,251,136,30,asun helsingissä<name>,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,2,3,1,...,2,0,asun helsingissä,,1.5,2.5,3.5,4.0,1.5,3.0


In [60]:
base = df.groupby('cefr_mean').apply(lambda g: len(g))
base_same = df.groupby('cefr_mean').count().iloc[:,1:3].rename(columns={'sample':'original','student':'copy'}).reset_index()
base_same  = pd.melt(base_same, id_vars=['cefr_mean'], value_vars=['original', 'copy'])
base_os = resampled.groupby('cefr_mean').apply(lambda g: len(g))

In [61]:
colour_map = {'original':'#6096B4', 'copy':'#A7D2CB', 'resampled':'#F2D388', 'augmented':'#F9B572'}

bar_base = go.Bar(x=base.index, y=base.values, marker=dict(color=colour_map['original']), showlegend=False)
bar_same = px.bar(base_same, x='cefr_mean', y='value', color='variable', 
                  color_discrete_map=colour_map)
bar_os = go.Bar(x=base_os.index, y=base_os.values, marker=dict(color=colour_map['resampled']), name='over-sampled')

fig = make_subplots(rows=1, cols=3, shared_yaxes=True, 
                   subplot_titles=(r'$\textit{BASE}$',r'$\textit{BASE_same}$',r'$\textit{BASE_OS}$'))

fig.add_trace(bar_base, row=1, col=1)
fig.add_traces(bar_same.data, rows=1, cols=2)
fig.add_traces([bar_base, bar_os], rows=1, cols=3)

fig.update_layout(barmode='stack', plot_bgcolor='whitesmoke', yaxis=dict(title='Number of samples'), 
                 width=800, height=400)
fig.update_xaxes(dtick=1, title='Class (CEFR score)')
fig.show()

In [79]:
base_augmented = df.groupby('cefr_mean').count().iloc[:,1:3].rename(columns={'sample':'original','student':'augmented'}).reset_index()
base_augmented = pd.melt(base_augmented, id_vars=['cefr_mean'], value_vars=['original', 'augmented'])

bar_augmented = px.bar(base_augmented, x='cefr_mean', y='value', color='variable', color_discrete_map=colour_map)

fig_2 = go.Figure()
fig_2.add_traces(bar_augmented.data)
fig_2.update_layout(barmode='stack', width=400, 
                    legend=dict(yanchor='top', xanchor='right', y=0.99, x=0.99), 
                    xaxis=dict(title='Class (CEFR score)', dtick=1), 
                    yaxis=dict(title='Number of samples'))
fig_2.show()

In [78]:
bar_os.marker.color = colour_map['augmented']
bar_os.name = 'augmented'
bar_os.showlegend = True
bar_base.showlegend = True
bar_base.name = 'original'

fig_3 = go.Figure(bar_base)
fig_3.add_trace(bar_os)
fig_3.update_layout(barmode='stack', width=400, 
                    legend=dict(yanchor='top', xanchor='right', y=0.99, x=0.99), 
                    xaxis=dict(title='Class (CEFR score)', dtick=1), 
                    yaxis=dict(title='Number of samples'), 
                    yaxis_range=[0,780])
fig_3.show()

In [64]:
bar_os

Bar({
    'legendgroup': 'augmented',
    'marker': {'color': '#F9B572'},
    'name': 'over-sampled',
    'showlegend': True,
    'x': array([1, 2, 3, 4, 5, 6, 7]),
    'y': array([568, 270,  82, 183, 184, 317, 508])
})

(Bar({
     'alignmentgroup': 'True',
     'hovertemplate': 'variable=original<br>cefr_mean=%{x}<br>value=%{y}<extra></extra>',
     'legendgroup': 'original',
     'marker': {'color': '#6096B4', 'pattern': {'shape': ''}},
     'name': 'original',
     'offsetgroup': 'original',
     'orientation': 'v',
     'showlegend': True,
     'textposition': 'auto',
     'x': array([1, 2, 3, 4, 5, 6, 7]),
     'xaxis': 'x',
     'y': array([ 33, 350, 512, 425, 416, 285,  91]),
     'yaxis': 'y'
 }),
 Bar({
     'alignmentgroup': 'True',
     'hovertemplate': 'variable=augmented<br>cefr_mean=%{x}<br>value=%{y}<extra></extra>',
     'legendgroup': 'augmented',
     'marker': {'color': '#F9B572', 'pattern': {'shape': ''}},
     'name': 'augmented',
     'offsetgroup': 'augmented',
     'orientation': 'v',
     'showlegend': True,
     'textposition': 'auto',
     'x': array([1, 2, 3, 4, 5, 6, 7]),
     'xaxis': 'x',
     'y': array([ 33, 350, 512, 425, 416, 285,  91]),
     'yaxis': 'y'
 }))

In [70]:
bar_base

Bar({
    'marker': {'color': '#6096B4'},
    'showlegend': False,
    'x': array([1, 2, 3, 4, 5, 6, 7]),
    'y': array([ 33, 350, 512, 425, 416, 285,  91])
})