Import Packages

In [1]:
from whisper.brain import initialize_from_whisper, WillettDataset, mse_adjusted_temporal_gaussian_infonce_loss
import whisper
import torch
import pandas as pd
import numpy as np
import torch

KeyboardInterrupt: 

Overall Outline: Training WhisperBrain
- Load data
- Preprocess data
- Train Contrastive Alignment of speech embeddings and brain embeddings
- Train brain-to-text encoder-decoder.
- Evaluate on test set
- Save best checkpoint in a separate directory.
- Transcribe holdout set and save to text file for submission (follow guidelines from competition)? Or do in other notebook?

Reference for some finetuning code: https://colab.research.google.com/drive/1P4ClLkPmfsaKn2tBbRp0nVjGMRKR-EWz?usp=sharing

In [None]:
# check if cuda is available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {'CUDA' if DEVICE.type == 'cuda' else 'CPU'}.")
#model = initialize_from_whisper(name='tiny.en')
base_checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
model = load_checkpoint(base_checkpoint_dir + "\checkpoint_3459.pt", 'tiny.en', DEVICE)
model.to(DEVICE)
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
dataset = WillettDataset("K:\ke\sta\data\Willett&EtAl2023\data\Willett&EtAl2023.h5", "train", device=DEVICE)
options = whisper.DecodingOptions(language="en", without_timestamps=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
model.toggle_freeze('audio')

Using CUDA.


NameError: name 'load_checkpoint' is not defined

# Aligning Speech and Brain Data

In [None]:
# Example training loop
def train(model, data_loader, optimizer, alpha=0.5, iqr=20, temperature=0.07, neg_sample_prop=None):
    model.train()
    batches = len(data_loader)
    i = 0
    for brain_data, mels, texts, signal_lengths in data_loader:
        optimizer.zero_grad()
        # Forward pass for brain and audio
        brain_embeddings = model.embed_brain(brain_data)
        audio_embeddings = model.embed_audio(mels)
        # Compute losses
        loss = mse_adjusted_temporal_gaussian_infonce_loss(brain_embeddings, audio_embeddings, signal_lengths, alpha=alpha, iqr=iqr, temperature=temperature, neg_sample_prop=neg_sample_prop)
        # Backpropagation
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        i += 1
        print(f"Loss: {loss.item()} ({i}/{batches})         ", end="\r")

In [None]:
train(model, dataloader, optimizer, alpha=1, iqr=5, temperature=0.2, neg_sample_prop=None)

  return F.conv3d(


Loss: 4.32578706741333 (14/3460)          

KeyboardInterrupt: 

In [None]:
# save model
checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}.pt")

# Train Brain Encoder with Sequence-to-Sequence Model

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
tokenizer = whisper.tokenizer.get_tokenizer("en")
text = "The quick brown fox jumps over the lazy dog."
brain_data, mels, texts, signal_lengths = dataloader.__iter__().__next__()
def clean_texts(texts):
    new_texts = []
    for text in texts:
        text = text.strip()
        # remove all puncutation except apostrophes for truncated words
        text = text.translate(str.maketrans("", "", ".,!?()[]{};:<>"))
        # remove double spaces
        text = text.replace("  ", " ")
        new_texts.append(text.strip())
    return tuple(new_texts)
def encode_texts(texts):
    # remove white space from start and end
    sots = list(tokenizer.sot_sequence_including_notimestamps)
    eot = [tokenizer.eot]
    token_map = map(tokenizer.encode, texts)
    token_array = []
    for tokens in token_map:
        token_array.append(sots + tokens + eot)
    token_array = torch.tensor(token_array, dtype=torch.long, device=DEVICE)
    return token_array
cleaned_texts = clean_texts(texts)
tokens = encode_texts(cleaned_texts)
print(tokens)
print(cleaned_texts)

tensor([[50258, 50259, 50359, 50363, 22493,   749,   382,  2368,   382,   264,
          3295, 50257]], device='cuda:0')
('Boys as hot as the sun',)


In [None]:
import torch
import pandas as pd
import random
from collections import Counter

checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
seq2seq_loss = torch.nn.CrossEntropyLoss()

import torch.nn.functional as F

def wasserstein_loss(brain_logits, audio_logits):
    brain_probs = F.softmax(brain_logits, dim=-1)
    audio_probs = F.softmax(audio_logits, dim=-1)

    # Cumulative distribution function (CDF) for each probability distribution
    brain_cdf = torch.cumsum(brain_probs, dim=-1)
    audio_cdf = torch.cumsum(audio_probs, dim=-1)

    # Wasserstein distance is the L1 distance between the CDFs
    return torch.mean(torch.abs(brain_cdf - audio_cdf))

def logit_kld_loss(brain_logits, audio_logits):
    brain_probs = F.log_softmax(brain_logits, dim=-1)
    audio_probs = F.softmax(audio_logits, dim=-1)
    return F.kl_div(brain_probs, audio_probs, reduction='batchmean')

mse = torch.nn.MSELoss()

def alignment_loss(brain_embeddings, audio_embeddings):
    # mean square error loss
    mse_loss = F.mse_loss(brain_embeddings, audio_embeddings)
    # cosine similarity along the time axis
    brain_norm = F.normalize(brain_embeddings, dim=1)
    audio_norm = F.normalize(audio_embeddings, dim=1)
    cosine_sim = torch.einsum('nct,ncp->ntp', brain_norm, audio_norm)
    # mean of the diagonal
    time_alignment = cosine_sim.diagonal(dim1=1, dim2=2).mean()
    # cosine similarity along the feature axis
    brain_norm = F.normalize(brain_embeddings, dim=2)
    audio_norm = F.normalize(audio_embeddings, dim=2)
    cosine_sim = torch.einsum('nct,npt->nctp', brain_norm, audio_norm)
    # mean of the diagonal
    feature_alignment = cosine_sim.diagonal(dim1=2, dim2=3).mean()
    return mse_loss, time_alignment, feature_alignment

def train_step_with_scheduled_sampling(brain_data, mels, texts, optimizer, model, beta, sampling_prob=0.1):
    optimizer.zero_grad()
    # Forward pass for brain and audio
    brain_embeddings = model.embed_brain(brain_data)
    audio_embeddings = model.embed_audio(mels)
    # Compute losses
    align_mse, align_cos_time, align_cos_feat = alignment_loss(brain_embeddings, audio_embeddings)
    texts = clean_texts(texts)
    tokens = encode_texts(texts)

    # Start and end tokens
    start_token = tokens[0, 0].item()
    end_token = tokens[0, -1].item()

    brain_seq2seq_loss = 0
    audio_seq2seq_loss = 0

    brain_generated_tokens = [start_token]
    audio_generated_tokens = [start_token]

    logit_loss = 0
    logit_mse = 0
    logit_kld = 0

    for t in range(1, tokens.size(1)):
        if random.random() < sampling_prob and t > 1:
            brain_input = torch.tensor([brain_generated_tokens], device=brain_embeddings.device)
            audio_input = torch.tensor([audio_generated_tokens], device=audio_embeddings.device)
        else:
            brain_input = tokens[:, :t]
            audio_input = tokens[:, :t]

        brain_logits = model.logits(brain_input, brain_embeddings)
        audio_logits = model.logits(audio_input, audio_embeddings)

        logit_loss += wasserstein_loss(brain_logits[:, -1, :], audio_logits[:, -1, :])
        logit_mse += mse(brain_logits[:, -1, :], audio_logits[:, -1, :])
        logit_kld += logit_kld_loss(brain_logits[:, -1, :], audio_logits[:, -1, :])

        brain_next_token = brain_logits[:, -1].argmax(dim=-1).item()
        audio_next_token = audio_logits[:, -1].argmax(dim=-1).item()

        brain_generated_tokens.append(brain_next_token)
        audio_generated_tokens.append(audio_next_token)

        brain_seq2seq_loss += seq2seq_loss(brain_logits[:, -1, :], tokens[:, t])
        audio_seq2seq_loss += seq2seq_loss(audio_logits[:, -1, :], tokens[:, t])

    brain_seq2seq_loss /= tokens.size(1)
    audio_seq2seq_loss /= tokens.size(1)
    logit_loss /= tokens.size(1)
    logit_mse /= tokens.size(1)
    logit_kld /= tokens.size(1)

    # get token agreement between brain and audio (allows for different length sequences)
    end_a = min(len(brain_generated_tokens), len(audio_generated_tokens))
    size_a = max(len(brain_generated_tokens), len(audio_generated_tokens))
    token_disagreement = (sum([1 for b, a in zip(brain_generated_tokens[:end_a], audio_generated_tokens[:end_a]) if b != a]) + size_a - end_a) / size_a

    end_t = min(len(brain_generated_tokens), tokens.size(1))
    size_t = max(len(brain_generated_tokens), tokens.size(1))
    token_error = (sum([1 for b, t in zip(brain_generated_tokens[:end_t], tokens[0, :end_t].tolist()) if b != t]) + size_t - end_t) / size_t

    end_ta = min(len(audio_generated_tokens), tokens.size(1))
    size_ta = max(len(audio_generated_tokens), tokens.size(1))
    audio_token_error = (sum([1 for b, t in zip(audio_generated_tokens[:end_ta], tokens[0, :end_ta].tolist()) if b != t]) + size_ta - end_ta) / size_ta

    # Backpropagation
    loss = beta * align_mse + (1 - beta) * logit_kld * 10
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()

    return loss, brain_seq2seq_loss, audio_seq2seq_loss, logit_loss, logit_mse, logit_kld, align_mse, align_cos_time, align_cos_feat, brain_generated_tokens, audio_generated_tokens, token_disagreement, token_error, audio_token_error


def train_full(model, data_loader, optimizer, beta, sampling_prob=0.1):
    timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
    model.train()
    batches = len(data_loader)
    i = 0
    log = {
        "text": [],
        "length": [],
        "loss": [],
        "audio_seq_loss": [],
        "brain_seq_loss": [],
        "align_mse": [],
        "align_cos_feat": [],
        "align_cos_time": [],
        "logit_was": [],
        "logit_mse": [],
        "logit_kldiv": [],
        "token_disagreement": [],
        "token_error": [],
        "audio_token_error": []
    }
    for brain_data, mels, texts, signal_lengths in data_loader:
        loss, brain_seq2seq_loss, audio_seq2seq_loss, align_mse, align_cos_time, align_cos_feat, logit_loss, logit_mse, logit_kld, brain_generated_tokens, audio_generated_tokens, token_disagreement, token_error, audio_token_error = train_step_with_scheduled_sampling(
            brain_data=brain_data,
            mels=mels,
            texts=texts,
            optimizer=optimizer,
            model=model,
            beta=beta,
            sampling_prob=sampling_prob
        )
        # Check token diversity
        brain_token_counter = Counter(brain_generated_tokens)
        audio_token_counter = Counter(audio_generated_tokens)

        brain_unique_tokens = len(brain_token_counter)
        audio_unique_tokens = len(audio_token_counter)

        most_common_brain_token, brain_token_count = brain_token_counter.most_common(1)[0]
        most_common_audio_token, audio_token_count = audio_token_counter.most_common(1)[0]

        #print(f"Batch {i+1}/{batches}: Brain unique tokens: {brain_unique_tokens}, Most common: {most_common_brain_token} ({brain_token_count} / {len(brain_generated_tokens)})")
        #print(f"Batch {i+1}/{batches}: Audio unique tokens: {audio_unique_tokens}, Most common: {most_common_audio_token} ({audio_token_count} / {len(audio_generated_tokens)})")

        if i % 100 == 0:
            torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-{i}.pt")
            try:
                df = pd.DataFrame(log)
                df.to_csv(checkpoint_dir + f"\{timestamp}-log.csv", mode='a', header=False, index=False)
            except Exception as e:
                message = f"Error saving log \n{log}\n {e}"
                raise Exception(message)
        log = {
            "text": [],
            "length": [],
            "loss": [],
            "audio_seq_loss": [],
            "brain_seq_loss": [],
            "align_mse": [],
            "align_cos_feat": [],
            "align_cos_time": [],
            "logit_was": [],
            "logit_mse": [],
            "logit_kldiv": [],
            "token_disagreement": [],
            "token_error": [],
            "audio_token_error": []
        }
        i += 1
        log["text"] += [texts[0]]
        log["length"] += [(signal_lengths.item()*20)/1000]
        log["loss"] += [loss.item()]
        log["audio_seq_loss"] += [audio_seq2seq_loss.item()]
        log["brain_seq_loss"] += [brain_seq2seq_loss.item()]
        log["align_mse"] += [align_mse]
        log["align_cos_feat"] += [align_cos_feat]
        log["align_cos_time"] += [align_cos_time]
        log["logit_was"] += [logit_loss.item()]
        log["logit_mse"] += [logit_mse.item()]
        log["logit_kldiv"] += [logit_kld.item()]
        log["token_disagreement"] += [token_disagreement]
        log["token_error"] += [token_error]
        log["audio_token_error"] += [audio_token_error]

        print(f"Loss: {loss.item():.4f} ({i}/{batches}) align: {align_mse.item():.4f} logit: {logit_kld*10:.4f}")
    df = pd.DataFrame(log)
    df.to_csv(checkpoint_dir + f"\{timestamp}-log.csv", mode='a', header=False, index=False)
    torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-final.pt")
    return log, timestamp

In [None]:
# load checkpoint
checkpoint = torch.load(checkpoint_dir + '/20240529063848-final.pt')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [None]:
model.toggle_freeze(part='whisper', unfreeze=False)
lr = 0.0001
sampling_prob = 0
for i in range(0, 11):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sampling_prob += i/20
    log, run_timestamp = train_full(model, dataloader, optimizer, beta=0.5, sampling_prob=sampling_prob)

Loss: 0.8632 (1/3460) align: 0.0427 logit: 0.0067
Loss: 0.9072 (2/3460) align: 0.0549 logit: 0.0069
Loss: 0.6970 (3/3460) align: 0.0338 logit: 0.0084
Loss: 1.8975 (4/3460) align: 0.1113 logit: 0.0087
Loss: 1.2560 (5/3460) align: 0.0995 logit: 0.0060
Loss: 0.8794 (6/3460) align: 0.0893 logit: 0.0066
Loss: 0.7782 (7/3460) align: 0.0694 logit: 0.0057
Loss: 0.8075 (8/3460) align: 0.0450 logit: 0.0088
Loss: 0.7468 (9/3460) align: 0.0352 logit: 0.0050
Loss: 0.6189 (10/3460) align: 0.0327 logit: 0.0068
Loss: 0.8085 (11/3460) align: 0.0313 logit: 0.0064
Loss: 0.6643 (12/3460) align: 0.0552 logit: 0.0079
Loss: 0.6293 (13/3460) align: 0.0316 logit: 0.0065
Loss: 0.8547 (14/3460) align: 0.0607 logit: 0.0073
Loss: 0.5996 (15/3460) align: 0.0274 logit: 0.0081
Loss: 0.5374 (16/3460) align: 0.0259 logit: 0.0081
Loss: 0.5252 (17/3460) align: 0.0209 logit: 0.0081
Loss: 0.7090 (18/3460) align: 0.0444 logit: 0.0044
Loss: 0.4631 (19/3460) align: 0.0292 logit: 0.0062
Loss: 0.8922 (20/3460) align: 0.0415 log

KeyboardInterrupt: 

In [None]:
# load from checkpoint
checkpoint = torch.load(checkpoint_dir + "\\20240528124126-final.pt")
model.toggle_freeze(part='audio', unfreeze=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
log, run_timestamp = train_full(model, dataloader, optimizer, loss_fn, mse_adjusted_temporal_gaussian_infonce_loss, beta=0.8, alpha=0.05, iqr=5, temperature=0.2, neg_sample_prop=7)
tough_sentences = log.loc[(log["ASL"] > 0.5) | (log["BSL"] > 0.5)]
tough_sentences

In [None]:
# save model any time
checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-final.pt")

# Evaluate on Test Set

In [None]:
import torch
from whisper import DecodingOptions
import pandas as pd
import jiwer
from whisper.normalizers import EnglishTextNormalizer
from whisper.decoding import DecodingTask

@torch.no_grad()
def decode_function(model, brain_data, options):
    result = DecodingTask(model, options).run(brain_data)
    return result

model.decode = decode_function

testset = WillettDataset("K:\ke\sta\data\Willett&EtAl2023\data\Willett&EtAl2023.h5", "test", device=DEVICE)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)

options = DecodingOptions(
    language="en",
    without_timestamps=True,
    fp16=False
    )
normalizer = EnglishTextNormalizer()

def decode_seq2seq(model, dataloader, options, sample_size=10):
    model.eval()

    brain_hypotheses = []
    audio_hypotheses = []
    references = []
    batches = len(dataloader)
    i = 0
    for brain_data, mels, texts, signal_lengths in dataloader:

        print(f"Batch {i+1}/{batches} ({len(brain_hypotheses)} hypotheses so far)                        ", end="\r")

        # Decode the neural data
        model.toggle_mode("brain")
        neural_results = decode_function(model, brain_data, options)
        brain_hypotheses.extend([result.text for result in neural_results])
        model.toggle_mode("audio")
        audio_results = decode_function(model, mels, options)
        audio_hypotheses.extend([result.text for result in audio_results])
        references.extend(texts)

        if i == sample_size:
            break
        i += 1

    data = pd.DataFrame(dict(
        brain_hypotheses=brain_hypotheses,
        audio_hypotheses=audio_hypotheses,
        reference=references))
    data["brain_hypotheses_clean"] = [normalizer(text) for text in data["brain_hypotheses"]]
    data["audio_hypotheses_clean"] = [normalizer(text) for text in data["audio_hypotheses"]]
    data["reference_clean"] = [normalizer(text) for text in data["reference"]]
    return data

def wer(hypothesis, reference):
    return jiwer.wer(reference, hypothesis)

In [None]:
results = decode_seq2seq(model, testloader, options, sample_size=10)
results

Batch 11/880 (10 hypotheses so far)                        

Unnamed: 0,brain_hypotheses,audio_hypotheses,reference,brain_hypotheses_clean,audio_hypotheses_clean,reference_clean
0,Thank you.,Well tell me this.,Well tell me this.,thank you,well tell me this,well tell me this
1,Thank you.,Clothing and gas,Clothing and gas.,thank you,clothing and gas,clothing and gas
2,Thank you.,What are we doing?,What are we doing?,thank you,what are we doing,what are we doing
3,Thank you.,I don't know how my father did it.,I don't know how my father did it.,thank you,i do not know how my father did it,i do not know how my father did it
4,Thank you.,You can do it.,You could do it.,thank you,you can do it,you could do it
5,Thank you.,I'll try that next week.,I'll try that next week. ...,thank you,i will try that next week,i will try that next week
6,Thank you.,Friday afternoon at 5.30,Friday afternoon at five thirty. ...,thank you,friday afternoon at 5.30,friday afternoon at 530
7,Thank you.,Very well persuaded.,Very well persuaded. ...,thank you,very well persuaded,very well persuaded
8,Thank you.,Those answers will be straightforward if you t...,Those answers will be straightforward if you t...,thank you,those answers will be straightforward if you t...,those answers will be straightforward if you t...
9,Thank you.,Just about any kind of music.,Just about any kind of music.,thank you,just about any kind of music,just about any kind of music


In [None]:
brain_wer = jiwer.wer(list(results["reference_clean"]), list(results["brain_hypotheses_clean"]))
audio_wer = jiwer.wer(list(results["reference_clean"]), list(results["audio_hypotheses_clean"]))

print(f"Brain WER: {brain_wer * 100:.2f} %")
print(f"Audio WER: {audio_wer * 100:.2f} %")

Brain WER: 97.01 %
Audio WER: 2.99 %
