In [22]:
import json
import os
from datetime import datetime, timedelta, time
import matplotlib.pyplot as plt
import yfinance as yf
import pandas as pd
import pytz
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from collections import defaultdict
from matplotlib.lines import Line2D
import glob
from bisect import bisect_left

companies = {
    "Apple": "AAPL", "Microsoft": "MSFT", "Amazon": "AMZN", "Alphabet": "GOOGL", "Meta": "META",
    "Tesla": "TSLA", "NVIDIA": "NVDA", "Samsung": "005930.KS", "Tencent": "TCEHY", "Alibaba": "BABA",
    "IBM": "IBM", "Intel": "INTC", "Oracle": "ORCL", "Sony": "SONY", "Adobe": "ADBE",
    "Netflix": "NFLX", "AMD": "AMD", "Qualcomm": "QCOM", "Cisco": "CSCO", "JP Morgan": "JPM",
    "Goldman Sachs": "GS", "Visa": "V", "Johnson & Johnson": "JNJ", "Pfizer": "PFE",
    "ExxonMobil": "XOM", "ASML": "ASML.AS", "SAP": "SAP.DE", "Siemens": "SIE.DE",
    "Louis Vuitton (LVMH)": "MC.PA", "TotalEnergies": "TTE.PA", "Shell": "SHEL.L",
    "Baidu": "BIDU", "JD.com": "JD", "BYD": "BYDDY", "ICBC": "1398.HK", "Toyota": "TM",
    "SoftBank": "9984.T", "Nintendo": "NTDOY", "Hyundai": "HYMTF", "Reliance Industries": "RELIANCE.NS",
    "Tata Consultancy Services": "TCS.NS"
}

def log(message):
    print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {message}")

def clean_text(text):
    return text.strip().replace("\n", " ").replace("\r", " ")

def convert_utc_to_ny(timestamp_str):
    utc_dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
    ny_tz = pytz.timezone("America/New_York")
    ny_dt = utc_dt.astimezone(ny_tz)
    return ny_dt.replace(minute=0, second=0, microsecond=0)

def get_texts_timestamps(news_data):
    texts = []
    timestamps = []
    for day_articles in news_data.values():
        for article in day_articles:
            ts = convert_utc_to_ny(article['publishedAt'])
            text = clean_text(article.get("title", "") + " " + article.get("description", ""))
            texts.append(text)
            timestamps.append(ts)
    return texts, timestamps

def get_sentiments(model_path, texts):
    log(f"Chargement du modèle depuis {model_path}")
    tokenizer = BertTokenizer.from_pretrained("ProsusAI/finbert")
    model = BertForSequenceClassification.from_pretrained(model_path)
    model.eval()
    sentiments = []

    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
        pred = torch.argmax(outputs.logits, dim=1).item()
        sentiments.append(pred)
    return sentiments

def align_timestamps(timestamps):
    from pandas.tseries.holiday import USFederalHolidayCalendar

    aligned = []
    ny_tz = pytz.timezone("America/New_York")
    calendar = USFederalHolidayCalendar()
    holidays = set(h.date() for h in calendar.holidays(start="2025-01-01", end="2025-12-31"))

    for ts in timestamps:
        ts = ts.astimezone(ny_tz)
        local_date = ts.date()
        local_time = ts.time()
        weekday = ts.weekday()  # 0=lundi, 6=dimanche

        # Cas 1 : weekend → reculer au vendredi précédent
        if weekday == 5:  # samedi
            target = ts - timedelta(days=1)
        elif weekday == 6:  # dimanche
            target = ts - timedelta(days=2)
        # Cas 2 : jour férié
        elif local_date in holidays:
            target = ts - timedelta(days=1)
            while target.date() in holidays or target.weekday() >= 5:
                target -= timedelta(days=1)
        # Cas 3 : jour de marché
        else:
            if datetime.strptime("09:30", "%H:%M").time() <= local_time < datetime.strptime("15:00", "%H:%M").time():
                aligned.append(ts.replace(minute=0, second=0, microsecond=0))
                continue
            elif local_time >= datetime.strptime("15:00", "%H:%M").time():
                aligned.append(ts.replace(hour=15, minute=0, second=0, microsecond=0))
                continue
            else:
                target = ts - timedelta(days=1)

        # S'assurer que le timestamp aligné est localisé New York
        aligned_dt = datetime.combine(target.date(), time(15, 0))
        aligned.append(ny_tz.localize(aligned_dt))

    return aligned


