In [141]:
import yaml
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random
import numpy as np
import copy
from collections import Counter

In [142]:
with open('../config.yml') as file:
    train_config = yaml.safe_load(file)

In [143]:
csv_path = train_config['data_args']['csv_fi']
df = pd.read_csv(csv_path)

In [144]:
# avg_samples_per_gp = base.sum()*2/len(base)
# weights = np.round(avg_samples_per_gp - base)
# weights = weights/weights.index
# resample_weights = [weights[cefr] for cefr in df.cefr_mean]
# resample_weights

In [145]:
def resample(train_dataset,
             criterion: str = "rating"):
    """Resample data to balanced out the data based on the chosen rating (default = cefr_mean)"""
    
    train_copy = copy.deepcopy(train_dataset) # always create a copy 
    ratings = train_copy[criterion].tolist()

    # calculate samlping rate
    group_counts = Counter(ratings)
    n_group = len(group_counts)
    n_copy = 2 # double the dataset
    avg_n_samples_per_gp = len(train_copy)*n_copy/n_group
    n_samples = [(group, avg_n_samples_per_gp - count) for group, count in group_counts.items()]
    assert all([n > 0 for _, n in n_samples]), f"This calculation does not work. Might have to re-design."
    weights = {group: round(100*n/group_counts[group]) for group, n in n_samples}
    resample_weights = [weights[r] for r in ratings]

    # resample data based on weights
    n = random.choices(range(len(train_copy)), weights=resample_weights, k=len(train_copy))
    train_copy = train_copy.iloc[n]

    return train_copy

In [146]:
resampled = resample(df, criterion='cefr_mean')
resampled

Unnamed: 0.1,Unnamed: 0,sample,student,task_id,transcript,recording_path,accuracy_mean,range_mean,fluency_mean,cefr_mean,...,pronunciation_mean,split,transcript_normalized,ASR_transcript,cefr_mean_original,pronunciation_mean_original,fluency_mean_original,accuracy_mean_original,range_mean_original,task_completion_mean_original
1434,1434,1075,120,27,yksi kahvi kiitos,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,1,4,1,...,3,0,yksi kahvi kiitos,,1.5,3.0,3.5,4.0,1.0,2.5
984,984,1436,176,15,geenimuunneltu ruoka ei ole minulle aiheena tu...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,7,...,4,3,geenimuunneltu ruoka ei ole minulle aiheena tu...,,7.0,4.0,4.0,4.0,3.0,3.0
970,970,1923,76,2,meidän yliopiston kirjastossa on paljon erilai...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,3,3,5,...,3,1,meidän yliopiston kirjastossa on paljon erilai...,,4.5,3.0,3.0,2.5,3.0,2.5
1446,1446,698,129,22,kuinka paljon kahvi maksaa,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,1,3,1,...,4,3,kuinka paljon kahvi maksaa,,1.5,3.5,3.5,4.0,1.5,3.0
1303,1303,521,121,22,mikä on kahvin maksaa,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,2,1,3,1,...,2,2,mikä on kahvin maksaa,,1.5,2.5,2.5,1.5,1.5,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
698,698,563,182,3,joo eli nämä kuvat nii tota ne herättää kaikki...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,6,...,4,2,joo eli nämä kuvat nii tota ne herättää kaikki...,,6.5,4.0,4.0,4.0,3.0,3.0
865,865,1243,176,19,heippa lotta<name> mä olin *huomannu* et sä ol...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,6,...,4,3,heippa lotta mä olin huomannu et sä olit tuonu...,,6.5,4.0,3.5,4.0,3.0,3.0
1969,1969,1374,262,23,moikka junassa,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,1,3,2,...,4,3,moikka junassa,,2.0,3.5,2.5,3.5,1.0,2.0
1218,1218,369,168,18,hei täällä maija<name> meikäläinen<name> olimm...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,3,4,...,4,2,hei täällä maija meikäläinen olimme eilen ystä...,,4.5,4.0,3.5,4.0,2.5,3.0


In [216]:
base = df.groupby('cefr_mean').apply(lambda g: len(g))
base_same = df.groupby('cefr_mean').count().iloc[:,1:3].rename(columns={'sample':'original','student':'copy'}).reset_index()
base_same  = pd.melt(base_same, id_vars=['cefr_mean'], value_vars=['original', 'copy'])
base_os = resampled.groupby('cefr_mean').apply(lambda g: len(g))

In [222]:
colour_map = {'original':'#6096B4', 'copy':'#A7D2CB', 'resampled':'#F2D388'}

bar_base = go.Bar(x=base.index, y=base.values, marker=dict(color=colour_map['original']), showlegend=False)
bar_same = px.bar(base_same, x='cefr_mean', y='value', color='variable', 
                  color_discrete_map=colour_map)
bar_os = go.Bar(x=base_os.index, y=base_os.values, marker=dict(color=colour_map['resampled']), name='over-sampled')

fig = make_subplots(rows=1, cols=3, shared_yaxes=True, 
                   subplot_titles=(r'$\textit{BASE}$',r'$\textit{BASE_same}$',r'$\textit{BASE_OS}$'))

fig.add_trace(bar_base, row=1, col=1)
fig.add_traces(bar_same.data, rows=1, cols=2)
fig.add_traces([bar_base, bar_os], rows=1, cols=3)

fig.update_layout(barmode='stack', plot_bgcolor='whitesmoke', yaxis=dict(title='Number of samples'), 
                 width=800, height=400)
fig.update_xaxes(dtick=1, title='Class (CEFR score)')
fig.show()