In [1]:
import yaml, torchaudio, torch
import pandas as pd
import numpy as np
from IPython.display import Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import sys
import plotly.express as px
import plotly.graph_objects as go

sys.path.append("../")
from helper import DataArguments, ModelArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("../config.yml") as f:
    config = yaml.safe_load(f)

In [3]:
data_args = DataArguments(**config["data_args"])
model_args = ModelArguments(**config["model_args"])

In [4]:
processor = Wav2Vec2Processor.from_pretrained(model_args.fi_pretrained)
model = Wav2Vec2ForCTC.from_pretrained("../experiments/ex0_base/output_fold_0/checkpoint-5544/")
model.config.output_attentions = True

In [66]:
df = pd.read_csv(data_args.csv_fi)
df = df[df.split==0]
random_i = np.random.choice(df.index)
random_i = 520 # to plot the attention 
# random_i = 176 # to plot the speech example 
random_path = df.loc[random_i]["recording_path"]
random_speech,sr = torchaudio.load(random_path)
# random_speech = random_speech[0,:16000]
random_speech.size(1)/sr

3.7535

In [67]:
Audio(random_speech[0,41485:41485+16000], rate=sr)
random_speech = random_speech[:,41485:41485+16000]

In [176]:
fig = go.Figure()

# Add a scatter plot for the actual data
fig.add_trace(go.Scatter(y=random_speech.squeeze(), xaxis='x2', line=dict(color='black', width=1), showlegend=False))

# Add a dummy scatter plot for the top x-axis
fig.add_trace(go.Scatter(
    x=[i for i in range(16000)],
    y=[None]*16000,  # No y-values
#     xaxis='x2',  # Use the secondary x-axis
    hoverinfo='none',  # No hover info
    showlegend=False  # Hide from the legend
))

# Update x-axis properties for bottom ticks
fig.update_layout(
    xaxis=dict(
        tickvals=[0, 0.25*16000, 0.5*16000, 0.75*16000, 16000],
        ticktext=[0, 0.25, 0.5, 0.75, 1],  # Custom labels for bottom ticks
        title='Time (second)', 
    ),
    xaxis2=dict(
        tickvals=x_tickvals,
        ticktext=list(processor.batch_decode(pred_ids)[0]),  # Custom labels for top ticks
        title='',
        overlaying='x',
        side='top', 
        tickfont=dict(size=20)
    ),
    yaxis=dict(
        title="amplitude"
    ), plot_bgcolor='white', width=900, height=400
)

fig.show()


In [170]:
random_speech[:,9723:]

tensor([[-2.5024e-02, -1.0498e-02,  2.3682e-02,  ..., -9.1553e-05,
         -6.1035e-05, -2.1362e-04]])

In [35]:
input_values = processor(audio=random_speech.squeeze(), sampling_rate=sr, return_tensors="pt").input_values

In [14]:
# px.line(y=random_speech.squeeze())

In [15]:
# px.line(y=input_values.squeeze())

In [36]:
with torch.no_grad():
    outputs = model(input_values)
    
attentions = outputs.attentions

In [37]:
attentions[-1][0].size()

torch.Size([16, 49, 49])

In [92]:
pred_ids = torch.argmax(outputs.logits,dim=-1)
ticktext = []

for pred_i in pred_ids.squeeze():
    if pred_i != 0:
        ticktext.append(processor.decode(pred_i))
    else:
        ticktext.append("")

prev = ''
for i, t in enumerate(ticktext):
    if t != prev:
        prev = t
    else:
        ticktext[i] = ''

x_tickvals = []
for i, text in enumerate(ticktext):
    if text != '':
        new_i = int((i/len(ticktext) * 16000)) + 1
        x_tickvals.append(new_i)
x_tickvals

[1633, 2286, 4572, 6531, 7184, 9470]

