# Timm

We use timm model architectures and create a wrapper class for adding a classification head and extracting embeddings.

Note: [run.py](https://github.com/dwiepert/mayo-timm/blob/main/src/run.py) script can also do evaluation only. 

Authors: Daniela Wiepert

The environment must include the following packages, all of which can be dowloaded with pip or conda:
* albumentations
* librosa
* torch, torchvision, torchaudio
* tqdm (this is essentially enumerate(dataloader) except it prints out a nice progress bar for you)
* speechbrain 
* pyarrow
* timm

If running on your local machine and not in a GCP environment, you will also need to install:
* google-cloud-storage

The [requirements.txt](https://github.com/dwiepert/mayo-ecapa-tdnn/blob/main/requirements.txt) can be used to set up this environment. 

To access data stored in GCS on your local machine, you will need to additionally run

```gcloud auth application-default login```

```gcloud auth application-defaul set-quota-project PROJECT_NAME```

Please note that if using GCS, the model expects arguments like model paths or directories to start with `gs://BUCKET_NAME/...` with the exception of defining an output cloud directory which should just be the prefix to save within a bucket. 

In [None]:
#IMPORTS
#built-in
import argparse
import os
import pickle

#third-party
import numpy as np
import torch
import pandas as pd
import pyarrow

from google.cloud import storage
from torch.utils.data import  DataLoader

#local
from utilities import *
from models import *
from dataloader import AudioDataset
from loops import *

### Upload/Download functions & Data loading functions
These are defined in `utilities/load_utils.py`.

## Arguments
There are many mutable arguments when running. Please explore the different options and make sure all arguments are set as you would like. 

In [None]:
parser = argparse.ArgumentParser()
#Inputs
parser.add_argument('-i','--prefix',default='speech_ai/speech_lake/', help='Input directory or location in google cloud storage bucket containing files to load')
parser.add_argument("-s", "--study", choices = ['r01_prelim','speech_poc_freeze_1', None], default='speech_poc_freeze_1', help="specify study name")
parser.add_argument("-d", "--data_split_root", default='gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620/test.csv', help="specify file path where datasplit is located. If you give a full file path to classification, an error will be thrown. On the other hand, evaluation and embedding expects a single .csv file.")
parser.add_argument('-l','--label_txt', default='src/labels.txt')
parser.add_argument('--lib', default=False, type=bool, help="Specify whether to load using librosa as compared to torch audio")
parser.add_argument("--trained_mdl_path", default=None, help="specify path to trained model")
parser.add_argument("--model_type", default='efficientnet_b0', help='specify the timm model type to initialize')
#GCS
parser.add_argument('-b','--bucket_name', default='ml-e107-phi-shared-aif-us-p', help="google cloud storage bucket name")
parser.add_argument('-p','--project_name', default='ml-mps-aif-afdgpet01-p-6827', help='google cloud platform project name')
parser.add_argument('--cloud', default=False, type=bool, help="Specify whether to save everything to cloud")
#output
parser.add_argument("--dataset", default=None,type=str, help="When saving, the dataset arg is used to set file names. If you do not specify, it will assume the lowest directory from data_split_root")
parser.add_argument("-o", "--exp_dir", default='./experiments', help='specify LOCAL output directory')
parser.add_argument('--cloud_dir', default='', type=str, help="if saving to the cloud, you can specify a specific place to save to in the CLOUD bucket")
#Mode specific
parser.add_argument("-m", "--mode", choices=['train','eval','extraction'], default='train')
parser.add_argument('--embedding_type', type=str, default='ft', help='specify whether embeddings should be extracted from classification head (ft) or base pretrained model (pt)', choices=['ft','pt'])
#Audio configuration parameters
parser.add_argument("--dataset_mean", default=-4.2677393, type=float, help="the dataset mean, used for input normalization")
parser.add_argument("--dataset_std", default=4.5689974, type=float, help="the dataset std, used for input normalization")
parser.add_argument("--target_length", default=1024, type=int, help="the input length in frames")
parser.add_argument("--num_mel_bins", default=128,type=int, help="number of input mel bins")
parser.add_argument("--resample_rate", default=16000,type=int, help='resample rate for audio files')
parser.add_argument("--reduce", default=True, type=bool, help="Specify whether to reduce to monochannel")
parser.add_argument("--clip_length", default=10.0, type=int, help="If truncating audio, specify clip length in seconds. 0 = no truncation")
parser.add_argument("--tshift", default=0, type=float, help="Specify p for time shift transformation")
parser.add_argument("--speed", default=0, type=float, help="Specify p for speed tuning")
parser.add_argument("--gauss", default=0, type=float, help="Specify p for adding gaussian noise")
parser.add_argument("--pshift", default=0, type=float, help="Specify p for pitch shifting")
parser.add_argument("--pshiftn", default=0, type=float, help="Specify number of steps for pitch shifting")
parser.add_argument("--gain", default=0, type=float, help="Specify p for gain")
parser.add_argument("--stretch", default=0, type=float, help="Specify p for audio stretching")
parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0)
parser.add_argument('--timem', help='time mask max length', type=int, default=0)
parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training")
parser.add_argument("--noise", type=bool, default=False, help="specify if augment noise in finetuning")
parser.add_argument("--skip_norm", type=bool, default=False, help="specify whether to skip normalization on spectrogram")
#Model parameters
parser.add_argument("-bs", "--batch_size", type=int, default=8, help="specify batch size")
parser.add_argument("-nw", "--num_workers", type=int, default=0, help="specify number of parallel jobs to run for data loader")
parser.add_argument("-lr", "--learning_rate", type=float, default=0.0003, help="specify learning rate")
parser.add_argument("-e", "--epochs", type=int, default=1, help="specify number of training epochs")
parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["adam", "adamw"])
parser.add_argument("--weight_decay", type=float, default=.0001, help='specify weight decay for adamw')
parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["MSE", "BCE"])
parser.add_argument("--scheduler", type=str, default="onecycle", help="specify lr scheduler", choices=["onecycle", None])
parser.add_argument("--max_lr", type=float, default=0.01, help="specify max lr for lr scheduler")
#classification head parameters
parser.add_argument("--activation", type=str, default='relu', help="specify activation function to use for classification head")
parser.add_argument("--final_dropout", type=float, default=0.3, help="specify dropout probability for final dropout layer in classification head")
parser.add_argument("--layernorm", type=bool, default=False, help="specify whether to include the LayerNorm in classification head")
args = parser.parse_args()

