# Wav2vec 2.0

We use w2v2 as implemented in [HuggingFace](https://huggingface.co/docs/transformers/model_doc/wav2vec2) and create wrapper classes for finetuning for our specific speech classification task as well as getting embeddings.

Note: [run.py](https://github.com/dwiepert/mayo-w2v2/blob/main/src/run.py) script can also do evaluation only. 

Authors: Daniela Wiepert

To begin, you will need access to google cloud storage bucket and the following packages must be installed on your system 

* albumentations (may run into issues in AIF)
* librosa
* torch, torchvision, torchaudio
* tqdm
* transformers
* pyarrow

(can ignore the following if using AIF)
* google-cloud-storage

The [requirements.txt](https://github.com/dwiepert/mayo-w2v2/blob/main/requirements.txt) can be used to set up this environment. 

To access data stored in GCS on your local machine, you will need to additionally run

```gcloud auth application-default login```

```gcloud auth application-defaul set-quota-project PROJECT_NAME```

In [1]:
#IMPORTS
#built-in
import argparse
import os
import pickle

#third-party
import torch
from tqdm import tqdm
import pandas as pd
import pyarrow

from google.cloud import storage
from torch.utils.data import  DataLoader

#local
from utilities import *
from models import *
#from loops import *
from dataloader import W2V2Dataset

### Upload/Download functions & Data loading functions
These are defined in `utilities/load_utils.py`.

## Arguments
There are many mutable arguments when running w2v2. Please explore the different options and make sure all arguments are set as you would like. 

An important argument to consider is the model checkpoint. On AIF, it cannot load the model directly from HuggingFace. Instead, you can look at `gs://ml-e107-phi-shared-aif-us-p/m144443/checkpoints` for options. The default is `gs://ml-e107-phi-shared-aif-us-p/m144443/checkpoints/wav2vec2-base-960h`.


In terms of arguments, specifically note the `weighted` and `layer` arguments which alter the training functionality.

`weighted` initializes another set of parameters to learn, wherein all hidden states have learned weights to indicate the contribution of the layers to classification. 

`layer` sets which hidden state to use as input to the classifier.

In [3]:
parser = argparse.ArgumentParser()
#Inputs
parser.add_argument('-i','--prefix',default='speech_ai/speech_lake/', help='Input directory or location in google cloud storage bucket containing files to load')
parser.add_argument("-s", "--study", choices = ['r01_prelim','speech_poc_freeze_1', None], default='speech_poc_freeze_1', help="specify study name")
parser.add_argument("-d", "--data_split_root", default='gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620/test.csv', help="specify file path where datasplit is located. If you give a full file path to classification, an error will be thrown. On the other hand, evaluation and embedding expects a single .csv file.")
parser.add_argument('-l','--label_txt', default='./labels.txt')
parser.add_argument('--lib', default=False, type=bool, help="Specify whether to load using librosa as compared to torch audio")
parser.add_argument("-c", "--checkpoint", default="gs://ml-e107-phi-shared-aif-us-p/m144443/checkpoints/wav2vec2-base-960h", help="specify path to pre-trained model weight checkpoint")
parser.add_argument("-mp", "--finetuned_mdl_path", default='/Users/m144443/Documents/GitHub/mayo-w2v2/experiments/weighted/amr_subject_dedup_594_train_100_test_binarized_v20220620_5_adam_epoch1_wav2vec2-base-960h_mdl_weighted.pt', help='If running eval-only or extraction, you have the option to load a fine-tuned model by specifying the save path here. If passed a gs:// file, will download to local machine.')
#GCS
parser.add_argument('-b','--bucket_name', default='ml-e107-phi-shared-aif-us-p', help="google cloud storage bucket name")
parser.add_argument('-p','--project_name', default='ml-mps-aif-afdgpet01-p-6827', help='google cloud platform project name')
parser.add_argument('--cloud', default=False, type=bool, help="Specify whether to save everything to cloud")
#output
parser.add_argument("--dataset", default=None,type=str, help="When saving, the dataset arg is used to set file names. If you do not specify, it will assume the lowest directory from data_split_root")
parser.add_argument("-o", "--exp_dir", default="./experiments/embeddings", help='specify LOCAL output directory')
parser.add_argument('--cloud_dir', default='m144443/temp_out/w2v2_ft_weighted', type=str, help="if saving to the cloud, you can specify a specific place to save to in the CLOUD bucket")
#Mode specific
parser.add_argument("--weighted", type=bool, default=True, help="specify whether to learn a weighted sum of layers for classification")
parser.add_argument("--layer", default=-1, type=int, help="specify which hidden state is being used. It can be between -1 and 12")
parser.add_argument("--freeze", type=bool, default=True, help='specify whether to freeze the base model')
parser.add_argument('--embedding_type', type=str, default='wt', help='specify whether embeddings should be extracted from classification head (ft) or base pretrained model (pt)', choices=['ft','pt'])
#Audio transforms
parser.add_argument("--resample_rate", default=16000,type=int, help='resample rate for audio files')
parser.add_argument("--reduce", default=True, type=bool, help="Specify whether to reduce to monochannel")
parser.add_argument("--clip_length", default=160000, type=int, help="If truncating audio, specify clip length in # of frames. 0 = no truncation")
parser.add_argument("--trim", default=False, type=int, help="trim silence")
#Model parameters
parser.add_argument("-pm", "--pooling_mode", default="mean", help="specify method of pooling last hidden layer", choices=['mean','sum','max'])
parser.add_argument("-bs", "--batch_size", type=int, default=8, help="specify batch size")
parser.add_argument("-nw", "--num_workers", type=int, default=0, help="specify number of parallel jobs to run for data loader")
parser.add_argument("-lr", "--learning_rate", type=float, default=0.0003, help="specify learning rate")
parser.add_argument("-e", "--epochs", type=int, default=1, help="specify number of training epochs")
parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["adam"])
parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["MSE", "BCE"])
parser.add_argument("--scheduler", type=str, default=None, help="specify lr scheduler", choices=["onecycle", None])
parser.add_argument("--max_lr", type=float, default=0.01, help="specify max lr for lr scheduler")
#classification head parameters
parser.add_argument("--activation", type=str, default='relu', help="specify activation function to use for classification head")
parser.add_argument("--final_dropout", type=float, default=0.25, help="specify dropout probability for final dropout layer in classification head")
parser.add_argument("--layernorm", type=bool, default=False, help="specify whether to include the LayerNorm in classification head")
#OTHER
parser.add_argument("--debug", default=True, type=bool)
args = parser.parse_args()

## Setting up environment
The first step is to make sure the GCS bucket is initialized if given a `bucket_name`. Additionally, the list of target labels must be set. There are a few other arguments to consider as well

In the original implementation, the list must be given as a `.txt` file to pass through the command line. In this implementation, we will set it as a list.

In [4]:
#CHECK GPU AVAILABLE
print('Cuda availability: ', torch.cuda.is_available())

Cuda availability:  False


In [5]:
# (1) Set up GCS
if args.bucket_name is not None:
    storage_client = storage.Client(project=args.project_name)
    bq_client = bigquery.Client(project=args.project_name)
    bucket = storage_client.bucket(args.bucket_name)
else:
    bucket = None



In [6]:
# (2), check if given study or if the prefix is the full prefix.
if args.study is not None:
    args.prefix = os.path.join(args.prefix, args.study)


In [None]:
# (3) get dataset name
if args.dataset is None:
    if '.csv' in args.data_split_root:
        args.dataset = '{}_{}'.format(os.path.basename(os.path.dirname(args.data_split_root)), os.path.basename(args.data_split_root[:-4]))
    else:
        args.dataset = os.path.basename(args.data_split_root)

In [7]:
# (4) get target labels
    #get list of target labels
args.target_labels = ['slow rate',
                        'irregular artic breakdowns',
                        'rapid rate',
                        'distortions',
                        'strained']

args.n_class = len(args.target_labels)

In [8]:
 # (5) check if output directory exists, SHOULD NOT BE A GS:// path
if not os.path.exists(args.exp_dir):
    os.makedirs(args.exp_dir)

In [None]:
# (6) check that clip length has been set
if args.clip_length == 0:
    try: 
        assert args.batch_size == 1, 'Not currently compatible with different length wav files unless batch size has been set to 1'
    except:
        args.batch_size = 1

In [None]:
# (7) dump arguments
args_path = "%s/args.pkl" % args.exp_dir
with open(args_path, "wb") as f:
    pickle.dump(args, f)

In [None]:
# (8) check if checkpoint is stored in gcs bucket or confirm it exists on local machine
assert args.checkpoint is not None, 'Must give a model checkpoint for W2V2'
args.checkpoint = gcs_model_exists(args.checkpoint, args.bucket_name, args.exp_dir, bucket)



In [None]:
# (8) dump arguments
args_path = "%s/args.pkl" % args.exp_dir
with open(args_path, "wb") as f:
    pickle.dump(args, f)
#in case of error, everything is immediately uploaded to the bucket
if args.cloud:
    upload(args.cloud_dir, args_path, bucket)

In [None]:
#(9) add bucket to args
args.bucket = bucket

## Finetuning

We will start with the finetuning option and not include the option for only evaluating an already fine-tuned model (which is available in the full .py script)

When loading data, we start with a data split root, which we expect to be a directory containing a `train.csv` file and `test.csv` file with file names for train/test and the associated label data.

We load the data, set up an audio configuration, set up `W2V2Dataset` objects (from `dataloader.py`), and set up the dataloaders.

The transforms in `W2V2Dataset` are set up using functions from `utilities/speech_utils.py`. 

The dataloaders take in the datasets and batch size + number of workers.

Please note that the resulting samples will be a dictionary with the keys `uid`, `waveform`, `targets`, `sample_rate`.

We randomly sample the train.csv within the load_data function to get a validation dataset of 50 samples. See `load_utils.py` for the implementation.

In [11]:
print('Running finetuning: ')
# (1) load data
assert '.csv' not in args.data_split_root, f'May have given a full file path, please confirm this is a directory: {args.data_split_root}'
train_df, val_df, test_df = load_data(args.data_split_root, args.target_labels, args.exp_dir, args.cloud, args.cloud_dir, args.bucket)




100


In [12]:
# (2) set up audio configuration for transforms
audio_conf = {'checkpoint': args.checkpoint, 'resample_rate':args.resample_rate, 'reduce': args.reduce,
                'trim': args.trim, 'clip_length': args.clip_length}
    

In [13]:
 # (3) set up datasets and dataloaders
dataset_train = W2V2Dataset(train_df, target_labels = args.target_labels,
                            audio_conf = audio_conf, prefix=args.prefix, bucket=args.bucket, librosa=args.lib)
dataset_val = W2V2Dataset(val_df, target_labels = args.target_labels,
                            audio_conf = audio_conf, prefix=args.prefix, bucket=args.bucket, librosa=args.lib)
dataset_test = W2V2Dataset(test_df, target_labels = args.target_labels,
                            audio_conf = audio_conf, prefix=args.prefix, bucket=args.bucket, librosa=args.lib)


In [14]:
dataloader_train = DataLoader(dataset_train, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
dataloader_val= DataLoader(dataset_val, batch_size = 1, shuffle = False, num_workers = args.num_workers)
dataloader_test= DataLoader(dataset_test, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers)

In [15]:
#TODO: TEST WHETHER YOU CAN LOAD A BATCH
batch = next(iter(dataloader_train))

### Set up the model

Set up the model using classes from `w2v2_models.py`. This includes a wrapper for a speech classification model that adds on a classification head with a dense layer, ReLU, dropout, and a final linear layer

In [20]:
# (4) initialize model
model = Wav2Vec2ForSpeechClassification(checkpoint=args.checkpoint, num_labels = args.n_class, pooling_mode = args.pooling_mode, 
                                        freeze=args.freeze, activation=args.activation, final_dropout=args.final_dropout, 
                                        layernorm=args.layernorm, weighted=args.weighted, layer=args.layer)    


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForSequenceClassification: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'projector.bias', 'wav2vec2.masked_spec_embed', 'classifier.weight', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be 

### Run training, evaluation

The model training loops are originally implemented in `loops.py`, but we will include them here for context.

In [None]:
def validation(model, criterion, dataloader_val):
    '''
    Validation loop for finetuning the w2v2 classification head. 
    :param model: W2V2 model
    :param criterion: loss function
    :param dataloader_val: dataloader object with validation data
    :return validation_loss: list with validation loss for each batch
    '''
    validation_loss = list()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader_val):
            x = torch.squeeze(batch['waveform'], dim=1)
            targets = batch['targets']
            x, targets = x.to(device), targets.to(device)
            o = model(x)
            val_loss = criterion(o, targets)
            validation_loss.append(val_loss.item())

    return validation_loss

In [None]:
def finetune(model, dataloader_train, dataloader_val = None, 
             optim='adamw', learning_rate=0.001, loss_fn='BCE',
             sched='onecycle', max_lr=0.01,
             epochs=10, exp_dir='', cloud=False, cloud_dir='', bucket=None):
    """
    Training loop for finetuning W2V2
    :param model: W2V2 model
    :param dataloader_train: dataloader object with training data
    :param dataloader_val: dataloader object with validation data
    :param optim: type of optimizer to initialize
    :param learning_rate: optimizer learning rate
    :param loss_fn: type of loss function to initialize
    :param sched: type of scheduler to initialize
    :param max_lr: max learning rate for onecycle scheduler
    :param epochs: number of epochs to run pretraining
    :param exp_dir: output directory on local machine
    :param cloud: boolean indicating whether uploading to cloud
    :param cloud_dir: output directory in google cloud storage bucket
    :param bucket: initialized GCS bucket object
    :return model: finetuned W2V2 model
    """
    print('Training start')
    #send to gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    #loss
    if loss_fn == 'MSE':
        criterion = torch.nn.MSELoss()
    elif loss_fn == 'BCE':
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        raise ValueError('MSE must be given for loss parameter')
    #optimizer
    if optim == 'adam':
        optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad],lr=learning_rate)
    elif optim == 'adamw':
         optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=learning_rate)
    else:
        raise ValueError('adam must be given for optimizer parameter')
    
    if sched == 'onecycle':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=max_lr, steps_per_epoch=len(dataloader_train), epochs=epochs)
    else:
        scheduler = None
    
    #train
    for e in range(epochs):
        training_loss = list()
        #t0 = time.time()
        model.train()
        for batch in tqdm(dataloader_train):
            x = torch.squeeze(batch['waveform'], dim=1)
            targets = batch['targets']
            x, targets = x.to(device), targets.to(device)
            optimizer.zero_grad()
            o = model(x)
            loss = criterion(o, targets)
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            loss_item = loss.item()
            training_loss.append(loss_item)

        if e % 10 == 0:
            #SET UP LOGS
            if scheduler is not None:
                lr = scheduler.get_last_lr()
            else:
                lr = learning_rate
            logs = {'epoch': e, 'optim':optim, 'loss_fn': loss_fn, 'lr': lr, 'scheduler':sched}
    
            logs['training_loss_list'] = training_loss
            training_loss = np.array(training_loss)
            logs['running_loss'] = np.sum(training_loss)
            logs['training_loss'] = np.mean(training_loss)

            print('RUNNING LOSS', e, np.sum(training_loss) )
            print(f'Training loss: {np.mean(training_loss)}')

            if dataloader_val is not None:
                print("Validation start")
                validation_loss = validation(model, criterion, dataloader_val)

                logs['val_loss_list'] = validation_loss
                validation_loss = np.array(validation_loss)
                logs['val_running_loss'] = np.sum(validation_loss)
                logs['val_loss'] = np.mean(validation_loss)
                
                print('RUNNING VALIDATION LOSS',e, np.sum(validation_loss) )
                print(f'Validation loss: {np.mean(validation_loss)}')
            
            #SAVE LOGS
            print(f'Saving epoch {e}')
            json_string = json.dumps(logs)
            logs_path = os.path.join(exp_dir, 'logs_ft_epoch{}.json'.format(e))
            with open(logs_path, 'w') as outfile:
                json.dump(json_string, outfile)
            
            #SAVE CURRENT MODEL
            
            mdl_path = os.path.join(exp_dir, 'w2v2_ft_mdl_epoch{}.pt'.format(e))
            torch.save(model.state_dict(), mdl_path)
            
            optim_path = os.path.join(exp_dir, 'w2v2_ft_optim_epoch{}.pt'.format(e))
            torch.save(optimizer.state_dict(), optim_path)

            if cloud:
                upload(cloud_dir, logs_path, bucket)
                #upload_from_memory(model.state_dict(), args.cloud_dir, mdl_path, args.bucket)
                upload(cloud_dir, mdl_path, bucket)
                upload(cloud_dir, optim_path, bucket)
    return model

