In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/NASH


In [None]:
import warnings
import os
import io
import csv
import time
import shlex
import tempfile
import subprocess
from typing import List, Tuple, Optional

import requests
from requests.auth import HTTPBasicAuth
from urllib.parse import urljoin
from bs4 import BeautifulSoup

import torch
import torch.nn.functional as F
import torchaudio

#!pip install paramiko
import paramiko

from aed_models import *


warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# === PARAMETRI MODELLO ===
sample_rate = 16000
duration = 1
seg_len = sample_rate * duration
threshold = 0.0010

In [None]:
# ==== CONFIG SSH REMOTA ====
REMOTE_HOST = "151.14.46.173"
REMOTE_USER = "milone"
SSH_PORT = 2200
SSH_KEY = "id_rsa"  # percorso alla tua chiave privata
UUID = "8c69c0b3-ae12-4552-9b11-0aaa0304a06d"
REMOTE_BASE = f"/mnt/BB-S_STORAGE/{UUID}/recordings"


# ================== FUNZIONI AUDIO ==================
def load_audio(file_like, target_length: int = seg_len) -> Tuple[torch.Tensor, int]:
    waveform, sr = torchaudio.load(file_like)
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)

    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    original_length = waveform.shape[1]

    if original_length > target_length:
        waveform = waveform[:, :target_length]
        original_length = target_length
    else:
        pad_len = target_length - original_length
        waveform = F.pad(waveform, (0, pad_len))

    return waveform.squeeze(0), original_length