In [14]:
attention_matrix = attentions[2][0][6].squeeze().cpu().numpy()
fig = px.imshow(attention_matrix, text_auto=True, height=600, width=600)
# fig.update_xaxes(ticktext=ticktext, tickvals=list(range(len(ticktext))), title="Time step")
# fig.update_yaxes(ticktext=ticktext, tickvals=list(range(len(ticktext))), title="Time step")
fig.show()

In [22]:
start_i = 40700
kiitos = random_speech.squeeze()[start_i:start_i+16000]
Audio(kahvi, rate=sr)

In [27]:
input_values = processor(audio=kiitos, sampling_rate=sr, return_tensors="pt").input_values
with torch.no_grad():
    outputs = model(input_values)
pred_ids = torch.argmax(outputs.logits, dim=-1)

In [64]:
cond_prob = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
px.bar(y=cond_prob[10])

In [48]:
# processor.tokenizer.get_vocab()
pred_ids

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0, 26, 26, 10,  0,  0,  0,  0,  0,  0, 10,
          0,  0,  0,  0,  0,  6,  8,  0,  0,  0,  0,  0,  0, 12, 12, 12,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [47]:
processor.batch_decode(pred_ids, group_tokens=False)

['kkiitosss']

In [12]:
processor.batch_decode(pred_ids)

['yksi kahvi kiitos']

In [13]:
outputs.logits.size()

torch.Size([1, 187, 35])

In [40]:
fig = go.Figure()
fig.add_trace(go.Scatter(
    y=random_speech[0][34500:38000], 
    line=dict(color='black', width=0.9)
))
fig.update_layout(plot_bgcolor='white', width=300, height=300)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

In [6]:
from plotly.subplots import make_subplots

In [7]:
def softmax(sim_tar, sim_all, k=1): 
    numerator = np.exp(sim_tar/k)
    s = np.sum(np.exp(sim_all)/k, axis=1, keepdims=True)
    return numerator/s

def contrastive(sim_tar, sim_all, k=1):
    return -np.log(softmax(sim_tar, sim_all, k))

In [13]:
target = np.array([1, 1]).reshape(1, -1) # (1, d)
distractor = np.array([-1, -2]).reshape(1, -1)
Q = np.concatenate([target, distractor])

alpha = np.linspace(0,1,500).reshape(1, -1)
x = (1-alpha).T@target + alpha.T@distractor # (n, d)

In [14]:
x_norm = np.linalg.norm(x, axis=1).reshape(-1, 1)
target_norm = np.linalg.norm(target)

sim_tar = (x@target.T)/(np.linalg.norm(x, axis=1)*np.linalg.norm(target)).reshape(-1,1)
sim_all = (x@Q.T)/(np.linalg.norm(x, axis=1, keepdims=True)@np.linalg.norm(Q, axis=1, keepdims=True).T)

In [25]:
soft = softmax(sim_tar, sim_all).squeeze()
y = contrastive(sim_tar, sim_all).squeeze()

soft_sharp = softmax(sim_tar, sim_all, 11).squeeze()
y_sharp = contrastive(sim_tar, sim_all, 11).squeeze()

x_ticks = ["Target","Distractor"]

