## Import Statements

In [5]:
# Install bertviz and matplotlib
# !pip install bertviz matplotlib
import torch
import soundfile as sf
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline 
from bertviz import head_view, model_view

import torch.nn.functional as F
from fairseq.models.wav2vec import Wav2Vec2Model, Wav2Vec2Config

## Loading pretraining checkpoint

In [6]:
ckpt = torch.load("data/models/indicwav2vec-large.pt")

## Base model
# conf = Wav2Vec2Config(quantize_targets = True,
#   final_dim= 256,
#   encoder_layerdrop= 0.05,
#   dropout_input= 0.1,
#   dropout_features= 0.1,
#   feature_grad_mult= 0.1,
#   encoder_embed_dim= 768)

## Large model
conf = Wav2Vec2Config(
                quantize_targets=True, 
                extractor_mode='layer_norm',
                layer_norm_first=True,
                final_dim=768, encoder_embed_dim=1024,
                latent_temp=[2.0,0.1,0.999995],
                dropout=0.0,
                attention_dropout=0.0,
                conv_bias=True,
                encoder_layerdrop= 0.00,
                dropout_input= 0.0,
                dropout_features= 0.0,
                encoder_layers=24, 
                encoder_ffn_embed_dim=4096,
                encoder_attention_heads=16,
                feature_grad_mult= 1.0)
                
model = Wav2Vec2Model.build_model(conf, task=None)
model.load_state_dict(ckpt['model'])
model.to(device='cuda')
model.eval();

## Attach forward hook to capture attn weights

Note: Some changes were made in the MultiheadAttention module and torch's multi_head_attention_forward function in order to get the attention weights in the desired format (as required by bertviz package).

In [7]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook
    
from fairseq.models.wav2vec import TransformerSentenceEncoderLayer
for name, module in model.named_modules():
    if isinstance(module, TransformerSentenceEncoderLayer):
        # print(name, module.parameters)
        module.register_forward_hook(get_activation(name))

## Module to process audio file

In [8]:
def postprocess(feats, curr_sample_rate, normalize=True):
    if feats.dim() == 2:
        feats = feats.mean(-1)

    assert feats.dim() == 1, feats.dim()

    if normalize:
        with torch.no_grad():
            feats = F.layer_norm(feats, feats.shape)
    return feats

## Load dataset with CTC level tokenization (pre-constructed)

In [25]:
lang = 'hi' # files are loaded by lang name, supported langs: hi, gu, ta, te, be, ne, kn and en
data = pd.read_csv(f'examples/example_outputs.txt', sep='\t', header=None, names = ["idx", "path", "out", "seq"])


## Main Function which returns weights (and token sequence) 

In [10]:
def get_aggregate(tensor, dim, aggregation='max'):
    if aggregation=='max':
        return torch.max(tensor, dim=dim).values.unsqueeze(dim)
    elif aggregation=='sum':
        return torch.sum(tensor, dim=dim).unsqueeze(dim)
    elif aggregation=='mean':
        return torch.mean(tensor, dim=dim).unsqueeze(dim)

