## Convert models from pytorch to tensorflow

### **install dependencies**

In [None]:
!pip install git+https://github.com/mnansary/gsoc-wav2vec2.git@main
!pip install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
!pip install -U kaggledatasets
!pip install fsspec
!pip install gcsfs

* **[there is code in this repo to convert the weights](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/convert_torch_to_tf.py)**

**However this actually fail and wont serve our purpose**

* the script only covers the following model conversion : 
```python
ACCEPTABLE_HF_IDS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-base", "facebook/wav2vec2-large-robust", "facebook/wav2vec2-large-xlsr-53"]
```

**WE CAN HOWEVER REUSE THE FUNCTIONS WITH SOME CHANGES**



# Model Selection and conversion
* The model we want to convert is **[arijitx/wav2vec2-xls-r-300m-bengali](https://huggingface.co/arijitx/wav2vec2-xls-r-300m-bengali)**
* we inspect these two configs [tensorflow config](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/wav2vec2/config.py) and [hugging face config](https://huggingface.co/arijitx/wav2vec2-xls-r-300m-bengali/blob/main/config.json) and spot the differences

| Tensorflow |Huggingface |
|:---:|:---:|
|num_heads: int = 12|"num_attention_heads": 16,|
|num_layers: int = 12|"num_hidden_layers": 24,|
|conv_bias: bool = False|"conv_bias": true,|
|conv_bias: bool = False|"conv_bias": true,|
|feature_extractor_norm_type: bool = "group"|"feat_extract_norm": "layer",|
|hidden_size: int = 768|"hidden_size": 1024,|
|intermediate_size: int = 3072|"intermediate_size": 4096,|

**Note:we can safely ignore differences like dropout while conversion**

The changes can be executed by :

```python
config = Wav2Vec2Config()
config.num_heads=16
config.num_layers=24
config.conv_bias=True
config.feature_extractor_norm_type="layer"
config.hidden_size=1024
config.intermediate_size=4096
```

**However to avoid complexity we can use the RobustModelConfig**

In [4]:
%tensorflow_version 2.x

UsageError: Line magic function `%tensorflow_version` not found.


In [3]:
from typing import Union
import tensorflow as tf
import transformers
import numpy as np
from tqdm.auto import tqdm
from wav2vec2 import Wav2Vec2Config, RobustWav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model


SUFFIX = ":0"
MAPPING = (
    ("layer_norm.weight", "layer_norm/gamma"),
    ("layer_norm.bias", "layer_norm.beta"),
    ("weight", "kernel"),
    (".", "/"),
)

# fill-in PyTorch keys to ignore below
KEYS_TO_IGNORE = []

ACCEPTABLE_HF_IDS = ["facebook/wav2vec2-base-960h", 
                     "facebook/wav2vec2-base", 
                     "facebook/wav2vec2-large-robust", 
                     "facebook/wav2vec2-large-xlsr-53",
                     "arijitx/wav2vec2-xls-r-300m-bengali"]

PREFIX_WITH_HEAD = "wav2vec2-ctc/"
SPECIAL_MAPPING_WITH_HEAD = {
    "wav2vec2.encoder.pos_conv_embed.conv.weight_g": f"{PREFIX_WITH_HEAD}wav2vec2/encoder/pos_conv_embed/conv/weight_g:0",
    "wav2vec2.encoder.pos_conv_embed.conv.weight_v": f"{PREFIX_WITH_HEAD}wav2vec2/encoder/pos_conv_embed/conv/weight_v:0",
}

PREFIX_WITHOUT_HEAD = "wav2vec2/"
SPECIAL_MAPPING_WITHOUT_HEAD = {
    "encoder.pos_conv_embed.conv.weight_g": f"{PREFIX_WITHOUT_HEAD}encoder/pos_conv_embed/conv/weight_g:0",
    "encoder.pos_conv_embed.conv.weight_v": f"{PREFIX_WITHOUT_HEAD}encoder/pos_conv_embed/conv/weight_v:0",
}


def replace(k: str, prefix) -> str:
    """
    Converts PyTorch state_dict keys to TensorFlow varible name.
    """
    for hf_v, tf_v in MAPPING:
        k = k.replace(hf_v, tf_v)
    return prefix + k + SUFFIX


def get_tf_pretrained_model(
    config: Wav2Vec2Config,
    hf_model_id: str,
    verbose=False,
    with_lm_head=True,
) -> Union[Wav2Vec2ForCTC, Wav2Vec2Model]:
    """
    Converts HuggingFace PyTorch weights to TensorFlow compatible weights.
    Args:
        config (:obj: `Wav2Vec2Config`):
            Configuration of TF model.
        hf_model_id (:obj: `str`):
            model_id of HuggingFace PyTorch model.
        with_lm_head (:obj: `bool`, default=True):
            Whether to return Wav2Vec2ForCTC or Wav2Vec2Model
    Returns:
        Instance of `Wav2Vec2ForCTC` loaded with pre-trained weights.
    """
    assert hf_model_id in ACCEPTABLE_HF_IDS, f"{hf_model_id} is not acceptable"

    if with_lm_head:
        tf_model = Wav2Vec2ForCTC(config)
        prefix = PREFIX_WITH_HEAD
        hf_model = transformers.Wav2Vec2ForCTC.from_pretrained(hf_model_id)
    else:
        tf_model = Wav2Vec2Model(config)
        tf_model._init(input_shape=(1, 2048))
        prefix = PREFIX_WITHOUT_HEAD
        hf_model = transformers.Wav2Vec2Model.from_pretrained(hf_model_id)

    hf_state_dict = hf_model.state_dict()

    tf_variables = tf_model.variables
    tf_variables_dict = {}
    for v in tf_variables:
        tf_variables_dict[v.name] = v

    tf_weights = []
    extra_keys = []
    for k in tqdm(hf_state_dict, desc="hf -> tf"):
        if k in KEYS_TO_IGNORE:
            continue

        if k in SPECIAL_MAPPING_WITH_HEAD or k in SPECIAL_MAPPING_WITHOUT_HEAD:
            new_k = (
                SPECIAL_MAPPING_WITH_HEAD[k]
                if with_lm_head
                else SPECIAL_MAPPING_WITHOUT_HEAD[k]
            )
        else:
            new_k = replace(k, prefix=prefix)

        if new_k not in tf_variables_dict.keys():
            extra_keys.append(k)
            print(f"SKIPPING {k}")
            continue

        if verbose:
            print(k, "->", new_k)

        array = hf_state_dict[k].numpy()

        # transpose the PyTorch weights for correct loading in TF-2
        # Weights corresponding to `SPECIAL_MAPPING` are 3D array while other weights are 2D
        # so we need to separate weights first & do special transpose on 3D weights
        if k in SPECIAL_MAPPING_WITH_HEAD or k in SPECIAL_MAPPING_WITHOUT_HEAD:
            array = np.transpose(array, axes=(2, 1, 0))
        elif "kernel" in new_k:
            array = np.transpose(array)

        tf_weights.append((tf_variables_dict[new_k], array))

    print("EXTRA KEYS:\n", extra_keys)

    tf.keras.backend.batch_set_value(tf_weights)

    return tf_model, hf_model




ModuleNotFoundError: No module named 'tensorflow'

In [None]:
###########################
is_robust= True 
with_lm_head=True
model_id ="tf-wav2vec2-xls-r-300m-bengali"
hf_model_id="arijitx/wav2vec2-xls-r-300m-bengali"
###########################
config = Wav2Vec2Config() if not is_robust else RobustWav2Vec2Config()
config.vocab_size=112    
tf_model, hf_model = get_tf_pretrained_model(config, hf_model_id, verbose=True, with_lm_head=with_lm_head)

# Data Access

* **PER_REPLICA_BATCH_SIZE**  global batch size while training will be **8 times the PER_REPLICA_BATCH_SIZE** we provide 

* **REC_SIZE=256** simply means while creating the tfrecords , we stored 256 audio files with their labels in one tfrecord

* for params
```python
PER_REPLICA_BATCH_SIZE  = 32      # this is a safe batch size 
EPOCHS                  = 50      # change this as needed .. keep the kaggle allowed TPU limit of 9 hours in mind    
```
* to use the full-dataset

```python
TRAIN_GCS_PATTERNS      = [os.path.join(GCS_PATH,"voted","*/*.tfrecord"),
                           os.path.join(GCS_PATH,"unverified","*/*.tfrecord"),]

```

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Initialize TPU
* we initialize the tpu cluster for using
* based on number of **replicas** or devices we fix:
    * BATCH_SIZE
    * STEPS_PER_EPOCH
    * and evaluation steps within an epoch (EVAL_STEPS)

In [None]:
#----------------------------------------------------------
# Detect hardware, return appropriate distribution strategy
#----------------------------------------------------------
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)

In [None]:
#-------------------------------------
# batching , strategy and steps
#-------------------------------------
if strategy.num_replicas_in_sync==1:
    BATCH_SIZE = PER_REPLICA_BATCH_SIZE
else:
    BATCH_SIZE = PER_REPLICA_BATCH_SIZE*strategy.num_replicas_in_sync

# set    
STEPS_PER_EPOCH = (len(train_recs)*REC_SIZE)//(BATCH_SIZE)
EVAL_STEPS      = (len(eval_recs)*REC_SIZE)//(2*BATCH_SIZE)
print("Batch Size:",BATCH_SIZE)
print("Steps:",STEPS_PER_EPOCH)
print("Eval Steps:",EVAL_STEPS)

In [None]:
#-------------------------------
# imports
#-------------------------------
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import random
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np 
from tqdm.auto import tqdm
from IPython.display import display,Audio
from wav2vec2 import Wav2Vec2Config,CTCLoss
tqdm.pandas()

#--------------------------
# GCS Paths and tfrecords
#-------------------------
train_recs=[]
eval_recs =[]
def get_tfrecs(gcs_pattern):
    file_paths = tf.io.gfile.glob(gcs_pattern)
    random.shuffle(file_paths)
    print("found ",len(file_paths), "tfrecords")
    return file_paths

for gcs in TRAIN_GCS_PATTERNS:
    print("Looking into gcs path:",gcs)
    train_recs+=get_tfrecs(gcs)
for gcs in EVAL_GCS_PATTERNS:
    print(gcs)
    eval_recs+=get_tfrecs(gcs)

print("Total Eval-recs:",len(eval_recs))
print("Total Train-recs:",len(train_recs))
#------------------------------------------------
# change config
#------------------------------------------------
config = Wav2Vec2Config()
config.vocab_size=len(VOCAB)+1
config

In [None]:
import os 
#------------------------------
# change able params
#------------------------------
TRAIN_GCS_PATTERNS      = [os.path.join(GCS_PATH,"voted","*/*.tfrecord"),
                           os.path.join(GCS_PATH,"unverified","*/*.tfrecord")]
                          
EVAL_GCS_PATTERNS       = [os.path.join(GCS_PATH,"eval","*/*.tfrecord")]

PER_REPLICA_BATCH_SIZE  = 32      # this is a safe batch size 
EPOCHS                  = 25      # change this as needed .. keep the kaggle allowed TPU limit of 9 hours in mind    

#------------------------------
# fixed params while creating the tfrecords
#------------------------------
REC_SIZE=256  
VOCAB   =[ 'pad','start','end','\u200d',
        ' ','!',"'",',','-','.',':',';','=','?','।',
        'ঁ','ং','ঃ',
        'অ','আ','ই','ঈ','উ','ঊ','ঋ','এ','ঐ','ও','ঔ',
        'ক','খ','গ','ঘ','ঙ',
        'চ','ছ','জ','ঝ','ঞ',
        'ট','ঠ','ড','ঢ','ণ',
        'ত','থ','দ','ধ','ন',
        'প','ফ','ব','ভ','ম',
        'য','র','ল',
        'শ','ষ','স','হ',
        'া','ি','ী','ু','ূ','ৃ','ে','ৈ','ো','ৌ','্',
        'ৎ','ড়','ঢ়','য়',
        '০','১','২','৩','৪','৫','৬','৭','৮','৯']


We import needed libraries here and collect the tfrecord paths that can be fed into [tf.data api](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)which is the official way to use tfrecords 

### Imports and data 

In [None]:
GCS_PATH='gs://kds-90328aa8d26e17c5bffb9a7f73013580f05a9bfecda822e30cc04946'

# Data Loader 
* cfg = our data config and some constant storing
* config=actual wave2vec2 modeling config

In [None]:
class cfg:
    audio_shape      =  (246000,)                   # this is actually fixed for the pretrained weights we are using -- highets audio length=15 secs
    label_shape      =  (250,)                      # this is actually fixed for the pretrained weights we are using 
    sample_rate      =  16000
    shuffle_buffer   =  1024
    batch_size       =  BATCH_SIZE
    vocab_len        =  len(VOCAB)+1                # the additional vocab can account for <UNK>
    

In [None]:
#------------------------------
# parsing tfrecords 
#------------------------------
def normalize(x):
    # -> (1, seqlen)
    mean = tf.reduce_mean(x, axis=-1, keepdims=True)
    var = tf.math.reduce_variance(x, axis=-1, keepdims=True)
    return tf.squeeze((x - mean) / tf.sqrt(var + 1e-5))

def read_raw_audio(audio):
    wave,rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1)
    return tf.reshape(wave, shape=[-1]) 
    