In [None]:
def evaluation(model, dataloader_eval, exp_dir, cloud=False, cloud_dir=None, bucket=None):
    """
    Start model evaluation
    :param model: W2V2 model
    :param dataloader_eval: dataloader object with evaluation data
    :param exp_dir: specify LOCAL output directory as str
    :param cloud: boolean to specify whether to save everything to google cloud storage
    :param cloud_dir: if saving to the cloud, you can specify a specific place to save to in the CLOUD bucket
    :param bucket: google cloud storage bucket object
    :return preds: model predictions
    :return targets: model targets (actual values)
    """
    print('Evaluation start')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outputs = []
    t = []
    model = model.to(device)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader_eval):
            x = torch.squeeze(batch['waveform'], dim=1)
            x = x.to(device)
            targets = batch['targets']
            targets = targets.to(device)
            o = model(x)
            outputs.append(o)
            t.append(targets)

    outputs = torch.cat(outputs).cpu().detach()
    t = torch.cat(t).cpu().detach()
    # SAVE PREDICTIONS AND TARGETS 
    print('Saving predictions')
    pred_path = os.path.join(exp_dir, 'w2v2_eval_predictions.pt')
    target_path = os.path.join(exp_dir, 'w2v2_eval_targets.pt')
    torch.save(outputs, pred_path)
    torch.save(t, target_path)

    if cloud:
        upload(cloud_dir, pred_path, bucket)
        upload(cloud_dir, target_path, bucket)

    print('Evaluation finished')
    return outputs, t

