In [1]:
import numpy as np
import torch.optim as optim
import numpy as np
import pickle
from torch.utils.data import DataLoader, Dataset
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift


from typing import Optional
import os
import math
import gc
import tarfile
import numpy as np
import pandas as pd
import scipy
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from io import BytesIO
from urllib.request import urlopen
from scipy.io import wavfile
from scipy.stats import pearsonr
from scipy.signal import resample
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold, ParameterGrid
from sklearn.utils import shuffle
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    Wav2Vec2ForCTC, 
    Wav2Vec2Processor, 
    Wav2Vec2Model, 
    Wav2Vec2PreTrainedModel, 
    Wav2Vec2Config, 
    WavLMModel, 
    WavLMConfig, 
    HubertModel, 
    HubertConfig, 
    HubertPreTrainedModel
)
## local inports 
from models import *
from utils import *
from losses import *
from dataset import *
import json



class Trainer:
    def __init__(self, config, model_classes, criterion, device, log_dir, model_save_dir, bert_config):
        self.bert_config = bert_config
        self.config = config
        self.model_classes = model_classes
        self.criterion = criterion
        self.device = device
        self.writer = SummaryWriter(log_dir)
        self.model_save_dir = model_save_dir
        os.makedirs(model_save_dir, exist_ok=True)
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")

    def train(self, train_data, val_data, test_data, epochs, batch_size, patience):
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=1, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

        results = {}

        for model_name, model_class in self.model_classes.items():
            print(f"Training {model_name}...")
            model_config = self.config[model_name]
            model_config['output_size'] = train_data.get_output_shape()
            model = model_class(bert_config = self.bert_config,config = model_config).to(self.device)
            
            optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-5)
            early_stopping = EarlyStopping(patience=patience, mode='min')

            best_val_loss = float('inf')

            for epoch in range(epochs):
                train_loss, train_acc = self._train_epoch(model, train_loader, optimizer, scheduler, epoch, epochs)
                val_loss, val_acc, val_acc_flat = self._evaluate(model, val_loader, val_data)
                test_loss, test_acc, test_acc_flat = self._evaluate(model, test_loader, test_data)

                self._log_metrics(model_name, epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc, val_acc_flat, test_acc_flat)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self._save_model(model, f"{model_name}_best_model_epoch_{epoch}.pt")

                if early_stopping(val_loss):
                    print(f"Early stopping triggered for {model_name} at epoch {epoch+1}")
                    break

            results[model_name] = {
                'val_loss': best_val_loss,
                'test_loss': test_loss,
                'test_acc': test_acc,
                'test_acc_flat': test_acc_flat
            }

        self.writer.close()
        return results

    def _train_epoch(self, model, dataloader, optimizer, scheduler, epoch, total_epochs):
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        
        progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{total_epochs}")
        
        for batch_idx, (input_values, labels, _) in enumerate(progress_bar):
            optimizer.zero_grad()

            input_values = self.processor(input_values, return_tensors="pt", padding="longest", sampling_rate = 16000).input_values
            input_values = input_values.reshape(input_values.shape[0], input_values.shape[-1])
            print(input_values.shape)
            input_values, labels = input_values.to(self.device), labels.to(self.device)
            predictions = model(input_values)
            loss = self.criterion(predictions.float(), labels.float())
            
            loss.backward()
            optimizer.step()
            scheduler.step(epoch + batch_idx / len(dataloader))
            
            total_loss += loss.item()
            total_acc += 1.0 - loss.item()  # Assuming accuracy is 1 - loss for this task
            
            progress_bar.set_description(f"Training Epoch {epoch+1}/{total_epochs}, Avg Loss: {total_loss/(batch_idx+1):.4f}, Acc: {total_acc/(batch_idx+1):.4f}")
            
            del input_values, labels, predictions, loss
            torch.cuda.empty_cache()
        
        return total_loss / len(dataloader), total_acc / len(dataloader)

    def _evaluate(self, model, dataloader, dataset):
        model.eval()
        total_loss = 0.0
        total_acc = 0.0
        total_acc_flat = 0.0
        
        with torch.no_grad():
            for input_values, labels, ground_truth_names in dataloader:
                input_values = input_values.to(self.device)
                
                ground_truth_labels = self._get_ground_truth_labels(ground_truth_names, dataset)
                
                predictions = self._process_sequences(model, input_values)
                loss = self.criterion(predictions.to("cpu"), labels)
                
                total_loss += loss.item()
                total_acc += 1.0 - loss.item() 
                
                average = self._unsplit_data_ogsize(predictions.cpu().numpy(), dataset.window_size, dataset.step_size, dataset.data_points_per_second, ground_truth_labels.shape[-1])
                total_acc_flat += self._calculate_flattened_accuracy(average, ground_truth_labels)
                
                del input_values, labels, predictions, loss
                torch.cuda.empty_cache()
        
        num_samples = len(dataloader.dataset)
        return total_loss / num_samples, total_acc / num_samples, total_acc_flat / num_samples

    def _get_ground_truth_labels(self, ground_truth_names, dataset):
        ground_truth_labels = []
        for batch_name in ground_truth_names:
            ground_truth_label = dataset.choose_real_labs_only_with_filenames([batch_name])
            ground_truth_labels.append(ground_truth_label)
        return np.array(ground_truth_labels)[:, :, -1].astype(np.float32)

    def _process_sequences(self, model, input_values):
        predictions = []
        for i in range(input_values.size(1)):
            input_slice = input_values[:, i, :]
            pred = model(input_slice.float())
            predictions.append(pred)
        return torch.stack(predictions, dim=1)

    def _calculate_flattened_accuracy(self, average, ground_truth_labels):
        s_acc = 0
        for b in range(len(ground_truth_labels)):
            s, _ = scipy.stats.pearsonr(average[b], ground_truth_labels[b])
            s_acc += s
        return s_acc / len(ground_truth_labels)

    def _unsplit_data_ogsize(windowed_data, window_size, step_size, data_points_per_second, original_length):
        batch_size, num_windows, prediction_size = windowed_data.shape
        window_size_points = window_size * data_points_per_second
        step_size_points = step_size * data_points_per_second
        original_data = np.zeros((batch_size, original_length))
        overlap_count = np.zeros((batch_size, original_length))
        
        for b in range(batch_size):
            for i in range(num_windows):
                start = i * step_size_points
                end = start + window_size_points
                if end > original_length:
                    end = original_length
                segment_length = end - start
                original_data[b, start:end] += windowed_data[b, i, :segment_length]
                overlap_count[b, start:end] += 1
        
        # Average the overlapping regions
        original_data = np.divide(original_data, overlap_count, where=overlap_count != 0)
        
        # Trim the data to match the original length
        original_data = original_data[:, :original_length]
        
        return original_data
    
    def _log_metrics(self, model_name, epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc, val_acc_flat, test_acc_flat):
        self.writer.add_scalar(f'{model_name}/Loss/train', train_loss, epoch)
        self.writer.add_scalar(f'{model_name}/Loss/val', val_loss, epoch)
        self.writer.add_scalar(f'{model_name}/Loss/test', test_loss, epoch)
        self.writer.add_scalar(f'{model_name}/Accuracy/train', train_acc, epoch)
        self.writer.add_scalar(f'{model_name}/Accuracy/val', val_acc, epoch)
        self.writer.add_scalar(f'{model_name}/Accuracy/test', test_acc, epoch)
        self.writer.add_scalar(f'{model_name}/Accuracy/val_flat', val_acc_flat, epoch)
        self.writer.add_scalar(f'{model_name}/Accuracy/test_flat', test_acc_flat, epoch)

    def _save_model(self, model, filename):
        path = os.path.join(self.model_save_dir, filename)
        torch.save(model.state_dict(), path)
        print(f"Saved model to {path}")