def preprocess_example(audio,label):
    with tf.device("/CPU:0"):
        signal = normalize(read_raw_audio(audio))
        label = tf.strings.to_number(tf.strings.split(label), out_type=tf.int32)
        return signal,label

def data_input_fn(recs): 
    '''
      This Function generates data from gcs
      * The parser function should look similiar now because of datasetEDA
    '''
    def _parser(example):   
        feature ={  'audio' : tf.io.FixedLenFeature([],tf.string) ,
                    'label' : tf.io.FixedLenFeature([],tf.string) 
        }    
        example=tf.io.parse_single_example(example,feature)
        audio,label=preprocess_example(**example)
        return audio,label
    # fixed code (for almost all tfrec training)

    dataset = tf.data.TFRecordDataset(recs)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(cfg.shuffle_buffer,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.padded_batch(cfg.batch_size, padded_shapes=(cfg.audio_shape[0],cfg.label_shape[0]), padding_values=(0.0, 0))
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    return dataset

In [None]:
train_ds=data_input_fn(train_recs)
eval_ds =data_input_fn(eval_recs)

### Visualize

In [None]:
#------------------------------
# view data
#------------------------------
for x,y in eval_ds.take(1):
    signal=x[0].numpy()
    display(Audio(data=signal, rate=cfg.sample_rate))
    label=y[0].numpy()
    sen="".join([VOCAB[int(i)] for i in label if i > VOCAB.index("end")])
    print("label:",sen)
    print("input shape:",x.shape)
    print("output shape:",y.shape)

# Modeling

In [None]:
def create_model(cfg):
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1",load_options=load_locally,trainable=True)
    inputs = tf.keras.Input(shape=cfg.audio_shape)
    states = pretrained_layer(inputs)
    logits= tf.keras.layers.Dense(cfg.vocab_len)(states)
    model = tf.keras.Model(inputs=inputs, outputs=logits)
    return model

**model weights can be loaded from saved ones to continue training**
```python
model.load_weights("path to previously trained weights")
```

In [None]:
with strategy.scope():
    model=create_model(cfg)
    # model.load_weights("model.h5")
model.summary()

# Training
* some ideas to extend: 
    * use different schedulers
    * use callbacks to track some metrics
    * reduce learning rate on plateau, early stopping setup might need some inspection 

In [None]:
    
# early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(patience=10, 
                                                  verbose=1, 
                                                  mode = 'auto') 
lr_reducer=tf.keras.callbacks.ReduceLROnPlateau( patience=3)
model_save=tf.keras.callbacks.ModelCheckpoint("model.h5",
                                                save_best_only=True,
                                                save_weights_only=True,
                                                verbose=1)
callbacks = [lr_reducer,model_save]

with strategy.scope():
    loss_fn = CTCLoss(config, (PER_REPLICA_BATCH_SIZE,cfg.audio_shape[0]), division_factor=PER_REPLICA_BATCH_SIZE)
    model.compile(optimizer=tf.keras.optimizers.Adam(5e-5),
                  loss=loss_fn)

In [None]:
history=model.fit(train_ds,
                  epochs=EPOCHS,
                  steps_per_epoch=STEPS_PER_EPOCH,
                  verbose=1,
                  validation_data=eval_ds,
                  validation_steps=EVAL_STEPS, 
                  callbacks=callbacks)

In [None]:
curves={}
for key in history.history.keys():
    curves[key]=history.history[key]
curves=pd.DataFrame(curves)
curves.to_csv(f"history.csv",index=False)

In [None]:
curves