Importation des Bibliothéques

In [41]:
import gym
from gym import spaces
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from typing import Dict, Tuple, List, Any
from functools import partial
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import random
import os
import pickle
from sklearn.model_selection import train_test_split

In [42]:
class ExperienceBuffer:
    def __init__(self, capacity: int = 10000):
        self.buffer = []
        self.capacity = capacity

    def add(self, state: Dict, action: np.ndarray, reward: float,
            next_state: Dict, done: bool) -> None:
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int) -> List:
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

Implémentation de l'Environement

In [43]:
class TextFeatureSelectionEnv(gym.Env):
    def __init__(self, embeddings: np.ndarray, labels: np.ndarray,
                 model: Any, batch_size: int = 32):
        super().__init__()

        self._validate_inputs(embeddings, labels)

        self.embeddings = jnp.array(embeddings)
        self.labels = jnp.array(labels)
        self.model = model
        self.batch_size = batch_size

        self.num_samples, self.num_tokens, self.embedding_dim = embeddings.shape

        self.action_space = spaces.Box(
            low=0, high=1,
            shape=(self.num_tokens,),
            dtype=np.float32
        )
        self.observation_space = spaces.Dict({
            'embedding': spaces.Box(
                low=-np.inf, high=np.inf,
                shape=(self.num_tokens, self.embedding_dim),
                dtype=np.float32
            ),
            'mask': spaces.Box(
                low=0, high=1,
                shape=(self.num_tokens,),
                dtype=np.float32
            ),
            'performance': spaces.Box(
                low=0, high=1,
                shape=(1,),
                dtype=np.float32
            )
        })

        self.current_batch_idx = 0
        self.current_episode = 0
        self.performance_history = []
        self.token_usage_stats = jnp.zeros(self.num_tokens)
        self.best_features = {i: [] for i in range(labels.shape[1])}  # Dictionnaire pour les meilleures caractéristiques

    def _validate_inputs(self, embeddings: np.ndarray, labels: np.ndarray) -> None:
        if not isinstance(embeddings, np.ndarray) or not isinstance(labels, np.ndarray):
            raise TypeError("Embeddings and labels must be numpy arrays")
        if len(embeddings.shape) != 3:
            raise ValueError("Embeddings must be 3-dimensional")
        if len(labels.shape) != 2:
            raise ValueError("Labels must be 2-dimensional")
        if embeddings.shape[0] != labels.shape[0]:
            raise ValueError("Number of samples must match between embeddings and labels")

    def _get_batch(self) -> Tuple[jnp.ndarray, jnp.ndarray]:
        start_idx = self.current_batch_idx * self.batch_size
        end_idx = min(start_idx + self.batch_size, self.num_samples)
        return (
            self.embeddings[start_idx:end_idx],
            self.labels[start_idx:end_idx]
        )

    def reset(self) -> Dict[str, jnp.ndarray]:
        key = jax.random.PRNGKey(self.current_episode)
        perm = jax.random.permutation(key, self.num_samples)
        self.embeddings = self.embeddings[perm]
        self.labels = self.labels[perm]

        self.current_batch_idx = 0
        self.current_episode += 1

        batch_embeddings, _ = self._get_batch()

        return {
            'embedding': batch_embeddings[0],
            'mask': jnp.ones(self.num_tokens),
            'performance': jnp.zeros(1)
        }

    @staticmethod
    @jax.jit
    def _compute_metrics(predictions: jnp.ndarray, true_labels: jnp.ndarray) -> Dict[str, jnp.ndarray]:
        pred_labels = jnp.argmax(predictions, axis=1)
        true_labels = jnp.argmax(true_labels, axis=1)
        accuracy = jnp.mean(pred_labels == true_labels)

        return {
            'accuracy': accuracy,
            'pred_labels': pred_labels
        }

    def step(self, action: jnp.ndarray) -> Tuple[Dict[str, jnp.ndarray], float, bool, Dict]:
        batch_embeddings, batch_labels = self._get_batch()

        action = jnp.clip(action, 0, 1)
        masked_embeddings = batch_embeddings * action[:, None]

        predictions = self.model(masked_embeddings)
        predictions = jnp.array(predictions)

        metrics = self._compute_metrics(predictions, batch_labels)

        self.token_usage_stats += action
        self.performance_history.append(metrics['accuracy'].item())

        reward = self._compute_reward(metrics['accuracy'], action)

        self.current_batch_idx += 1
        done = self.current_batch_idx * self.batch_size >= self.num_samples

        current_class_id = int(jnp.argmax(batch_labels[0]))

        selected_features = jnp.where(action > 0.5)[0].tolist()
        feature_names = [f'feature_{i}' for i in selected_features]  # Générer les noms de caractéristiques
        self.best_features[current_class_id].extend(feature_names)

        if not done:
            next_batch_embeddings, _ = self._get_batch()
            next_state = {
                'embedding': next_batch_embeddings[0],
                'mask': action,
                'performance': jnp.array([metrics['accuracy']])
            }
        else:
            next_state = self.reset()

        info = {
            'metrics': metrics,
            'token_usage': self.token_usage_stats / (self.current_batch_idx + 1),
            'class_id': current_class_id,
            'selected_features': selected_features
        }

        return next_state, float(reward), done, info

    def _compute_reward(self, accuracy: float, action: jnp.ndarray) -> float:
        # Contraindre l'action à sélectionner un maximum de 64 caractéristiques
        max_features = 64
        sorted_indices = jnp.argsort(action)[::-1]  # Trier les indices par ordre décroissant de valeur
        top_indices = sorted_indices[:max_features]  # Sélectionner les 64 meilleurs
    
        # Créer une action finale avec uniquement les 64 meilleurs activés
        final_action = jnp.zeros_like(action)
        final_action = final_action.at[top_indices].set(1)  # Activer uniquement les 64 meilleurs
    
        # Calculer la pénalité pour la parcimonie
        sparsity_penalty = 0.1 * jnp.mean(final_action)
    
        # Calculer le nombre de caractéristiques sélectionnées
        num_selected = jnp.sum(final_action)
    
        # Appliquer une pénalité si le nombre de caractéristiques dépasse 64 (ne devrait pas arriver avec la contrainte stricte)
        feature_count_penalty = max(0, (num_selected - max_features) * 0.05)
    
        # Calculer la récompense finale
        return float(accuracy - sparsity_penalty - feature_count_penalty)

