In [2]:
!pip install TTS
import os
import glob
import json
import logging
import subprocess
from datetime import datetime
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
import librosa
import torch
import requests
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig, XttsAudioConfig
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.models.xtts import XttsArgs
from pydub import AudioSegment

# -----------------------------------------------------
# Embedding utilities adapted from xtts_match_emidunno.py
# -----------------------------------------------------

# initialize torch serialization for XTTS
torch.serialization.add_safe_globals([
    XttsConfig,
    XttsAudioConfig,
    BaseDatasetConfig,
    XttsArgs,
])

logger = logging.getLogger("cort_xtts")
logging.basicConfig(level=logging.INFO)

FFMPEG_PATH = "ffmpeg"
FFPROBE_PATH = "ffprobe"


def load_xtts_model(force_cpu: bool = False) -> Optional[TTS]:
    """Load the XTTS model using CPU or GPU."""
    try:
        device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
        tts = TTS(
            model_name="tts_models/multilingual/multi-dataset/xtts_v2",
            progress_bar=False,
            gpu=(device == "cuda"),
        )
        logger.info("XTTS model loaded on %s", device)
        return tts
    except Exception as e:
        logger.error("Failed to load XTTS model: %s", e)
        return None


def convert_mp3_to_wav(mp3_path: str, temp_files: List[str]) -> Optional[str]:
    """Convert an MP3 file to WAV and track the temporary file."""
    try:
        base = os.path.splitext(os.path.basename(mp3_path))[0]
        wav_path = os.path.join(os.path.dirname(mp3_path), f"{base}_temp.wav")
        audio = AudioSegment.from_mp3(mp3_path)
        audio.export(wav_path, format="wav", parameters=["-ac", "1", "-ar", "22050"])
        temp_files.append(wav_path)
        return wav_path
    except Exception as e:
        logger.error("MP3 conversion failed for %s: %s", mp3_path, e)
        return None


@torch.inference_mode()
def extract_speaker_embedding(audio_path: str, tts: TTS, temp_files: List[str]):
    """Extract a speaker embedding from an audio file."""
    try:
        if audio_path.lower().endswith(".mp3"):
            wav = convert_mp3_to_wav(audio_path, temp_files)
            if not wav:
                return None
            audio_path = wav
        audio, sr = librosa.load(audio_path, sr=22050, mono=True)
        try:
            latent, _, _ = tts.tts_model.get_conditioning_latents(audio, sr)
        except AttributeError:
            latent, _, _ = tts.synthesizer.tts_model.get_conditioning_latents(audio, sr)
        return latent.squeeze().cpu().numpy()
    except Exception as e:
        logger.error("Embedding extraction failed for %s: %s", audio_path, e)
        return None


def process_directory(directory: str, tts: TTS, lang_prefix: str, temp_files: List[str]):
    """Process a directory of wav or mp3 files and return embeddings."""
    embeddings: Dict[str, np.ndarray] = {}
    if not os.path.isdir(content/common_voice_fr_39586341.wav):
        raise FileNotFoundError(f"Directory not found: {directory}")
    audio_files = glob.glob(os.path.join(directory, "**", "*.wav"), recursive=True)
    audio_files += glob.glob(os.path.join(directory, "**", "*.mp3"), recursive=True)
    for path in audio_files:
        emb = extract_speaker_embedding(path, tts, temp_files)
        if emb is not None:
            rel = f"{lang_prefix}/{os.path.relpath(path, directory)}"
            embeddings[rel] = emb
    return embeddings


def save_embeddings(embeddings: Dict[str, np.ndarray], filename: str = "voice_embeddings.npz"):
    np.savez(filename, **{k: v.astype(np.float32) for k, v in embeddings.items()})
    logger.info("Saved embeddings to %s", filename)