## Setting up environment
The first step is to make sure the GCS bucket is initialized if given a `bucket_name`. Additionally, the list of target labels must be set. There are a few other arguments to consider as well

In the original implementation, the list must be given as a `.txt` file to pass through the command line. In this implementation, we will set it as a list.

In [None]:
print('Torch version: ',torch.__version__)
print('Cuda availability: ', torch.cuda.is_available())
print('Cuda version: ', torch.version.cuda)

In [None]:
#variables
# (1) Set up GCS
if args.bucket_name is not None:
    storage_client = storage.Client(project=args.project_name)
    bucket = storage_client.bucket(args.bucket_name)
else:
    bucket = None

In [None]:
# (2), check if given study or if the prefix is the full prefix.
if args.study is not None:
    args.prefix = os.path.join(args.prefix, args.study)


In [None]:
# (3) get dataset name
if args.dataset is None:
    if args.trained_mdl_path is None or args.mode == 'train':
        if '.csv' in args.data_split_root:
            args.dataset = '{}_{}'.format(os.path.basename(os.path.dirname(args.data_split_root)), os.path.basename(args.data_split_root[:-4]))
        else:
            args.dataset = os.path.basename(args.data_split_root)
    else:
        args.dataset = os.path.basename(args.trained_mdl_path)[:-7]

In [None]:
# (4) get target labels
    #get list of target labels
args.target_labels = ['slow rate',
                        'irregular artic breakdowns',
                        'rapid rate',
                        'distortions',
                        'strained']

args.n_class = len(args.target_labels)

In [None]:

# (5) check if output directory exists, SHOULD NOT BE A GS:// path
if not os.path.exists(args.exp_dir):
    os.makedirs(args.exp_dir)


In [None]:
# (6) check that clip length has been set
if args.clip_length == 0:
    try: 
        assert args.batch_size == 1, 'Not currently compatible with different length wav files unless batch size has been set to 1'
    except:
        args.batch_size = 1


In [None]:
# (7) dump arguments
args_path = "%s/args.pkl" % args.exp_dir
with open(args_path, "wb") as f:
    pickle.dump(args, f)
