Import Packages

In [1]:
from whisper.brain import initialize_from_whisper, WillettDataset, mse_adjusted_temporal_gaussian_infonce_loss
import whisper
import torch
import pandas as pd
import numpy as np
import torch

Overall Outline: Training WhisperBrain
- Load data
- Preprocess data
- Train Contrastive Alignment of speech embeddings and brain embeddings
- Train brain-to-text encoder-decoder.
- Evaluate on test set
- Save best checkpoint in a separate directory.
- Transcribe holdout set and save to text file for submission (follow guidelines from competition)? Or do in other notebook?

Reference for some finetuning code: https://colab.research.google.com/drive/1P4ClLkPmfsaKn2tBbRp0nVjGMRKR-EWz?usp=sharing

In [2]:
# check if cuda is available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {'CUDA' if DEVICE.type == 'cuda' else 'CPU'}.")
model = initialize_from_whisper(name='tiny.en')
model.to(DEVICE)
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
dataset = WillettDataset("K:\ke\sta\data\Willett&EtAl2023\data\Willett&EtAl2023.h5", "train", device=DEVICE)
options = whisper.DecodingOptions(language="en", without_timestamps=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
model.toggle_freeze('audio')

Using CUDA.
Model is English-only and has 45,115,760 parameters.


# Aligning Speech and Brain Data

In [20]:
# Example training loop
def train(model, data_loader, optimizer, alpha=0.5, iqr=20, temperature=0.07, neg_sample_prop=None):
    model.train()
    batches = len(data_loader)
    i = 0
    for brain_data, mels, texts, signal_lengths in data_loader:
        optimizer.zero_grad()
        # Forward pass for brain and audio
        brain_embeddings = model.embed_brain(brain_data)
        audio_embeddings = model.embed_audio(mels)
        # Compute losses
        loss = mse_adjusted_temporal_gaussian_infonce_loss(brain_embeddings, audio_embeddings, signal_lengths, alpha=alpha, iqr=iqr, temperature=temperature, neg_sample_prop=neg_sample_prop)
        # Backpropagation
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        i += 1
        print(f"Loss: {loss.item()} ({i}/{batches})         ", end="\r")

In [30]:
train(model, dataloader, optimizer, alpha=0.1, iqr=1, temperature=0.2, neg_sample_prop=7)

Loss: 0.26037299633026123 (199/3460)

KeyboardInterrupt: 

In [None]:
# save model
checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}.pth")

# Train Brain Encoder with Sequence-to-Sequence Model

In [3]:
loss_fn = torch.nn.CrossEntropyLoss()
tokenizer = whisper.tokenizer.get_tokenizer("en")
text = "The quick brown fox jumps over the lazy dog."
tokens = tokenizer.encode(text)
brain_data, mels, texts, signal_lengths = dataloader.__iter__().__next__()
def encode_texts(texts):
    # remove white space from start and end
    for sentence in texts:
        sentence = sentence.strip()
    sots = list(tokenizer.sot_sequence_including_notimestamps)
    eot = [tokenizer.eot]
    token_map = map(tokenizer.encode, texts)
    token_array = []
    for tokens in token_map:
        token_array.append(sots + tokens + eot)
    token_array = torch.tensor(token_array, dtype=torch.long, device=DEVICE)
    return token_array
tokens = encode_texts(texts)
print(tokens)

tensor([[50258, 50259, 50359, 50363,    40,   390, 11679,   538,   552,    13,
           220,   220,   220,   220,   220,   220,   220,   220,   220,   220,
           220,   220,   220,   220,   220,   220,   220,   220,   220,   220,
           220,   220,   220,   220,   220, 50257]], device='cuda:0')


In [32]:
def train_seq2seq(model, data_loader, optimizer, loss_fn, tokenizer):
    model.train()
    batches = len(data_loader)
    i = 0
    for brain_data, mels, texts, signal_lengths in data_loader:
        optimizer.zero_grad()
        # Forward pass for brain and audio
        brain_embeddings = model.embed_brain(brain_data)
        tokens = encode_texts(texts)
        logits = model.logits(tokens, brain_embeddings)
        loss = loss_fn(logits.view(-1, logits.size(-1)), tokens.view(-1))
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        i += 1
        print(f"Loss: {loss.item()} ({i}/{batches})", end="\r")

In [25]:
import torch
import pandas as pd
import random
from collections import Counter

checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"

def train_step_with_scheduled_sampling(brain_data, mels, texts, signal_lengths, optimizer, model, seq2seq_loss, beta, alpha=0.5, iqr=20, temperature=0.07, neg_sample_prop=None, max_length=150, sampling_prob=0.1):
    optimizer.zero_grad()
    # Forward pass for brain and audio
    brain_embeddings = model.embed_brain(brain_data)
    audio_embeddings = model.embed_audio(mels)
    # Compute losses
    alignment_loss = mse_adjusted_temporal_gaussian_infonce_loss(brain_embeddings, audio_embeddings, signal_lengths, alpha=alpha, iqr=iqr, temperature=temperature, neg_sample_prop=neg_sample_prop)
    tokens = encode_texts(texts)

    # Start and end tokens
    start_token = tokens[0, 0].item()
    end_token = tokens[0, -1].item()

    brain_seq2seq_loss = 0
    audio_seq2seq_loss = 0

    brain_generated_tokens = [start_token]
    audio_generated_tokens = [start_token]

    for t in range(1, tokens.size(1)):
        if random.random() < sampling_prob and t > 1:
            brain_input = torch.tensor([brain_generated_tokens], device=brain_embeddings.device)
            audio_input = torch.tensor([audio_generated_tokens], device=audio_embeddings.device)
        else:
            brain_input = tokens[:, :t]
            audio_input = tokens[:, :t]

        brain_logits = model.logits(brain_input, brain_embeddings)
        audio_logits = model.logits(audio_input, audio_embeddings)

        brain_next_token = brain_logits[:, -1].argmax(dim=-1).item()
        audio_next_token = audio_logits[:, -1].argmax(dim=-1).item()

        brain_generated_tokens.append(brain_next_token)
        audio_generated_tokens.append(audio_next_token)

        brain_seq2seq_loss += seq2seq_loss(brain_logits[:, -1, :], tokens[:, t])
        audio_seq2seq_loss += seq2seq_loss(audio_logits[:, -1, :], tokens[:, t])

    brain_seq2seq_loss /= tokens.size(1)
    audio_seq2seq_loss /= tokens.size(1)

    # Backpropagation
    loss = beta * (brain_seq2seq_loss + audio_seq2seq_loss) / 2 + (1 - beta) * alignment_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.5)
    optimizer.step()

    return loss, brain_seq2seq_loss, audio_seq2seq_loss, alignment_loss, brain_generated_tokens, audio_generated_tokens


def train_full(model, data_loader, optimizer, seq2seq_loss, alignment_loss, beta, alpha=0.5, iqr=20, temperature=0.07, neg_sample_prop=None, max_length=150, sampling_prob=0.1):
    timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
    model.train()
    batches = len(data_loader)
    i = 0
    log = {
        "text": [],
        "length": [],
        "loss": [],
        "ASL": [],
        "BSL": [],
        "ACL": [],
    }
    for brain_data, mels, texts, signal_lengths in data_loader:
        loss, brain_seq2seq_loss, audio_seq2seq_loss, alignment_loss, brain_generated_tokens, audio_generated_tokens = train_step_with_scheduled_sampling(
            brain_data=brain_data,
            mels=mels,
            texts=texts,
            signal_lengths=signal_lengths,
            optimizer=optimizer,
            model=model,
            seq2seq_loss=seq2seq_loss,
            beta=beta,
            alpha=alpha,
            iqr=iqr,
            temperature=temperature,
            neg_sample_prop=neg_sample_prop,
            max_length=150,
            sampling_prob=sampling_prob
        )
        # Check token diversity
        brain_token_counter = Counter(brain_generated_tokens)
        audio_token_counter = Counter(audio_generated_tokens)

        brain_unique_tokens = len(brain_token_counter)
        audio_unique_tokens = len(audio_token_counter)

        most_common_brain_token, brain_token_count = brain_token_counter.most_common(1)[0]
        most_common_audio_token, audio_token_count = audio_token_counter.most_common(1)[0]

        print(f"Batch {i+1}/{batches}: Brain unique tokens: {brain_unique_tokens}, Most common: {most_common_brain_token} ({brain_token_count} / {len(brain_generated_tokens)})")
        print(f"Batch {i+1}/{batches}: Audio unique tokens: {audio_unique_tokens}, Most common: {most_common_audio_token} ({audio_token_count} / {len(audio_generated_tokens)})")

        if i % 100 == 0:
            torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-{i}.pth")
            df = pd.DataFrame(log)
            df.to_csv(checkpoint_dir + f"\{timestamp}-log.csv", mode='a', header=False, index=False)
            log = {
                "text": [],
                "length": [],
                "loss": [],
                "ASL": [],
                "BSL": [],
                "ACL": [],
            }
        i += 1
        log["text"] += [texts[0]]
        log["length"] += [(signal_lengths.item()*20)/1000]
        log["loss"] += [loss.item()]
        log["ASL"] += [audio_seq2seq_loss.item()]
        log["BSL"] += [brain_seq2seq_loss.item()]
        log["ACL"] += [alignment_loss.item()]
        print(f"Loss: {loss.item()} ({i}/{batches}, BSL: {brain_seq2seq_loss}, ASL: {audio_seq2seq_loss} ACL: {alignment_loss})                          \nLEN: {(signal_lengths.item()*20)/1000} s                        \nTXT: {texts[0]}                                                                      ")
    df = pd.DataFrame(log)
    df.to_csv(checkpoint_dir + f"\{timestamp}-log.csv", mode='a', header=False, index=False)
    torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-final.pth")
    return log, timestamp


In [18]:
# load checkpoint
#checkpoint = torch.load("K:\ke\sta\data\Willett&EtAl2023\checkpoints\\20240527155310.pth")
#model.load_state_dict(checkpoint)

In [27]:
model.toggle_freeze(part='audio', unfreeze=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
log, run_timestamp = train_full(model, dataloader, optimizer, loss_fn, mse_adjusted_temporal_gaussian_infonce_loss, beta=0.5, alpha=0.1, iqr=10, temperature=0.2, neg_sample_prop=7, max_length=150, sampling_prob=0.1)
tough_sentences = log.loc[(log["ASL"] > 0.5) | (log["BSL"] > 0.5)]
tough_sentences

Batch 1/3460: Brain unique tokens: 7, Most common: 220 (31 / 42)
Batch 1/3460: Audio unique tokens: 7, Most common: 220 (32 / 42)


Loss: 1.0050287246704102 (1/3460, BSL: 1.6450364589691162, ASL: 1.6509501934051514 ACL: 0.3620641231536865)                          
LEN: 6.46 s                        
Batch 2/3460: Brain unique tokens: 7, Most common: 220 (35 / 49)                                                                      
Batch 2/3460: Audio unique tokens: 7, Most common: 220 (41 / 49)
Loss: 1.3242324590682983 (2/3460, BSL: 2.221036672592163, ASL: 1.9877849817276 ACL: 0.5440540313720703)                          
LEN: 6.24 s                        
Batch 3/3460: Brain unique tokens: 8, Most common: 13 (9 / 18)                                                                                        
Batch 3/3460: Audio unique tokens: 8, Most common: 13 (9 / 18)
Loss: 3.69403076171875 (3/3460, BSL: 6.870942115783691, ASL: 6.8098464012146 ACL: 0.5476672649383545)                          
LEN: 9.1 s                        
Batch 4/3460: Brain unique tokens: 7, Most common: 220 (47 / 60)                       

KeyboardInterrupt: 

In [None]:
# load from checkpoint
checkpoint = torch.load(checkpoint_dir + "\\20240528124126-final.pth")
model.toggle_freeze(part='audio', unfreeze=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
log, run_timestamp = train_full(model, dataloader, optimizer, loss_fn, mse_adjusted_temporal_gaussian_infonce_loss, beta=0.8, alpha=0.05, iqr=5, temperature=0.2, neg_sample_prop=7)
tough_sentences = log.loc[(log["ASL"] > 0.5) | (log["BSL"] > 0.5)]
tough_sentences

In [45]:
# save model any time
checkpoint_dir = "K:\ke\sta\data\Willett&EtAl2023\checkpoints"
import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
timestamp = pd.Timestamp.now().strftime("%Y%m%d%H%M%S")
torch.save(model.state_dict(), checkpoint_dir + f"\{timestamp}-final.pth")

# Evaluate on Test Set

In [53]:
import torch
from whisper import DecodingOptions
import pandas as pd
import jiwer
from whisper.normalizers import EnglishTextNormalizer
from whisper.decoding import DecodingTask

@torch.no_grad()
def decode_function(model, brain_data, options):
    result = DecodingTask(model, options).run(brain_data)
    return result

model.decode = decode_function

testset = WillettDataset("K:\ke\sta\data\Willett&EtAl2023\data\Willett&EtAl2023.h5", "test", device=DEVICE)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)

options = DecodingOptions(
    language="en",
    without_timestamps=True,
    fp16=False
    )
normalizer = EnglishTextNormalizer()

def decode_seq2seq(model, dataloader, options, sample_size=10):
    model.eval()

    brain_hypotheses = []
    audio_hypotheses = []
    references = []
    batches = len(dataloader)
    i = 0
    for brain_data, mels, texts, signal_lengths in dataloader:

        print(f"Batch {i+1}/{batches} ({len(brain_hypotheses)} hypotheses so far)                        ", end="\r")

        # Decode the neural data
        model.toggle_mode("brain")
        neural_results = decode_function(model, brain_data, options)
        brain_hypotheses.extend([result.text for result in neural_results])
        model.toggle_mode("audio")
        audio_results = decode_function(model, mels, options)
        audio_hypotheses.extend([result.text for result in audio_results])
        references.extend(texts)

        if i == sample_size:
            break
        i += 1

    data = pd.DataFrame(dict(
        brain_hypotheses=brain_hypotheses,
        audio_hypotheses=audio_hypotheses,
        reference=references))
    data["brain_hypotheses_clean"] = [normalizer(text) for text in data["brain_hypotheses"]]
    data["audio_hypotheses_clean"] = [normalizer(text) for text in data["audio_hypotheses"]]
    data["reference_clean"] = [normalizer(text) for text in data["reference"]]
    return data

def wer(hypothesis, reference):
    return jiwer.wer(reference, hypothesis)

In [54]:
results = decode_seq2seq(model, testloader, options, sample_size=100)
results

Batch 101/880 (100 hypotheses so far)                        

Unnamed: 0,brain_hypotheses,audio_hypotheses,reference,brain_hypotheses_clean,audio_hypotheses_clean,reference_clean
0,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,Reports and papers and that sort of thing.,,,reports and papers and that sort of thing
1,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,That's what I was doing before.,,,that is what i was doing before
2,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,Everyone knows they're going to get it. ...,,,everyone knows they are going to get it
3,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,I thought that that was what celery seed was for.,,,i thought that that was what celery seed was for
4,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,I'm glad it's done. ...,,,i am glad it is done
...,...,...,...,...,...,...
96,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,Solving the variables in the equation. ...,,,solving the variables in the equation
97,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,He's in first grade now. ...,,,he is in 1st grade now
98,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,Since this time last year I've changed.,,,since this time last year i have changed
99,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,<|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh|><|zh...,Two hour classes.,,,2 hour classes


In [56]:
brain_wer = jiwer.wer(list(results["reference_clean"]), list(results["brain_hypotheses_clean"]))
audio_wer = jiwer.wer(list(results["reference_clean"]), list(results["audio_hypotheses_clean"]))

print(f"Brain WER: {brain_wer * 100:.2f} %")
print(f"Audio WER: {audio_wer * 100:.2f} %")

Brain WER: 100.00 %
Audio WER: 100.00 %