def calculate_similarity(embeddings: Dict[str, np.ndarray]) -> pd.DataFrame:
    files = list(embeddings.keys())
    arr = np.array([e.flatten() for e in embeddings.values()])
    norms = np.linalg.norm(arr, axis=1, keepdims=True)
    norms[norms == 0] = 1
    normed = arr / norms
    sim = np.dot(normed, normed.T)
    return pd.DataFrame(sim, index=files, columns=files)

# -----------------------------------------------------
# DeepSeek assistant from untitled18.py
# -----------------------------------------------------

class AssistantConfig(BaseModel):
    api_key: str

class CodeRequest(BaseModel):
    session_id: str
    code: str
    error: Optional[str] = None
    objective: Optional[str] = None

class SessionRequest(BaseModel):
    session_id: str
    filename: Optional[str] = None

class UserChoice(BaseModel):
    session_id: str
    choice: str
    modified_code: Optional[str] = None

class EmbeddingRequest(BaseModel):
    session_id: str
    directory: str
    lang_prefix: str = "en"
    output_file: str = "voice_embeddings.npz"

class DeepSeekAssistant:
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key or os.getenv("sk-1af9e4036f7d4bd4b6f2e0791c170489")
        self.base_url = "https://api.deepseek.com/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        self.code_history: List[Dict] = []
        self.session_history: List[Dict] = []
        self.max_history = 5
        self.current_code = ""

    def _call_api(self, messages: List[Dict], model: str = "deepseek-chat", temperature: float = 0.7, max_tokens: int = 10000) -> str:
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        try:
            resp = requests.post(self.base_url, headers=self.headers, json=payload)
            resp.raise_for_status()
            data = resp.json()
            return data["choices"][0]["message"]["content"].strip()
        except Exception as e:
            logger.error("API error: %s", e)
            return f"API Error: {e}"

    def analyze_problem(self, code: str, error: Optional[str] = None, objective: Optional[str] = None) -> Dict:
        prompt = f"I'm working on this Python code:\n```python\n{code}\n```\n"
        if error:
            prompt += f"I encountered this error:\n\n{error}\n"
        if objective:
            prompt += f"My objective is:\n{objective}\n"
        else:
            prompt += "Please analyze this code and suggest improvements or fixes."

        messages = [{"role": "user", "content": prompt}]
        analysis = self._call_api(messages, model="deepseek-reasoner", temperature=0.3, max_tokens=6000)
        solution = self._call_api(messages + [{"role": "assistant", "content": analysis}], model="deepseek-chat", temperature=0.2, max_tokens=4000)
        attempt = {
            "timestamp": datetime.now().isoformat(),
            "code": code,
            "error": error,
            "analysis": analysis,
            "solution": solution,
        }
        self.code_history.append(attempt)
        if len(self.code_history) > self.max_history:
            self.code_history.pop(0)
        return {"analysis": analysis, "solution": solution, "attempt_id": len(self.code_history) - 1}

    @staticmethod
    def get_code_from_response(response: str) -> str:
        if "```python" in response:
            return response.split("```python")[1].split("```")[0]
        if "```" in response:
            return response.split("```")[1].split("```")[0]
        return response

    def test_code(self, code: str) -> tuple[bool, str]:
        with open("_temp_code.py", "w", encoding="utf-8") as f:
            f.write(code)
        try:
            result = subprocess.run(["python", "_temp_code.py"], capture_output=True, text=True, timeout=10)
            if result.returncode == 0:
                return True, result.stdout
            return False, result.stderr
        except subprocess.TimeoutExpired:
            return False, "Timeout: Code took too long to execute"
        except Exception as e:
            return False, str(e)

    @staticmethod
    def show_diff(old_code: str, new_code: str) -> str:
        import difflib
        diff = difflib.unified_diff(
            old_code.splitlines(keepends=True),
            new_code.splitlines(keepends=True),
            fromfile="original",
            tofile="suggested",
        )
        return "".join(diff)

    def save_session(self, filename: str) -> str:
        data = {
            "timestamp": datetime.now().isoformat(),
            "code_history": self.code_history,
            "session_history": self.session_history,
            "current_code": self.current_code,
        }
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)
        return filename

    def load_session(self, filename: str) -> None:
        with open(filename, "r", encoding="utf-8") as f:
            data = json.load(f)
            self.code_history = data.get("code_history", [])
            self.session_history = data.get("session_history", [])
            self.current_code = data.get("current_code", "")

