In [28]:
import torch
import torchaudio
import torch.nn as nn

class SpectrogramClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(SpectrogramClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)

        x = x.view(x.size(0), -1)

        in_features = x.size(1)

        self.fc1 = nn.Linear(in_features, 128)
        
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def predict_audio(audio_path, model_path, device='cuda'):
    model = SpectrogramClassifier(num_classes=2)
    map_location = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # Load model weights with strict=False to ignore missing/extra keys
    state_dict = torch.load(model_path, map_location=map_location)
    model.load_state_dict(state_dict, strict=False)
    
    model.to(device)
    model.eval()
    
    sample_rate = 44100
    n_mels = 64
    
    waveform, sr = torchaudio.load(audio_path)
    
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)
    
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=2048,
        hop_length=512,
        n_mels=n_mels
    )(waveform)
    
    mel_spec = mel_spec.unsqueeze(0)  
    
    mel_spec = mel_spec.to(device)
    
    with torch.no_grad():
        outputs = model(mel_spec)
        _, predicted = torch.max(outputs, 1)
        
    label = "Noisy" if predicted.item() == 1 else "Clean"
    
    probabilities = torch.nn.functional.softmax(outputs, dim=1)
    confidence = probabilities[0][predicted.item()].item()
    
    return label, confidence

if __name__ == "__main__":
    model_path = "/Users/johannasmriti/Downloads/Final working model/classifier/class-epoch-8.pth"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    audio_file = "/Users/johannasmriti/Downloads/Final working model/denoiser/noisy_example.wav"
    
    label, confidence = predict_audio(audio_file, model_path, device)
    
    print(f"Audio file: {audio_file}")
    print(f"Prediction: {label}")
    print(f"Confidence: {confidence:.2%}")


Audio file: /Users/johannasmriti/Downloads/Final working model/denoiser/noisy_example.wav
Prediction: Noisy
Confidence: 65.89%


  state_dict = torch.load(model_path, map_location=map_location)