In [26]:
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.15)
fig.add_trace(go.Scatter(y=soft, x=alpha.squeeze(), showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(y=y, x=alpha.squeeze(), showlegend=False), row=1, col=2)
fig.update_xaxes(title=r"$\alpha$")
fig.update_yaxes(title="Softmax output", row=1, col=1)
fig.update_yaxes(title="Contrastive loss", row=1, col=2)
fig.update_layout(plot_bgcolor="whitesmoke", height=500)
fig.show()

In [27]:
fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.15)
fig.add_trace(go.Scatter(y=soft_sharp, x=alpha.squeeze(), showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(y=y_sharp, x=alpha.squeeze(), showlegend=False), row=1, col=2)
fig.update_xaxes(title=r"$\alpha$")
fig.update_yaxes(title="Softmax output", row=1, col=1)
fig.update_yaxes(title="Contrastive loss", row=1, col=2)
fig.update_layout(plot_bgcolor="whitesmoke", height=500)
fig.show()

In [7]:
df[df.transcript.str.contains("moi")]

Unnamed: 0.1,Unnamed: 0,sample,student,task_id,transcript,recording_path,accuracy_mean,range_mean,fluency_mean,cefr_mean,...,pronunciation_mean,split,transcript_normalized,ASR_transcript,cefr_mean_original,pronunciation_mean_original,fluency_mean_original,accuracy_mean_original,range_mean_original,task_completion_mean_original
18,18,177,22,1,tämä paikka on minulle tärkeä koska siellä tun...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,2,3,2,5,...,4,0,tämä paikka on minulle tärkeä koska siellä tun...,,5.0,3.5,2.5,2.5,2.5,3.0
72,72,1105,8,18,<bgnoise> moi tääl on <bgnoise> on onni<name> ...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,2,3,4,...,4,0,moi tääl on on onni lehtonen mä olin eilen ill...,,4.0,4.0,2.5,3.0,2.0,3.0
77,77,1785,52,19,<garbage> moi matti<name> tässä öö<hesitation>...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,2,3,4,...,3,0,moi matti tässä öö kiitos hei että tota öö sai...,,4.0,3.0,2.5,3.0,2.0,3.0
87,87,1281,29,1,paikka on minulle tärkeä koska olen ollu siell...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,4,3,4,7,...,4,0,paikka on minulle tärkeä koska olen ollu siell...,,6.5,4.0,4.0,4.0,3.0,2.5
176,176,497,40,14,moi minä olen anna<name> sjökvist<name> öö<hes...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,2,1,4,3,...,3,0,moi minä olen anna sjökvist öö minulle kuuluu ...,,3.0,3.5,4.0,2.5,1.5,2.5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1944,1944,279,106,21,moikka minun nimi on steven<name> minun *paiva...,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,1,2,2,2,...,1,0,moikka minun nimi on steven minun paiva amulla...,,2.0,1.5,2.0,1.5,1.5,2.0
1975,1975,175,69,23,moi minä olen *kahvilasa*,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,2,2,3,...,3,0,moi minä olen kahvilasa,,3.0,3.0,2.5,3.0,1.5,2.5
1991,1991,83,261,29,moi anna<name> ihan hyvää entä sulle,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,3,2,3,2,...,4,0,moi anna ihan hyvää entä sulle,,2.5,3.5,3.5,3.5,2.0,3.0
2055,2055,1805,164,14,moi maria<name> olen hyvää kiitos,/m/teamwork/t40511_asr/c/digitala/DigiTala_201...,2,1,2,2,...,3,0,moi maria olen hyvää kiitos,,2.0,3.0,2.0,1.5,1.5,2.0


In [7]:
pre_trained_model = Wav2Vec2ForCTC.from_pretrained(model_args.fi_pretrained)

In [8]:
test_speech, sr = torchaudio.load(df.recording_path[943])
Audio(test_speech[0], rate=sr)

In [9]:
input_values = processor(audio=test_speech.squeeze(), sampling_rate=sr, return_tensors="pt").input_values
with torch.no_grad():
    outputs = pre_trained_model(input_values)

In [10]:
pred_ids = torch.argmax(outputs.logits, dim=-1)

In [17]:
cond_prob = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
mask = (cond_prob > 0.01).float().sum(dim=-1)
i = mask.argmax()

fig = go.Figure()
fig.add_trace(go.Bar(
    y = cond_prob[i], 
    marker=dict(color="#337CCF")
))
fig.update_layout(plot_bgcolor="#F5F5F5")

ticktext = list(processor.tokenizer.get_vocab().keys())
tickvals = list(processor.tokenizer.get_vocab().values())

fig.update_xaxes(ticktext=ticktext, tickvals=tickvals, title="Tokens")
fig.update_yaxes(title=r"$P(l'|t, X)$")
fig.update_layout(width=600)
fig.show()

In [24]:
speech_sample = 

22.627416997969522

torch.Size([323, 35])