In [1]:
# include
import timm
import torch
import torch.nn as nn
import torchvision.transforms as T
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path
from pytorch_metric_learning import losses

from PDDD.Codes.modelpy.visual_model.ResNet_50_101_152 import ResNet152
from PDDD.Codes.modelpy.visual_model.ViT_L import VisionTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Config 
# ----------------------------
BATCH_SIZE = 60
# NUM_CLASSES = 27
BASE_LR = 0.003
WEIGHT_DECAY = 0.01
WARMUP_EPOCHS = 2 # Trong 2 epoch đầu, learning rate sẽ được tăng dần từ nhỏ đến BASE_LR
EPOCHS = 100
LAYER_DECAY = 0.8
ACCUM_GRAD_STEPS = 1 # Gradient accumulation: để mô phỏng batch size lớn hơn.
DATA_DIR = "/media/icnlab/Data/Manh/tinyML/FieldPlant-11/cropped"


In [3]:
# DataModule
# ----------------------------
class FilteredImageFolder(ImageFolder):
    def __init__(self, root, included_classes, **kwargs):
        super().__init__(root, **kwargs)
        # Lưu lại các chỉ số class cần giữ
        included_indices = [self.class_to_idx[cls] for cls in included_classes]
        self.samples = [s for s in self.samples if s[1] in included_indices]
        self.targets = [s[1] for s in self.samples]


class Dataset(pl.LightningDataModule):
    # init dataset, split, transform, dataloader
    def __init__(self, data_dir, batch_size, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        
        # Define separate transforms for training (with augmentation) and evaluation
        self.train_transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.RandomHorizontalFlip(p=0.5),  # Randomly flip images horizontally
            T.RandomRotation(degrees=15),    # Randomly rotate images by up to 15 degrees
            T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Random color adjustments
            T.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random translations
            T.RandomPerspective(distortion_scale=0.2, p=0.5),  # Random perspective changes
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
        ])
        
        # Transform for validation and testing (no augmentation)
        self.eval_transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.416, 0.468, 0.355],   # normalize
                                std=[0.210, 0.206, 0.213])
        ])  # in PDDD paper
        
        # Ensure ratios sum to 1
        assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5, "Ratios must sum to 1"
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio

    def setup(self, stage=None):
        # Load the full dataset with eval transform initially
        self.full_dataset = ImageFolder(self.data_dir, transform=self.eval_transform)

        # bug
        # included = ['Tomato Brown Spots', 'Tomato blight leaf', 'Tomato healthy', 'Tomato leaf yellow virus']
        # self.full_dataset = FilteredImageFolder(self.data_dir, included, transform=self.eval_transform)
        
    # Calculate split sizes
        dataset_size = len(self.full_dataset)
        train_size = int(dataset_size * self.train_ratio)
        val_size = int(dataset_size * self.val_ratio)
        test_size = dataset_size - train_size - val_size
        
        # Split the dataset
        from torch.utils.data import random_split
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42)  # For reproducibility
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8)
        
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8)

In [4]:
# Layer-wise LR decay helper
# ----------------------------
def get_layer_decay_param_groups(model, base_lr, weight_decay, layer_decay):
    param_groups = []
    layers = list(model.named_parameters())
    num_layers = len(layers)

    for i, (name, param) in enumerate(layers):
        if not param.requires_grad:
            continue
        lr = base_lr * (layer_decay ** (num_layers - i - 1))
        param_groups.append({
            "params": [param],
            "lr": lr,
            "weight_decay": weight_decay if param.ndim >= 2 else 0.0
        })
    return param_groups