#in case of error, everything is immediately uploaded to the bucket
if args.cloud:
    upload(args.cloud_dir, args_path, bucket)


In [None]:
# (8) check if trained model is stored in gcs bucket or confirm it exists on local machine
if args.trained_mdl_path is not None:
    args.trained_mdl_path = gcs_model_exists(args.trained_mdl_path, args.bucket_name, args.exp_dir, bucket)


In [None]:
 #(9) add bucket to args
args.bucket = bucket

## Training
We will start with the training option and not include the option for only evaluating an already trained model (which is available in the full .py script)

When loading data, we start with a data split root, which we expect to be a directory containing a `train.csv` file and `test.csv` file with file names for train/test and the associated label data.

We load the data, set up an audio configuration, set up `AudioDataset` objects (from `dataloader.py`), and set up the dataloaders.

The transforms in `AudioDataset` are set up using functions from `utilities/speech_utils.py`. 

The dataloaders take in the datasets and batch size + number of workers.

Please note that the resulting samples will be a dictionary with the keys `uid`, `waveform`, `targets`, `sample_rate`, `fbank`.

We randomly sample the train.csv within the load_data function to get a validation dataset of 50 samples. See `load_utils.py` for the implementation.

In [None]:
print('Running training: ')
# (1) load data
assert '.csv' not in args.data_split_root, f'May have given a full file path, please confirm this is a directory: {args.data_split_root}'
train_df, val_df, test_df = load_data(args.data_split_root, args.target_labels, args.exp_dir, args.cloud, args.cloud_dir, args.bucket)


In [None]:
#(2) set audio configurations (val loader and eval loader will both use the eval_audio_conf
train_audio_conf = {'dataset': args.dataset, 'mode': 'train', 'resample_rate': args.resample_rate, 'reduce': args.reduce, 'clip_length': args.clip_length,
                'tshift':args.tshift, 'speed':args.speed, 'gauss_noise':args.gauss, 'pshift':args.pshift, 'pshiftn':args.pshiftn, 'gain':args.gain, 'stretch': args.stretch,
                'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'noise':args.noise,
                'mean':args.dataset_mean, 'std':args.dataset_std, 'skip_norm':args.skip_norm}

#note, mixup should always be 0 for the evaluation
eval_audio_conf = {'dataset': args.dataset, 'mode': 'evaluation', 'resample_rate': args.resample_rate, 'reduce': args.reduce, 'clip_length': args.clip_length,
                'tshift':args.tshift, 'speed':args.speed, 'gauss_noise':args.gauss, 'pshift':args.pshift, 'pshiftn':args.pshiftn, 'gain':args.gain, 'stretch': args.stretch,
                'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': 0, 'noise':args.noise,
                'mean':args.dataset_mean, 'std':args.dataset_std, 'skip_norm':args.skip_norm}


In [None]:
 #(3) Generate audio dataset, note that if bucket not given, it assumes None and loads from local files
train_dataset = AudioDataset(annotations_df=train_df, target_labels=args.target_labels, audio_conf=train_audio_conf, 
                                prefix=args.prefix, bucket=args.bucket, librosa=args.lib) #librosa = True (might need to debug this one)
val_dataset = AudioDataset(annotations_df=val_df, target_labels=args.target_labels, audio_conf=eval_audio_conf, 
                                prefix=args.prefix, bucket=args.bucket, librosa=args.lib) #librosa = True (might need to debug this one)
eval_dataset = AudioDataset(annotations_df=test_df, target_labels=args.target_labels, audio_conf=eval_audio_conf, 
                            prefix=args.prefix, bucket=args.bucket, librosa=args.lib)


In [None]:
#(4) set up data loaders (val loader always has batchsize 1)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)


In [None]:
#TODO: TEST WHETHER YOU CAN LOAD A BATCH
batch = next(iter(train_loader))

### Set up the model

Set up the model using classes from `timm_models.py`. This includes a wrapper for a speech classification model that adds on a classification head with a dense layer, ReLU, dropout, and a final linear layer

In [None]:
 # (4) initialize model
model = timmForSpeechClassification(args.model_type, args.n_class, args.activation, args.final_dropout, args.layernorm)


### Run training, evaluation

The model training loops are originally implemented in `loops.py`, but we will include them here for context.

