In [1]:
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import pandas as pd

In [2]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wave2vec = Wav2Vec2ForSequenceClassification.\
    from_pretrained("facebook/wav2vec2-base-960h", num_labels=8)
    

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
import os
from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset
from typing import Tuple, Any, List
from tqdm import tqdm
import warnings



def preprocess_df(df: pd.DataFrame, filename: str, wavfile_base_path: str) -> pd.DataFrame:
    df.dropna(inplace=True)
    df['filename'] = filename
    df['label'] = df['labels_male'] if filename[5] == 'M' else df['label_feamle']
    df['filename'] = df['filename'].apply(lambda x: os.path.join(wavfile_base_path, filename + '.wav'))
    df.drop('Unnamed: 0', axis=1, inplace=True)    
    df.drop(columns=['labels_male', 'label_feamle'], inplace=True)
    return df

def load_file(path: str, wavfile_base_path: str) -> pd.DataFrame:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        df = pd.read_csv(path)
    filename = path.split('/')[-1].split('.')[0]
    if 'Ses01' in filename:
        session = 'Session1'
    elif 'Ses02' in filename:
        session = 'Session2'
    elif 'Ses03' in filename:
        session = 'Session3'
    elif 'Ses04' in filename:
        session = 'Session4'
    elif 'Ses05' in filename:
        session = 'Session5'
    else:
        raise ValueError("Invalid Session  Name")
    
    wavfile_base_path = wavfile_base_path.replace('SessionX', session)
        
    df = preprocess_df(df, filename, wavfile_base_path)
    return df

def load_files(files: List[str], wavfile_base_path: str):
    data = []
    for file in tqdm(files):
        data.append(load_file(file, wavfile_base_path))
    return pd.concat(data)