def plot_comparison(df, sentiments_a, sentiments_b, timestamps, title_a, title_b):
    aligned_ts = align_timestamps(timestamps)

    def group_by_time(ts_list, sentiments_list):
        grouped = defaultdict(list)
        for t, s in zip(ts_list, sentiments_list):
            grouped[t].append(s)
        return grouped

    grouped_a = group_by_time(aligned_ts, sentiments_a)
    grouped_b = group_by_time(aligned_ts, sentiments_b)

    def plot_sub(df, ax, grouped, title):
        df = df.set_index("Datetime" if "Datetime" in df.columns else df.columns[0])
        index_list = df.index.to_list()
        ax.plot(df.index, df["Close"], label="Price", color="black")
        colors = {0: "red", 1: "orange", 2: "green"}
        offset = 0.5

        for t, s_list in grouped.items():
            pos = bisect_left(index_list, t)
            if pos == len(index_list):
                continue
            nearest = index_list[pos] if abs(index_list[pos] - t) <= timedelta(minutes=90) else None
            if nearest:
                price = df.loc[nearest]["Close"]
                for i, s in enumerate(s_list):
                    ax.scatter(nearest, price + i * offset, color=colors[s], s=60)

        ax.set_title(title)
        ax.set_ylabel("Price")
        ax.grid(True)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10), sharex=True)
    plot_sub(df, ax1, grouped_a, title_a)
    plot_sub(df, ax2, grouped_b, title_b)

    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Positive', markerfacecolor='green', markersize=10),
        Line2D([0], [0], marker='o', color='w', label='Neutral', markerfacecolor='orange', markersize=10),
        Line2D([0], [0], marker='o', color='w', label='Negative', markerfacecolor='red', markersize=10),
        Line2D([0], [0], color='black', lw=2, label='Price')
    ]
    ax2.legend(handles=legend_elements)
    plt.tight_layout()
    plt.show()

def run_analysis(company, json_path, model_path_a, model_path_b):
    log(f"Téléchargement des prix pour {company}")
    ticker_symbol = companies.get(company, company)
    ticker = yf.Ticker(ticker_symbol)
    df = ticker.history(start="2025-01-01", end="2025-04-15", interval="60m")
    df = df.reset_index() if 'Datetime' not in df.columns else df

    log(f"Chargement des news depuis {json_path}")
    with open(json_path, 'r', encoding='utf-8') as f:
        news_data = json.load(f)

    texts, timestamps = get_texts_timestamps(news_data)
    sentiments_a = get_sentiments(model_path_a, texts)
    sentiments_b = get_sentiments(model_path_b, texts)
    plot_comparison(df, sentiments_a, sentiments_b, timestamps, "Model A", "Model B")

def count_news_per_company(json_dir):
    summary = {}
    for path in glob.glob(os.path.join(json_dir, "*_news.json")):
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        total = sum(len(v) for v in data.values())
        company = os.path.basename(path).replace("_news.json", "")
        summary[company] = (total, path)
    return {k: v for k, v in sorted(summary.items(), key=lambda x: x[1][0], reverse=True)}
