In [1]:
!pip install -U pip wheel setuptools
!pip install -U tensorflow tensorflow-metal

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [4]:
import os, multiprocessing as mp
import tensorflow as tf

# Réduit le spam de logs
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")

def setup_apple_silicon():
    """Configure TF pour Apple Silicon (M1/M2/M3)."""
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        # --- Chemin GPU (Metal) ---
        # Mixed precision = GROS boost sur GPU Apple
        try:
            tf.keras.mixed_precision.set_global_policy("mixed_float16")
        except Exception:
            # (Keras 3 standalone)
            from keras import mixed_precision
            mixed_precision.set_global_policy("mixed_float16")

        tf.config.set_soft_device_placement(True)  # place auto sur CPU si besoin
        batch_size = 128
        print("✅ GPU Metal détecté :", gpus)
        print("🟣 Mixed precision activée (float16).")
        print(f"🧠 Batch size conseillé : {batch_size}")
        return dict(use_xla=False, on_gpu=True, batch=batch_size)     # XLA pas indispensable/utile ici
    else:
        # --- Chemin CPU ---
        cores = mp.cpu_count()
        batch_size = 32
        tf.keras.mixed_precision.set_global_policy("float32")  # plus stable sur CPU
        tf.config.threading.set_intra_op_parallelism_threads(cores)
        tf.config.threading.set_inter_op_parallelism_threads(min(4, max(2, cores // 4)))
        os.environ["OMP_NUM_THREADS"] = str(cores)
        print("🟡 Pas de GPU Metal. Optimisation CPU (threads).")
        print(f"🧠 Batch size conseillé : {batch_size}")
        return dict(use_xla=True, on_gpu=False, batch=batch_size)    # XLA (jit) souvent bénéfique sur CPU

CONF = setup_apple_silicon()


✅ GPU Metal détecté : [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
🟣 Mixed precision activée (float16).


In [5]:
!pip install -q faiss-cpu datasets pandas sentence-transformers sacrebleu tf-keras

In [6]:
# === Imports légers & ordonnés ===
import math
import random
import pathlib
import datetime as dt
from collections import Counter

import numpy as np
import tensorflow as tf
from tensorflow.keras import callbacks as Kcb
from tensorflow.keras import mixed_precision

# (optionnel) petites infos de run
print("TF", tf.__version__)
print("Devices:", tf.config.list_physical_devices())



TF 2.19.0
Devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
✅ Metal GPU détecté → policy 'mixed_float16' activée (variables en float32).


In [17]:
# =========================
# Données
# =========================
def load_squad_pairs():
    ds = load_dataset("squad", split="train")
    pairs = []
    for it in ds:
        ctx = (it["context"] or "").strip()
        q = (it["question"] or "").strip()
        ans = it["answers"]["text"][0].strip() if it["answers"]["text"] else ""
        if ctx and q and ans:
            pairs.append((f"{ctx}\nQ: {q}", ans))
    print(f"✅ SQuAD: {len(pairs)} paires")
    return pairs

def load_shirayuki_pairs(csv_path="shirayuki.csv"):
    df = pd.read_csv(csv_path)
    pairs = [(str(i).strip(), str(o).strip())
             for i,o in zip(df["guy"], df["girl"])
             if str(i).strip() and str(o).strip()]
    print(f"✅ Shirayuki: {len(pairs)} paires")
    return pairs

def split_pairs(pairs, val_ratio=0.02, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(pairs))
    rng.shuffle(idx)
    cut = max(1, int(len(pairs) * (1 - val_ratio)))
    train_idx, val_idx = idx[:cut], idx[cut:]
    train = [pairs[i] for i in train_idx]
    val = [pairs[i] for i in val_idx]
    return train, val

def make_ds_from_pairs(pairs, tokenizer, max_len=96, batch_size=64, shuffle=True):
    X = [x for x,_ in pairs]
    Y = [f"[START] {y} [END]" for _,y in pairs]
    enc = tokenizer(X)
    out = tokenizer(Y)
    dec_in = out[:, :-1]
    dec_tg = out[:, 1:]
    ds = tf.data.Dataset.from_tensor_slices(
        ({"encoder_input": enc, "decoder_input": dec_in}, dec_tg)
    )
    if shuffle:
        ds = ds.shuffle(10000)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    steps = math.ceil(len(pairs) / batch_size)
    return ds, steps

def prepare_datasets(pairs, tokenizer=None, vocab_size=20000, max_len=96, batch_size=64, val_ratio=0.02):
    train_pairs, val_pairs = split_pairs(pairs, val_ratio=val_ratio)
    X_all = [x for x,_ in pairs]
    Y_all = [f"[START] {y} [END]" for _,y in pairs]
    if tokenizer is None:
        tokenizer = TextVectorization(
            max_tokens=vocab_size,
            output_sequence_length=max_len,
            standardize="lower_and_strip_punctuation",
            split="whitespace"
        )
        tokenizer.adapt(X_all + Y_all)
    train_ds, train_steps = make_ds_from_pairs(train_pairs, tokenizer, max_len, batch_size, shuffle=True)
    val_ds, val_steps     = make_ds_from_pairs(val_pairs, tokenizer, max_len, batch_size, shuffle=False)
    return tokenizer, train_ds, val_ds, train_steps, val_steps

In [9]:
# =========================
# Mémoire FAISS (RAG light)
# =========================
import os
import faiss
from sentence_transformers import SentenceTransformer


EMBED_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
EMBED_DIM = 384
MEMORY_FILE = "shirayuki_memory.jsonl"
INDEX_FILE = "shirayuki_faiss.index"
index = faiss.read_index(INDEX_FILE) if os.path.exists(INDEX_FILE) else faiss.IndexFlatL2(EMBED_DIM)

def _encode(text): return np.array([EMBED_MODEL.encode(text)], dtype="float32")

def save_to_memory(user_text, bot_text):
    ts = datetime.datetime.now().isoformat()
    index.add(_encode(user_text))
    with open(MEMORY_FILE, "a", encoding="utf-8") as f:
        f.write(json.dumps({"input": user_text, "response": bot_text, "timestamp": ts}, ensure_ascii=False) + "\n")
    faiss.write_index(index, INDEX_FILE)

def search_memory(query, top_k=3):
    if index.ntotal == 0 or not os.path.exists(MEMORY_FILE): return []
    D, I = index.search(_encode(query), top_k)
    with open(MEMORY_FILE, "r", encoding="utf-8") as f:
        mem = [json.loads(l) for l in f]
    return [mem[i] for i in I[0] if 0 <= i < len(mem)]


  from .autonotebook import tqdm as notebook_tqdm


In [10]:

# =========================
# Masques (compatibles Keras MHA)
# =========================
PAD = 0
def padding_mask_2d(token_ids):
    return tf.cast(tf.not_equal(token_ids, PAD), tf.float32)   # (B,T)
def self_attention_mask(tokens):
    m = padding_mask_2d(tokens)                                # (B,T)
    return tf.einsum("bi,bj->bij", m, m)                       # (B,T,T)
def look_ahead_matrix(T):
    return tf.linalg.band_part(tf.ones((T, T), dtype=tf.float32), -1, 0)  # (T,T)
def decoder_self_mask(dec_tokens):
    m = padding_mask_2d(dec_tokens)                            # (B,Td)
    pad_pair = tf.einsum("bi,bj->bij", m, m)                   # (B,Td,Td)
    la = look_ahead_matrix(tf.shape(dec_tokens)[1])            # (Td,Td)
    return pad_pair * la                                       # (B,Td,Td)
def cross_attention_mask(dec_tokens, enc_tokens):
    m_dec = padding_mask_2d(dec_tokens)                        # (B,Td)
    m_enc = padding_mask_2d(enc_tokens)                        # (B,Te)
    return tf.einsum("bi,bj->bij", m_dec, m_enc)               # (B,Td,Te)


In [11]:
# =========================
# Modèle Transformer
# =========================
class Block(tf.keras.layers.Layer):
    def __init__(self, d, h, ff, drop=0.1, decoder=False):
        super().__init__()
        self.decoder = decoder
        self.self_att = MultiHeadAttention(num_heads=h, key_dim=d//h, dropout=drop)
        self.ln1 = LayerNormalization(epsilon=1e-6)
        self.do1 = Dropout(drop)
        if decoder:
            self.cross = MultiHeadAttention(num_heads=h, key_dim=d//h, dropout=drop)
            self.ln_c = LayerNormalization(epsilon=1e-6)
            self.do_c = Dropout(drop)
        self.ffn = tf.keras.Sequential([Dense(ff, activation="gelu"), Dense(d)])
        self.ln2 = LayerNormalization(epsilon=1e-6)
        self.do2 = Dropout(drop)
    def call(self, x, enc_out=None, self_mask=None, enc_mask=None, training=False):
        a = self.self_att(x, x, x, attention_mask=self_mask, training=training)
        x = self.ln1(x + self.do1(a, training=training))
        if self.decoder and enc_out is not None:
            a2 = self.cross(x, enc_out, enc_out, attention_mask=enc_mask, training=training)
            x = self.ln_c(x + self.do_c(a2, training=training))
        f = self.ffn(x)
        return self.ln2(x + self.do2(f, training=training))

class Seq2Seq(tf.keras.Model):
    def __init__(self, vocab_size, d=256, h=8, ff=768, max_len=96, L=4, drop=0.1):
        super().__init__()
        self.d, self.max_len = d, max_len
        self.tok_emb = Embedding(vocab_size, d)
        self.pos_emb = Embedding(max_len, d)
        self.enc = [Block(d, h, ff, drop, decoder=False) for _ in range(L)]
        self.dec = [Block(d, h, ff, drop, decoder=True) for _ in range(L)]
        self.final = Dense(vocab_size)
    def _add_pos(self, tok_ids):
        T = tf.shape(tok_ids)[1]
        return self.tok_emb(tok_ids) + self.pos_emb(tf.range(T)[tf.newaxis, :])
    def encode(self, enc_tokens, training=False):
        x = self._add_pos(enc_tokens)
        mask = self_attention_mask(enc_tokens)                 # (B,Te,Te)
        for blk in self.enc:
            x = blk(x, self_mask=mask, training=training)
        return x
    def decode(self, dec_tokens, enc_tokens, enc_out, training=False):
        y = self._add_pos(dec_tokens)
        self_m = decoder_self_mask(dec_tokens)                 # (B,Td,Td)
        cross_m = cross_attention_mask(dec_tokens, enc_tokens) # (B,Td,Te)
        for blk in self.dec:
            y = blk(y, enc_out=enc_out, self_mask=self_m, enc_mask=cross_m, training=training)
        return y
    def call(self, inputs, training=False):
        enc_tokens = inputs["encoder_input"]
        dec_tokens = inputs["decoder_input"]
        enc_out = self.encode(enc_tokens, training=training)
        dec_out = self.decode(dec_tokens, enc_tokens, enc_out, training=training)
        return self.final(dec_out)


In [12]:
# =========================
# Génération
# =========================
def build_generation(tokenizer, model):
    vocab = tokenizer.get_vocabulary()
    tok2id = {t:i for i,t in enumerate(vocab)}
    START = tok2id.get("[START]", 1)
    END = tok2id.get("[END]", 2)

    @tf.function(reduce_retracing=True)
    def _tf_encode(enc_tokens):
        return model.encode(enc_tokens, training=False)
    @tf.function(reduce_retracing=True)
    def _tf_decode(dec_tokens, enc_tokens, enc_out):
        y = model.decode(dec_tokens, enc_tokens, enc_out, training=False)
        return model.final(y)[:, -1, :]

    def generate_response(prompt, max_new_tokens=64, temperature=0.7, top_k=None, use_memory=True, save_mem=True):
        ctx = ""
        if use_memory:
            hits = search_memory(prompt, top_k=3)
            if hits:
                ctx = "\n".join([f"User: {m['input']}\nShirayuki: {m['response']}" for m in hits]) + "\n"
        full_inp = ctx + f"User: {prompt}\nShirayuki:"

        enc_tokens = tokenizer([full_inp])
        enc_out = _tf_encode(enc_tokens)

        y = tf.constant([[START]], dtype=tf.int64)
        for _ in range(max_new_tokens):
            logits = _tf_decode(y, enc_tokens, enc_out)
            if temperature and temperature > 0:
                logits = logits / temperature
                if top_k and top_k > 0:
                    values, indices = tf.math.top_k(logits, k=top_k)
                    probs = tf.nn.softmax(values)
                    next_id_rel = tf.random.categorical(tf.math.log(probs), 1)
                    next_id = tf.gather(indices, next_id_rel, batch_dims=1)
                    next_token = int(next_id.numpy()[0][0])
                else:
                    next_token = int(tf.random.categorical(logits, 1).numpy()[0][0])
            else:
                next_token = int(tf.argmax(logits, axis=-1).numpy()[0])
            if next_token == END: break
            y = tf.concat([y, tf.constant([[next_token]], dtype=tf.int64)], axis=1)

        id2tok = {i:t for i,t in enumerate(vocab)}
        toks = [id2tok.get(int(t), "") for t in y.numpy()[0] if int(t) not in (0, START, END)]
        text = " ".join(toks).strip()
        if save_mem:
            save_to_memory(prompt, text)
        return text or "[Aucune réponse générée]"

    return generate_response

In [None]:
# =========================
# Callbacks
# =========================
def build_callbacks(run_name="run"):
    ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = pathlib.Path("logs") / f"{run_name}-{ts}"
    ckpt_dir = pathlib.Path("ckpts") / f"{run_name}-{ts}"
    log_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # LR schedule: warmup -> cosine
    warmup_epochs = 1
    max_epochs = 50
    base_lr = 1e-3
    min_lr = 1e-5
    def lr_schedule(epoch, lr):
        if epoch < warmup_epochs:
            return base_lr * (epoch + 1) / warmup_epochs
        # cosine decay from base_lr to min_lr
        t = (epoch - warmup_epochs) / max(1, (max_epochs - warmup_epochs))
        return float(min_lr + 0.5*(base_lr - min_lr)*(1 + math.cos(math.pi * t)))

    def make_gen_cb(gen_fn):
        sample_prompts = ["Hello Shirayuki", "How are you today?"]
        def _on_epoch_end(epoch, logs=None):
            print("\n🧪 Samples:")
            for p in sample_prompts:
                print(" >", p)
                print(" >", gen_fn(p, temperature=0.8, top_k=40))
        return Kcb.LambdaCallback(on_epoch_end=_on_epoch_end)

    cbs = [
        Kcb.TensorBoard(log_dir=str(log_dir), histogram_freq=0, write_graph=True),
        Kcb.BackupAndRestore(backup_dir=str(log_dir / "backup")),
        Kcb.ModelCheckpoint(
            filepath=str(ckpt_dir / "{epoch:02d}-{val_loss:.3f}.weights.h5"),
            save_weights_only=True, monitor="val_loss", mode="min", save_best_only=True, verbose=1
        ),
        Kcb.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True, verbose=1),
        Kcb.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-5, verbose=1),
        Kcb.LearningRateScheduler(lr_schedule, verbose=0),
        Kcb.CSVLogger(str(log_dir / "training.csv"), append=False),
        Kcb.TerminateOnNaN(),
    ]
    return cbs


In [14]:
# ===== Imports utiles =====
import math, datetime, pathlib, random
from collections import Counter
import numpy as np
import tensorflow as tf
Kcb = tf.keras.callbacks

# -------------------------------------------------
# 1) Warmup + Cosine schedule (sur les *steps*)
# -------------------------------------------------
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, min_lr, warmup_steps, total_steps):
        self.base_lr = float(base_lr)
        self.min_lr = float(min_lr)
        self.warmup_steps = int(warmup_steps)
        self.total_steps = int(total_steps)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warm = tf.cond(step < self.warmup_steps,
            lambda: self.base_lr * (step + 1.0) / tf.maximum(1.0, float(self.warmup_steps)),
            lambda: self.min_lr + 0.5*(self.base_lr - self.min_lr) *
                    (1.0 + tf.cos(math.pi * tf.minimum(1.0,
                       (step - self.warmup_steps) / tf.maximum(1.0, float(self.total_steps - self.warmup_steps))))))
        return warm

# -------------------------------------------------
# 2) BLEU-4 et ROUGE-L (implémentation light)
# -------------------------------------------------
def _tokenize(s): return s.strip().split()

def _ngrams(toks, n):
    return [tuple(toks[i:i+n]) for i in range(len(toks)-n+1)]

def bleu4(ref, hyp, smooth=1.0):
    # ref/hyp: strings
    r = _tokenize(ref); h = _tokenize(hyp)
    if len(h) == 0: return 0.0
    precisions = []
    for n in range(1, 5):
        R = Counter(_ngrams(r, n)); H = Counter(_ngrams(h, n))
        overlap = sum((R & H).values())
        total = max(sum(H.values()), 1)
        precisions.append((overlap + smooth) / (total + smooth))
    bp = math.exp(1 - len(r)/max(len(h), 1)) if len(h) < len(r) else 1.0
    return float(bp * math.exp(sum(map(math.log, precisions)) / 4.0))

def _lcs_len(a, b):
    # a, b: list of tokens
    m, n = len(a), len(b)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m):
        ai = a[i]
        row = dp[i]
        row1 = dp[i+1]
        for j in range(n):
            row1[j+1] = row[j] + 1 if ai == b[j] else max(row1[j], row[j+1])
    return dp[m][n]

def rouge_l_f1(ref, hyp, beta=1.2):
    r = _tokenize(ref); h = _tokenize(hyp)
    if len(r) == 0 or len(h) == 0: return 0.0
    L = _lcs_len(r, h)
    p = L / len(h); rc = L / len(r)
    if p + rc == 0: return 0.0
    b2 = beta * beta
    return float((1 + b2) * p * rc / (rc + b2 * p))

# -------------------------------------------------
# 3) Callback d’éval NLG (BLEU/ROUGE/PPL + TB)
# -------------------------------------------------
class EvalNLG(Kcb.Callback):
    def __init__(self, gen_fn, eval_pairs, log_dir, every=1, name_prefix="val"):
        super().__init__()
        self.gen = gen_fn
        self.pairs = eval_pairs
        self.every = int(every)
        self.name = name_prefix
        self.tb = tf.summary.create_file_writer(str(log_dir))

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.every != 0: return
        logs = logs or {}
        preds, refs = [], []
        for q, a in self.pairs:
            # greedy pour des métriques stables
            hyp = self.gen(q, temperature=0.0)  # top_k=None par défaut
            preds.append(hyp); refs.append(a)
        bleu = float(np.mean([bleu4(r, h) for r, h in zip(refs, preds)]))
        rouge = float(np.mean([rouge_l_f1(r, h) for r, h in zip(refs, preds)]))
        ppl = float(np.exp(logs['val_loss'])) if 'val_loss' in logs else float('nan')

        # injecter dans logs -> utilisable par EarlyStopping/Checkpoint
        logs[f'{self.name}_bleu'] = bleu
        logs[f'{self.name}_rougeL'] = rouge
        logs[f'{self.name}_perplexity'] = ppl

        # log TensorBoard
        lr = self._current_lr()
        with self.tb.as_default():
            tf.summary.scalar(f'{self.name}_bleu', bleu, step=epoch)
            tf.summary.scalar(f'{self.name}_rougeL', rouge, step=epoch)
            tf.summary.scalar(f'{self.name}_perplexity', ppl, step=epoch)
            if lr is not None:
                tf.summary.scalar('lr', lr, step=epoch)

        # console
        print(f"\n📊 {self.name.upper()} — BLEU: {bleu:.3f} | ROUGE-L: {rouge:.3f} | PPL: {ppl:.1f} | LR: {lr:.2e if lr else np.nan}")

        # petit aperçu
        for p in ["Hello Shirayuki", "How are you today?"]:
            print(" >", p)
            print(" >", self.gen(p, temperature=0.8, top_k=40))

    def _current_lr(self):
        lr = getattr(self.model.optimizer, 'learning_rate', None)
        if lr is None: return None
        try:
            # schedule -> callable
            return float(lr(self.model.optimizer.iterations))
        except TypeError:
            # constant -> variable
            return float(tf.keras.backend.get_value(lr))

# -------------------------------------------------
# 4) Callbacks pack (cohérent et centré ROUGE-L)
# -------------------------------------------------
def build_callbacks_optim(run_name, gen_fn, eval_pairs, max_epochs, steps_per_epoch):
    ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = pathlib.Path("logs") / f"{run_name}-{ts}"
    ckpt_dir = pathlib.Path("ckpts") / f"{run_name}-{ts}"
    log_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    eval_cb = EvalNLG(gen_fn, eval_pairs, log_dir, every=1, name_prefix="val")

    cbs = [
        # IMPORTANT: EvalNLG AVANT Checkpoint/EarlyStopping (il insère 'val_rougeL' dans logs)
        eval_cb,
        Kcb.TensorBoard(log_dir=str(log_dir), histogram_freq=0, write_graph=True, profile_batch=(10, 20)),
        Kcb.BackupAndRestore(backup_dir=str(log_dir / "backup")),
        # meilleur modèle par ROUGE-L
        Kcb.ModelCheckpoint(
            filepath=str(ckpt_dir / "best-rouge-{epoch:02d}-{val_rougeL:.3f}.weights.h5"),
            save_weights_only=True, monitor="val_rougeL", mode="max", save_best_only=True, verbose=1
        ),
        # on garde aussi le meilleur par val_loss (utile pour perplexity)
        Kcb.ModelCheckpoint(
            filepath=str(ckpt_dir / "best-loss-{epoch:02d}-{val_loss:.3f}.weights.h5"),
            save_weights_only=True, monitor="val_loss", mode="min", save_best_only=True, verbose=1
        ),
        # early stop sur la vraie fitness
        Kcb.EarlyStopping(monitor="val_rougeL", patience=4, mode="max", restore_best_weights=True, verbose=1),
        Kcb.CSVLogger(str(log_dir / "training.csv"), append=False),
        Kcb.TerminateOnNaN(),
    ]
    return cbs, log_dir


In [None]:
# =========================
# Données + modèle (Apple Silicon ready)
# =========================
import os, random
import numpy as np
import tensorflow as tf
from datasets import load_dataset

# (Optionnel) un peu plus de reproductibilité
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Charge les paires
squad_pairs = load_squad_pairs()

# Batch dynamique selon GPU/CPU
tokenizer, squad_train, squad_val, squad_steps, squad_val_steps = prepare_datasets(
    squad_pairs, vocab_size=20000, max_len=96, batch_size=CONF["batch"], val_ratio=0.02
)

# (Recommandé) écarter le padding (id=0) de la loss/accuracy via sample_weight
PAD_ID = 0
def add_mask(x, y):
    # sw: (B, T) float32 ; 1 pour tokens non-pad, 0 pour pad
    sw = tf.cast(tf.not_equal(y, PAD_ID), tf.float32)
    return x, y, sw

AUTOTUNE = tf.data.AUTOTUNE
# + parallel map et cache (utile si les paires tiennent en RAM)
squad_train = (squad_train
               .map(add_mask, num_parallel_calls=AUTOTUNE)
               .prefetch(AUTOTUNE))
squad_val   = (squad_val
               .map(add_mask, num_parallel_calls=AUTOTUNE)
               .prefetch(AUTOTUNE))

# Scheduler de LR sur les steps (on scale linéairement si batch changé)
max_epochs   = 50
scale        = float(CONF["batch"]) / 64.0
base_lr      = 3e-4 * scale
min_lr       = 1e-5
total_steps  = int(squad_steps * max_epochs)
warmup_steps = int(0.1 * total_steps)

lr_sched = WarmupCosine(base_lr=base_lr, min_lr=min_lr,
                        warmup_steps=warmup_steps, total_steps=total_steps)

# Optimizer = AdamW + clipnorm (compat Keras/TF)
# Essaye TF-Keras 2.x, puis Keras 3, puis Adam fallback
opt = None
# TF >= 2.11
try:
    opt = tf.keras.optimizers.AdamW(
        learning_rate=lr_sched, weight_decay=1e-4, clipnorm=1.0
    )
except Exception:
    pass
# Keras 3 (backend-agnostic)
if opt is None:
    try:
        import keras
        opt = keras.optimizers.AdamW(
            learning_rate=lr_sched, weight_decay=1e-4, clipnorm=1.0
        )
    except Exception:
        opt = tf.keras.optimizers.Adam(learning_rate=lr_sched, clipnorm=1.0)

# Modèle
# TextVectorization a bien tokenizer.vocabulary_size()
vocab_size = tokenizer.vocabulary_size() if hasattr(tokenizer, "vocabulary_size") else \
             (getattr(tokenizer, "num_words", None) or getattr(tokenizer, "vocab_size", None) or 20000)

model = Seq2Seq(
    vocab_size=vocab_size, d=256, h=8, ff=768,
    max_len=96, L=4, drop=0.1
)

# XLA: sur Apple/Metal, évite XLA GPU. On l’active seulement s’il n’y a PAS de GPU logique.
use_xla_cpu_only = bool(CONF.get("use_xla")) and (len(tf.config.list_logical_devices('GPU')) == 0)

# Compile
model.compile(
    optimizer=opt,
    loss=SparseCEFromLogitsFP32(),                   # stable avec mixed precision
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="tok_acc")],
    jit_compile=use_xla_cpu_only
)