In [None]:
def validation(model, criterion, dataloader_val):
    '''
    Validation loop for training
    :param model: model
    :param criterion: loss function
    :param dataloader_val: dataloader object with validation data
    :return validation_loss: list with validation loss for each batch
    '''
    validation_loss = list()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        model = model.cuda()
    except:
        model = model.to(device)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader_val):
            x = torch.unsqueeze(batch['fbank'],1)
            targets = batch['targets']
            try:
                x,targets = x.cuda(), targets.cuda()
            except:
                x, targets = x.to(device), targets.to(device)
            o = model(x)
            val_loss = criterion(o, targets)
            validation_loss.append(val_loss.item())

    return validation_loss

In [None]:
def train(model, dataloader_train, dataloader_val = None, 
             optim='adamw', learning_rate=0.001, weight_decay=0.001, 
             loss_fn='BCE', sched='onecycle', max_lr=0.01,
             epochs=10, exp_dir='', cloud=False, cloud_dir='', bucket=None):
    """
    Training loop for training
    :param model: model
    :param dataloader_train: dataloader object with training data
    :param dataloader_val: dataloader object with validation data
    :param optim: type of optimizer to initialize
    :param learning_rate: optimizer learning rate
    :param weight_decay: weight decay value for adamw optimizer
    :param loss_fn: type of loss function to initialize
    :param sched: type of scheduler to initialize
    :param max_lr: max learning rate for onecycle scheduler
    :param epochs: number of epochs to run pretraining
    :param exp_dir: output directory on local machine
    :param cloud: boolean indicating whether uploading to cloud
    :param cloud_dir: output directory in google cloud storage bucket
    :param bucket: initialized GCS bucket object
    :return model: finetuned model
    """
    print('Finetuning start')
    #send to gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    try:
        model = model.cuda()
    except:
        model = model.to(device)

    #loss
    if loss_fn == 'MSE':
        criterion = torch.nn.MSELoss()
    elif loss_fn == 'BCE':
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        raise ValueError(f'Given loss function ({loss_fn}) not supported. Must be either MSE or BCE')
    #optimizer
    if optim == 'adam':
        optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad],lr=learning_rate)
    elif optim == 'adamw':
         optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=learning_rate, weight_decay=weight_decay)
    else:
        raise ValueError(f'Given optimizer ({optim}) not supported. Must be either adam or adamw')
    
    if sched == 'onecycle':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=len(dataloader_train), epochs=epochs)
    else:
        scheduler = None
    
    #train
    for e in range(epochs):
        training_loss = list()
        #t0 = time.time()
        model.train()
        for batch in tqdm(dataloader_train):
            x = torch.unsqueeze(batch['fbank'],1)
            targets = batch['targets']
            try:
                x = x.cuda()
                targets = targets.cuda()
            except:
                x, targets = x.to(device), targets.to(device)
            optimizer.zero_grad()
            o = model(x)
            loss = criterion(o, targets)
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            loss_item = loss.item()
            training_loss.append(loss_item)

        if e % 10 == 0:
            #SET UP LOGS
            if scheduler is not None:
                lr = scheduler.get_last_lr()
            else:
                lr = learning_rate
            logs = {'epoch': e, 'optim':optim, 'loss_fn': loss_fn, 'lr': lr, 'scheduler':sched}
    
            logs['training_loss_list'] = training_loss
            training_loss = np.array(training_loss)
            logs['running_loss'] = np.sum(training_loss)
            logs['training_loss'] = np.mean(training_loss)

            print('RUNNING LOSS', e, np.sum(training_loss) )
            print(f'Training loss: {np.mean(training_loss)}')

            if dataloader_val is not None:
                print("Validation start")
                validation_loss = validation(model, criterion, dataloader_val)

                logs['val_loss_list'] = validation_loss
                validation_loss = np.array(validation_loss)
                logs['val_running_loss'] = np.sum(validation_loss)
                logs['val_loss'] = np.mean(validation_loss)
                
                print('RUNNING VALIDATION LOSS',e, np.sum(validation_loss) )
                print(f'Validation loss: {np.mean(validation_loss)}')
            
            #SAVE LOGS
            json_string = json.dumps(logs)
            logs_path = os.path.join(exp_dir, 'logs_ft_epoch{}.json'.format(e))
            with open(logs_path, 'w') as outfile:
                json.dump(json_string, outfile)
            
            #SAVE CURRENT MODEL
            print(f'Saving epoch {e}')
            mdl_path = os.path.join(exp_dir, 'timm_ft_mdl_epoch{}.pt'.format(e))
            torch.save(model.state_dict(), mdl_path)
            
            optim_path = os.path.join(exp_dir, 'timm_ft_optim_epoch{}.pt'.format(e))
            torch.save(optimizer.state_dict(), optim_path)

            if cloud:
                upload(cloud_dir, logs_path, bucket)
                upload(cloud_dir, mdl_path, bucket)
                upload(cloud_dir, optim_path, bucket)

    print('Finetuning finished')
    return model