In [None]:

# (5) start fine-tuning classification
model = finetune(model, dataloader_train, dataloader_val,
                    args.optim, args.learning_rate, args.loss, 
                    args.scheduler, args.max_lr, args.epochs,
                    args.exp_dir, args.cloud, args.cloud_dir, args.bucket)


In [None]:
print('Saving final epoch')

if model.weighted:
    mdl_path = os.path.join(args.exp_dir, '{}_{}_{}_epoch{}_{}_mdl_weighted.pt'.format(args.dataset, args.n_class, args.optim, args.epochs, os.path.basename(args.checkpoint)))
else:
    if args.layer==-1:
        args.layer='Final'
    mdl_path = os.path.join(args.exp_dir, '{}_{}_{}_layer{}_epoch{}_{}_mdl.pt'.format(args.dataset, args.n_class, args.optim, args.layer, args.epochs, os.path.basename(args.checkpoint)))
torch.save(model.state_dict(), mdl_path)

if args.cloud:
    upload(args.cloud_dir, mdl_path, args.bucket)


In [None]:
# (6) start evaluating
preds, targets = evaluation(model, dataloader_test, args.exp_dir, args.cloud, args.cloud_dir, args.bucket)


## Get Embeddings

Embedding extraction is a slightly different process. We instead load in one csv file, initialize and load a finetuned model, then run the embedding loop which extracts a hidden layer if specified as 'pt', the weighted sum output if specified as 'wt', or the output of the first layer of the classification head 'ft' (which functions as the embedding of dim 768)