class EarlyStopping:
    def __init__(self, patience=7, mode='min', delta=0):
        self.patience = patience
        self.counter = 0
        self.mode = mode
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, score):
        if self.mode == 'min':
            score = -score
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0
        return self.early_stop

def prepare_data_model(audio_interspeech_norm, breath_interspeech_folder, window_size, step_size, fold):
    # Load and prepare data
    train_data, train_labels, train_dict, frame_rate = load_data(audio_interspeech_norm, breath_interspeech_folder, 'train')
    devel_data, devel_labels, devel_dict, _ = load_data(audio_interspeech_norm, breath_interspeech_folder, 'devel')
    test_data, test_labels, test_dict, _ = load_data(audio_interspeech_norm, breath_interspeech_folder, 'test')
    
    # Prepare data
    prepared_train_data, prepared_train_labels, _ = prepare_data(train_data, train_labels, train_dict, frame_rate, window_size * 16000, step_size * 16000)
    prepared_devel_data, prepared_devel_labels, _ = prepare_data(devel_data, devel_labels, devel_dict, frame_rate, window_size * 16000, step_size * 16000)
    prepared_test_data, prepared_test_labels, _= prepare_data(test_data, test_labels, test_dict, frame_rate, window_size * 16000, step_size * 16000)

    # Create custom datasets
    train_dataset = CustomDataset(prepared_train_data, prepared_train_labels, train_dict)
    val_dataset = CustomDataset(prepared_devel_data, prepared_devel_labels, devel_dict)
    test_dataset = CustomDataset(prepared_test_data, prepared_test_labels, test_dict)
    train_dataset.print_shapes()
    val_dataset.print_shapes()
    number_of_test_samples = int((len(val_dataset) + len(train_dataset))/ fold)
    print(number_of_test_samples)
    new_val_dataset_item = val_dataset.pop_first_n(number_of_test_samples)
    new_val_dataset = CustomDataset(new_val_dataset_item[0], new_val_dataset_item[1], new_val_dataset_item[2])
    new_val_dataset.print_shapes()

    combined_train_data = np.concatenate((train_dataset.data, val_dataset.data), axis=0)
    combined_train_labels = np.concatenate((train_dataset.labels, val_dataset.labels), axis=0)
    combined_train_dict = np.concatenate((train_dataset.name, val_dataset.name), axis=0)
    combined_train_data, combined_train_labels = flatten_data_for_model(combined_train_data, combined_train_labels)
    combined_train_dataset = CustomDataset(combined_train_data, combined_train_labels, [])

    return combined_train_dataset, new_val_dataset, test_dataset

