# Data Preparation

In [None]:
import json
import glob
import os

In [None]:
root_dir="/home/wonkyum/fc-asr/gridspace-stanford-harper-valley"
transcript_files=glob.glob(os.path.join(root_dir, "data/transcript", "*.json"))

In [None]:
all_data=[]
for jsonfile in transcript_files:
    jsondata=json.load(open(jsonfile,'r'))
    for datum in jsondata:
        if "transcript" in datum and "dialog_acts" in datum:
            if len(datum["transcript"]) > 20:
                all_data.append({"transcript": datum["speaker_role"]+": "+datum["human_transcript"], "dialog_acts": datum["dialog_acts"]})

In [None]:
all_data[0]

In [None]:
all_labels=[]
for datum in all_data:
    for label in datum["dialog_acts"]:
        all_labels.append(label)

In [None]:
unique_labels=sorted(list(set(all_labels)))

In [None]:
unique_labels

In [None]:
len(unique_labels)

# Dataset

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertModel, BertTokenizer

import torch

In [None]:
mlb = MultiLabelBinarizer(classes=unique_labels)

In [None]:
mlb.fit_transform([["gridspace_bear_with_me", "gridspace_acknowledgement"]])

In [None]:
texts=[]
labels=[]

for datum in all_data:
    texts.append(datum["transcript"])
    labels.append(datum["dialog_acts"])

In [None]:
labels[0], labels[1], labels[2]

In [None]:
binary_labels=mlb.fit_transform(labels)

In [None]:
binary_labels[0], binary_labels[1], binary_labels[2]

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def encode_texts(texts):
    return tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")


In [None]:
encoded_inputs = encode_texts(texts)
binary_labels = torch.tensor(binary_labels)

In [None]:
dataset = TensorDataset(encoded_inputs['input_ids'], encoded_inputs['attention_mask'], binary_labels)
dataloader = DataLoader(dataset, batch_size=32)

# Modeling

In [None]:
import torch.nn as nn
from transformers import BertModel, BertTokenizer


class BertForMultiLabelClassification(nn.Module):
    def __init__(self, num_labels):
        super(BertForMultiLabelClassification, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')  # Load pre-trained BERT
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return torch.sigmoid(logits)  # Apply sigmoid to output logits for binary classification


# Train the model

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
import torch.optim as optim

model = BertForMultiLabelClassification(num_labels=len(unique_labels))
# Freeze all the parameters in the BERT model
for param in model.bert.parameters():
    param.requires_grad = False

# Move the model to the GPU after freezing the parameters
model = model.to("cuda")

n = count_parameters(model)
print("Number of parameters: %s" % n)



In [None]:
model

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_function = nn.BCELoss()

num_epochs=10
model.train()
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        input_ids, attention_mask, labels = [b.to("cuda") for b in batch]
        outputs = model(input_ids, attention_mask)
        loss = loss_function(outputs, labels.float())
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# ASR + Dialog Acts

In [None]:
import pathlib
from argparse import ArgumentParser

import sentencepiece as spm

import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
import json
from torchaudio.models import Hypothesis, RNNTBeamSearch
from typing import List, Tuple
import math
from IPython.display import Audio


In [None]:
sp_model = spm.SentencePieceProcessor(model_file='/home/wonkyum/fc-asr/spm_unigram_1023.model')

In [None]:
checkpoint_path = '/home/wonkyum/fc-asr/exp/checkpoints/epoch=21-step=1451337.ckpt'

In [None]:
rnnt_module = ConformerRNNTModule.load_from_checkpoint(checkpoint_path, sp_model=sp_model).eval()

In [None]:
rnnt_module.model.to("cuda")
decoder = RNNTBeamSearch(rnnt_module.model, 1023)

In [None]:
def post_process_hypos(
    hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
    tokens_idx = 0
    score_idx = 3
    post_process_remove_list = [
        sp_model.unk_id(),
        sp_model.eos_id(),
        sp_model.pad_id(),
    ]
    filtered_hypo_tokens = [
        [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
    ]
    hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
    hypos_ids = [h[tokens_idx][1:] for h in hypos]
    hypos_score = [[math.exp(h[score_idx])] for h in hypos]

    nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))

    return nbest_batch

In [None]:
def _piecewise_linear_log(x):
    x = x * _gain
    x[x > math.e] = torch.log(x[x > math.e])
    x[x <= math.e] = x[x <= math.e] / math.e
    return x


class FunctionalModule(torch.nn.Module):
    def __init__(self, functional):
        super().__init__()
        self.functional = functional

    def forward(self, input):
        return self.functional(input)

class GlobalStatsNormalization(torch.nn.Module):
    def __init__(self, global_stats_path):
        super().__init__()

        with open(global_stats_path) as f:
            blob = json.loads(f.read())

        self.mean = torch.tensor(blob["mean"])
        self.invstddev = torch.tensor(blob["invstddev"])

    def forward(self, input):
        return (input - self.mean) * self.invstddev

In [None]:
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)



def run_decoder(waveform):
    extra_pipeline= torch.nn.Sequential(
            FunctionalModule(_piecewise_linear_log),
            GlobalStatsNormalization('./global_stats.json'),
    )
    mel_f = _spectrogram_transform(waveform[0].squeeze()).transpose(1, 0)
    mel_f = torch.nn.utils.rnn.pad_sequence(mel_f, batch_first=True)
    feats=extra_pipeline(mel_f)
    lengths=torch.tensor(feats.shape[0])
    hypotheses = decoder(feats.to("cuda"), lengths.to("cuda"), 20)
    result=post_process_hypos(hypotheses, sp_model)
    return result[0][0]






In [None]:
my_wave_form, samplerate=torchaudio.load('/home/wonkyum/help.wav')
Audio(my_wave_form.numpy(), rate=samplerate)

In [None]:
asr_text=run_decoder(my_wave_form)
print(asr_text)

In [None]:
encoded_inputs = encode_texts("agent: "+asr_text)
model(encoded_inputs['input_ids'].to("cuda"), encoded_inputs['attention_mask'].to("cuda"))