In [None]:

args.data_split_root = 'gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620/test.csv'
args.embedding_type='ft' #if 'pt', it will get embeddings from only the pretrained model, 'wt' from weighted sum parameter
#args.mdl_path = None #TODO: must set to a finetuned model if you want it to load and get embeddings in that way.

In [None]:
# Get original model arguments
model_args, args.finetuned_mdl_path = setup_mdl_args(args, args.finetuned_mdl_path)

In [None]:
print('Running Embedding Extraction: ')
# (1) load data to get embeddings for
assert '.csv' in args.data_split_root, f'A csv file is necessary for embedding extraction. Please make sure this is a full file path: {args.data_split_root}'
annotations_df = pd.read_csv(args.data_split_root, index_col = 'uid') #data_split_root should use the CURRENT arguments regardless of the finetuned model
annotations_df["distortions"]=((annotations_df["distorted Cs"]+annotations_df["distorted V"])>0).astype(int)


In [None]:
# (2) set up audio configuration for transforms
audio_conf = {'checkpoint': args.checkpoint, 'resample_rate':args.resample_rate, 'reduce': args.reduce,
                'trim': args.trim, 'clip_length': args.clip_length}


In [None]:
# (3) set up dataloaders
waveform_dataset = W2V2Dataset(annotations_df = annotations_df, target_labels = model_args.target_labels,
                                audio_conf = audio_conf, prefix=args.prefix, bucket=args.bucket, librosa=args.lib) #not super important for embeddings, but the dataset should be selecting targets based on the FINETUNED model