# Génération pour les callbacks
generate_response = build_generation(tokenizer, model)

# Petit set d'éval texte (256 QA aléatoires mais fixes)
random.seed(SEED)
eval_pairs = random.sample(squad_pairs, k=min(256, len(squad_pairs)))

cbs, log_dir = build_callbacks_optim(
    "pretrain_squad", generate_response, eval_pairs, max_epochs, squad_steps
)

model.build(input_shape={
    "encoder_input": (None, None),   # (batch, seq_len)
    "decoder_input": (None, None)
})

print("🚀 Pré-entraînement sur SQuAD (optimisé M3 — GPU si dispo)…")

history = model.fit(
    squad_train,
    validation_data=squad_val,
    epochs=max_epochs,
    steps_per_epoch=int(squad_steps),
    validation_steps=int(squad_val_steps),
    callbacks=cbs
)


✅ SQuAD: 87599 paires


KeyError: 'batch'

In [None]:
with tf.device("/CPU:0"):
    # model.fit(
    #     squad_train,
    #     validation_data=squad_val,
    #     epochs=5,
    #     steps_per_epoch=squad_steps,
    #     validation_steps=squad_val_steps,
    #     callbacks=cbs_pre,
    #     verbose=1
    # )

    shirayuki_pairs = load_shirayuki_pairs("/Users/christopher/Documents/IA/ani/datasets/conversation_dataset_ShirayukiV3.csv")   # <-- assure le fichier présent
    _, sh_train, sh_val, sh_steps, sh_val_steps = prepare_datasets(
        shirayuki_pairs, tokenizer=tokenizer, max_len=96, batch_size=64, val_ratio=0.05
    )
    cbs_ft = build_callbacks("finetune_shirayuki")
    def _on_epoch_end_ft(epoch, logs=None):
        print("\n🧪 FT Samples:")
        for p in ["Hello Shirayuki", "Peux-tu m'aider à planifier ma journée ?"]:
            print(" >", p)
            print(" >", generate_response(p, temperature=0.8, top_k=40))
    cbs_ft.append(Kcb.LambdaCallback(on_epoch_end=_on_epoch_end_ft))

    print("🔄 Fine-tuning sur Shirayuki...")
    model.fit(
        sh_train,
        validation_data=sh_val,
        epochs=10,
        steps_per_epoch=sh_steps,
        validation_steps=sh_val_steps,
        callbacks=cbs_ft,
        verbose=1
    )


