In [None]:
# Fast install, might break in the future.
!pip install 'safetensors<0.6'
!pip install 'sphn<0.2'
!pip install --no-deps "moshi==0.2.11"
# Slow install (will download torch and cuda), but future proof.
# !pip install "moshi==0.2.11"

In [None]:
"""
G√©n√©rateur audio TTS avec Kyutai/Moshi TTS
Adapt√© pour les conversations psychiatriques structur√©es
Installation: pip install 'safetensors<0.6' 'sphn<0.2' --no-deps "moshi==0.2.11"
"""

import json
import numpy as np
import torch
from pathlib import Path
from scipy.io import wavfile
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel

# ============================================================================
# CONFIGURATION
# ============================================================================

GT_FOLDER = "gt"
OUTPUT_DIR = Path("audio_output_kyutai")

# Voix disponibles (voir https://huggingface.co/kyutai/dsm-tts-voices)
VOICE_MEDECIN = "expresso/ex03-ex01_neutral_001_channel1_334s.wav"
VOICE_PATIENT = "expresso/ex01-ex01_neutral_001_channel1_334s.wav"

# Param√®tres TTS
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_Q = 32
TEMP = 0.6
CFG_COEF = 2.0
PADDING_BETWEEN = 1  # secondes de silence entre les tours


class KyutaiTTSGenerator:
    """G√©n√©rateur TTS avec Kyutai Moshi"""

    def __init__(self, device=DEVICE, n_q=N_Q, temp=TEMP):
        """Initialise le mod√®le TTS (une seule fois)"""
        print("üîÑ Chargement du mod√®le Kyutai TTS...")
        print(f"   Device: {device}")

        checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
        self.tts_model = TTSModel.from_checkpoint_info(
            checkpoint_info, n_q=n_q, temp=temp, device=device
        )

        print("‚úÖ Mod√®le charg√© avec succ√®s\n")
        print(f"Voix disponibles: https://huggingface.co/{DEFAULT_DSM_TTS_VOICE_REPO}")

    def generate_from_json(
        self,
        json_file: Path,
        voice_medecin: str = VOICE_MEDECIN,
        voice_patient: str = VOICE_PATIENT,
        cfg_coef: float = CFG_COEF,
        padding_between: float = PADDING_BETWEEN
    ) -> Path:
        """
        G√©n√®re un fichier audio √† partir d'un JSON de conversation

        Args:
            json_file: Chemin vers le fichier JSON
            voice_medecin: Voix pour le m√©decin (speaker_id=0)
            voice_patient: Voix pour le patient (speaker_id=1)
            cfg_coef: Coefficient CFG pour la g√©n√©ration
            padding_between: Secondes de silence entre tours

        Returns:
            Chemin vers le fichier WAV g√©n√©r√©
        """
        # Charger le JSON
        with open(json_file, 'r', encoding='utf-8') as f:
            dialogue_data = json.load(f)

        print(f"\nüéôÔ∏è  G√©n√©ration Kyutai TTS")
        print(f"    Fichier source : {json_file.name}")
        print(f"    Conversation : {dialogue_data['metadata']['conversation_id']}")
        print(f"    M√©decin : {dialogue_data['participants']['medecin']['nom']} ‚Üí {voice_medecin}")
        print(f"    Patient : {dialogue_data['participants']['patient']['prenom']} {dialogue_data['participants']['patient']['nom']} ‚Üí {voice_patient}")
        print(f"    Segments : {len(dialogue_data['dialogue'])}\n")

        # Pr√©parer les textes et voix
        texts = []
        voices_needed = []

        for item in dialogue_data["dialogue"]:
            text = item["text"]
            speaker_id = item["speaker_id"]

            texts.append(text)
            # speaker_id = 0 ‚Üí M√©decin, speaker_id = 1 ‚Üí Patient
            voice = voice_medecin if speaker_id == 0 else voice_patient
            voices_needed.append(voice)

        print(f"üìù Pr√©paration du script ({len(texts)} segments)...")
        entries = self.tts_model.prepare_script(texts, padding_between=padding_between)

        print(f"üé§ Pr√©paration des voix...")
        voice_paths = [self.tts_model.get_voice_path(v) for v in voices_needed]
        condition_attributes = self.tts_model.make_condition_attributes(
            voice_paths, cfg_coef=cfg_coef
        )

        print(f"üîä G√©n√©ration audio...")
        pcms = []

        def on_frame(frame):
            if (frame != -1).all():
                pcm = self.tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
                pcms.append(np.clip(pcm[0, 0], -1, 1))
            print(f"  √âtape {len(pcms)}", end="\r")

        # G√©n√©rer l'audio
        with self.tts_model.mimi.streaming(1):
            result = self.tts_model.generate(
                [entries],
                [condition_attributes],
                on_frame=on_frame
            )

        print("\nüîß Assemblage final...")
        audio = np.concatenate(pcms, axis=-1)

        # Sauvegarder le fichier WAV
        OUTPUT_DIR.mkdir(exist_ok=True)
        conv_id = dialogue_data["metadata"]["conversation_id"]
        output_file = OUTPUT_DIR / f"audio_{conv_id}.wav"

        sample_rate = self.tts_model.mimi.sample_rate
        # Convertir en int16 pour WAV
        audio_int16 = (audio * 32767).astype(np.int16)
        wavfile.write(output_file, sample_rate, audio_int16)

        print(f"\n‚úÖ Audio g√©n√©r√© : {output_file}")
        print(f"‚è±Ô∏è  Dur√©e : {len(audio) / sample_rate:.1f}s")
        print(f"üîä Sample rate : {sample_rate} Hz")
        print(f"üíä Diagnostic : {dialogue_data['cas_clinique']['diagnostic_principal']}")

        return output_file