def test_run_analysis():
    import matplotlib
    matplotlib.use('Agg')  # Pour exécuter sans afficher
    import matplotlib.pyplot as plt
    from collections import Counter

    news_counts = count_news_per_company("JSONS")
    if not news_counts:
        raise RuntimeError("Aucun fichier JSON valide trouvé dans le dossier 'JSONS'.")

    company, (_, json_path) = next(iter(news_counts.items()))
    ticker_symbol = companies.get(company, company)
    log(f"⚙️ Test sur {company} ({ticker_symbol})")

    # Charger données
    with open(json_path, 'r', encoding='utf-8') as f:
        news_data = json.load(f)
    texts, timestamps = get_texts_timestamps(news_data)

    sentiments_a = get_sentiments("./ProsusAI_finetuned", texts)
    sentiments_b = get_sentiments("./finbert_finetuned", texts)

    assert len(texts) == len(timestamps) == len(sentiments_a) == len(sentiments_b), \
        "❌ Mismatch entre les longueurs des listes."

    for s in sentiments_a + sentiments_b:
        assert s in {0, 1, 2}, f"❌ Label de sentiment invalide : {s}"

    aligned_ts = align_timestamps(timestamps)
    grouped = defaultdict(list)
    for t, s in zip(aligned_ts, sentiments_a):  # on teste pour le modèle A
        grouped[t].append(s)

    # Charger prix
    df = yf.Ticker(ticker_symbol).history(start="2025-01-01", interval="60m")
    df = df.reset_index() if 'Datetime' not in df.columns else df
    df = df.set_index("Datetime" if "Datetime" in df.columns else df.columns[0])
    index_list = df.index.to_list()

    # Tracer sur figure cachée
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(df.index, df["Close"], color="black")

    from bisect import bisect_left
    colors = {0: "red", 1: "orange", 2: "green"}
    total_points = 0

    for t, s_list in grouped.items():
        pos = bisect_left(index_list, t)
        if pos == len(index_list):
            continue
        nearest = index_list[pos] if abs(index_list[pos] - t) <= timedelta(minutes=60) else None
        if nearest:
            price = df.loc[nearest]["Close"]
            for i, s in enumerate(s_list):
                ax.scatter(nearest, price + i * 0.5, color=colors[s], s=60)
                total_points += 1

    expected = len(sentiments_a)
    missing_points = []

    for t, s_list in grouped.items():
        pos = bisect_left(index_list, t)
        nearest = None
        if pos < len(index_list) and abs(index_list[pos] - t) <= timedelta(minutes=60):
            nearest = index_list[pos]
        elif pos > 0 and abs(index_list[pos - 1] - t) <= timedelta(minutes=60):
            nearest = index_list[pos - 1]

        if nearest:
            price = df.loc[nearest]["Close"]
            for i, s in enumerate(s_list):
                ax.scatter(nearest, price + i * 0.5, color=colors[s], s=60)
                total_points += 1
        else:
            # Ajouter les timestamps manquants
            for s in s_list:
                missing_points.append((t, s))

    # Affichage d'erreur s'il manque des points
    if len(missing_points) > 0:
        print(f"❌ Seulement {total_points}/{expected} points ont été affichés.")
        print("❌ Timestamps sans correspondance dans les données de prix (±60 minutes) :")
        for t, s in missing_points:
            print(f"   - {t} (sentiment: {s})")
    else:
        log(f"✅ Tous les {expected} points ont bien été affichés.")

        log("✔ Résumé des prédictions :")
        print("  ProsusAI:", Counter(sentiments_a))
        print("  Fine-tuné:", Counter(sentiments_b))
        log(f"✅ Tous les tests sont passés et {total_points}/{expected} points ont bien été affichés.")


if __name__ == "__main__":
    #test_run_analysis()

    news_counts = count_news_per_company("JSONS")
    print("Entreprises avec le plus de news :")
    for company, (count, _) in news_counts.items():
        print(f"{company}: {count} news")

    top_2_companies = list(news_counts.items())[:2]
    for company, (_, path) in top_2_companies:
        log(f"Analyse pour {company}")
        run_analysis(
            company=company,
            json_path=path,
            model_path_a="./ProsusAI_finetuned",
            model_path_b="./finbert_finetuned"
        )


Entreprises avec le plus de news :
Tesla: 90 news
Shell: 83 news
Amazon: 52 news
Apple: 49 news
NVIDIA: 47 news
Meta: 28 news
Microsoft: 25 news
Visa: 21 news
Goldman_Sachs: 17 news
Intel: 16 news
Samsung: 15 news
SAP: 15 news
Oracle: 14 news
BYD: 13 news
Alphabet: 10 news
SoftBank: 8 news
TotalEnergies: 8 news
Cisco: 7 news
Toyota: 7 news
Alibaba: 6 news
AMD: 6 news
Hyundai: 5 news
IBM: 5 news
Siemens: 5 news
ASML: 4 news
Adobe: 3 news
Nintendo: 3 news
Sony: 3 news
ExxonMobil: 2 news
Johnson_&_Johnson: 2 news
Reliance_Industries: 2 news
Tencent: 2 news
Baidu: 1 news
Pfizer: 1 news
Qualcomm: 1 news
Tata_Consultancy_Services: 1 news
ICBC: 0 news
JP_Morgan: 0 news
Louis_Vuitton_(LVMH): 0 news
[2025-05-17 12:52:50] Analyse pour Tesla
[2025-05-17 12:52:50] Téléchargement des prix pour Tesla
[2025-05-17 12:52:50] Chargement des news depuis JSONS\Tesla_news.json
[2025-05-17 12:52:50] Chargement du modèle depuis ./ProsusAI_finetuned
[2025-05-17 12:53:46] Chargement du modèle depuis ./finbert_

  plt.show()


[2025-05-17 12:54:40] Analyse pour Shell
[2025-05-17 12:54:40] Téléchargement des prix pour Shell
[2025-05-17 12:54:40] Chargement des news depuis JSONS\Shell_news.json
[2025-05-17 12:54:40] Chargement du modèle depuis ./ProsusAI_finetuned
[2025-05-17 12:55:38] Chargement du modèle depuis ./finbert_finetuned


  plt.show()