✅ Shirayuki: 4362 paires
🔄 Fine-tuning sur Shirayuki...
Epoch 1/10
[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - loss: 1.0858
Epoch 1: val_loss improved from inf to 0.91253, saving model to ckpts/finetune_shirayuki-20250810-221942/01-0.913.weights.h5

🧪 FT Samples:
 > Hello Shirayuki
 > if i that ii mean even get that do just just i say it not like i end
 > Peux-tu m'aider à planifier ma journée ?
 > up at that end
[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m170s[0m 3s/step - loss: 1.0844 - val_loss: 0.9125 - learning_rate: 0.0010
Epoch 2/10
[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - loss: 0.8993
Epoch 2: val_loss improved from 0.91253 to 0.88246, saving model to ckpts/finetune_shirayuki-20250810-221942/02-0.882.weights.h5

🧪 FT Samples:
 > Hello Shirayuki
 > even of
 > Peux-tu m'aider à planifier ma journée ?
 > to hold so i mean that things ii just just think it end
[1m65/65[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3

In [None]:
# =========================
# Démo rapide post-entraînement
# =========================
tests = [
    "Hello Shirayuki",
    "How are you today?",
    "What's your favorite music?",
    "Peux-tu m'aider à planifier ma journée ?"
]
for t in tests:
    print("\n> 💬", t)
    print("> 🤖", generate_response(t, temperature=0.8, top_k=40))


> 💬 Hello Shirayuki
> 🤖 end

> 💬 How are you today?
> 🤖 end and end

> 💬 What's your favorite music?
> 🤖 end end

> 💬 Peux-tu m'aider à planifier ma journée ?
> 🤖 end