def run_inference(model, audio_tensor: torch.Tensor, original_length: int, threshold: float = threshold):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    model.transform_tf = model.transform_tf.to(device)

    audio_tensor = audio_tensor.unsqueeze(0).to(device)
    feature_vector = model.preprocessing(audio_tensor)

    hop_length = model.transform_tf.hop_length
    real_t_bins = 1 + (original_length // hop_length)
    real_vector_array_size = real_t_bins - model.frames + 1

    with torch.no_grad():
        encoded = model.encoder(feature_vector)
        bottleneck = model.bottleneck(encoded)
        reconstructed = model.decoder(bottleneck)

    feature_vector = feature_vector[:, :real_vector_array_size, :]
    reconstructed = reconstructed[:, :real_vector_array_size, :]

    loss = F.mse_loss(reconstructed, feature_vector).item()
    return feature_vector.cpu(), reconstructed.cpu(), loss

# ================== SSH HELPERS ==================
def _open_ssh():
    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(REMOTE_HOST, port=SSH_PORT, username=REMOTE_USER, key_filename=SSH_KEY)
    return ssh

def _ssh_exec(remote_cmd: str, check: bool = True):
    ssh = _open_ssh()
    stdin, stdout, stderr = ssh.exec_command(remote_cmd)
    out = stdout.read().decode("utf-8", errors="ignore")
    err = stderr.read().decode("utf-8", errors="ignore")
    code = stdout.channel.recv_exit_status()
    ssh.close()
    if check and code != 0:
        raise RuntimeError(f"SSH error ({code}): {err.strip()}")
    return out

def _ssh_exec_bin(remote_cmd: str, check: bool = True):
    ssh = _open_ssh()
    stdin, stdout, stderr = ssh.exec_command(remote_cmd)
    out = stdout.read()
    err = stderr.read().decode("utf-8", errors="ignore")
    code = stdout.channel.recv_exit_status()
    ssh.close()
    if check and code != 0:
        raise RuntimeError(f"SSH error ({code}): {err.strip()}")
    return out

def _scp_to_temp(remote_path: str) -> str:
    ssh = paramiko.Transport((REMOTE_HOST, SSH_PORT))
    ssh.connect(username=REMOTE_USER, pkey=paramiko.RSAKey.from_private_key_file(SSH_KEY))
    sftp = paramiko.SFTPClient.from_transport(ssh)
    fd, tmp_path = tempfile.mkstemp(suffix=os.path.splitext(remote_path)[1])
    os.close(fd)
    sftp.get(remote_path, tmp_path)
    sftp.close()
    ssh.close()
    return tmp_path

def remote_delete_first_line(list_path: str) -> None:
    remote_cmd = f"sed -i '1d' {shlex.quote(list_path)}"
    _ssh_exec(remote_cmd, check=True)

def remote_head_first_line(list_path: str) -> str:
    remote_cmd = f"head -n 1 {shlex.quote(list_path)}"
    out = _ssh_exec(remote_cmd, check=False).strip().replace("\r", "")
    return out


# ================== ANALISI FILE REMOTI ==================
def analyze_remote_file(remote_path: str, model) -> float:
    """
    Analizza un file audio remoto scaricandolo byte-per-byte tramite SFTP
    per garantire identico comportamento rispetto all'analisi in memoria.
    """
    tmp_path = None
    try:
        tmp_path = _scp_to_temp(remote_path)  # scarica il file remoto in locale
        waveform, real_samples = load_audio(tmp_path)
        _, _, loss = run_inference(model, waveform, real_samples)
        return loss
    finally:
        if tmp_path and os.path.exists(tmp_path):
            os.remove(tmp_path)

# ================== QUEUE ==================
def list_remote_audio_files(giorno: str, device: str) -> List[str]:
    base_dir = f"{REMOTE_BASE}/{giorno}/device/{device}"
    find_cmd = (
        f"test -d {shlex.quote(base_dir)} && "
        f"find {shlex.quote(base_dir)} -type f ( -iname '*.wav' -o -iname '*.flac' ) || true"
    )
    out = _ssh_exec(find_cmd, check=False)
    files = [ln.strip() for ln in (out or "").splitlines() if ln.strip()]
    return files

def _parse_queue_line_to_base_url(line: str, uuid: str = UUID) -> Optional[str]:
    line = (line or "").strip()
    if not line:
        return None
    if line.startswith(("http://", "https://")):
        return line
    parts = [p.strip() for p in line.split(",")]
    if len(parts) >= 2:
        wav_path = parts[0]
        giorno = parts[1]
        device = None
        try:
            if wav_path and not wav_path.startswith("/"):
                wav_path = "/" + wav_path
            segs = [s for s in wav_path.split("/") if s]
            if "device" in segs:
                idx = segs.index("device")
                if idx + 1 < len(segs):
                    device = segs[idx + 1]
        except Exception:
            device = None
        if giorno and device:
            return f"https://lys-ai.it/recordings/{uuid}/recordings/{giorno}/FLAC/device/{device}"
    return None

def consume_queue_full_ssh(device: str,
                           model,
                           model_type: str,
                           csv_name: str,
                           poll_every_seconds: int = 10):
    if not device:
        raise ValueError("DEVICE non può essere vuoto")
    list_path = f"{REMOTE_BASE}/BB-AI_{device}_LIST.txt"
    print(f"📖 [QUEUE][SSH] In ascolto su: {list_path}")

    out_dir = f"esiti_{csv_name}"
    os.makedirs(out_dir, exist_ok=True)

    # Nome CSV fisso per il giorno, lo arricchiamo man mano
    def write_anomalies(giorno: str, anomalies: List[Tuple[str, str, float]]):
        csv_filename = os.path.join(out_dir, f"{model_type}_esito_{giorno}.csv")
        file_exists = os.path.exists(csv_filename)

        with open(csv_filename, mode='a', newline='') as f:
            w = csv.writer(f)
            # Se il file non esiste ancora, scrivo meta-intestazione + header
            if not file_exists:
                w.writerow(["modello", "tipo_file", "durata_segmento", "verificato1", "verificato2"])
                w.writerow(["NS V1.01S_V5", "WAV/FLAC", "1", "0", "1"])
                w.writerow(["path", "status", "loss"])
            # Scrivo solo nuove anomalie
            for pth, st, ls in anomalies:
                w.writerow([pth, st, f"{ls:.6f}"])

        print(f"[SSH] Salvate {len(anomalies)} anomalie in '{csv_filename}'")

    while True:
        line = remote_head_first_line(list_path)
        if not line:
            print(f"[QUEUE][SSH] Nessuna riga trovata, attendo...")
            time.sleep(poll_every_seconds)
            continue

        print(f"👉 Riga da processare: {line}")
        parts = [p.strip() for p in line.split(",")]

        anomalies: List[Tuple[str, str, float]] = []
        giorno_for_csv = "unknown"
        ok = False

        try:
            # Caso 1: percorso assoluto
            if len(parts) >= 2 and parts[0] and (parts[0].startswith("/") or parts[0].startswith("mnt")):
                wav_path = parts[0]
                giorno_for_csv = parts[1] or "unknown"
                if not wav_path.startswith("/"):
                    wav_path = "/" + wav_path
                print(f"[SSH] Analisi singolo file: {wav_path}")
                try:
                    loss = analyze_remote_file(wav_path, model)
                    status = "Anomalo" if loss > threshold else "Normale"
                    print(f" --> {status} (MSE: {loss:.6f})")
                    if status == "Anomalo":
                        anomalies.append((wav_path, status, loss))
                    ok = True
                except Exception as e:
                    print(f"❌ Errore su {wav_path}: {e}")
                    ok = False

            # Caso 2: URL
            else:
                url = _parse_queue_line_to_base_url(line)
                if not url:
                    print(f"[SSH] Riga non interpretabile: {line!r}")
                    ok = False
                else:
                    try:
                        tail = url.split("/recordings/")[1]
                        giorno_for_csv = tail.split("/")[0]
                        device_name = tail.split("/device/")[1].split("/")[0]
                    except Exception:
                        print(f"[SSH] Impossibile estrarre giorno/device da URL: {url}")
                        ok = False
                    else:
                        paths = list_remote_audio_files(giorno_for_csv, device_name)
                        if not paths:
                            print(f"[SSH] Nessun file trovato per {giorno_for_csv}/{device_name}")
                            ok = True
                        else:
                            print(f"[SSH] Trovati {len(paths)} file. Avvio analisi...")
                            ok = True
                            for rp in paths:
                                try:
                                    loss = analyze_remote_file(rp, model)
                                    status = "Anomalo" if loss > threshold else "Normale"
                                    print(f"  {os.path.basename(rp)} -> {status} (MSE: {loss:.6f})")
                                    if status == "Anomalo":
                                        anomalies.append((rp, status, loss))
                                except Exception as e:
                                    print(f"  ❌ Errore su {rp}: {e}")
                                    ok = False

            # === Scrittura CSV solo se ci sono anomalie ===
            if anomalies:
                write_anomalies(giorno_for_csv, anomalies)
            else:
                print(f"[SSH] Nessuna anomalia trovata per {giorno_for_csv}, non scrivo nulla.")

            # Rimuovo la riga processata se tutto ok
            if ok:
                try:
                    remote_delete_first_line(list_path)
                except Exception as e:
                    print(f"⚠️ Impossibile cancellare la riga su {list_path}: {e}")

        except Exception as outer:
            print(f"❌ Errore generale sulla riga: {outer}")

        time.sleep(poll_every_seconds)



# ================== UTILITY ==================
def prepend_metadata_to_csv(csv_filename: str,
                            modello: str,
                            tipo_file: str,
                            durata_segmento: str,
                            verificato1: str,
                            verificato2: str) -> None:
    header = ["modello", "tipo_file", "durata_segmento", "verificato1", "verificato2"]
    metadata = [modello, tipo_file, durata_segmento, verificato1, verificato2]

    try:
        with open(csv_filename, mode='r') as original:
            original_lines = original.readlines()

        # se il file contiene solo l'intestazione "path,status,loss", non aggiungo meta
        if len(original_lines) <= 1:
            print(f"⚠️ Nessuna anomalia nel file '{csv_filename}', salto meta-intestazione.")
            return

        with open(csv_filename, mode='w', newline='') as modified:
            writer = csv.writer(modified)
            writer.writerow(header)
            writer.writerow(metadata)
            modified.writelines(original_lines)

        print(f"✅ Meta-intestazione inserita in '{csv_filename}'")
    except Exception as e:
        print(f"❌ Errore durante l'inserimento della meta-intestazione in {csv_filename}: {e}")



In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = "results_1s_v5/best_model_1s_v5.pth"

    # Istanzia il modello e carica i pesi
    model = NS_1()
    model.load_state_dict(torch.load(model_path, map_location=device))

    model_type = "1s"
    csv_name = "2308_v5"

    # Avvio del consumer SSH
    consume_queue_full_ssh(
        device="RSP1-MIC1",   # Nome del device che vuoi monitorare
        model=model,          # <--- qui passi l'istanza, non la classe
        model_type=model_type,
        csv_name=csv_name,
        poll_every_seconds=5
    )
