Notebook Entrainement de modèle depuis les données mongoDB vers MLFLOW

In [1]:
import os
import json
import pymongo
import mlflow
import mlflow.sklearn
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from pyspark.sql.types import StructType, StructField, StringType, MapType, TimestampType
from transformers import DistilBertTokenizer
from transformers import DistilBertModel
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from pyspark.sql import SparkSession

# Create a Spark session with the MongoDB Spark Connector package
spark = SparkSession.builder \
    .appName("myApp") \
    .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.1") \
    .getOrCreate()


schema = StructType([
    StructField("user", StringType(), True),
    StructField("repo", StringType(), True),
    StructField("mainLanguage", StringType(), True),
    StructField("languages", MapType(StringType(), StringType()), True),
    StructField("readme", StringType(), True),
    StructField("processed_readme", StringType(), True),
    StructField("last_updated", TimestampType(), True),
])

df = spark.read \
    .format("mongo") \
    .option("database", "dev") \
    .option("collection", "raw_data") \
    .option("uri", "mongodb://mongo:27017/") \
    .schema(schema) \
    .load()

df.show()

:: loading settings :: url = jar:file:/opt/bitnami/spark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /root/.ivy2/cache
The jars for the packages stored in: /root/.ivy2/jars
org.mongodb.spark#mongo-spark-connector_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-61678298-bd72-41f8-938d-29ad269a3571;1.0
	confs: [default]
	found org.mongodb.spark#mongo-spark-connector_2.12;3.0.1 in central
	found org.mongodb#mongodb-driver-sync;4.0.5 in central
	found org.mongodb#bson;4.0.5 in central
	found org.mongodb#mongodb-driver-core;4.0.5 in central
downloading https://repo1.maven.org/maven2/org/mongodb/spark/mongo-spark-connector_2.12/3.0.1/mongo-spark-connector_2.12-3.0.1.jar ...
	[SUCCESSFUL ] org.mongodb.spark#mongo-spark-connector_2.12;3.0.1!mongo-spark-connector_2.12.jar (227ms)
downloading https://repo1.maven.org/maven2/org/mongodb/mongodb-driver-sync/4.0.5/mongodb-driver-sync-4.0.5.jar ...
	[SUCCESSFUL ] org.mongodb#mongodb-driver-sync;4.0.5!mongodb-driver-sync.jar (104ms)
downloading https://repo1.maven.org/maven2/org/m

+---------------+-------------------+------------+--------------------+--------------------------+--------------------+--------------------+
|           user|               repo|mainLanguage|           languages|                    readme|    processed_readme|        last_updated|
+---------------+-------------------+------------+--------------------+--------------------------+--------------------+--------------------+
|             d3|                 d3|       Shell|{Shell -> NULL, J...|      # D3: Data-Driven...|# d ##3 : data - ...|2025-03-05 11:13:...|
| papers-we-love|     papers-we-love|       Shell|     {Shell -> NULL}|      ﻿## ![Papers We L...|# # ! [ papers we...|2025-03-05 11:14:...|
|         nvm-sh|                nvm|       Shell|{Dockerfile -> NU...|      <a href="https://...|< a hr ##ef = " h...|2025-03-05 11:15:...|
|           base|               node|       Shell|{Dockerfile -> NU...|      ![Base](logo.webp...|! [ base ] ( logo...|2025-03-05 11:16:...|
|            

In [3]:
num_lines = df.count()
print(f"Number of lines in the dataframe: {num_lines}")

Number of lines in the dataframe: 50


In [4]:
# Connexion à MongoDB
client = pymongo.MongoClient("mongodb://mongo:27017/")  
client = pymongo.MongoClient("mongodb://mongo:27017/")  
db = client.get_database("dev")
collection = db.get_collection("raw_data")

In [5]:
# Vérifier si la base de données existe
db_list = client.list_database_names()
if "dev" in db_list:
    print("La base de données existe.")
else:
    print("La base de données n'existe pas.")

La base de données existe.


In [6]:
# Charger les données depuis MongoDB
def load_data_from_mongo():
    data = list(collection.find())
    texts = [item.get("readme_clean", "") for item in data]
    labels = [item.get("mainLanguage", "") for item in data]
    return texts, labels


In [7]:
# Charger le tokenizer et le modèle DistilBERT
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")

def encode_texts_with_bert(texts, tokenizer, model):
    """Tokenise et vectorise les textes avec DistilBERT."""
    encoded_texts = []
    
    for text in texts:
        inputs = tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors="pt")
        with torch.no_grad():  # Pas besoin de calculer les gradients
            outputs = model(**inputs)
        
        # On prend le CLS token ([0, 0, :]) qui représente l'ensemble du texte
        sentence_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
        encoded_texts.append(sentence_embedding)
    
    return np.array(encoded_texts)

# Charger et préparer les données
texts, labels = load_data_from_mongo()
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(labels)



In [8]:
# Transformer les textes en embeddings DistilBERT
#document_vectors = encode_texts_with_bert(texts, tokenizer, bert_model)

# Division des données en train/test
#X_train, X_test, y_train, y_test = train_test_split(document_vectors, encoded_labels, test_size=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(texts, encoded_labels, test_size=0.2, random_state=42)


In [9]:
#Entraînement du classifieur
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, y_train)



In [10]:
#Évaluation du modèle
train_score = classifier.score(X_train, y_train)
test_score = classifier.score(X_test, y_test)

print(f"Train Accuracy: {train_score:.4f}")
print(f"Test Accuracy: {test_score:.4f}")



The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh(<full-path-to-git-executable>)

All git commands will error until this is rectified.

This initial message can be silenced or aggravated in the future by setting the
$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - quiet|q|silence|s|silent|none|n|0: for no message or exception
    - error|e|exception|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet



Train Accuracy: 1.0000
Test Accuracy: 1.0000


MlflowException: API request to http://localhost:8090/api/2.0/mlflow/runs/create failed with exception HTTPConnectionPool(host='localhost', port=8090): Max retries exceeded with url: /api/2.0/mlflow/runs/create (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f3a2d1f76b0>: Failed to establish a new connection: [Errno 111] Connection refused'))

In [12]:
#Enregistrement dans MLflow
mlflow.set_tracking_uri("http://mlflow:8080")

with mlflow.start_run():
    mlflow.sklearn.log_model(classifier, "random_forest_model")
    mlflow.log_param("model_type", "RandomForestClassifier")
    mlflow.log_metric("train_accuracy", train_score)
    mlflow.log_metric("test_accuracy", test_score)

print("Modèle entraîné et enregistré dans MLflow")



🏃 View run suave-owl-916 at: http://mlflow:8080/#/experiments/0/runs/8748aeecbf664e978bc5ef70b279e4a4
🧪 View experiment at: http://mlflow:8080/#/experiments/0
Modèle entraîné et enregistré dans MLflow
