In [1]:
%%capture
!pip install torchaudio --quiet
!pip install soundfile --quiet
!pip install pydub --quiet
!apt-get install -y ffmpeg > /dev/null 2>&1


In [2]:
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
import tarfile
import json
import os
import io
import soundfile as sf
import tempfile
import subprocess
from tqdm.notebook import tqdm

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
TRAIN_TAR = "/content/drive/MyDrive/hw3_speech/train_data.tar"
TEST_TAR  = "/content/drive/MyDrive/hw3_speech/test_data.tar"
BOUNDS_JSON = "/content/drive/MyDrive/hw3_speech/word_bounds.json"

In [4]:
SAMPLE_RATE = 16000
WINDOW_SIZE = 1.5
HOP_SIZE = 0.5
N_MELS = 40
EPS = 1e-9

BATCH_SIZE = 512
EPOCHS = 1
LR = 3e-4


In [5]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [6]:
import soundfile as sf

def load_opus_from_bytes(audio_bytes):
    with tempfile.NamedTemporaryFile(suffix=".opus", delete=False) as f:
        f.write(audio_bytes)
        opus_path = f.name

    wav_path = opus_path.replace(".opus", ".wav")

    try:
        subprocess.run(
            [
                "ffmpeg", "-y", "-loglevel", "error",
                "-i", opus_path,
                "-ar", str(SAMPLE_RATE),
                "-ac", "1",
                wav_path
            ],
            check=True
        )

        wav, sr = sf.read(wav_path, dtype="float32")
        wav = torch.from_numpy(wav).unsqueeze(0)

    finally:
        if os.path.exists(opus_path):
            os.remove(opus_path)
        if os.path.exists(wav_path):
            os.remove(wav_path)

    return wav, sr


In [7]:
class KWSDataset(Dataset):
    def __init__(self, tar_path, bounds_json):
        self.tar = tarfile.open(tar_path, "r")

        self.members = [
            m for m in self.tar.getmembers()
            if m.name.endswith(".opus") and not os.path.basename(m.name).startswith("._")
        ]

        print(f"Найдено {len(self.members)} аудиофайлов")

        with open(bounds_json) as f:
            self.bounds = json.load(f)

        print(f"Загружено {len(self.bounds)} разметок")

        self.mel = T.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=1024,
            n_mels=N_MELS,
            f_max=8000
        )

        self.win_len = int(WINDOW_SIZE * SAMPLE_RATE)
        self.hop_len = int(HOP_SIZE * SAMPLE_RATE)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, idx):
      member = self.members[idx]
      uid = os.path.basename(member.name).replace(".opus", "")

      f = self.tar.extractfile(member)
      audio_bytes = f.read()
      wav, sr = load_opus_from_bytes(audio_bytes)

      if sr != SAMPLE_RATE:
          wav = T.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)(wav)

      wav = wav.mean(0, keepdim=True)  # mono

      segments, labels = [], []
      start_end = self.bounds.get(uid, None)

      if wav.shape[1] < self.win_len:
          pad_size = self.win_len - wav.shape[1]
          wav = torch.nn.functional.pad(wav, (0, pad_size))

      for start in range(0, wav.shape[1] - self.win_len + 1, self.hop_len):
          end = start + self.win_len
          seg = wav[:, start:end]

          label = 0
          if start_end is not None:
              s, e = start_end
              t0, t1 = start / SAMPLE_RATE, end / SAMPLE_RATE
              if not (t1 < s or t0 > e):
                  label = 1

          log_mel = (self.mel(seg) + EPS).log2()
          segments.append(log_mel)
          labels.append(label)

      if len(segments) == 0:
          empty = torch.zeros(1, N_MELS, 47)
          return empty.unsqueeze(0), torch.tensor([0], dtype=torch.long)

      return torch.stack(segments), torch.tensor(labels, dtype=torch.long)


In [8]:
def collate_fn(batch):
    xs, ys = [], []
    for x, y in batch:
        xs.append(x)
        ys.append(y)
    return torch.cat(xs, dim=0), torch.cat(ys, dim=0)

full_ds = KWSDataset(TRAIN_TAR, BOUNDS_JSON)

val_len = int(0.1 * len(full_ds))
train_len = len(full_ds) - val_len

train_ds, val_ds = random_split(full_ds, [train_len, val_len])

print(f"Train: {train_len}, Val: {val_len}")

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    val_ds,
    batch_size=512,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)

Найдено 90000 аудиофайлов
Загружено 45000 разметок
Train: 81000, Val: 9000


### Модель

In [9]:
class TemporalCNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, 3, stride=(2,1), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, 3, stride=(2,1), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.mean(dim=[2, 3])
        return self.fc(x)