In [5]:
# Lightning Module
# ----------------------------
class EVA02Lightning(pl.LightningModule):
    def __init__(self, embedding_dim=2048, projection_dim=128):
        super().__init__()
        # Load the base model
        model = ResNet152()
        model_path = "/media/icnlab/Data/Manh/tinyML/PDDD/model/ResNet152.std"
        model.load_state_dict(torch.load(model_path, map_location='cpu'))

        for name, param in model.named_parameters():
            if name.startswith(('conv1', 'layer1', 'layer2', 'layer3')):
                param.requires_grad = False
            else:
                param.requires_grad = True
        
    # Add a projection head for contrastive learning
        self.backbone = nn.Sequential(*list(model.children())[:-1])  # Remove final classification layer
        
        self.projection_head = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
        
        try:
            self.backbone = torch.compile(self.backbone)
        except Exception:
            pass  # torch.compile requires PyTorch 2+

        # Define contrastive loss
        self.loss_fn = losses.SupConLoss(temperature=0.07)

    def forward(self, x):
        # Get embeddings from backbone
        features = self.backbone(x)
        # Handle potential dimension issues
        if len(features.shape) > 2:
            features = features.squeeze()
        # Get projected embeddings
        projections = self.projection_head(features)
        # Normalize projections for contrastive loss
        normalized_projections = nn.functional.normalize(projections, dim=1)
        return features, normalized_projections

    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Get embeddings and normalized projections
        _, normalized_projections = self(x)
        
        # Calculate contrastive loss
        loss = self.loss_fn(normalized_projections, y)
        
        # Log metrics
        self.log("train_loss", loss)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        # For validation, we'll just log the contrastive loss
        _, normalized_projections = self(x)
        val_loss = self.loss_fn(normalized_projections, y)
        
        # Log metrics
        self.log("val_loss", val_loss, prog_bar=True)
        
        # Since we don't have a classifier, we can't calculate accuracy directly
        # We could implement a nearest-neighbor evaluation here if needed

    def configure_optimizers(self):
        # Get parameters with layer-wise decay for backbone
        param_groups = get_layer_decay_param_groups(self.backbone, BASE_LR, WEIGHT_DECAY, LAYER_DECAY)
        
        # Add projection head parameters
        param_groups.append({
            "params": self.projection_head.parameters(), 
            "lr": BASE_LR, 
            "weight_decay": WEIGHT_DECAY
        })
        
        optimizer = torch.optim.AdamW(param_groups, eps=1e-6)

        def lr_schedule_fn(current_step):
            if current_step < WARMUP_EPOCHS:
                return float(current_step) / float(max(1, WARMUP_EPOCHS))
            else:
                progress = float(current_step - WARMUP_EPOCHS) / float(max(1, EPOCHS - WARMUP_EPOCHS))
                return max(0.0, 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.1415926535))))

        scheduler = {
            "scheduler": LambdaLR(optimizer, lr_lambda=lr_schedule_fn),
            "interval": "epoch",
            "frequency": 1,
        }
        return [optimizer], [scheduler]

In [6]:
# Train
# ----------------------------
if __name__ == "__main__":
    import time
    from datetime import timedelta
    import os
    
    # Create checkpoints directory if it doesn't exist
    os.makedirs("checkpoints", exist_ok=True)
    
    pl.seed_everything(42)
    data = Dataset(DATA_DIR, BATCH_SIZE)
    model = EVA02Lightning()
    
    # Define checkpoint callback
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath="checkpoints",
        filename="eva02-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,  # Save the 1 best models
        save_last=True,  # Additionally save the last model
        verbose=True,
        auto_insert_metric_name=False
    )
    
    # Add early stopping callback (optional)
    early_stop_callback = pl.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=5,  # Stop if no improvement for 5 epochs
        mode="min",
        verbose=True
    )
    
    # Add timer callback to track training time
    class TimingCallback(pl.Callback):
        def __init__(self):
            super().__init__()
            self.epoch_start_time = None
            self.training_start_time = None
            self.epoch_times = []
        
        def on_train_start(self, trainer, pl_module):
            self.training_start_time = time.time()
            print(f"Training started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        def on_train_epoch_start(self, trainer, pl_module):
            self.epoch_start_time = time.time()
            if trainer.current_epoch > 0 and len(self.epoch_times) > 0:
                avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
                remaining_epochs = trainer.max_epochs - trainer.current_epoch
                est_remaining_time = avg_epoch_time * remaining_epochs
                print(f"Estimated time remaining: {timedelta(seconds=int(est_remaining_time))}")
        
    def on_train_epoch_end(self, trainer, pl_module):
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - self.epoch_start_time
        self.epoch_times.append(epoch_time)

        print(f"Epoch {trainer.current_epoch} completed in: {timedelta(seconds=int(epoch_time))}")

        if len(self.epoch_times) > 0:
            avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
            remaining_epochs = trainer.max_epochs - trainer.current_epoch - 1
            est_remaining_time = avg_epoch_time * remaining_epochs

            print(f"Average epoch time: {timedelta(seconds=int(avg_epoch_time))}")
            print(f"Estimated time remaining: {timedelta(seconds=int(est_remaining_time))}")
        
        def on_train_end(self, trainer, pl_module):
            total_time = time.time() - self.training_start_time
            print(f"\nTraining completed in: {timedelta(seconds=int(total_time))}")
            print(f"Finished at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Create the timing callback
    timing_callback = TimingCallback()

    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accumulate_grad_batches=ACCUM_GRAD_STEPS,
        precision="16-mixed",
        gradient_clip_val=1.0,
        accelerator="gpu",
        devices="auto",  # multi-GPU nếu có
        # strategy="ddp_find_unused_parameters_false", 
        log_every_n_steps=10,
        callbacks=[checkpoint_callback, early_stop_callback, timing_callback],  # Add callbacks here
    )

    trainer.fit(model, datamodule=data, 
    # ckpt_path='/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-07-2.61-0.00.ckpt'
    )
    
    # Print path to best model checkpoint
    print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")
    print(f"Best model score: {checkpoint_callback.best_model_score:.4f}")


Seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-05-29 01:41:37.152984: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748457697.165490  282514 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748457697.169352  282514 cuda_blas.cc:1407] Unable to register cuBLAS 