def load_data(base_path: str, wavfile_base_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train_files = [os.path.join(base_path, file) for file in os.listdir(base_path) if 'Ses05' not in file]
    test_files = [os.path.join(base_path, file) for file in os.listdir(base_path) if 'Ses05' in file]
    train = load_files(train_files, wavfile_base_path)
    test = load_files(test_files, wavfile_base_path)
    return (train, test)

In [4]:
path = './iemocap_processed/merged/Ses01F_impro01.csv'
wav_file_path_template = '/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_release/SessionX/dialog/wav'

df = load_file(path, wav_file_path_template)
df.head()

Unnamed: 0,Frame#,Time,X01,Y01,Z01,X02,Y02,Z02,X03,Y03,...,Y53,Z53,X60,Y60,Z60,X61,Y61,Z61,filename,label
908,909,7.58266,-27.5677,34.97618,-54.40529,-0.23704,25.24322,-61.08478,23.64563,36.2148,...,24.1399,-45.0436,55.95958,56.6439,121.62669,-45.31238,52.59035,129.44393,/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_rel...,Neutral
909,910,7.591,-27.59504,34.97534,-54.53862,-0.14238,25.29796,-61.08746,23.63664,36.28221,...,24.20146,-45.19563,55.99112,56.57716,121.61237,-45.29489,52.42889,129.39733,/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_rel...,Neutral
910,911,7.59934,-27.62559,34.94876,-54.6067,-0.15016,25.26522,-61.18179,23.5933,36.19091,...,24.20199,-45.38675,56.02899,56.47296,121.5813,-45.28643,52.32635,129.36415,/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_rel...,Neutral
911,912,7.60768,-27.66087,34.82575,-54.76922,-0.11912,25.22126,-61.34704,23.66278,36.07449,...,24.22021,-45.6033,56.02719,56.37018,121.55001,-45.29266,52.30411,129.36309,/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_rel...,Neutral
912,913,7.61602,-27.59027,34.67632,-54.79646,-0.14865,25.1628,-61.38764,23.70279,35.97943,...,24.12938,-45.68056,56.03284,56.31125,121.55555,-45.30947,52.3553,129.35845,/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_rel...,Neutral


In [5]:
base_path = './iemocap_processed/merged/'
wav_file_path_template = '/media/cv/Extreme Pro/IEMOCAP/IEMOCAP_full_release/SessionX/dialog/wav'

# train, test = load_data(base_path, wav_file_path_template)

save_train = './iemocap_processed/pickled/train_with_audio_file.pkl'
save_test = './iemocap_processed/pickled/test_with_audio_file.pkl'
# train.to_pickle(save_train)
# test.to_pickle(save_test)

In [6]:

save_train = './iemocap_processed/pickled/train_with_audio_file.pkl'
save_test = './iemocap_processed/pickled/test_with_audio_file.pkl'

train_data = pd.read_pickle(save_train)
test_data = pd.read_pickle(save_test)

In [7]:
import librosa as lb
import torch
import numpy as np
import warnings

class AudioDataset(Dataset):
    def __init__(self, df, audio_processor, sr=16_000, duration: float=2):
        self.df = df
        self.df = self.df.loc[ self.df['label'] != 'Other']
        
        self.label2id = {
            'Frustration':0,
            'Anger':1,
            'Excited':2,
            'Neutral':3,
            'Happiness':4,
            'Sadness':5,
            'Fear':6,
            'Surprise':7,
        }

        self.id2label = {
            0: 'Frustration',
            1: 'Anger',
            2: 'Excited',
            3: 'Neutral',
            4: 'Happiness',
            5: 'Sadness',
            6: 'Fear',
            7: 'Surprise',
        }


        self.processor = audio_processor
        self.duration = duration
        self.sampling_rate = sr
        self.samples = self.duration * self.sampling_rate
    
    def __getitem__(self, index) -> Any:
        file_path = self.df.iloc[index]['filename']
        offset = max(0, self.df.iloc[index]['Time'] - self.duration)
    
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            audio_input ,sr = lb.load(file_path, sr=self.sampling_rate, mono=True, offset=offset, duration=self.duration)
        audio_input = self.pad_or_truncate(audio_input)
        input_values = self.processor(audio_input, sampling_rate=sr, return_tensors='pt').input_values
        labels = torch.tensor(self.label2id[self.df.iloc[index]['label']])
        
        return (input_values.squeeze(), labels)
    
    def __len__(self):
        return len(self.df)

    def pad_or_truncate(self, arr):
        if len(arr) > self.samples:
            arr = arr[:self.samples]
        else:
            arr = np.pad(arr, (0, self.samples - len(arr)), 'constant', constant_values=(0))
        return arr

def tensor_data_to_dataloader(dataset: TensorDataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size = batch_size,
        sampler=RandomSampler(dataset)
    )

def build_dataset(pickle_path: str, audio_processor, batch_size=32, sampling_frac=0.5):
    data = pd.read_pickle(pickle_path)
    data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))
    data = tensor_data_to_dataloader(AudioDataset(data, audio_processor), batch_size=batch_size)
    return data


batch_size = 32
trainloader = build_dataset(save_train, processor, batch_size=batch_size)
testloader = build_dataset(save_test, processor, batch_size=batch_size)
# i, temp = next(enumerate(train_audio))

In [9]:
i, temp = next(enumerate(trainloader))

In [13]:
import os
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
import datetime, time, random
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

class TrainingArgs:
    def __init__(self, device='cuda', learning_rate=2e-5, epsilon=1e-8, epochs=4) -> None:
        self.device = device
        self.lr = learning_rate
        self.epsilon = epsilon
        self.epochs = epochs
        self.warmup_steps = 0
        self.seed = 1024

    