In [None]:
def evaluation(model, dataloader_eval):
    """
    Start model evaluation
    :param model: model
    :param dataloader_eval: dataloader object with evaluation data
    :return preds: model predictions
    :return targets: model targets (actual values)
    """
    print('Evaluation start')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    outputs = []
    t = []
    try:
        model = model.cuda()
    except:
        model = model.to(device)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader_eval):
            x = torch.unsqueeze(batch['fbank'],1)
            targets = batch['targets']
            try:
                x,targets = x.cuda(), targets.cuda()
            except:
                x, targets = x.to(device), targets.to(device)
            o = model(x)
            outputs.append(o)
            t.append(targets)

    outputs = torch.cat(outputs).cpu().detach()
    t = torch.cat(t).cpu().detach()
    # SAVE PREDICTIONS AND TARGETS 
    print('Evaluation finished')
    return outputs, t

In [None]:
#(5) run model
model = train(model, train_loader, val_loader, 
                args.optim, args.learning_rate, args.weight_decay,
                args.loss, args.scheduler, args.max_lr, args.epochs,
                args.exp_dir, args.cloud, args.cloud_dir, args.bucket)

In [None]:
 
print('Saving final model')
mdl_path = os.path.join(args.exp_dir, '{}_{}_{}_epoch{}_{}_mdl.pt'.format(args.dataset, args.n_class, args.optim, args.epochs, args.model_type))
torch.save(model.state_dict(), mdl_path)

if args.cloud:
    upload(args.cloud_dir, mdl_path, args.bucket)

In [None]:
 # (6) start evaluating
preds, targets = evaluation(model, eval_loader)
    

In [None]:
print('Saving predictions and targets')
pred_path = os.path.join(args.exp_dir, '{}_{}_{}_epoch{}_{}_predictions.pt'.format(args.dataset, args.n_class, args.optim, args.epochs, args.model_type))
target_path = os.path.join(args.exp_dir, '{}_{}_{}_epoch{}_{}_targets.pt'.format(args.dataset, args.n_class, args.optim, args.epochs, args.model_type))
torch.save(preds, pred_path)
torch.save(targets, target_path)

if args.cloud:
    upload(args.cloud_dir, pred_path, args.bucket)
    upload(args.cloud_dir, target_path, args.bucket)

## Get Embeddings

Embedding extraction is a slightly different process. We instead load in one csv file, initialize and load a trained model, then run the embedding loop which extracts the emebdding from the output of the base model if specifiedd as 'pt' or the output of the first layer of the classification head 'ft'.

In [None]:
args.data_split_root = 'gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620/test.csv'
args.embedding_type='ft' #if 'pt', it will get embeddings from only the pretrained model, 'wt' from weighted sum parameter
#args.trained_mdl_path = None #TODO: must set to a finetuned model if you want it to load and get embeddings in that way.
args.target_labels = [] #not a required step, but embeddings do not require target labels to run. 
args.n_class = 0

In [None]:
print('Running Embedding Extraction: ')
assert args.trained_mdl_path is not None, 'must give a model to load for embedding extraction. '
# Get original 
model_args = load_args(args, args.trained_mdl_path)

In [None]:
# (1) load data to get embeddings for
assert '.csv' in args.data_split_root, f'A csv file is necessary for embedding extraction. Please make sure this is a full file path: {args.data_split_root}'
annotations_df = pd.read_csv(args.data_split_root, index_col = 'uid') #data_split_root should use the CURRENT arguments regardless of the finetuned model

if 'distortions' in args.target_labels and 'distortions' not in annotations_df.columns:
    annotations_df["distortions"]=((annotations_df["distorted Cs"]+annotations_df["distorted V"])>0).astype(int)


In [None]:
#(2) set audio configurations
args.mixup=0 #mixup should always be 0 for embedding extraction
audio_conf = {'dataset': args.dataset, 'mode': 'evaluation', 'resample_rate': args.resample_rate, 'reduce': args.reduce, 'clip_length': args.clip_length,
                'tshift':args.tshift, 'speed':args.speed, 'gauss_noise':args.gauss, 'pshift':args.pshift, 'pshiftn':args.pshiftn, 'gain':args.gain, 'stretch': args.stretch,
                'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'noise':args.noise,
                'mean':args.dataset_mean, 'std':args.dataset_std, 'skip_norm':args.skip_norm}


In [None]:
# (3) set up dataloader with current args
dataset = AudioDataset(annotations_df=annotations_df, target_labels=args.target_labels, audio_conf=audio_conf, 
                            prefix=args.prefix, bucket=args.bucket, librosa=args.lib)
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn) 


In [None]:
# (4) initialize model
model = timmForSpeechClassification(model_args.model_type, model_args.n_class, model_args.activation, model_args.final_dropout, model_args.layernorm)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load(args.trained_mdl_path, map_location=device)
model.load_state_dict(sd, strict=False)

Again, the embedding extraction loop is implemented in `loops.py`, but we will include it here for context

In [None]:
def embedding_extraction(model, dataloader, embedding_type='ft'):
    """
    Run a specific subtype of evaluation for getting embeddings.
    :param model: model
    :param dataloader_eval: dataloader object with data to get embeddings for
    :param embedding_type: string specifying whether embeddings should be extracted from classification head (ft) or base pretrained model (pt)
    :param layer: int indicating which hidden state layer to use.
    :param task: finetuning task, only used for 'pt' or 'wt' embedding extraction.
    :return embeddings: an np array containing the embeddings
    """

    print('Calculating Embeddings')
    embeddings = np.array([])
    # send to gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    try: 
        model = model.cuda()
    except:
        model = model.to(device)

    with torch.no_grad():
        model.eval()
        for batch in tqdm(dataloader):
            x = torch.unsqueeze(batch['fbank'],1)
            try:
                x = x.cuda()
            except:
                x = x.to(device)
            e = model.extract_embedding(x, embedding_type)
            if embeddings.size == 0:
                embeddings = e
            else:
                embeddings = np.append(embeddings, e, axis=0)

    return embeddings

def calc_auc(preds, targets, target_labels,
         exp_dir, cloud, cloud_dir, bucket):
    """
    Get AUC scores, doesn't return, just saves the metrics to a csv
    :param args: dict with all the argument values
    :param preds: model predictions
    :param targets: model targets (actual values)
    """
    #get AUC score and all data for ROC curve
    preds = preds[targets.isnan().sum(1)==0]
    targets[targets.isnan().sum(1)==0]
    pred_mat=torch.sigmoid(preds).numpy()
    target_mat=targets.numpy()
    aucs=roc_auc_score(target_mat, pred_mat, average = None) #TODO: this doesn't work when there is an array with all labels as 0???
    print(aucs)
    data = pd.DataFrame({'Label':target_labels, 'AUC':aucs})
    data.to_csv(os.path.join(exp_dir, 'aucs.csv'), index=False)
    if cloud:
        upload(cloud_dir, os.path.join(exp_dir, 'aucs.csv'), bucket)

    return data


In [None]:
# (5) get embeddings
embeddings = embedding_extraction(model, loader, args.embedding_type)

In [None]:
df_embed = pd.DataFrame([[r] for r in embeddings], columns = ['embedding'], index=annotations_df.index)


In [None]:
try:
    pqt_path = '{}/{}_{}_embeddings.pqt'.format(args.exp_dir, args.dataset, args.embedding_type)
    
    df_embed.to_parquet(path=pqt_path, index=True, engine='pyarrow') 

    if args.cloud:
        upload(args.cloud_dir, pqt_path, args.bucket)
except:
    print('Unable to save as pqt, saving instead as csv')
    csv_path = '{}/{}_{}_embeddings.csv'.format(args.exp_dir, args.dataset, args.embedding_type)
    df_embed.to_csv(csv_path, index=True)

    if args.cloud:
        upload(args.cloud_dir, csv_path, args.bucket)