In [2]:
!pip install -U opencv-python tensorflow scikit-learn pandas matplotlib tensorflow_datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


# IMPORTATION DES LIBRAIRIES

In [3]:
import numpy as np
import pandas as pd
import os
import sklearn
import tensorflow as tf
from tensorflow import keras
import cv2
import matplotlib.pyplot as plt
import tensorflow_datasets.public_api as tfds
import requests
import zipfile
from tqdm import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam

2025-07-21 00:45:31.537465: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753058731.551766    2431 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753058731.556384    2431 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753058731.569094    2431 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753058731.569107    2431 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753058731.569108    2431 computation_placer.cc:177] computation placer alr

# Chargement du dataset

In [4]:
def telecharger_dezip(url, chemin_sauv="plant_village_dataset.zip", extract_path="."):
    print(" Début du téléchargement")
    try:
        response=requests.get(url, stream=True)
        response.raise_for_status()

        #Taille totale du fichier pour la barre de progression
        total_size=int(response.headers.get('content-length',0))
        block_size=1064
        bar_progression = tqdm(total=total_size, unit='iB', unit_scale=True)

        #Téléchargement
        with open(chemin_sauv, 'wb') as file:
            for data in response.iter_content(block_size):
                bar_progression.update(len(data))
                file.write(data)
        bar_progression.close()

        if total_size != 0 and bar_progression.n != total_size:
            print("ERREUR, quelque chose s'est mal passé pendant le téléchargement.")
            return

        print(f"Téléchargement terminé. Fichier sauvegardé sous : {chemin_sauv}")

        # Créer le dossier d'extraction s'il n'existe pas
        if not os.path.exists(extract_path):
            os.makedirs(extract_path)

        # Décompresser le fichier ZIP
        print(f"Décompression du fichier dans le dossier : {extract_path}")
        with zipfile.ZipFile(chemin_sauv, 'r') as zip_ref:
            zip_ref.extractall(extract_path)

        print("Décompression terminée.")

        # Optionnel : Supprimer le fichier .zip après extraction pour économiser de l'espace
        print(f"Suppression du fichier {chemin_sauv}...")
        os.remove(chemin_sauv)
        print("Opération terminée avec succès !")

    except requests.exceptions.RequestException as e:
        print(f"Une erreur de réseau est survenue: {e}")
    except zipfile.BadZipFile:
        print("Erreur: Le fichier téléchargé n'est pas un fichier ZIP valide.")
    except Exception as e:
        print(f"Une erreur inattendue est survenue: {e}")

In [5]:
URL = "https://data.mendeley.com/datasets/tywbtsjrjv/1/files/b4e3a32f-c0bd-4060-81e9-6144231f2520/file_downloaded"

In [6]:
extract_folder = "plant_village_dataset"

In [7]:
telecharger_dezip(URL, "PlantVillage.zip", extract_folder)

 Début du téléchargement


100%|██████████| 949M/949M [00:44<00:00, 21.3MiB/s]  


Téléchargement terminé. Fichier sauvegardé sous : PlantVillage.zip
Décompression du fichier dans le dossier : plant_village_dataset
Décompression terminée.
Suppression du fichier PlantVillage.zip...
Opération terminée avec succès !


In [8]:
path="/workspace/plant_village_dataset/Plant_leave_diseases_dataset_with_augmentation"

In [9]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

In [10]:
data_gen=ImageDataGenerator(rescale=1./255)

In [11]:
data=data_gen.flow_from_directory(
    path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical"
)

Found 61486 images belonging to 39 classes.


# MODELISATION

## ResTS avec 

**Teacher:resnet50**

**Student:resnet18**

P, F. R. P., U, A. S., Moustafa, M. A., & Ali, M. A. S. (2023). Detecting plant disease in corn leaf using EfficientNet Architecture—An analytical approach. Electronics, 12(8), 1938. https://doi.org/10.3390/electronics12081938

In [12]:
import os
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn.functional as F

In [13]:
# --------- 1. Préparer les données ---------
filepaths = []
labels = []
folds = os.listdir(path)
for fold in folds:
    f_path = os.path.join(path, fold)
    if not os.path.isdir(f_path):
        continue
    for file in os.listdir(f_path):
        filepaths.append(os.path.join(f_path, file))
        labels.append(fold)

df = pd.DataFrame({'filepaths': filepaths, 'labels': labels})
print(f"Total des images trouvées : {len(df)}")

Total des images trouvées : 61486