if __name__ == "__main__":
    bert_config = HubertConfig.from_pretrained("facebook/hubert-base-ls960")

    
    config = {
        "VRBModel": {
            "model_name": "facebook/hubert-large-ls960",
            "hidden_units": 64,
            "n_gru": 3,
            "output_size": None  
        },
        "Wav2Vec2ConvLSTMModel": {
            "model_name": "facebook/wav2vec2-base",
            "hidden_units": 128,
            "n_lstm": 2,
            "output_size": None 
        },
        "RespBertLSTMModel": {
            "bert_config": "wav2vec2",
            "hidden_units": 128,
            "n_lstm": 2,
            "output": None  
        },
        "RespBertAttionModel": {
            "bert_config": "hubert",
            "hidden_units": 128,
            "n_attion": 1,
            "output": None  
        }
    }
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = PearsonLoss()
    
    model_classes = {
        "Wav2Vec2ConvLSTMModel": Wav2Vec2ConvLSTMModel,
        "VRBModel": VRBModel,
        "RespBertLSTMModel": RespBertLSTMModel,
        "RespBertAttionModel": RespBertAttionModel,
    }
    
    # Prepare data
    train_data, val_data, test_data = prepare_data_model(
        "/home/glenn/Downloads/ComParE2020_Breathing/wav/",
        "/home/glenn/Downloads/ComParE2020_Breathing/lab/",
        window_size=32,
        step_size=4,
        fold=5
    )
    
    trainer = Trainer(config, model_classes, criterion, device, "logs", "models", bert_config)

    # Train all models
    results = trainer.train(train_data, val_data, test_data, epochs=200, batch_size=1, patience=50)

    # Print results
    for model_name, model_results in results.items():
        print(f"\nResults for {model_name}:")
        for metric, value in model_results.items():
            print(f"{metric}: {value}")

