In [1]:
from pathlib import Path
from typing import List

import torch
import torchaudio
import pandas as pd
import numpy as np

import IPython.display as ipd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TestSpotterDataset(torch.utils.data.Dataset):
    def __init__(self, manifest_path: Path, transform):
        super().__init__()
        
        self.transform = transform
        manifest = pd.read_csv(manifest_path)
        self.wav_files = [
            manifest_path.parent / wav_path for wav_path in manifest.path
        ]
        self.labels = manifest.index.values
        
    def __len__(self):
        return len(self.wav_files)
    
    def __getitem__(self, idx):
        wav, sr = torchaudio.load(self.wav_files[idx])
        features = self.transform(wav)
        return wav[0], features, self.labels[idx]

In [3]:
from model import Conv1dNet
weights = '/home/eugeny/soundmipt/hw2/runs/Normalized-V4/Normalized-V4-epoch=153-step=7392-val_loss=0.3072.ckpt'
model = Conv1dNet.load_from_checkpoint(weights)

In [4]:
model.eval()

class SpecScaler(torch.nn.Module):
    def forward(self, x):
        return torch.log(x.clamp_(1e-9, 1e9))

def collator(data):
    specs = []
    labels = []
    for wav, features, label in data:
        specs.append(features)
        labels.append(label)
    specs = torch.cat(specs)  
    labels = torch.Tensor(labels).long()
    return specs, labels

val_transform = torch.nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=model.conf.sample_rate, **model.conf.features),
    SpecScaler()
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=TestSpotterDataset(
        manifest_path=Path('/home/eugeny/Datasets/keyword-spotting/test/test/manifest.csv'),
        transform=val_transform
    ),
    collate_fn=collator,
    batch_size=64,
    shuffle=False
)

In [5]:
index_lst, label_lst = [], []

model = model.to('cuda')
for inputs, idx in test_dataloader:
    with torch.no_grad():
        preds = model(inputs.to(model.conf.device)).argmax(-1).cpu().tolist()
    
    label_lst.extend([model.conf.idx_to_keyword[idx] for idx in preds])
    index_lst.extend(idx.tolist())

In [6]:
(
    pd.DataFrame({'index': index_lst, 'label': label_lst})
    .to_csv('submit.csv', index=False)
)