In [14]:
# Split 80/20 avec stratification
train_df, test_df = train_test_split(
    df,
    test_size=0.1,
    random_state=42,
    stratify=df['labels']
)
train_df, val_df = train_test_split(
    train_df,
    test_size=0.2,
    random_state=42,
    stratify=train_df['labels']
)

In [15]:
# Mapping des classes en indices
class_names = sorted(df['labels'].unique())
class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}
num_classes = len(class_names)

In [16]:
# --------- 2. Dataset personnalisé ---------
class CustomImageDataset(Dataset):
    def __init__(self, df, class_to_idx, transform=None):
        self.df = df.reset_index(drop=True)
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'filepaths']
        label_name = self.df.loc[idx, 'labels']
        label = self.class_to_idx[label_name]
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [17]:
# --------- 3. Data augmentation et loaders ---------
train_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]) # EfficientNet normalization
])

val_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [18]:
train_dataset = CustomImageDataset(train_df, class_to_idx, transform=train_transforms)
val_dataset = CustomImageDataset(val_df, class_to_idx, transform=val_transforms)
test_dataset = CustomImageDataset(test_df, class_to_idx, transform=val_transforms)

In [19]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Teacher & Student Models ----------
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, num_classes)
teacher.to(device)
teacher.eval()  # Le teacher ne sera pas entraîné

student = models.resnet18(pretrained=True)
student.fc = nn.Linear(student.fc.in_features, num_classes)
student.to(device)



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 273MB/s]


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 437MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [21]:
# ---------- Distillation Loss ----------
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=3.0):
    """
    alpha: poids entre CrossEntropy (vraies étiquettes) et KLDiv (distillation)
    temperature: adoucit les probabilités du teacher
    """
    ce_loss = F.cross_entropy(student_logits, labels)
    # KL divergence entre distributions adoucies
    p_student = F.log_softmax(student_logits / temperature, dim=1)
    p_teacher = F.softmax(teacher_logits / temperature, dim=1)
    kl_loss = F.kl_div(p_student, p_teacher, reduction="batchmean") * (temperature ** 2)
    return alpha * ce_loss + (1 - alpha) * kl_loss


In [22]:
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
epochs = 20
patience = 5
best_val_acc = 0.0
epochs_no_improve = 0
best_model_path = "/workspace/models/best_model_resTS"