dataloader = DataLoader(waveform_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)


In [None]:
# (4) set up embedding model
model = Wav2Vec2ForSpeechClassification(checkpoint=model_args.checkpoint, num_labels = model_args.n_class, pooling_mode = model_args.pooling_mode, 
                                        freeze=model_args.freeze, activation=model_args.activation, final_dropout=model_args.final_dropout, 
                                        layernorm=model_args.layernorm, weighted=model_args.weighted, layer=model_args.layer)   #should look like the finetuned model (so using model_args). If pretrained model, will resort to current args

if args.finetuned_mdl_path is not None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sd = torch.load(args.finetuned_mdl_path, map_location=device)
    model.load_state_dict(sd, strict=False)
else:
    print(f'Extracting embeddings from only a pretrained model: {args.pretrained_mdl_path}. Extraction method changed to pt.')
    args.embedding_type = 'pt' #manually change the type to 'pt' if not given a finetuned mdl path.


Again, the embedding extraction loop is implemented in `loops.py`, but we will include it here for context

In [None]:
def embedding_extraction(model, dataloader,embedding_type='ft',layer=-1, pooling_mode='mean'):
    """
    Run a specific subtype of evaluation for getting embeddings.
    :param model: W2V2 model
    :param dataloader_eval: dataloader object with data to get embeddings for
    :param embedding_type: string specifying whether embeddings should be extracted from classification head (ft) or base pretrained model (pt)
    :return embeddings: an np array containing the embeddings
    :param layer: hidden layer to take out and do results for - must be between 0-12
    """
    print('Getting embeddings')
    embeddings = np.array([])

    # send to gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader):
            x = torch.squeeze(batch['waveform'], dim=1)
            x = x.to(device)
            e = model.extract_embedding(x, embedding_type,layer=layer, pooling_mode=pooling_mode)
            e = e.cpu().numpy()
            if embeddings.size == 0:
                embeddings = e
            else:
                embeddings = np.append(embeddings, e, axis=0)
        
    return embeddings


