In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torchaudio librosa soundfile matplotlib tqdm scikit-learn pesq encodec

Collecting pesq
  Downloading pesq-0.0.4.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting encodec
  Downloading encodec-0.1.1.tar.gz (3.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m [31m78.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-man

In [36]:
import os
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
from encodec.model import EncodecModel
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from pesq import pesq
from sklearn.metrics import mean_squared_error, mean_absolute_error
import librosa

In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ Используем: {device}")

✅ Используем: cuda


In [38]:
class AudioFragmentsDataset(Dataset):
    def __init__(self, directory, sample_rate=24000):
        self.paths = []
        for root, _, files in os.walk(directory):
            for f in files:
                if f.endswith('.wav'):
                    self.paths.append(os.path.join(root, f))
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        waveform, sr = torchaudio.load(path)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        return waveform.squeeze(0)


In [39]:
model_path = '/content/drive/MyDrive/models/encodec_24khz.th'

model = EncodecModel.encodec_model_24khz()
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)

  WeightNorm.apply(module, name, dim)


In [40]:
train_dataset = AudioFragmentsDataset('/content/drive/MyDrive/dataset/train')
test_dataset = AudioFragmentsDataset('/content/drive/MyDrive/dataset/test')

CHECKPOINT_DIR = "/content/drive/MyDrive/checkpoints"
LOG_DIR = "/content/drive/MyDrive/logs"

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [41]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 100

train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

metrics_log = {
    'epoch': [],
    'train_loss': [],
    'test_loss': [],
    'mae': [],
    'mse': [],
    'pesq': [],
    'train_accuracy': [],
    'test_accuracy': []
}

early_stopping_patience = 5
best_loss = float('inf')
no_improve_counter = 0

In [42]:
def evaluate_metrics(model, dataloader, n_batches=5):
    model.eval()
    maes, mses, pesqs = [], [], []
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= n_batches:
                break
            batch = batch.to(device).unsqueeze(1)
            encoded = model.encode(batch)
            decoded = model.decode(encoded)

            for orig, recon in zip(batch, decoded):
                orig_np = orig.squeeze().cpu().numpy()
                recon_np = recon.squeeze().cpu().numpy()

                maes.append(mean_absolute_error(orig_np, recon_np))
                mses.append(mean_squared_error(orig_np, recon_np))
                try:
                    orig_resampled = librosa.resample(orig_np, orig_sr=24000, target_sr=16000)
                    recon_resampled = librosa.resample(recon_np, orig_sr=24000, target_sr=16000)

                    min_len = min(len(orig_resampled), len(recon_resampled))
                    orig_resampled = orig_resampled[:min_len]
                    recon_resampled = recon_resampled[:min_len]

                    pesq_score = pesq(16000, orig_resampled, recon_resampled, 'wb')
                    pesqs.append(pesq_score)
                except Exception as e:
                    print(f"[⚠️ PESQ error] {e}")
                    pesqs.append(np.nan)

    return np.mean(maes), np.mean(mses), np.nanmean(pesqs)

def compute_loss(original, reconstructed):
    return torch.nn.functional.mse_loss(reconstructed, original)

In [43]:
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"🧪 Обучение: эпоха {epoch+1}"):
        batch = batch.to(device)
        batch = batch.to(device).unsqueeze(1)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
            encoded = model.encode(batch)
            decoded = model.decode(encoded)
            loss = compute_loss(batch, decoded)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    train_accuracy = 1 - train_loss
    train_accuracies.append(train_accuracy)

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device).unsqueeze(1)
            encoded = model.encode(batch)
            decoded = model.decode(encoded)
            loss = compute_loss(batch, decoded)
            test_loss += loss.item()
    test_loss /= len(test_loader)
    test_losses.append(test_loss)

    test_accuracy = 1 - test_loss
    test_accuracies.append(test_accuracy)

    if test_loss < best_loss - 1e-5:
        best_loss = test_loss
        no_improve_counter = 0
    else:
        no_improve_counter += 1
        if no_improve_counter >= early_stopping_patience:
            print(f"🛑 Остановка обучения: test loss не улучшается {early_stopping_patience} эпох подряд.")
            break

    mae, mse, pesq_score = evaluate_metrics(model, test_loader, n_batches=5)
    metrics_log['epoch'].append(epoch+1)
    metrics_log['train_loss'].append(train_loss)
    metrics_log['test_loss'].append(test_loss)
    metrics_log['train_accuracy'].append(train_accuracy)
    metrics_log['test_accuracy'].append(test_accuracy)
    metrics_log['mae'].append(mae)
    metrics_log['mse'].append(mse)
    metrics_log['pesq'].append(pesq_score)

    print(f"📊 Epoch {epoch+1}: Train Loss = {train_loss:.4f} | Test Loss = {test_loss:.4f} | "
          f"Train Acc = {train_accuracy:.4f} | Test Acc = {test_accuracy:.4f} | "
          f"MAE = {mae:.6f} | MSE = {mse:.6f} | PESQ = {pesq_score:.3f}")

    ckpt_path = os.path.join(CHECKPOINT_DIR, f"epoch_{epoch+1}.pth")
    torch.save(model.state_dict(), ckpt_path)

    plt.figure(figsize=(10, 5))
    plt.plot(metrics_log['epoch'], metrics_log['train_loss'], label='Train Loss')
    plt.plot(metrics_log['epoch'], metrics_log['test_loss'], label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train/Test Loss')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(LOG_DIR, f"loss_plot_epoch_{epoch+1}.png"))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(metrics_log['epoch'], metrics_log['train_accuracy'], label='Train Accuracy')
    plt.plot(metrics_log['epoch'], metrics_log['test_accuracy'], label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train/Test Accuracy')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(LOG_DIR, f"accuracy_plot_epoch_{epoch+1}.png"))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(metrics_log['epoch'], metrics_log['mae'], label='MAE', color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.title('Mean Absolute Error (MAE)')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(LOG_DIR, f"mae_plot_epoch_{epoch+1}.png"))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(metrics_log['epoch'], metrics_log['mse'], label='MSE', color='green')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.title('Mean Squared Error (MSE)')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(LOG_DIR, f"mse_plot_epoch_{epoch+1}.png"))
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(metrics_log['epoch'], metrics_log['pesq'], label='PESQ', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('PESQ')
    plt.title('Perceptual Evaluation of Speech Quality (PESQ)')
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(LOG_DIR, f"pesq_plot_epoch_{epoch+1}.png"))
    plt.close()

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 1: 100%|██████████| 903/903 [1:14:04<00:00,  4.92s/it]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 1: Train Loss = 0.0024 | Test Loss = 0.0022 | Train Acc = 0.9976 | Test Acc = 0.9978 | MAE = 0.007266 | MSE = 0.000225 | PESQ = 3.158


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 2: 100%|██████████| 903/903 [03:16<00:00,  4.60it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 2: Train Loss = 0.0023 | Test Loss = 0.0022 | Train Acc = 0.9977 | Test Acc = 0.9978 | MAE = 0.007168 | MSE = 0.000219 | PESQ = 3.084


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 3: 100%|██████████| 903/903 [03:15<00:00,  4.62it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 3: Train Loss = 0.0023 | Test Loss = 0.0022 | Train Acc = 0.9977 | Test Acc = 0.9978 | MAE = 0.007094 | MSE = 0.000215 | PESQ = 3.017


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 4: 100%|██████████| 903/903 [03:15<00:00,  4.61it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 4: Train Loss = 0.0022 | Test Loss = 0.0022 | Train Acc = 0.9978 | Test Acc = 0.9978 | MAE = 0.007066 | MSE = 0.000212 | PESQ = 2.942


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 5: 100%|██████████| 903/903 [03:13<00:00,  4.66it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 5: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.007026 | MSE = 0.000210 | PESQ = 2.929


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 6: 100%|██████████| 903/903 [03:13<00:00,  4.66it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 6: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.007007 | MSE = 0.000209 | PESQ = 2.902


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 7: 100%|██████████| 903/903 [03:14<00:00,  4.64it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 7: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.006983 | MSE = 0.000207 | PESQ = 2.863


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 8: 100%|██████████| 903/903 [03:14<00:00,  4.65it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 8: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.006954 | MSE = 0.000205 | PESQ = 2.842


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 9: 100%|██████████| 903/903 [03:16<00:00,  4.59it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 9: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.006959 | MSE = 0.000205 | PESQ = 2.826


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 10: 100%|██████████| 903/903 [03:14<00:00,  4.63it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 10: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.006925 | MSE = 0.000203 | PESQ = 2.813


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 11: 100%|██████████| 903/903 [03:13<00:00,  4.66it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 11: Train Loss = 0.0022 | Test Loss = 0.0021 | Train Acc = 0.9978 | Test Acc = 0.9979 | MAE = 0.006915 | MSE = 0.000202 | PESQ = 2.800


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 12: 100%|██████████| 903/903 [03:15<00:00,  4.62it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 12: Train Loss = 0.0021 | Test Loss = 0.0021 | Train Acc = 0.9979 | Test Acc = 0.9979 | MAE = 0.006903 | MSE = 0.000201 | PESQ = 2.779


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 13: 100%|██████████| 903/903 [03:16<00:00,  4.60it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 13: Train Loss = 0.0021 | Test Loss = 0.0021 | Train Acc = 0.9979 | Test Acc = 0.9979 | MAE = 0.006919 | MSE = 0.000201 | PESQ = 2.748


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 14: 100%|██████████| 903/903 [03:16<00:00,  4.59it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 14: Train Loss = 0.0021 | Test Loss = 0.0021 | Train Acc = 0.9979 | Test Acc = 0.9979 | MAE = 0.006885 | MSE = 0.000200 | PESQ = 2.751


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 15: 100%|██████████| 903/903 [03:12<00:00,  4.70it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 15: Train Loss = 0.0021 | Test Loss = 0.0021 | Train Acc = 0.9979 | Test Acc = 0.9979 | MAE = 0.006883 | MSE = 0.000199 | PESQ = 2.742


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 16: 100%|██████████| 903/903 [03:15<00:00,  4.63it/s]


[⚠️ PESQ error] b'No utterances detected'
📊 Epoch 16: Train Loss = 0.0021 | Test Loss = 0.0021 | Train Acc = 0.9979 | Test Acc = 0.9979 | MAE = 0.006875 | MSE = 0.000199 | PESQ = 2.732


  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
🧪 Обучение: эпоха 17: 100%|██████████| 903/903 [03:15<00:00,  4.63it/s]


🛑 Остановка обучения: test loss не улучшается 5 эпох подряд.


In [1]:
plt.figure(figsize=(10, 5))
plt.plot(metrics_log['epoch'], metrics_log['train_loss'], label='Train Loss')
plt.plot(metrics_log['epoch'], metrics_log['test_loss'], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Final Train/Test Loss')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(metrics_log['epoch'], metrics_log['train_accuracy'], label='Train Accuracy')
plt.plot(metrics_log['epoch'], metrics_log['test_accuracy'], label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Final Train/Test Accuracy')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(metrics_log['epoch'], metrics_log['mae'], label='MAE')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.title('Mean Absolute Error (MAE)')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(metrics_log['epoch'], metrics_log['mse'], label='MSE')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.title('Mean Squared Error (MSE)')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(metrics_log['epoch'], metrics_log['pesq'], label='PESQ')
plt.xlabel('Epoch')
plt.ylabel('PESQ')
plt.title('Perceptual Evaluation of Speech Quality (PESQ)')
plt.legend()
plt.grid()
plt.show()

NameError: name 'plt' is not defined