class TrainerWave2Vec:
    
    def __init__(self,
                 model: nn.Module,
                 trainloader: DataLoader,
                 testloader: DataLoader,
                 out_dir: str,
                 args: TrainingArgs) -> None:
        
        self.args = args
        self.num_epochs = self.args.epochs
        self.device = self.args.device
        self.out_dir = out_dir
        
        self.train_data = trainloader
        self.val_data = testloader

        self.best_acc = 0
        
        self.training_stats = []

        self.model = model.to(self.device)
        self.optimizer = AdamW(model.parameters(), lr=self.args.lr, eps=self.args.epsilon)
        self.schedular = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps= self.args.warmup_steps, num_training_steps=len(trainloader) * self.args.epochs            
        )
        
        self.fix_seeds(self.args.seed)
    
    def fix_seeds(self, seed_val):
        random.seed(seed_val)
        np.random.seed(seed_val)
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        
    def validate(self):
        print("")
        print("Running Validation...")
        t0 = time.time()
        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        self.model.eval()
        # Tracking variables 
        total_eval_loss = 0
        
        preds = []
        labels = []
        
        for _, batch in enumerate(tqdm(self.val_data)):
            b_input_ids = batch[0].to(self.device)
            b_labels = batch[1].to(self.device)

            with torch.no_grad():        
                output= self.model(b_input_ids, 
                                    labels=b_labels)
            loss = output.loss
            total_eval_loss += loss.item()
            # Move logits and labels to CPU if we are using GPU
            logits = output.logits
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.cpu().numpy()
            
            preds.extend(np.argmax(logits, axis=1))
            labels.extend(label_ids)

        # Report the final accuracy for this validation run.
        preds, labels = np.array(preds), np.array(labels)
        avg_val_accuracy = f1_score(preds.flatten(), labels.flatten(), average='weighted')# total_eval_accuracy / len(dataloader)
        print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(self.val_data)

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        if avg_val_accuracy > self.best_acc:
            torch.save(self.model.state_dict(), os.path.join(self.out_dir, 'wave2vec.pt'))
            self.best_acc = avg_val_accuracy

        return (avg_val_accuracy, avg_val_loss, validation_time)
    
    def train_epoch(self, epoch_i):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, self.num_epochs))
        print('Training...')
        t0 = time.time()
        total_train_loss = 0
        self.model.train()
        for step, batch in enumerate(tqdm(self.train_data)):
        
            b_input_ids = batch[0].to(self.device)
            b_labels = batch[1].to(self.device)
            self.optimizer.zero_grad()
            output = self.model(b_input_ids, 
                                labels=b_labels)        
            loss = output.loss
            total_train_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
    
            self.optimizer.step()
            self.schedular.step()
            

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(self.train_data)            
        
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)
        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(training_time))    
        return (avg_train_loss, training_time)
    
    def train(self):
        for epoch in range(self.num_epochs):
            train_loss, train_time = self.train_epoch(epoch)
            val_acc, val_loss, val_time = self.validate()
            self.training_stats.append(
                {
                    'epoch': epoch + 1,
                    'Training Loss': train_loss,
                    'Valid. Loss': val_loss,
                    'Valid. Accur.': val_acc,
                    'Training Time': train_time,
                    'Validation Time': val_time
                }
            )
            


In [14]:


trainer = TrainerWave2Vec(
    model=wave2vec,
    trainloader=trainloader,
    testloader=testloader,
    out_dir='./saved_weights_iemocap_audio',
    args=TrainingArgs()

)

In [None]:
trainer.train()

In [21]:
class AudioLandmarkDataset(Dataset):
    def __init__(self, df, audio_processor, sr=16_000, duration: float=2):
        self.df = df
        self.df = self.df.loc[ self.df['label'] != 'Other']
        
        self.label2id = {
            'Frustration':0,            'Anger':1,            'Excited':2,            'Neutral':3,
            'Happiness':4,            'Sadness':5,            'Fear':6,            'Surprise':7,
        }

        self.id2label = {
            0: 'Frustration',            1: 'Anger',            2: 'Excited',            3: 'Neutral',
            4: 'Happiness',            5: 'Sadness',            6: 'Fear',            7: 'Surprise',
        }
        
        self.columns_to_extract = [
            'X15', 'Y15', 'Z15', 'X16', 'Y16', 'Z16', 'X18', 'Y18', 'Z18', 'X07',
            'Y07', 'Z07', 'X08', 'Y08', 'Z08', 'X10', 'Y10', 'Z10', 'X01', 'Y01',
            'Z01', 'X02', 'Y02', 'Z02', 'X03', 'Y03', 'Z03', 'X46', 'Y46', 'Z46',
            'X47', 'Y47', 'Z47', 'X48', 'Y48', 'Z48', 'X49', 'Y49', 'Z49', 'X50',
            'Y50', 'Z50', 'X51', 'Y51', 'Z51', 'X52', 'Y52', 'Z52', 'X53', 'Y53',
            'Z53', 'X29', 'Y29', 'Z29', 'X27', 'Y27', 'Z27', 'X28', 'Y28', 'Z28'
            ]
        

        self.processor = audio_processor
        self.duration = duration
        self.sampling_rate = sr
        self.samples = self.duration * self.sampling_rate
    
    def __getitem__(self, index) -> Any:
        file_path = self.df.iloc[index]['filename']
        offset = max(0, self.df.iloc[index]['Time'] - self.duration)
    
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            audio_input ,sr = lb.load(file_path, sr=self.sampling_rate, mono=True, offset=offset, duration=self.duration)
        audio_input = self.pad_or_truncate(audio_input)
        input_values = self.processor(audio_input, sampling_rate=sr, return_tensors='pt').input_values
        labels = torch.tensor(self.label2id[self.df.iloc[index]['label']])
        pose = torch.tensor(self.df.iloc[index][self.columns_to_extract].astype(float).to_numpy(), dtype=torch.float32)
        
        return (input_values.squeeze(), pose, labels)
    
    def __len__(self):
        return len(self.df)

    def pad_or_truncate(self, arr):
        if len(arr) > self.samples:
            arr = arr[:self.samples]
        else:
            arr = np.pad(arr, (0, self.samples - len(arr)), 'constant', constant_values=(0))
        return arr