2024-09-19 17:43:13.847513: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-19 17:43:13.864350: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-19 17:43:13.869410: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-19 17:43:13.881348: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


data : (17, 53, 512000), labels: (17, 53, 800), names : (17,)
data : (16, 53, 512000), labels: (16, 53, 800), names : (16,)
6
data : (6, 53, 512000), labels: (6, 53, 800), names : (6,)




Training Wav2Vec2ConvLSTMModel...


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training Epoch 1/200:   0%|          | 0/1431 [00:00<?, ?it/s]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training Epoch 1/200, Avg Loss: 0.9964, Acc: 0.0036:   0%|          | 1/1431 [00:03<1:27:06,  3.66s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9922, Acc: 0.0078:   0%|          | 2/1431 [00:06<1:22:39,  3.47s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0050, Acc: -0.0050:   0%|          | 3/1431 [00:10<1:21:27,  3.42s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9886, Acc: 0.0114:   0%|          | 4/1431 [00:13<1:20:43,  3.39s/it] It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9945, Acc: 0.0055:   0%|          | 5/1431 [00:17<1:20:18,  3.38s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9979, Acc: 0.0021:   0%|          | 6/1431 [00:20<1:20:01,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9933, Acc: 0.0067:   0%|          | 7/1431 [00:23<1:19:51,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9904, Acc: 0.0096:   1%|          | 8/1431 [00:27<1:19:49,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9969, Acc: 0.0031:   1%|          | 9/1431 [00:30<1:19:53,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9981, Acc: 0.0019:   1%|          | 10/1431 [00:33<1:19:52,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 0.9993, Acc: 0.0007:   1%|          | 11/1431 [00:37<1:19:47,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0056, Acc: -0.0056:   1%|          | 12/1431 [00:40<1:19:40,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0043, Acc: -0.0043:   1%|          | 13/1431 [00:43<1:19:28,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0033, Acc: -0.0033:   1%|          | 14/1431 [00:47<1:19:21,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0056, Acc: -0.0056:   1%|          | 15/1431 [00:50<1:19:14,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0046, Acc: -0.0046:   1%|          | 16/1431 [00:54<1:19:11,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0071, Acc: -0.0071:   1%|          | 17/1431 [00:57<1:19:03,  3.35s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0091, Acc: -0.0091:   1%|▏         | 18/1431 [01:00<1:18:57,  3.35s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0112, Acc: -0.0112:   1%|▏         | 19/1431 [01:04<1:18:56,  3.35s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0116, Acc: -0.0116:   1%|▏         | 20/1431 [01:07<1:18:57,  3.36s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0097, Acc: -0.0097:   1%|▏         | 21/1431 [01:10<1:19:05,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


Training Epoch 1/200, Avg Loss: 1.0102, Acc: -0.0102:   2%|▏         | 22/1431 [01:14<1:19:08,  3.37s/it]It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


torch.Size([1, 512000])
Shape after wav2vec2: torch.Size([1, 1599, 768])
Shape after permute: torch.Size([1, 768, 1599])
Shape after conv: torch.Size([1, 768, 1599])
Shape after relu: torch.Size([1, 768, 1599])
Shape after second permute: torch.Size([1, 1599, 768])
Shape after lstm: torch.Size([1, 1599, 128])
Shape after selecting last time step: torch.Size([1, 128])
Shape after embedding: torch.Size([1, 128])
Shape after output: torch.Size([1, 800])
Shape after tanh: torch.Size([1, 800])
Shape after flatten: torch.Size([1, 800])


In [None]:
import torch
torch.cuda.is_available()
print(torch.cuda.get_device_name())

NVIDIA GeForce GTX 1050 Ti with Max-Q Design
