In [1]:
import warnings
from torch import nn, optim

from data_loader import *
from trainer import trainer_
from aed_models import *

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
h = 128
learning_rate = 1e-3
epochs = 15
batch_size = 256
batch_size_2 = 64

ae_model = NS_1().to(device)
# clf_model = Wavegram_AttentionMap(h=h, lr=learning_rate).to(device)
optimizer = optim.AdamW(ae_model.parameters(), lr=1e-3,  weight_decay=1e-5)
criterion = nn.MSELoss()

In [4]:
### === UNSUPERVISED === ###
ae_dataset = AudioDataset(
    audio_dir="audio_sources/dataset/background_nash/0_normal_1",
    supervised=False)

# Split unsupervised
train_size = int(0.75 * len(ae_dataset))
val_size = int(0.15 * len(ae_dataset))
test_size = len(ae_dataset) - train_size - val_size
ae_train, ae_val, ae_test = random_split(ae_dataset, [train_size, val_size, test_size])

ae_train_loader = DataLoader(ae_train, batch_size=batch_size, shuffle=True)
ae_val_loader = DataLoader(ae_val, batch_size=batch_size, shuffle=False)
ae_test_loader = DataLoader(ae_test, batch_size=batch_size, shuffle=False)

In [5]:
### === INFO === ###
print("Ready!")
print(f" - AE: {len(ae_dataset)} samples")

Ready!
 - AE: 958580 samples


In [None]:
trainer_(model=ae_model, train_loader=ae_train_loader, val_loader=ae_val_loader,
         criterion=criterion, optimizer=optimizer, device=device, num_epochs=15,
         results_dir="results_1s", csv_name="training_log_1s.csv", model_name="best_model_1s.pth")

In [None]:
sample_rate = 16000
duration = 5
seg_len = sample_rate * duration
threshold = 0.0020

In [None]:
# Parametri globali
sample_rate = 16000
duration = 5
seg_len = sample_rate * duration  # 160000
threshold = 0.0020
severe_threshold = 0.0040

In [None]:
# === DETECTOR === #
def load_audio(filepath, target_length=seg_len):
    waveform, sr = torchaudio.load(filepath)
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)

    # Mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    original_length = waveform.shape[1]  # salviamo la lunghezza reale

    # Padding a 10s (senza taglio se >10s)
    if original_length < target_length:
        pad_len = target_length - original_length
        waveform = F.pad(waveform, (0, pad_len))
    else:
        waveform = waveform[:, :target_length]
        original_length = target_length  # limitiamo a max 10s

    return waveform.squeeze(0), original_length  # [160000], valore reale


def run_inference(model, audio_tensor, original_length, threshold=threshold):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    model.transform_tf = model.transform_tf.to(device)

    audio_tensor = audio_tensor.unsqueeze(0).to(device)  # [1, 160000]
    feature_vector = model.preprocessing(audio_tensor)  # [1, T, 640]

    # Calcola quanti frame sono reali (non padding)
    real_t_bins = 1 + (original_length // model.transform_tf.hop_length)
    real_vector_array_size = real_t_bins - model.frames + 1

    # Ricostruzione
    with torch.no_grad():
        encoded = model.encoder(feature_vector)
        bottleneck = model.bottleneck(encoded)
        reconstructed = model.decoder(bottleneck)

    # Taglia alla durata reale
    feature_vector = feature_vector[:, :real_vector_array_size, :]
    reconstructed = reconstructed[:, :real_vector_array_size, :]

    # Soglie

    # Errore globale
    loss = F.mse_loss(reconstructed, feature_vector).item()

    # Classificazione in base alla soglia
    if loss > severe_threshold:
        result = "Anomalia Grave"
    elif loss > threshold:
        result = "Anomalo"
    else:
        result = "Normale"

    print(f"Anomaly Score (MSE): {loss:.6f} → {result}")


    # === Visualizzazione === #
    with torch.no_grad():
        reconstruction_error = torch.mean((feature_vector - reconstructed) ** 2, dim=2).squeeze()  # [T]
        diff = (feature_vector - reconstructed).abs().squeeze()  # [T, 640]
        # diff_mel = diff.view(-1, 5, 128).mean(dim=1)  # [T, 128]
        diff_mel = diff.reshape(-1, 128)

    # --- Plot 1D ---
    # --- Plot 1D (con threshold e tempo in secondi) ---
    reconstruction_error_np = reconstruction_error.cpu().numpy()
    num_frames = reconstruction_error_np.shape[0]
    hop_length = model.transform_tf.hop_length
    frame_times = torch.arange(num_frames) * hop_length / sample_rate  # tempo in secondi

    plt.figure(figsize=(12, 4))
    plt.plot(frame_times, reconstruction_error_np, label="Errore di ricostruzione")
    plt.axhline(y=threshold, color='red', linestyle='--', label=f"Soglia ({threshold:.4f})")
    plt.title("Errore di Ricostruzione per Frame")
    plt.xlabel("Tempo [s]")
    plt.ylabel("MSE")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot 2D ---
    plt.figure(figsize=(12, 6))
    plt.imshow(diff_mel.cpu().T, aspect="auto", origin="lower", cmap="viridis",
               extent=[frame_times[0].item(), frame_times[-1].item(), 0, 128])
    plt.title("Errore di Ricostruzione (Tempo × Frequenza)")
    plt.xlabel("Tempo [s]")
    plt.ylabel("Mel Frequency Bin")
    plt.colorbar(label="|Errore|")
    plt.tight_layout()
    plt.show()


    return feature_vector.cpu(), reconstructed.cpu(), loss


In [None]:
# === RUN === #
audio_path = "test/200501.flac"
model_path = "results_5s/best_model_5s.pth"

# Carica modello
model = NS_5()
model.load_state_dict(torch.load(model_path, map_location="cpu"))

# Carica audio e fai inferenza
waveform, original_length = load_audio(audio_path)
original, reconstructed, loss = run_inference(model, waveform, original_length)


In [None]:
state_dict = torch.load(model_path, map_location="cpu")
print(type(state_dict))  # Se è <class 'dict'> con 'state_dict', è un modello intero


In [None]:
for k in state_dict.keys():
    print(k)
