In [1]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained('openai/whisper-small',
                                             language='Chinese',
                                             task='transcribe')

#定长编码 -> [80, 3000]
# processor.feature_extractor([0.1] * 10000,
#                             sampling_rate=16000,
#                             return_tensors='pt').input_features[0]

#变长编码 -> [1, 14]
# processor.tokenizer('东部联盟球队', return_tensors='pt').input_ids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch


class Dataset(torch.utils.data.Dataset):

    def __init__(self, split):
        from datasets import load_dataset, Audio
        dataset = load_dataset(path='mozilla-foundation/common_voice_11_0',
                               name='zh-CN',
                               split=split)

        size = 5000 if split == 'train' else 100
        dataset = dataset.shuffle(seed=0).select(range(size))
        dataset = dataset.cast_column('audio', Audio(sampling_rate=16000))

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        data = self.dataset[i]

        speech = processor.feature_extractor(
            data['audio']['array'], sampling_rate=16000,
            return_tensors='pt').input_features[0]

        text = processor.tokenizer(data['sentence'],
                                   return_tensors='pt').input_ids[0]

        return {'speech': speech, 'text': text}


dataset = Dataset(split='train')

# len(dataset), dataset[0]

Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/mozilla-foundation--common_voice_11_0/3f27acf10f303eac5b6fbbbe02495aeddb46ecffdb0a2fe3507fcfbf89094631 (last modified on Mon Jul 31 14:13:56 2023) since it couldn't be found locally at mozilla-foundation/common_voice_11_0., or remotely on the Hugging Face Hub.


In [3]:
def collate_fn(data):
    speech = [{'input_features': i['speech']} for i in data]
    speech = processor.feature_extractor.pad(
        speech, return_tensors='pt').input_features

    text = [{'input_ids': i['text']} for i in data]
    text = processor.tokenizer.pad(text, return_tensors='pt').input_ids

    return {'speech': speech, 'text': text}


loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

# len(loader), next(iter(loader))

In [4]:
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

        from transformers import WhisperForConditionalGeneration
        pretrained = WhisperForConditionalGeneration.from_pretrained(
            'openai/whisper-small')

        %run 1.encoder.ipynb
        self.encoder = load_encoder(pretrained.model.encoder)
        
        %run 2.decoder.ipynb
        self.decoder = load_decoder(pretrained.model.decoder)

        self.fc_out = torch.nn.Linear(768, 51865, bias=False)
        self.fc_out.load_state_dict(pretrained.proj_out.state_dict())

    def forward(self, speech, text):
        #向右偏移一位
        text = torch.cat([text[:, :1], text], dim=1)[:, :-1]

        kv = self.encoder(speech)
        out = self.decoder(x=text, kv=kv)

        return self.fc_out(out)


model = Model()

# model(speech=torch.randn(2, 80, 3000), text=torch.ones(2, 55).long()).shape

In [5]:
def train():
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    loss_fn = torch.nn.CrossEntropyLoss()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.train()
    model.to(device)

    for epoch in range(1):
        for i, data in enumerate(loader):
            for k in data.keys():
                data[k] = data[k].to(device)
            out = model(**data)

            loss = loss_fn(out.flatten(end_dim=-2), data['text'].flatten()) / 4
            loss.backward()
            if (i + 1) % 4 == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                optimizer.zero_grad()

            if (i + 1) % 100 == 0:
                print(epoch, i, loss.item())

    model.to('cpu')


train()

0 99 0.09121968597173691
0 199 0.08079050481319427
0 299 0.028110790997743607
0 399 0.05658833310008049
0 499 0.12785246968269348
0 599 0.05768570676445961
0 699 0.030965985730290413
0 799 0.03860100358724594
0 899 0.11119092255830765
0 999 0.1324881762266159
0 1099 0.0743391141295433
0 1199 0.06170809641480446


In [6]:
%run 3.forward_generate.ipynb

@torch.no_grad()
def generate(speech):
    text = torch.LongTensor([[50258]])
    cache_kv = None
    kv = model.encoder(speech.unsqueeze(0))
    generate = [text.item()]

    for _ in range(100):
        text, cache_kv = forward_decoder(model.decoder,
                                         x=text,
                                         kv=kv,
                                         cache_kv=cache_kv)

        text = model.fc_out(text).argmax(dim=2)
        generate.append(text.item())

        if text.item() == 50257:
            break

    return processor.decode(generate)


generate(dataset[0]['speech'])

'<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>为了弄出有酷有炫的效果。<|endoftext|>'

In [7]:
@torch.no_grad()
def test():
    from IPython.display import Audio, display
    dataset_test = Dataset('test')
    for i in range(5):
        display(Audio(dataset_test.dataset[i]['audio']['array'], rate=16000))
        print(generate(dataset_test[i]['speech']))


test()

Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/mozilla-foundation--common_voice_11_0/3f27acf10f303eac5b6fbbbe02495aeddb46ecffdb0a2fe3507fcfbf89094631 (last modified on Mon Jul 31 14:13:56 2023) since it couldn't be found locally at mozilla-foundation/common_voice_11_0., or remotely on the Hugging Face Hub.


<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>皮拉罗科人口变化图示<|endoftext|>


<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>后来科学家再没有观察到同样的衰变活动。<|endoftext|>


<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>很多设备都是为了使用兰冈宝弹而进行修改。<|endoftext|>


<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>李西九，安平人。<|endoftext|>


<|startoftranscript|><|startoftranscript|><|zh|><|transcribe|><|notimestamps|>牧章长南馆和并而来。<|endoftext|>