def tensor_data_to_dataloader(dataset: TensorDataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size = batch_size,
        sampler=RandomSampler(dataset)
    )

def build_dataset(pickle_path: str, audio_processor, batch_size=32, sampling_frac=0.5):
    data = pd.read_pickle(pickle_path)
    data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))
    data = tensor_data_to_dataloader(AudioLandmarkDataset(data, audio_processor), batch_size=batch_size)
    return data



  data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))
  data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))


In [25]:
class MLPLandmark(nn.Module):
    def __init__(self, inp_dim: int, out_dim: int, layers: list):
        super().__init__()
        
        self.inp_dim = inp_dim
        
        self.mlp = nn.ModuleList()
        
        self.mlp.append(nn.Linear(self.inp_dim, layers[0]))
        
        for i in range(1, len(layers)):
            self.mlp.append(
                nn.Sequential(
                    nn.Linear(layers[i-1], layers[i]),
                    nn.ReLU(),
                    nn.LayerNorm(layers[i])
                )
            )
        
        self.mlp.append(nn.Linear(layers[-1], out_dim))
        
    def forward(self, x: torch.Tensor):
        x = x.view(x.shape[0], self.inp_dim)
        
        for layer in self.mlp:
            x = layer(x)
        
        return x
    
class LandmarkAudioModel(nn.Module):
    def __init__(self,
                 inp_dim_landmark : int,
                 feat_dim: int = 256,
                 layers: List[int] = [128, 256, 512],
                 audio_model_path: str = "facebook/wav2vec2-base-960h",
                 num_labels: int = 8,
                 criterion=nn.CrossEntropyLoss()):
        super().__init__()
        self.audio_model_path = audio_model_path
        self.num_labels = num_labels

        self.landmark_model = MLPLandmark(inp_dim_landmark , feat_dim, layers)
        
        self.audio_model = wave2vec = Wav2Vec2ForSequenceClassification.\
                                        from_pretrained(audio_model_path, num_labels=feat_dim)
                                        
        self.fc = nn.Linear(2*feat_dim, num_labels)

        self.criterion = criterion

    def forward(self, audio_input, landmarks, labels=None):
        audio = self.audio_model(audio_input).logits
        text = self.landmark_model(landmarks)

        out = self.fc(torch.cat((audio, text), dim=1))
        print(out.shape)
        if labels is None:
            loss = None
        else:
            loss = self.criterion(out, labels)

        return {'logits': out, 'loss': loss}

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
class TrainingArgs:
    def __init__(self, device='cuda', learning_rate=2e-5, epsilon=1e-8, epochs=4) -> None:
        self.device = device
        self.lr = learning_rate
        self.epsilon = epsilon
        self.epochs = epochs
        self.warmup_steps = 0
        self.seed = 1024

    
