In [1]:
import torch
import torchaudio
import numpy as np
import torch.nn as nn

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os

class VoiceDataset(Dataset):
    def __init__(self, data_dir, sample_rate=16000, num_mel_bins=40):
        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.num_mel_bins = num_mel_bins
        self.file_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        # Загрузка аудиофайла
        waveform, original_sample_rate = torchaudio.load(self.file_paths[idx])
        
        # Преобразование частоты дискретизации
        if original_sample_rate != self.sample_rate:
            waveform = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=self.sample_rate)(waveform)
        
        # Преобразование в мел-спектрограмму
        mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=self.num_mel_bins
        )(waveform)
        
        # Метка класса (можно задать 1 для вашего голоса)
        label = 1  # Например, 1 — для вашего голоса
        
        return mel_spectrogram, label

# Создаем датасет и загрузчик данных
data_dir = "data/my_voice"
voice_dataset = VoiceDataset(data_dir)
data_loader = DataLoader(voice_dataset, batch_size=4, shuffle=True)


In [None]:
SAMPLE_RATE = 16000
NUM_MEL_BINS = 40

class VGGVoxModel(nn.Module):
    def __init__(self):
        super(VGGVoxModel, self).__init__()
        pass
    
    def forward(self, x):
        pass


def preprocess_audio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    if sample_rate != SAMPLE_RATE:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SAMPLE_RATE)(waveform)
    
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=NUM_MEL_BINS
    )(waveform)
    
    mel_spectrogram = mel_spectrogram.unsqueeze(0)
    return mel_spectrogram


def infer(model, audio_path):
    model.eval()
    with torch.no_grad():
        input_data = preprocess_audio(audio_path)
        
        output = model(input_data)
        
        probabilities = torch.softmax(output, dim=1)
        predicted_class = probabilities.argmax(dim=1).item()
        
        return predicted_class, probabilities


model = VGGVoxModel()
model.load_state_dict(torch.load("vggvox_model.pth"))

audio_path = "path_to_audio_file.wav"
predicted_class, probabilities = infer(model, audio_path)

print("Predicted Class:", predicted_class)
print("Probabilities:", probabilities)