In [11]:
def get_viz_params_word(rowno, aggregation='max', remove_bcl=False):

    """
    Params:
        rowno: row index of the loaded pandas dataset
        granularity: several levels of granularity for viz are supported, namely:
                    'None' (speech-unit level), 'char' (character level) or 'word' (word level)
        aggregation: type of aggregation required to club multiple attention weights,
                     supported types: 'max', 'mean', 'sum' 
        remove_bcl: if True, removes all instances of bcls in the file while visualization.
    """
    row = data.iloc[rowno]
    wav, sr = sf.read(row['path'])
    feats = torch.from_numpy(wav).float()
    out = postprocess(feats, sr).unsqueeze(0)
    out = out.to(device='cuda')
    
    with torch.no_grad():
        _ = model(out)
        acts = [ac[1] for ac in activation.values()]
        acts = torch.stack(acts)
        print(acts.shape)
        res = eval(row['seq'])
        lst = [[res[0][0], '|']] + [en for en in res[1:-1] if en[1] != 'pad']
        if lst[-1][1] != '|':
            lst += [[res[-1][0], '|']]
        df = pd.DataFrame(lst, columns=['cid', 'char'])
        df['char'].replace({'|':'bcl'}, inplace=True)

        #evaluating row first
        news = acts
        prev = -1
        a = []
        for i in range(0,len(df)):
            new = int(df['cid'][i])
            nex = int(df['cid'][i+1]) if i<len(df)-1 else -1
            ph = df['char'][i]

            if ph =='bcl':
                if not remove_bcl:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,prev:new,:], 3, aggregation))
                    if not nex == -1:
                        a.append(get_aggregate(news[:,:,:,new:nex,:], 3, aggregation))
                    else:
                        a.append(get_aggregate(news[:,:,:,new:,:], 3, aggregation))
                    prev = nex
                else:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,prev:new,:], 3, aggregation))
                    prev = nex

        news = torch.cat(a, dim=3)
        prev = -1
        a = []
        for i in range(0,len(df)):
            new = int(df['cid'][i])
            nex = int(df['cid'][i+1]) if i<len(df)-1 else -1
            ph = df['char'][i]

            if ph =='bcl':
                if not remove_bcl:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,:,prev:new], 4, aggregation))
                    if not nex == -1:
                        a.append(get_aggregate(news[:,:,:,:,new:nex], 4, aggregation))
                    else:
                        a.append(get_aggregate(news[:,:,:,:,new:], 4, aggregation))
                    prev = nex
                else:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,:,prev:new], 4, aggregation))
                    prev = nex

    final = torch.cat(a, dim=4)

    if remove_bcl:
        tokens = row['out'].split(' ')
    else:
        tokens = row['out'].split(' ')
        new_st = (' bcl '.join(tokens)).lstrip().rstrip()
        tokens = ['bcl'] + new_st.split(' ') + ['bcl']
    final = F.softmax(final, dim=-1, dtype=torch.float32)
    
    return final, tokens

In [20]:
def get_viz_params_char(rowno, aggregation='max', remove_bcl=False):

    """
    Params:
        rowno: row index of the loaded pandas dataset
        granularity: several levels of granularity for viz are supported, namely:
                    'None' (speech-unit level), 'char' (character level) or 'word' (word level)
        aggregation: type of aggregation required to club multiple attention weights,
                     supported types: 'max', 'mean', 'sum' 
        remove_bcl: if True, removes all instances of bcls in the file while visualization.
    """
    row = data.iloc[rowno]
    wav, sr = sf.read(row['path'])
    print(wav.shape)
    feats = torch.from_numpy(wav).float()
    out = postprocess(feats, sr).unsqueeze(0)
    out = out.to(device='cuda')
    
    with torch.no_grad():
        _ = model(out)
        acts = [ac[1] for ac in activation.values()]
        print(acts[0].shape)
        acts = torch.stack(acts)
        res = eval(row['seq'])
        lst = [[res[0][0], '|']] + [en for en in res[1:-1] if en[1] != 'pad']
        if lst[-1][1] != '|':
            lst += [[res[-1][0], '|']]
        df = pd.DataFrame(lst, columns=['cid', 'char'])
        df['char'].replace({'|':'bcl'}, inplace=True)
        #evaluating row first
        news = acts
        print(news.shape)
        prev = -1
        a = []
        for i in range(0,len(df)):
            new = int(df['cid'][i])
            nex = int(df['cid'][i+1]) if i<len(df)-1 else -1
            ph = df['char'][i]

            if ph =='bcl':
                if not remove_bcl:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,prev:new,:], 3, aggregation))
                    if not nex == -1:
                        # print(news.shape, news)
                        a.append(get_aggregate(news[:,:,:,new:nex,:], 3, aggregation))
                    else:
                        a.append(get_aggregate(news[:,:,:,new:,:], 3, aggregation))
                    prev = nex
                else:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,prev:new,:], 3, aggregation))
                    prev = nex
                skipped = True
            else:
                if skipped:
                    skipped = False
                    continue
                if not prev == -1:
                    a.append(get_aggregate(news[:,:,:,prev:new,:], 3, aggregation))
                prev = new

        news = torch.cat(a, dim=3)
        prev = -1
        a = []
        for i in range(0,len(df)):
            new = int(df['cid'][i])
            nex = int(df['cid'][i+1]) if i<len(df)-1 else -1
            ph = df['char'][i]

            if ph =='bcl':
                if not remove_bcl:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,:,prev:new], 4, aggregation))
                    if not nex == -1:
                        a.append(get_aggregate(news[:,:,:,:,new:nex], 4, aggregation))
                    else:
                        a.append(get_aggregate(news[:,:,:,:,new:], 4, aggregation))
                    prev = nex
                else:
                    if not prev == -1:
                        a.append(get_aggregate(news[:,:,:,:,prev:new], 4, aggregation))
                    prev = nex
                i += 1
                skipped = True
            else:
                if skipped:
                    skipped = False
                    continue
                if not prev == -1:
                    a.append(get_aggregate(news[:,:,:,:,prev:new], 4, aggregation))
                prev = new
                

    final = torch.cat(a, dim=4)

    # if remove_bcl:
    #     tokens = row['out'].split(' ')
    # else:
    #     tokens = row['out'].split(' ')
    #     new_st = (' bcl '.join(tokens)).lstrip().rstrip()
    #     tokens = ['bcl'] + new_st.split(' ') + ['bcl']
    final = F.softmax(final, dim=-1, dtype=torch.float32)
    if remove_bcl:
        tokens = df[df['char'] != 'bcl']['char'].tolist()
    else:
        tokens = df['char'].tolist()
        
    return final, tokens
    