Training started at: 2025-05-29 01:42:04                                   
Epoch 0: 100%|██████████| 33/33 [01:58<00:00,  0.28it/s, v_num=18, val_loss=4.230]

Metric val_loss improved. New best score: 4.230
Epoch 0, global step 33: 'val_loss' reached 4.22982 (best 4.22982), saving model to '/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-00-4.23-0.00.ckpt' as top 1


Epoch 1: 100%|██████████| 33/33 [00:09<00:00,  3.47it/s, v_num=18, val_loss=3.760]

Metric val_loss improved by 0.468 >= min_delta = 0.0. New best score: 3.762
Epoch 1, global step 66: 'val_loss' reached 3.76182 (best 3.76182), saving model to '/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-01-3.76-0.00.ckpt' as top 1


Epoch 2: 100%|██████████| 33/33 [00:09<00:00,  3.49it/s, v_num=18, val_loss=3.760]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 3.757
Epoch 2, global step 99: 'val_loss' reached 3.75687 (best 3.75687), saving model to '/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-02-3.76-0.00.ckpt' as top 1


Epoch 3: 100%|██████████| 33/33 [00:09<00:00,  3.47it/s, v_num=18, val_loss=3.740]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 3.736
Epoch 3, global step 132: 'val_loss' reached 3.73562 (best 3.73562), saving model to '/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-03-3.74-0.00.ckpt' as top 1


Epoch 4: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s, v_num=18, val_loss=3.720]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 3.717
Epoch 4, global step 165: 'val_loss' reached 3.71676 (best 3.71676), saving model to '/media/icnlab/Data/Manh/tinyML/checkpoints/eva02-04-3.72-0.00.ckpt' as top 1


Epoch 5: 100%|██████████| 33/33 [00:09<00:00,  3.38it/s, v_num=18, val_loss=3.770]

Epoch 5, global step 198: 'val_loss' was not in top 1


Epoch 6: 100%|██████████| 33/33 [00:09<00:00,  3.38it/s, v_num=18, val_loss=3.770]

Epoch 6, global step 231: 'val_loss' was not in top 1


Epoch 7: 100%|██████████| 33/33 [00:09<00:00,  3.34it/s, v_num=18, val_loss=3.830]

Epoch 7, global step 264: 'val_loss' was not in top 1


Epoch 8: 100%|██████████| 33/33 [00:10<00:00,  3.29it/s, v_num=18, val_loss=3.950]

Epoch 8, global step 297: 'val_loss' was not in top 1


Epoch 9: 100%|██████████| 33/33 [00:10<00:00,  3.27it/s, v_num=18, val_loss=4.040]

Monitored metric val_loss did not improve in the last 5 records. Best score: 3.717. Signaling Trainer to stop.
Epoch 9, global step 330: 'val_loss' was not in top 1


Epoch 9: 100%|██████████| 33/33 [00:12<00:00,  2.58it/s, v_num=18, val_loss=4.040]
Best model checkpoint: /media/icnlab/Data/Manh/tinyML/checkpoints/eva02-04-3.72-0.00.ckpt
Best model score: 3.7168


In [None]:
best_model = EVA02Lightning.load_from_checkpoint(checkpoint_callback.best_model_path)