In [44]:
class PolicyNetwork(nn.Module):
    """Attention-based policy network for token selection"""

    @nn.compact
    def __call__(self, state: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        embedding = state['embedding']
        mask = state['mask']
        performance = state['performance']

        if embedding.ndim == 2:
            embedding = jnp.expand_dims(embedding, 0)
            mask = jnp.expand_dims(mask, 0)
            performance = jnp.expand_dims(performance, 0)

        x = nn.Dense(512)(embedding)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)

        q = nn.Dense(512)(x)
        k = nn.Dense(512)(x)
        v = nn.Dense(512)(x)

        scale = jnp.sqrt(512)
        scores = jnp.einsum('bik,bjk->bij', q, k) / scale
        attention_weights = nn.softmax(scores, axis=-1)
        attended = jnp.einsum('bij,bjk->bik', attention_weights, v)

        mask_info = jnp.expand_dims(mask, -1)
        perf_info = jnp.broadcast_to(
            jnp.expand_dims(performance, 1),
            (attended.shape[0], attended.shape[1], 1)
        )

        x = jnp.concatenate([attended, mask_info, perf_info], axis=-1)

        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = nn.sigmoid(x)

        return x.squeeze((-1, 0))

In [45]:

class RLAgent:
    def __init__(
        self,
        env: gym.Env,
        model: Any,  # Ajoutez le modèle ici
        learning_rate: float = 1e-4,
        gamma: float = 0.99,
        buffer_capacity: int = 10000
    ):
        self.env = env
        self.model = model  # Stockez le modèle dans l'agent
        self.gamma = gamma
        self.experience_buffer = ExperienceBuffer(buffer_capacity)

        # Initialize policy
        self.policy = PolicyNetwork()
        self.optimizer = optax.adam(learning_rate)

        # Initialize parameters
        dummy_state = env.reset()
        self.params = self.policy.init(
            jax.random.PRNGKey(0),
            {k: jnp.expand_dims(v, 0) for k, v in dummy_state.items()}
        )
        self.opt_state = self.optimizer.init(self.params)
        # Initialiser best_features pour chaque classe
        self.best_features = {i: [] for i in range(env.labels.shape[1])}  # Dictionnaire pour les meilleures caractéristiques

    @partial(jax.jit, static_argnums=(0,))
    def _policy_step(
        self,
        params: Dict,
        state: Dict[str, jnp.ndarray],
        key: jnp.ndarray,
        k: int = 10
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        logits = self.policy.apply(params, state)

        # Ajouter du bruit pour l'exploration
        noise = jax.random.normal(key, logits.shape) * 0.1
        logits = logits + noise

        action = self._select_top_k_features(logits, k)

        log_prob = jnp.sum(
            action * jnp.log(logits + 1e-8) +
            (1 - action) * jnp.log(1 - logits + 1e-8)
        )
        return action, log_prob

    @partial(jax.jit, static_argnums=(0,))
    def _compute_loss(self, params: Dict, batch: List) -> jnp.ndarray:
        total_loss = 0.0
        for state, action, reward, _, _ in batch:
            logits = self.policy.apply(params, state)
            log_prob = jnp.sum(
                action * jnp.log(logits + 1e-8) +
                (1 - action) * jnp.log(1 - logits + 1e-8)
            )
            total_loss -= log_prob * reward
        return total_loss / len(batch)

    def _select_top_k_features(self, logits: jnp.ndarray, k: int) -> jnp.ndarray:
        k = jnp.clip(k, 0, logits.shape[0])  # Assurer que k ne dépasse pas le nombre de logits
        threshold = jnp.sort(logits)[-k]
        return jnp.where(logits >= threshold, 1.0, 0.0)

    def train(
        self,
        num_episodes: int,
        max_steps_per_episode: int = 100,
        batch_size: int = 32,
        max_features: int = 64,
        save_dir: str = "models"
    ) -> Tuple[List[float], Dict[int, np.ndarray]]:
        self.training = True
        returns = []
        best_return = float('-inf')

        os.makedirs(save_dir, exist_ok=True)

        try:
            for episode in range(num_episodes):
                print(f"\n{'='*20} Épisode {episode + 1}/{num_episodes} {'='*20}")

                state = self.env.reset()
                episode_return = 0
                episode_steps = 0

                while episode_steps < max_steps_per_episode:
                    print(f"\nÉtape {episode_steps + 1}/{max_steps_per_episode}")

                    key = jax.random.PRNGKey(episode * max_steps_per_episode + episode_steps)

                    action, log_prob = self._policy_step(
                        self.params,
                        state,
                        key,
                        k=random.randint(1, max_features)
                    )

                    next_state, reward, done, info = self.env.step(action)
                    episode_return += reward

                    self.experience_buffer.add(state, action, reward, next_state, done)

                    if len(self.experience_buffer.buffer) >= batch_size:
                        batch = self.experience_buffer.sample(batch_size)
                        loss = self._compute_loss(self.params, batch)
                        grads = jax.grad(lambda p: self._compute_loss(p, batch))(self.params)
                        updates, self.opt_state = self.optimizer.update(grads, self.opt_state)
                        self.params = optax.apply_updates(self.params, updates)

                    print(f"Récompense: {reward:.4f}, Retour cumulé: {episode_return:.4f}, Caractéristiques sélectionnées: {info['selected_features']}")

                    if done:
                        print("Épisode terminé")
                        break

                    state = next_state
                    episode_steps += 1

                returns.append(episode_return)

                print(f"\nStatistiques de l'épisode {episode + 1}:")
                print(f"Retour total: {episode_return:.4f}")

        except KeyboardInterrupt:
            print("\nEntraînement interrompu par l'utilisateur")

        finally:
            self.training = False

            if returns:
                print("\nStatistiques finales d'entraînement:")
                print(f"Nombre total d'épisodes complétés: {len(returns)}")
                print(f"Meilleur retour: {max(returns):.4f}")
                print(f"Retour moyen: {np.mean(returns):.4f}")
            else:
                print("Aucun épisode n'a été complété.")

        self.display_best_features()  # Afficher les meilleures caractéristiques à la fin
        return returns, {}

    def display_best_features(self):
        print("\nMeilleures caractéristiques par classe :")
        for class_id, features in self.best_features.items():
            print(f"Meilleures caractéristiques pour la classe {class_id}: {set(features)}")

    def save_model(self, path: str):
        """Sauvegarde les paramètres du modèle dans un fichier."""
        with open(path, 'wb') as f:
            pickle.dump(self.params, f)
        print(f"Modèle sauvegardé à : {path}")

    def evaluate(self, test_embeddings: np.ndarray, test_labels: np.ndarray) -> float:
        """Évalue le modèle sur l'ensemble de test."""
        self.training = False  # Assurez-vous que nous ne sommes pas en mode entraînement
        test_env = TextFeatureSelectionEnv(test_embeddings, test_labels, self.model)
        state = test_env.reset()
        total_accuracy = 0.0
        num_batches = 0

        while True:
            action = self._select_top_k_features(self.policy.apply(self.params, state), k=test_env.num_tokens)
            next_state, reward, done, info = test_env.step(action)
            metrics = info['metrics']
            total_accuracy += metrics['accuracy']
            num_batches += 1

            if done:
                break
            state = next_state
        print(f"Total Accuracy: {total_accuracy}, Number of Batches: {num_batches}")  # Debugging line
        return total_accuracy / num_batches if num_batches > 0 else 0.0

In [46]:
# Chargement des données
embeddings = np.load("embeddings.npy")  # de taille (5000,128,768)
labels = np.load("labels_one_hot.npy")  # de taille (5000,)
model = load_model('modele.keras')

In [47]:
# Préparer les indices pour chaque classe
num_classes = labels.shape[1]
train_embeddings = []
test_embeddings = []
train_labels = []
test_labels = []

for class_id in range(num_classes):
    class_indices = np.where(labels[:, class_id] == 1)[0]
    train_idx, test_idx = train_test_split(class_indices, test_size=0.2, random_state=42)
    train_embeddings.append(embeddings[train_idx])
    test_embeddings.append(embeddings[test_idx])
    train_labels.append(labels[train_idx])
    test_labels.append(labels[test_idx])

train_embeddings = np.vstack(train_embeddings)
test_embeddings = np.vstack(test_embeddings)
train_labels = np.vstack(train_labels)
test_labels = np.vstack(test_labels)


In [48]:
print(f"Taille des données d'entraînement : {train_embeddings.shape}, {train_labels.shape}")
print(f"Taille des données de test : {test_embeddings.shape}, {test_labels.shape}")

Taille des données d'entraînement : (4000, 128, 768), (4000, 10)
Taille des données de test : (1000, 128, 768), (1000, 10)


In [49]:
env = TextFeatureSelectionEnv(train_embeddings, train_labels, model)
agent = RLAgent(env, model)

In [56]:
# Lancer l'entraînement
returns, feature_counts = agent.train(
    num_episodes=100,
    max_steps_per_episode=100,
    batch_size=32,
    max_features=64
)



Étape 1/100
Récompense: 0.2201, Retour cumulé: 0.2201
Caractéristiques sélectionnées : [6, 7, 15, 16, 20, 25, 33, 42, 44, 52, 54, 67, 77, 78, 85, 92, 95, 102, 107, 111, 115, 116, 120, 124]

Étape 21/100
Récompense: 0.1914, Retour cumulé: 4.2331
Caractéristiques sélectionnées : [0, 4, 5, 7, 11, 12, 17, 37, 52, 58, 60, 63, 68, 78, 81, 89, 92, 94, 95, 102, 109, 112, 113, 119, 123]

Étape 41/100
Récompense: 0.1588, Retour cumulé: 8.0265
Caractéristiques sélectionnées : [10, 14, 21, 24, 25, 27, 38, 55, 56, 59, 70, 76, 77, 79, 81, 93, 99, 104, 107, 112, 122, 124]

Étape 61/100
Récompense: 0.2135, Retour cumulé: 12.2148
Caractéristiques sélectionnées : [5, 7, 8, 9, 30, 32, 33, 34, 37, 42, 48, 52, 61, 71, 76, 85, 90, 92, 93, 98, 107, 109, 125]

Étape 81/100
Récompense: 0.2598, Retour cumulé: 16.2088
Caractéristiques sélectionnées : [0, 4, 5, 7, 9, 12, 17, 20, 26, 27, 29, 43, 52, 58, 59, 60, 63, 78, 81, 96, 102, 109, 112, 123, 126]

Étape 100/100
Récompense: 0.2120, Retour cumulé: 20.0230
Car