def lister_conversations_disponibles(gt_dir: Path) -> list[Path]:
    """Liste toutes les conversations disponibles"""
    if not gt_dir.exists():
        return []
    return sorted(gt_dir.glob("conv_*.json"))


def afficher_conversations(conv_files: list[Path]):
    """Affiche la liste des conversations disponibles"""
    print("\nüìö CONVERSATIONS DISPONIBLES\n")
    print("=" * 80)

    for idx, conv_file in enumerate(conv_files, 1):
        try:
            with open(conv_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                metadata = data.get('metadata', {})
                participants = data.get('participants', {})
                cas = data.get('cas_clinique', {})

                conv_id = metadata.get('conversation_id', 'N/A')
                medecin = participants.get('medecin', {}).get('nom', 'N/A')
                patient_prenom = participants.get('patient', {}).get('prenom', 'N/A')
                patient_nom = participants.get('patient', {}).get('nom', 'N/A')
                diagnostic = cas.get('diagnostic_principal', 'N/A')
                nb_segments = len(data.get('dialogue', []))

                print(f"\n[{idx}] {conv_file.name}")
                print(f"    ID: {conv_id}")
                print(f"    M√©decin: Dr. {medecin}")
                print(f"    Patient: {patient_prenom} {patient_nom}")
                print(f"    Diagnostic: {diagnostic[:80]}...")
                print(f"    Segments: {nb_segments}")
        except Exception as e:
            print(f"\n[{idx}] {conv_file.name} - ‚ö†Ô∏è Erreur: {e}")

    print("\n" + "=" * 80)


def choisir_conversation(conv_files: list[Path]):
    """Permet de choisir une conversation"""
    afficher_conversations(conv_files)

    while True:
        choix = input(f"\nChoisissez une conversation (1-{len(conv_files)}) ou 'q' pour quitter : ").strip().lower()

        if choix == 'q':
            return None

        try:
            idx = int(choix)
            if 1 <= idx <= len(conv_files):
                return conv_files[idx - 1]
            else:
                print(f"‚ùå Choisissez un nombre entre 1 et {len(conv_files)}")
        except ValueError:
            print("‚ùå Entr√©e invalide")


def main():
    print("=" * 80)
    print("üé¨ G√âN√âRATEUR AUDIO KYUTAI TTS")
    print("   Pour conversations psychiatriques structur√©es")
    print("=" * 80)

    # V√©rifier CUDA
    if torch.cuda.is_available():
        print(f"‚úÖ CUDA disponible : {torch.cuda.get_device_name(0)}")
    else:
        print("‚ö†Ô∏è  CUDA non disponible, utilisation du CPU (plus lent)")

    # V√©rifier le dossier GT
    gt_dir = Path(GT_FOLDER)
    if not gt_dir.exists():
        print(f"\n‚ùå Le dossier '{GT_FOLDER}' n'existe pas!")
        return

    # Lister les conversations
    conv_files = lister_conversations_disponibles(gt_dir)
    if not conv_files:
        print(f"\n‚ùå Aucune conversation trouv√©e dans '{GT_FOLDER}'")
        return

    print(f"\nüìä {len(conv_files)} conversations trouv√©es")

    # Initialiser le g√©n√©rateur TTS (une seule fois)
    try:
        generator = KyutaiTTSGenerator()
    except Exception as e:
        print(f"\n‚ùå Erreur lors du chargement du mod√®le: {e}")
        print("\nüí° Assurez-vous que les d√©pendances sont install√©es:")
        print("   pip install 'safetensors<0.6' 'sphn<0.2' --no-deps 'moshi==0.2.11'")
        return

    # Boucle de g√©n√©ration
    while True:
        conv_file = choisir_conversation(conv_files)

        if conv_file is None:
            print("\nüëã Au revoir!")
            break

        try:
            output_file = generator.generate_from_json(conv_file)
            print(f"\nüéâ Succ√®s!\n")
        except Exception as e:
            print(f"\n‚ùå Erreur lors de la g√©n√©ration: {e}")
            import traceback
            traceback.print_exc()

        continuer = input("\nG√©n√©rer une autre conversation? (o/n) : ").strip().lower()
        if continuer != 'o':
            print("\nüëã Au revoir!")
            break


if __name__ == "__main__":
    main()