In [23]:
# ---------- Entraînement ----------
for epoch in range(epochs):
    student.train()
    running_loss, running_corrects = 0.0, 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        with torch.no_grad():
            teacher_outputs = teacher(inputs)  # Sorties du teacher

        student_outputs = student(inputs)
        loss = distillation_loss(student_outputs, teacher_outputs, labels, alpha=0.5, temperature=3.0)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        preds = torch.argmax(student_outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels).item()
    
    train_loss = running_loss / len(train_dataset)
    train_acc = running_corrects / len(train_dataset)
    
    # Validation
    student.eval()
    val_loss, val_corrects = 0.0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = student(inputs)
            loss = F.cross_entropy(outputs, labels)
            
            preds = torch.argmax(outputs, 1)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels).item()
    
    val_loss /= len(val_dataset)
    val_acc = val_corrects / len(val_dataset)
    
    print(f"Epoch [{epoch+1}/{epochs}] - "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(student.state_dict(), best_model_path)
        print(">> Nouveau meilleur Student sauvegardé.")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(">> Early stopping déclenché.")
            break

print(f"Entraînement ResTS terminé. Meilleure accuracy Student (val) : {best_val_acc:.4f}")

Epoch [1/20] - Train Loss: 0.7861, Train Acc: 0.8039 - Val Loss: 0.5018, Val Acc: 0.9493
>> Nouveau meilleur Student sauvegardé.
Epoch [2/20] - Train Loss: 0.6084, Train Acc: 0.9029 - Val Loss: 0.4726, Val Acc: 0.9434
Epoch [3/20] - Train Loss: 0.5686, Train Acc: 0.9220 - Val Loss: 0.5041, Val Acc: 0.9501
>> Nouveau meilleur Student sauvegardé.
Epoch [4/20] - Train Loss: 0.5401, Train Acc: 0.9350 - Val Loss: 0.4012, Val Acc: 0.9736
>> Nouveau meilleur Student sauvegardé.
Epoch [5/20] - Train Loss: 0.5264, Train Acc: 0.9412 - Val Loss: 0.3614, Val Acc: 0.9793
>> Nouveau meilleur Student sauvegardé.
Epoch [6/20] - Train Loss: 0.5118, Train Acc: 0.9466 - Val Loss: 0.3725, Val Acc: 0.9810
>> Nouveau meilleur Student sauvegardé.
Epoch [7/20] - Train Loss: 0.5032, Train Acc: 0.9488 - Val Loss: 0.3310, Val Acc: 0.9877
>> Nouveau meilleur Student sauvegardé.
Epoch [8/20] - Train Loss: 0.4933, Train Acc: 0.9549 - Val Loss: 0.3433, Val Acc: 0.9827
Epoch [9/20] - Train Loss: 0.4861, Train Acc: 0.

#### Evaluation

In [35]:
import time, psutil, torch
from sklearn.metrics import classification_report

In [36]:
# Charger le Student entraîné (ResTS)
student.load_state_dict(torch.load("/workspace/models/best_model_resTS", map_location=device))
student.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [37]:
all_preds = []
all_labels = []

In [38]:
start_time = time.time()
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        batch_start = time.time()
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = student(inputs)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # --- Profiling ---
        cpu_usage = psutil.cpu_percent(interval=None)
        ram = psutil.virtual_memory()
        if torch.cuda.is_available():
            gpu_mem = torch.cuda.memory_allocated() / 1024**2
        else:
            gpu_mem = 0.0
        print(f"[Batch {i+1}] Time: {time.time()-batch_start:.2f}s | CPU: {cpu_usage:.1f}% | RAM: {ram.used/1024**3:.2f}GB | GPU: {gpu_mem:.2f}MB")

end_time = time.time()

[Batch 1] Time: 0.01s | CPU: 3.9% | RAM: 85.44GB | GPU: 299.57MB
[Batch 2] Time: 0.01s | CPU: 3.1% | RAM: 85.45GB | GPU: 299.57MB
[Batch 3] Time: 0.01s | CPU: 5.6% | RAM: 85.46GB | GPU: 299.57MB
[Batch 4] Time: 0.01s | CPU: 6.0% | RAM: 85.47GB | GPU: 299.57MB
[Batch 5] Time: 0.01s | CPU: 4.7% | RAM: 85.53GB | GPU: 299.57MB
[Batch 6] Time: 0.01s | CPU: 4.5% | RAM: 85.55GB | GPU: 299.57MB
[Batch 7] Time: 0.01s | CPU: 5.1% | RAM: 85.55GB | GPU: 299.57MB
[Batch 8] Time: 0.01s | CPU: 5.6% | RAM: 85.57GB | GPU: 299.57MB
[Batch 9] Time: 0.01s | CPU: 4.6% | RAM: 85.57GB | GPU: 299.57MB
[Batch 10] Time: 0.01s | CPU: 4.9% | RAM: 85.57GB | GPU: 299.57MB
[Batch 11] Time: 0.01s | CPU: 5.2% | RAM: 85.57GB | GPU: 299.57MB
[Batch 12] Time: 0.01s | CPU: 5.2% | RAM: 85.58GB | GPU: 299.57MB
[Batch 13] Time: 0.01s | CPU: 4.6% | RAM: 85.58GB | GPU: 299.57MB
[Batch 14] Time: 0.01s | CPU: 4.5% | RAM: 85.59GB | GPU: 299.57MB
[Batch 15] Time: 0.01s | CPU: 5.5% | RAM: 85.59GB | GPU: 299.57MB
[Batch 16] Time: 0.

In [39]:
total_time = end_time - start_time
print(f"\nTemps Test Total: {total_time:.2f} sec")
print(f"Throughput: {len(test_dataset) / total_time:.2f} images/sec")


Temps Test Total: 8.56 sec
Throughput: 718.44 images/sec


In [40]:
# Rapport complet
print("=== Test Set Evaluation (Student) ===")
print(classification_report(all_labels, all_preds, target_names=class_names))

=== Test Set Evaluation (Student) ===
                                               precision    recall  f1-score   support

                           Apple___Apple_scab       1.00      0.97      0.98       100
                            Apple___Black_rot       0.99      1.00      1.00       100
                     Apple___Cedar_apple_rust       1.00      1.00      1.00       100
                              Apple___healthy       0.98      1.00      0.99       164
                    Background_without_leaves       0.97      0.99      0.98       114
                          Blueberry___healthy       1.00      1.00      1.00       150
                      Cherry___Powdery_mildew       1.00      1.00      1.00       105
                             Cherry___healthy       1.00      0.99      0.99       100
   Corn___Cercospora_leaf_spot Gray_leaf_spot       0.95      0.99      0.97       100
                           Corn___Common_rust       1.00      1.00      1.00       119
    