model = TemporalCNN().to(DEVICE)
print(model)
print(f"Параметров: {sum(p.numel() for p in model.parameters()):,}")

test_input = torch.randn(8, 1, 40, 47).to(DEVICE)
test_output = model(test_input)
print(f"Test input shape: {test_input.shape}")
print(f"Test output shape: {test_output.shape}")

TemporalCNN(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=64, out_features=2, bias=True)
  )
)
Параметров: 101,506
Test input shape: torch.Size([8, 1, 40, 47])
Test output shape: torch.Size([8, 2])


### Обучение

In [10]:
opt = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

best_val_acc = 0

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    num_batches = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for x, y in pbar:
        x, y = x.to(DEVICE), y.to(DEVICE)

        out = model(x)
        loss = criterion(out, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()
        num_batches += 1
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / num_batches

    model.eval()
    correct = 0
    total = 0

    tp, fp, fn, tn = 0, 0, 0, 0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(1)

            correct += (pred == y).sum().item()
            total += y.numel()

            tp += ((pred == 1) & (y == 1)).sum().item()
            fp += ((pred == 1) & (y == 0)).sum().item()
            fn += ((pred == 0) & (y == 1)).sum().item()
            tn += ((pred == 0) & (y == 0)).sum().item()

    val_acc = correct / total

    num_pos = tp + fn
    num_neg = tn + fp
    frr = fn / num_pos if num_pos > 0 else 0
    far = fp / num_neg if num_neg > 0 else 0

    if (1 - frr) + (1 - far) > 0:
        score = 2 * (1 - frr) * (1 - far) / ((1 - frr) + (1 - far))
    else:
        score = 0

    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Val Acc={val_acc:.4f}, "
          f"FRR={frr:.4f}, FAR={far:.4f}, Score={score:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), '/content/best_model.pt')

print(f"\nBest validation accuracy: {best_val_acc:.4f}")

Epoch 1/1:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch 1: Loss=0.6059, Val Acc=0.6911, FRR=0.8948, FAR=0.0343, Score=0.1897

Best validation accuracy: 0.6911


### Тестовый Dataset

In [11]:
class KWSTestDataset(Dataset):
    def __init__(self, tar_path):
        self.tar = tarfile.open(tar_path, "r")

        self.members = [
            m for m in self.tar.getmembers()
            if m.name.endswith(".opus") and not os.path.basename(m.name).startswith("._")
        ]

        print(f"Найдено {len(self.members)} тестовых файлов")

        self.mel = T.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=1024,
            n_mels=N_MELS,
            f_max=8000
        )

        self.win_len = int(WINDOW_SIZE * SAMPLE_RATE)
        self.hop_len = int(HOP_SIZE * SAMPLE_RATE)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, idx):
        member = self.members[idx]
        uid = os.path.basename(member.name).replace(".opus", "")

        f = self.tar.extractfile(member)
        audio_bytes = f.read()
        wav, sr = load_opus_from_bytes(audio_bytes)

        if sr != SAMPLE_RATE:
            wav = T.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)(wav)

        wav = wav.mean(0, keepdim=True)

        if wav.shape[1] < self.win_len:
            wav = torch.nn.functional.pad(
                wav, (0, self.win_len - wav.shape[1])
            )

        segments = []

        for start in range(0, wav.shape[1] - self.win_len + 1, self.hop_len):
            seg = wav[:, start:start + self.win_len]
            log_mel = (self.mel(seg) + EPS).log2()
            segments.append(log_mel)

        if len(segments) == 0:
            dummy = self.mel(wav[:, :self.win_len])
            dummy = (dummy + EPS).log2()
            return uid, dummy.unsqueeze(0)

        return uid, torch.stack(segments)


### Предсказание на тестовых данных

In [12]:
model.load_state_dict(torch.load('/content/best_model.pt'))
model.eval()

test_ds = KWSTestDataset(TEST_TAR)

results = []

with torch.no_grad():
    for idx in tqdm(range(len(test_ds)), desc="Inference"):
        uid, segments = test_ds[idx]
        segments = segments.to(DEVICE)

        logits = model(segments)
        probs = torch.softmax(logits, dim=1)

        max_prob = probs[:, 1].max().item()

        label = 1 if max_prob > 0.5 else 0

        results.append({'id': uid, 'label': label})

Найдено 27000 тестовых файлов


Inference:   0%|          | 0/27000 [00:00<?, ?it/s]

In [13]:
import pandas as pd

df = pd.DataFrame(results)

output_path = '/content/submission.csv'
df.to_csv(output_path, index=False)

In [14]:
from google.colab import files
files.download(output_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>