Note all the different arguments for the embedding_extraction code. 
You can change the embedding type to:
* `ft`: extracts the embedding from the dense layer of the classification head
* `pt`: extracts the embedding from a hidden state (specified with `layer`)
* `wt`: extracts the embedding after weighted sum of layers

You can also change the pooling mode to one of 'mean', 'sum', 'max'.

The main embedding extraction code is implemented in the `w2v2_models.py` `W2V2ForSpeechClassification` class under `extract_embedding(...)`.

In [None]:
# (5) get embeddings
embeddings = embedding_extraction(model, dataloader, args.embedding_type, 
                                  args.layer, args.pooling_mode)
      

In [None]:
df_embed = pd.DataFrame([[r] for r in embeddings], columns = ['embedding'], index=annotations_df.index)


### Save embeddings

In [None]:
if args.embedding_type == 'ft':
    args.layer='NA'
    args.pooling_mode='NA'
elif args.embedding_type == 'wt':
    args.layer='NA'
elif args.layer==-1:
    args.layer='Final'

try:
    pqt_path = '{}/{}_layer{}_{}_w2v2_{}_embeddings.pqt'.format(args.exp_dir, args.dataset, args.layer, args.pooling_mode,args.embedding_type)
    
    df_embed.to_parquet(path=pqt_path, index=True, engine='pyarrow') #TODO: fix

    if args.cloud:
        upload(args.cloud_dir, pqt_path, args.bucket)
except:
    print('Unable to save as pqt, saving instead as csv')
    csv_path = '{}/{}_layer{}_{}_w2v2_{}_embeddings.csv'.format(args.exp_dir, args.dataset, args.layer, args.pooling_mode,args.embedding_type)
    
    df_embed.to_csv(csv_path, index=True)

    if args.cloud:
        upload(args.cloud_dir, csv_path, args.bucket)