# -----------------------------------------------------
# FastAPI setup and endpoints
# -----------------------------------------------------

app = FastAPI(title="CoRT XTTS Embedder", description="Embedding extraction with DeepSeek assistance")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

code_assistants: Dict[str, DeepSeekAssistant] = {}
session_files: Dict[str, str] = {}

@app.post("/api/initialize")
async def initialize_assistant(config: AssistantConfig):
    session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}_{os.urandom(4).hex()}"
    assistant = DeepSeekAssistant(api_key=config.api_key)
    code_assistants[session_id] = assistant
    session_files[session_id] = f"session_{session_id}.json"
    return {"session_id": session_id, "status": "initialized"}

@app.post("/api/analyze")
async def analyze_code(request: CodeRequest):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    assistant.current_code = request.code
    result = assistant.analyze_problem(request.code, request.error, request.objective)
    suggested = assistant.get_code_from_response(result["solution"])
    return {
        "session_id": request.session_id,
        "analysis": result["analysis"],
        "suggested_code": suggested,
        "attempt_id": result["attempt_id"],
    }

@app.post("/api/test")
async def test_current_code(request: CodeRequest):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    success, output = assistant.test_code(request.code)
    return {"session_id": request.session_id, "success": success, "output": output}

@app.post("/api/process_choice")
async def process_choice(request: UserChoice):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    if request.choice == "a":
        if not assistant.code_history:
            raise HTTPException(status_code=400, detail="No suggestions available")
        latest = assistant.code_history[-1]
        assistant.current_code = assistant.get_code_from_response(latest["solution"])
        success, output = assistant.test_code(assistant.current_code)
        return {"session_id": request.session_id, "action": "accepted", "success": success, "output": output, "code": assistant.current_code}
    elif request.choice == "m" and request.modified_code:
        assistant.current_code = request.modified_code
        return {"session_id": request.session_id, "action": "modified", "code": assistant.current_code}
    elif request.choice == "t":
        success, output = assistant.test_code(assistant.current_code)
        return {"session_id": request.session_id, "action": "tested", "success": success, "output": output}
    else:
        raise HTTPException(status_code=400, detail="Invalid choice or missing code")

@app.post("/api/save_session")
async def save_session_endpoint(request: SessionRequest):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    filename = request.filename or session_files[request.session_id]
    saved = assistant.save_session(filename)
    return {"status": "saved", "filename": saved}

@app.post("/api/load_session")
async def load_session_endpoint(request: SessionRequest):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    filename = request.filename or session_files[request.session_id]
    assistant.load_session(filename)
    return {"status": "loaded", "filename": filename}

@app.get("/api/session_history/{session_id}")
async def session_history(session_id: str):
    if session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[session_id]
    hist = []
    for h in assistant.code_history:
        hist.append({
            "timestamp": h["timestamp"],
            "error": (h["error"][:200] + "...") if h.get("error") and len(h["error"]) > 200 else h.get("error"),
            "analysis": (h["analysis"][:200] + "...") if len(h["analysis"]) > 200 else h["analysis"],
        })
    return {"session_id": session_id, "code_history": hist}

