In [49]:
from datasets import load_dataset, load_from_disk
import os
import torchaudio
import torch
from transformers import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")

In [224]:
import random


class DataLoader:
    def __init__(self, model, cache, path):
        self.model = model
        self.cache = cache
        self.path = path

    def _prepare_dataset_custom(self, batch, input_text_prompt="", split_type="train", lang="en"):
        filename = batch[f"{lang}_wav"]
        speech, sampling_rate = torchaudio.load(f"{self.path}/wav/{split_type}/{filename}")
        resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
        batch["input_values"] = resampler.forward(speech.squeeze(0)).numpy()
        batch["lengths"] = len(batch["input_values"])
        sent = batch[f"{lang}_sentence"].lower()

        decoder_input, decoder_target = self._create_self_decoder_input(self.model.decoder_model, self.model.tokenizer,
                                                                        input_text_prompt + sent,
                                                                        next(self.model.parameters()).device)
        batch["input_text_prompt"] = input_text_prompt
        batch["text_input_ids"] = decoder_input
        batch["labels"] = decoder_target
        batch["labels"] += [self.model.tokenizer.eos_token_id]
        rnd = int(random.uniform(1, 10)) % 9 == 0
        if rnd:
            logger.info(batch)
            logger.info(decoder_input)
            logger.info(sent)

        return batch

    def _create_self_decoder_input(decoder_model, tokenizer, input_sent, device):
        gen_input = tokenizer(input_sent, add_special_tokens=True).input_ids
        predicted = [decoder_model.config.decoder_start_token_id]
        with torch.no_grad():
            decoder_model.eval()
            decoder_length = max(decoder_model.config.max_length, len(gen_input))
            for _ in range(decoder_length):
                max_item = torch.argmax(
                    decoder_model(input_ids=torch.tensor([gen_input], device=device),
                                  output_hidden_states=True,
                                  decoder_input_ids=torch.tensor(
                                      [predicted],
                                      device=device)).logits, -1)[:, -1].item()
                if decoder_model.config.eos_token_id == max_item:
                    break
                predicted.append(max_item)
        return gen_input, predicted[1:]


    def load_custom_datasets(self, set_name, lang, cache):
        dataset = None
        print(f"{self.path}/transformers/{set_name}_{lang}.data")

        if cache and os.path.isdir(self.path):
            logger.info("Getting cached files")
            dataset = load_from_disk(f"{self.path}/transformers/{set_name}_{lang}.data")
        else:
            logger.info("1. Loading custom uncached files")
            logger.info(f"{self.path}/transformers/{set_name}_{next(self.model.parameters()).device}_{lang}.data")
            json_ds = load_dataset("json", data_files=f"{self.path}/transformers/{set_name}.json", cache_dir="./.cache")
            logger.info("2. Creating custom uncached files")
            dataset = json_ds.map(self._prepare_dataset_custom,
                                  fn_kwargs={"input_text_prompt": "", "split_type": f"{set_name}",
                                             "lang": lang})
            logger.info("3. Saving to disk")
            dataset.save_to_disk(f"{self.path}/transformers/{set_name}_{next(self.model.parameters()).device}_{lang}.data")
        return dataset


In [225]:
#v1
input_args = {'speech_model_config': 'facebook/wav2vec2-base', 'nlp_model_config': 'facebook/bart-base', 'SpeechMixEED': False,
              'SpeechMixED': False, 'SpeechMixSelf': False, 'SpeechMixAdapter': False, 'SpeechMixGAN': False,
              'SpeechMixFixed': False, 'HFSpeechMixEED': True, 'HFSpeechMixED': False, 'HFSpeechMixSelf': False,
              'HFSpeechMixAdapter': False, 'HFSpeechMixGAN': False, 'HFSpeechMixFixed': False, 'cache': False,
              'field': 'clean', 'train_split': 'train.100', 'test_split': 'validation', 'notes': 'base',
              'grad_accum': 20, 'logging_steps': 10, 'warmup_steps': 500, 'unfreeze_warmup_steps': 1000,
              'save_total_limit': 2, 'max_grad_norm': 10, 'worker': 10, 'batch': 3, 'epoch': 30, 'lr': 4e-05,
              'eval_step': 700, 'share_layer_ratio': 0.0, 'down_scale': 2, 'weighted_sum': False,
              'fixed_parameters': False, 'custom_set_path': 'speechBSD', 'max_input_length_in_sec': 20,
              'group_by_length': False,
              'fixed_except': ['layer_norm', 'encoder_attn', 'enc_to_dec_proj', 'length_adapter', 'layernorm_embedding',
                               'attention', 'encoder'], 'fp16': False, 'wandb': True}

In [226]:
model_type = "HFSpeechMixEED"
model = speechmix.HFSpeechMixEED(**input_args)
device = torch.device('cuda', 0) if torch.cuda.is_available() else 'cpu'
model.to(device)

loading configuration file config.json from cache at /Users/alicjaharaszczuk/.cache/huggingface/hub/models--facebook--bart-base/snapshots/aadd2ab0ae0c8268c7c9693540e9904811f36177/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "

Before layer sharing num_speech_encoder_layers 12
After layer sharing  num_speech_encoder_layers 12 num_nlp_encoder_layers 6 share_layer_ratio 0.0 remove_layers 0


SpeechMixEED(
  (encoder_model): UpstreamExpert(
    (model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
            (3): GELU(approximate='none')
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
            (2): GELU(approximate='none')
          )
          (3): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (1): Dropout(p=0.0, inplace=False)
   

In [227]:
dl = DataLoader(model, False, "speechBSD")

In [228]:
dl.load_custom_datasets("test", "en")

1. Loading custom uncached files
speechBSD/transformers/test_cpu_en.data


speechBSD/transformers/test_en.data


Using custom data configuration default-9e03c3eac4c48589
Found cached dataset json (/Users/alicjaharaszczuk/Desktop/PROJECT/SpeechMix/./.cache/json/default-9e03c3eac4c48589/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

2. Creating custom uncached files


  0%|          | 0/2120 [00:00<?, ?ex/s]

PreTrainedTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
the japanese market has been very interested in our product.
[0, 627, 1236, 20948, 4468, 210, 34, 57, 182, 2509, 11, 84, 1152, 4, 2]
[2, 0, 627, 1236, 20948, 4468, 210, 34, 57, 182, 2509, 11, 84, 1152, 4]
cpu
{'no': 10, 'en_speaker': 'Mr. Wayne Willis', 'ja_speaker': 'ウェイン ウィリスさん', 'en_sentence': 'The Japanese market has been very interested in our product.', 'ja_sentence': '日本市場が当社製品に興味をもっているようです。', 'ja_spkid': '190315_E001_13_spk1_ja', 'en_spkid': '190315_E001_13_spk1_en', 'ja_wav': '190315_E001_13_spk1_no10_ja.wav', 'en_wav': '190315_E001_13_spk1_no10_en.wav', 'ja_spk_gender': 'M', 'en_spk

DatasetDict({
    train: Dataset({
        features: ['no', 'en_speaker', 'ja_speaker', 'en_sentence', 'ja_sentence', 'ja_spkid', 'en_spkid', 'ja_wav', 'en_wav', 'ja_spk_gender', 'en_spk_gender', 'ja_spk_prefecture', 'en_spk_state', 'input_values', 'lengths', 'input_text_prompt', 'text_input_ids', 'labels'],
        num_rows: 2120
    })
})

In [56]:
import datasets

test = datasets.load_from_disk(f"speechBSD/transformers/train_en.data")


In [57]:
id = int(random.uniform(1, 2000))
#print(test['train'][0])
#print(test['train'][0]['input_text_prompt'])
#print(test['train'][0]['input_values'])
print(test['train'][id]['labels'])

[0, 405, 74, 28, 10, 372, 7, 778, 7, 120, 7, 216, 103, 9, 84, 2539, 4, 2]


In [76]:
test['train'][id]

{'no': 23,
 'ja_speaker': '三木さん',
 'en_speaker': 'Miss Miki',
 'ja_sentence': 'クライアントと顔見知りになれるいいチャンスだと思います。',
 'en_sentence': 'It would be a great to chance to get to know some of our clients.',
 'ja_spkid': '190329_J03_15_spk2_ja',
 'en_spkid': '190329_J03_15_spk2_en',
 'ja_wav': '190329_J03_15_spk2_no23_ja.wav',
 'en_wav': '190329_J03_15_spk2_no23_en.wav',
 'ja_spk_gender': 'F',
 'en_spk_gender': 'M',
 'ja_spk_prefecture': '神奈川',
 'en_spk_state': 'CA',
 'input_values': [0.0,
  -3.0517578125e-05,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  -3.0517578125e-05,
  3.0517578125e-05,
  3.0517578125e-05,
  3.0517578125e-05,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  3.0517578125e-05,
  0.0,
  0.0,
  0.0,
  0.0,
  -3.0517578125e-05,
  0.0,
  0.0,
  0.0,
  3.0517578125e-05,
  0.0,
  -3.0517578125e-05,
  0.0,
  0.0,
  0.0,
  3.0517578125e-05,
  0.0,
  -3.0517578125e-05,
  -3.0517578125e-05,
  0.0,
  -3.0517578125e-05,

In [32]:

model.tokenizer('kurwa', add_special_tokens=True).input_ids

[0, 330, 710, 2739, 2]

In [64]:
from IPython.display import Audio

Audio(test['train'][1444]['input_values'], rate=20000)

In [115]:
import speechmix
!export PYTORCH_ENABLE_MPS_FALLBACK=1
args = {'speech_model_config': 'wav2vec2', 'nlp_model_config': 'facebook/bart-base', 'cache': False,
        'train_split': 'train.100', 'notes': 'base', 'grad_accum': 20, 'logging_steps': 10, 'warmup_steps': 500,
        'unfreeze_warmup_steps': 1000, 'save_total_limit': 2, 'max_grad_norm': 10, 'worker': 15, 'batch': 3,
        'epoch': 30, 'lr': 4e-05, 'eval_step': 700, 'share_layer_ratio': 0.5, 'down_scale': 2, 'weighted_sum': False,
        'fixed_parameters': False, 'custom_set_path': 'speechBSD', 'max_input_length_in_sec': 20,
        'group_by_length': False,
        'fixed_except': ['layer_norm', 'encoder_attn', 'enc_to_dec_proj', 'length_adapter', 'layernorm_embedding',
                         'attention', 'encoder'], 'fp16': False, 'wandb': True}

model = speechmix.SpeechMixEED(**args).to('cpu')

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


loading configuration file config.json from cache at /Users/alicjaharaszczuk/.cache/huggingface/hub/models--facebook--bart-base/snapshots/aadd2ab0ae0c8268c7c9693540e9904811f36177/config.json
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "

Before layer sharing num_speech_encoder_layers 12
After layer sharing  num_speech_encoder_layers 6 num_nlp_encoder_layers 6 share_layer_ratio 0.5 remove_layers 6


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [133]:
import numpy

device = torch.device("cuda" if torch.cuda.is_available() else "mps")
a = test['train'][1444]['input_values']
d = torch.from_numpy(numpy.array(a, dtype=numpy.float16)).to(device)

#print(d.dim)
output = model.encoder_model([torch.tensor(test['train'][1444]['input_values'], device='cpu')])

stacked_feature = torch.stack(output['hidden_states'], dim=0)
_, *origin_shape = stacked_feature.shape

<built-in method dim of Tensor object at 0x7fbc5fedb720>
tensor([[[[ 3.4398e-01,  3.7594e-01,  2.3876e-01,  ...,  1.9952e-01,
            7.5121e-01,  1.0692e-01],
          [ 8.9623e-01,  4.1402e-01, -3.8561e-03,  ..., -4.0359e-01,
            3.2028e-01, -1.1701e-01],
          [ 1.0949e-01, -1.9351e-01,  2.7562e-01,  ..., -0.0000e+00,
            2.6685e-01,  3.6859e-01],
          ...,
          [ 1.7444e-01,  6.3554e-02,  3.0391e-01,  ..., -2.2491e-02,
            5.5170e-01,  7.3816e-01],
          [ 5.6106e-01, -4.2777e-02,  0.0000e+00,  ..., -9.6225e-02,
            5.1907e-01,  1.7300e-01],
          [-2.5432e-02, -9.2782e-02,  5.0341e-01,  ..., -4.0247e-02,
            5.1748e-01,  3.3440e-01]]],


        [[[ 8.2866e-04,  6.0372e-01,  2.6215e-02,  ...,  4.7935e-01,
            2.9329e-01,  2.6431e-01],
          [ 5.6563e-01,  1.2847e-01, -2.3541e-02,  ..., -6.3465e-01,
           -1.9352e-02,  1.5068e-01],
          [ 1.1945e-01,  8.5310e-02,  1.8124e-01,  ..., -4.8132e-02,

In [135]:
print(origin_shape)

[1, 173, 768]


In [125]:
embeds = output['last_hidden_state']
print(embeds.shape)
print(embeds.transpose(1, 2).shape)

torch.Size([1, 173, 768])
torch.Size([1, 768, 173])


In [138]:
transposed_embeds = embeds.transpose(1, 2)
red = model.length_adapters(transposed_embeds).transpose(1, 2)
print(f"Reducing the length: {red.shape}")

Reducing the length: torch.Size([1, 86, 768])


In [139]:
projected = model.enc_to_dec_proj(red)

print(f"Projected: {projected.shape}")

Projected: torch.Size([1, 86, 768])


In [145]:
text_prompt = model.nlp_emb(model.tokenizer("heh", return_tensors="pt")['input_ids'])
text_prompt

tensor([[[ 0.0125,  0.0014, -0.0096,  ...,  0.0022,  0.1057,  0.0103],
         [-0.0212, -0.0169,  0.0043,  ..., -0.0465,  0.0558,  0.0256],
         [-0.0394, -0.0207,  0.0830,  ...,  0.0734,  0.1359,  0.0962],
         [ 0.0842, -0.0389,  0.0096,  ...,  0.0583,  0.0082,  0.0357]]],
       grad_fn=<EmbeddingBackward0>)

In [147]:
inputs_embeds = torch.cat((text_prompt, projected), 1)
inputs_embeds

tensor([[[ 0.0125,  0.0014, -0.0096,  ...,  0.0022,  0.1057,  0.0103],
         [-0.0212, -0.0169,  0.0043,  ..., -0.0465,  0.0558,  0.0256],
         [-0.0394, -0.0207,  0.0830,  ...,  0.0734,  0.1359,  0.0962],
         ...,
         [-0.0083,  0.0146,  0.0519,  ..., -0.0645, -0.0566, -0.0319],
         [ 0.0701,  0.0811, -0.0443,  ..., -0.0690, -0.0530, -0.0057],
         [ 0.2023, -0.0073, -0.0448,  ..., -0.0658,  0.0305, -0.0829]]],
       grad_fn=<CatBackward0>)

In [161]:
!pip3 install tensorboard

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting tensorboard
  Using cached tensorboard-2.11.0-py3-none-any.whl (6.0 MB)
Collecting markdown>=2.6.8
  Using cached Markdown-3.4.1-py3-none-any.whl (93 kB)
Collecting werkzeug>=1.0.1
  Using cached Werkzeug-2.2.2-py3-none-any.whl (232 kB)
Collecting absl-py>=0.4
  Using cached absl_py-1.3.0-py3-none-any.whl (124 kB)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Using cached google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Using cached tensorboard_data_server-0.6.1-py3-none-macosx_10_9_x86_64.whl (3.5 MB)
Collecting tensorboard-plugin-wit>=1.6.0
  Using cached tensorboard_plugin_wit-1.8.1-py3-none-any.whl (781 kB)
Collecting grpc

In [244]:
test['train'][0]['input_values']

tensor([ 0.0000e+00,  0.0000e+00, -3.0518e-05,  ..., -9.1553e-05,
         9.7656e-04,  5.4932e-04])

In [16]:
#from torchviz import make_dot, make_dot_from_trace
import datasets
import torch

test = datasets.load_from_disk(f"speechBSD/transformers/test_cpu_en.data/train")
#a = test['train'].with_format("torch", device=torch.device("mps"))
print(test[18].keys())

#make_dot(model(a[0]['input_values']), params=dict(model.named_parameters()))

dict_keys(['no', 'en_speaker', 'ja_speaker', 'en_sentence', 'ja_sentence', 'ja_spkid', 'en_spkid', 'ja_wav', 'en_wav', 'ja_spk_gender', 'en_spk_gender', 'ja_spk_prefecture', 'en_spk_state', 'input_values', 'lengths', 'input_text_prompt', 'text_input_ids', 'labels'])


In [19]:
a = datasets.load_from_disk(f"speechBSD/transformers/test_cpu_en.data/train")
a

Dataset({
    features: ['no', 'en_speaker', 'ja_speaker', 'en_sentence', 'ja_sentence', 'ja_spkid', 'en_spkid', 'ja_wav', 'en_wav', 'ja_spk_gender', 'en_spk_gender', 'ja_spk_prefecture', 'en_spk_state', 'input_values', 'lengths', 'input_text_prompt', 'text_input_ids', 'labels'],
    num_rows: 2120
})