class TrainerWave2Vec:
    
    def __init__(self,
                 model: nn.Module,
                 trainloader: DataLoader,
                 testloader: DataLoader,
                 out_dir: str,
                 args: TrainingArgs) -> None:
        
        self.args = args
        self.num_epochs = self.args.epochs
        self.device = self.args.device
        self.out_dir = out_dir
        
        self.train_data = trainloader
        self.val_data = testloader

        self.best_acc = 0
        
        self.training_stats = []

        self.model = model.to(self.device)
        self.optimizer = AdamW(model.parameters(), lr=self.args.lr, eps=self.args.epsilon)
        self.schedular = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps= self.args.warmup_steps, num_training_steps=len(trainloader) * self.args.epochs            
        )
        
        self.fix_seeds(self.args.seed)
    
    def fix_seeds(self, seed_val):
        random.seed(seed_val)
        np.random.seed(seed_val)
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        
    def validate(self):
        print("")
        print("Running Validation...")
        t0 = time.time()
        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        self.model.eval()
        # Tracking variables 
        total_eval_loss = 0
        
        preds = []
        labels = []
        
        for _, batch in enumerate(tqdm(self.val_data)):
            b_audio = batch[0].to(self.device)
            b_landmarks = batch[1].to(self.device)
            b_labels = batch[2].to(self.device)

            with torch.no_grad():        
                output= self.model(b_audio, b_landmarks, 
                                    labels=b_labels)
            loss = output['loss']
            total_eval_loss += loss.item()
            # Move logits and labels to CPU if we are using GPU
            logits = output['logits']
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.cpu().numpy()
            
            preds.extend(np.argmax(logits, axis=1))
            labels.extend(label_ids)

        # Report the final accuracy for this validation run.
        preds, labels = np.array(preds), np.array(labels)
        avg_val_accuracy = f1_score(preds.flatten(), labels.flatten(), average='weighted')# total_eval_accuracy / len(dataloader)
        print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(self.val_data)

        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        if avg_val_accuracy > self.best_acc:
            torch.save(self.model.state_dict(), os.path.join(self.out_dir, 'wave2vec.pt'))
            self.best_acc = avg_val_accuracy

        return (avg_val_accuracy, avg_val_loss, validation_time)
    
    def train_epoch(self, epoch_i):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, self.num_epochs))
        print('Training...')
        t0 = time.time()
        total_train_loss = 0
        self.model.train()
        for step, batch in enumerate(tqdm(self.train_data)):
        
            b_audio = batch[0].to(self.device)
            b_landmarks = batch[1].to(self.device)
            
            b_labels = batch[1].to(self.device)
            self.optimizer.zero_grad()
            
            print(b_audio.shape, b_landmarks.shape, b_labels.shape)
            output = self.model(b_audio, b_landmarks, 
                                labels=b_labels)  
            print(output.shape)      
            loss = output['loss']
            total_train_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
    
            self.optimizer.step()
            self.schedular.step()
            

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(self.train_data)            
        
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)
        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(training_time))    
        return (avg_train_loss, training_time)
    
    def train(self):
        for epoch in range(self.num_epochs):
            train_loss, train_time = self.train_epoch(epoch)
            val_acc, val_loss, val_time = self.validate()
            self.training_stats.append(
                {
                    'epoch': epoch + 1,
                    'Training Loss': train_loss,
                    'Valid. Loss': val_loss,
                    'Valid. Accur.': val_acc,
                    'Training Time': train_time,
                    'Validation Time': val_time
                }
            )
            

In [39]:

batch_size = 32
trainloader = build_dataset(save_train, processor, batch_size=batch_size)
testloader = build_dataset(save_test, processor, batch_size=batch_size)
model = LandmarkAudioModel(inp_dim_landmark=60)

  data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))
  data = data.groupby('label', group_keys=False).apply(lambda x: x.sample(frac=sampling_frac))
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
trainer = TrainerWave2Vec(
    model=wave2vec,
    trainloader=trainloader,
    testloader=testloader,
    out_dir='./saved_weights_iemocap_audio_text',
    args=TrainingArgs()

)

In [44]:
trainer.train()


Training...


  0%|          | 0/7651 [00:00<?, ?it/s]

torch.Size([32, 32000]) torch.Size([32, 60]) torch.Size([32, 60])


  0%|          | 0/7651 [00:00<?, ?it/s]


ValueError: Expected input batch_size (32) to match target batch_size (1920).