@app.post("/api/extract_embeddings")
async def extract_embeddings(request: EmbeddingRequest):
    if request.session_id not in code_assistants:
        raise HTTPException(status_code=404, detail="Session not found")
    assistant = code_assistants[request.session_id]
    temp_files: List[str] = []
    try:
        tts = load_xtts_model(force_cpu=False)
        if tts is None:
            raise RuntimeError("XTTS model failed to load")
        embeddings = process_directory(request.directory, tts, request.lang_prefix, temp_files)
        save_embeddings(embeddings, request.output_file)
        similarity = calculate_similarity(embeddings).to_dict()
        return {"status": "completed", "embeddings": len(embeddings), "similarity": similarity}
    except Exception as e:
        analysis = assistant.analyze_problem(open(__file__).read(), str(e), "Fix embedding extraction")
        suggestion = assistant.get_code_from_response(analysis["solution"])
        return {"status": "error", "error": str(e), "analysis": analysis["analysis"], "suggestion": suggestion}
    finally:
        for f in temp_files:
            if os.path.exists(f):
                os.remove(f)

@app.websocket("/ws/{session_id}")
async def interactive_session(websocket: WebSocket, session_id: str):
    await websocket.accept()
    if session_id not in code_assistants:
        await websocket.send_json({"error": "Session not found"})
        await websocket.close()
        return
    assistant = code_assistants[session_id]
    try:
        await websocket.send_json({"type": "status", "message": "Interactive session started."})
        while True:
            data = await websocket.receive_json()
            if data["type"] == "code":
                result = assistant.analyze_problem(data["code"], data.get("error"), data.get("objective"))
                suggested = assistant.get_code_from_response(result["solution"])
                await websocket.send_json({
                    "type": "analysis",
                    "analysis": result["analysis"],
                    "suggested_code": suggested,
                    "diff": assistant.show_diff(data["code"], suggested),
                })
            elif data["type"] == "choice":
                if data["choice"] == "a":
                    if not assistant.code_history:
                        await websocket.send_json({"type": "error", "message": "No suggestions available"})
                        continue
                    latest = assistant.code_history[-1]
                    assistant.current_code = assistant.get_code_from_response(latest["solution"])
                    success, output = assistant.test_code(assistant.current_code)
                    await websocket.send_json({"type": "test_result", "success": success, "output": output, "code": assistant.current_code})
                elif data["choice"] == "m":
                    if "code" not in data:
                        await websocket.send_json({"type": "error", "message": "Modified code not provided"})
                        continue
                    assistant.current_code = data["code"]
                    await websocket.send_json({"type": "status", "message": "Code modified", "code": assistant.current_code})
                elif data["choice"] == "t":
                    code_to_test = data.get("code", assistant.current_code)
                    success, output = assistant.test_code(code_to_test)
                    await websocket.send_json({"type": "test_result", "success": success, "output": output})
                elif data["choice"] == "embed":
                    directory = data.get("directory")
                    prefix = data.get("lang_prefix", "en")
                    temp_files: List[str] = []
                    try:
                        tts = load_xtts_model(force_cpu=False)
                        if tts is None:
                            raise RuntimeError("XTTS model failed to load")
                        embeddings = process_directory(directory, tts, prefix, temp_files)
                        save_embeddings(embeddings)
                        await websocket.send_json({"type": "embed_result", "count": len(embeddings)})
                    except Exception as e:
                        analysis = assistant.analyze_problem(open(__file__).read(), str(e), "Fix embedding extraction")
                        await websocket.send_json({"type": "error", "message": str(e), "analysis": analysis["analysis"]})
                    finally:
                        for f in temp_files:
                            if os.path.exists(f):
                                os.remove(f)
            elif data["type"] == "save":
                filename = data.get("filename", session_files[session_id])
                saved_file = assistant.save_session(filename)
                await websocket.send_json({"type": "status", "message": f"Session saved to {saved_file}"})
    except WebSocketDisconnect:
        logger.info("WebSocket disconnected: %s", session_id)
    except Exception as e:
        logger.error("WebSocket error: %s", e)
        await websocket.send_json({"type": "error", "message": str(e)})

@app.get("/")
async def root():
    return {"message": "CoRT XTTS Embedder is running"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("cort_xtts_embedder:app", host="0.0.0.0", port=8000, reload=True)



INFO:     Will watch for changes in these directories: ['/content']
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Started reloader process [1653] using StatReload
INFO:     Stopping reloader process [1653]


In [2]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful