In [4]:
from fastprogress import master_bar, progress_bar
from fastai.vision.all import SimpleNamespace, set_seed
import wandb
import numpy as np
import whisper
import torch
from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity

config = SimpleNamespace(
    seed = 42,
    lr = 0.0005,
    batch_size = 1,
    epochs = 5,
    dropout = 0.2,
    weight_decay = 0.01
)

SAMPLE_RATE = 16000
BATCH_SIZE = 2
TRAIN_RATE = 0.8

AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120
run = wandb.init(project="finetune-whisper",entity="ludeksvoboda", config=config, job_type="sparsify_test_run")

set_seed(config.seed)

config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mludeksvoboda[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [1]:
from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

# common_voice["train"] = load_dataset("mozilla-foundation/common_voice_13_0", "cs", split="train+validation", use_auth_token=True)
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_13_0", "cs", split="train[:10%]", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_13_0", "cs", split="test", use_auth_token=True)
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])

from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="cs", task="transcribe")

# input_str = common_voice["train"][0]["sentence"]
# labels = tokenizer(input_str).input_ids
# decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
# decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="cs", task="transcribe")

from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)



In [2]:
type(common_voice)

datasets.dataset_dict.DatasetDict

In [5]:
woptions = whisper.DecodingOptions(language="cs", without_timestamps=True)
model = whisper.load_model("small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="cs", task=woptions.task)

 12%|████▌                                 | 56.0M/461M [00:20<02:26, 2.90MiB/s]


RuntimeError: Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.

In [None]:
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

In [8]:
class JvsSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, data_dict) -> None:
        super().__init__()

        self.data_dict = data_dict

    def __len__(self):
        return len(self.data_dict)
    
    def __getitem__(self, id):
        data_row = self.data_dict[id]

        # audio
        mel = torch.tensor(data_row['input_features'])

        text = data_row['labels']
        labels = text[1:]
        dec_in_ids = text[:-1]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": dec_in_ids
        }

In [9]:
class WhisperDataCollatorWhithPadding:
    def __call__(sefl, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])
        
        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths+dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

In [10]:
dataset = JvsSpeechDataset(common_voice['train'])
loader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=WhisperDataCollatorWhithPadding())


In [43]:
for i, b in enumerate(loader):
    if i > 0:
        break
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = tokenizer.decode(token, skip_special_tokens=False)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = tokenizer.decode(dec, skip_special_tokens=False)
        print(text)


    break

torch.Size([2, 31])
torch.Size([2, 80, 3000])
torch.Size([2, 31])
<|cs|><|transcribe|><|notimestamps|>Musí v mezinárodní politice zasahovat důsledněji a účinněji.<|endoftext|>
<|startoftranscript|><|cs|><|transcribe|><|notimestamps|>Musí v mezinárodní politice zasahovat důsledněji a účinněji.
<|cs|><|transcribe|><|notimestamps|>Nizozemské úřady dodnes nevyjádřily svůj postoj.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|cs|><|transcribe|><|notimestamps|>Nizozemské úřady dodnes nevyjádřily svůj postoj.<|endoftext|><|endoftext|><|endoftext|><|endoftext|>


In [11]:
from torch.nn import CrossEntropyLoss
import evaluate
from torch.optim import AdamW

In [12]:
loss_fn = CrossEntropyLoss(ignore_index=-100)
metrics_wer = evaluate.load("wer")
metrics_cer = evaluate.load("cer")

In [13]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() 
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": config.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() 
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

In [14]:
from sparseml.pytorch.optim import ScheduledModifierManager

optimizer = AdamW(optimizer_grouped_parameters, 
                          lr=config.lr)

manager = ScheduledModifierManager.from_yaml('sparsify_recipe.yaml')
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(loader))

2023-08-22 09:05:03 sparseml.pytorch.utils.logger INFO     Logging all SparseML modifier-level logs to sparse_logs/22-08-2023_09.05.03.log


In [15]:
###Cut subset of data before testing
mb = master_bar(range(manager.max_epochs))
for epoch in mb:
    for batch in progress_bar(loader, len(loader), parent=mb):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].cuda()

        labels = batch["labels"].long().cuda()
        dec_input_ids = batch["dec_input_ids"].long().cuda()

        with torch.no_grad():
            audio_features = model.encoder(input_ids)

        out = model.decoder(dec_input_ids, audio_features)
        loss = loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        ##Make wandb log
        wandb.log({"train_loss": loss})
manager.finalize(model)

torch.save({'descripiton': """Quick training for testsing sarisfication
                        """,
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, '/home/lu/Models/checkpoints/FinetuneWhisper/tiny/' + 'comvoice_subset_sparse_final(' + str(epoch) + ').tar')

In [None]:
def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o, skip_special_tokens=True))
            l_list.append(self.tokenizer.decode(l, skip_special_tokens=True))
        cer = self.metrics_cer.compute(references=l_list, predictions=o_list)
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("val/cer", cer, on_step=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True)

        return {
            "cer": cer,
            "wer": wer,
            "loss": loss
        }