In [13]:
def get_viz_params_token(rowno, num_words=1):

    """
    Params:
        rowno: row index of the loaded pandas dataset
        granularity: several levels of granularity for viz are supported, namely:
                    'None' (speech-unit level), 'char' (character level) or 'word' (word level)
        aggregation: type of aggregation required to club multiple attention weights,
                     supported types: 'max', 'mean', 'sum' 
        remove_bcl: if True, removes all instances of bcls in the file while visualization.
    """
    row = data.iloc[rowno]
    wav, sr = sf.read(row['path'])
    feats = torch.from_numpy(wav).float()
    out = postprocess(feats, sr).unsqueeze(0)
    out = out.to(device='cuda')
    
    with torch.no_grad():
        _ = model(out)
        acts = [ac[1] for ac in activation.values()]
        acts = torch.stack(acts)
        res = eval(row['seq'])
        lst = [[res[0][0], '|']] + [en for en in res[1:-1] if en[1] != 'pad']
        if lst[-1][1] != '|':
            lst += [[res[-1][0], '|']]
        df = pd.DataFrame(lst, columns=['cid', 'char'])
        df['char'].replace({'|':'bcl'}, inplace=True)

        tok_seq = []
        prev = 0
        for i in range(1,len(df)):
            new = int(df['cid'][i])
            tok_seq.extend([df['char'][i-1]] * (new-prev))
            ph = df['char'][i]
            if ph =='bcl':
                num_words -= 1
                if num_words == 0:
                    break
            prev = new
    
    final = F.softmax(acts, dim=-1, dtype=torch.float32)
    final = final[:,:,:,:new,:new]
    # final = acts[:,:,:,:new,:new]
        
    return final, tok_seq

In [21]:
lang = 'hi'
rowno = 2
toktype='char' # use 'char' for best visualization
aggregation='max'
remove_bcl=False

In [27]:
# For generating tokens for new speech segments
# if toktype=='speech':
#     final, tokens = get_viz_params_token(rowno, num_words=2)
# elif toktype=='char':
#     final, tokens = get_viz_params_char(rowno, aggregation, remove_bcl)
# elif toktype=='word':
#     final, tokens = get_viz_params_word(rowno, aggregation, remove_bcl)

## Load generated outputs and tokens
final = np.load(f'examples/viz/att_pre_{lang}_{rowno}_{toktype}_{aggregation}_{remove_bcl}.npy')
tokens = np.load(f'examples/viz/tokens_{lang}_{rowno}_{toktype}_{aggregation}_{remove_bcl}.npy')
final = torch.tensor(final, device='cuda')

list_att = [final[i,:,:,:,:] for i in range(final.shape[0])]

# np.save(f'examples/viz/att_pre_{lang}_{rowno}_{toktype}_{aggregation}_{remove_bcl}.npy', final.cpu().numpy())
# np.save(f'examples/viz/tokens_{lang}_{rowno}_{toktype}_{aggregation}_{remove_bcl}.npy', np.array(tokens))

include_layers=(np.array([0, 1, 2, 3])+12).tolist()
model_view(list_att, tokens, include_layers=include_layers, include_heads=[14], display_mode="light")

<IPython.core.display.Javascript object>