In [11]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torch.utils.data import Dataset
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from typing import Optional
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchvision.transforms as T
from torch.utils.data import Dataset
import os
from collections import Counter
import torch.nn as nn
from pytorch_lightning.loggers import TensorBoardLogger
import tensorboard



In [12]:
# 1) Cài vào đúng Python của kernel hiện tại
import sys
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install tensorboard

# 2) Kiểm tra cài thành công
import importlib, pkgutil
spec = importlib.util.find_spec("tensorboard")
print("tensorboard found:", spec is not None)
if spec:
    import tensorboard
    print("tensorboard version:", tensorboard.__version__)


tensorboard found: True
tensorboard version: 2.20.0


In [13]:
try:
    import timm
    _USE_TIMM = True
except ImportError:
    _USE_TIMM = False


In [14]:



class PlantDataset(Dataset):
    def __init__(self, root_dir='/Users/braly/Desktop/lmvh/plant-identify/dataset',
                 split='train',
                 transform=None,
                 extensions=('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
        """
        Args:
            root_dir (str): đường dẫn tới thư mục chứa train/val/test
            split (str): 'train' | 'val' | 'test'
            transform (callable, optional): torchvision transforms or any callable applied lên PIL image
            extensions (tuple): các hậu tố file ảnh chấp nhận
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.extensions = tuple(e.lower() for e in extensions)

        self.images = []   # danh sách đường dẫn ảnh
        self.labels = []   # label dưới dạng index
        self.classes = []  # tên lớp (sorted)
        self.class_to_idx = {}
        self.idx_to_class = {}

        split_dir = os.path.join(self.root_dir, self.split)
        if not os.path.isdir(split_dir):
            raise ValueError(f"Split folder not found: {split_dir}")

        # Lấy danh sách lớp (thư mục con) và map sang index
        classes = [d for d in os.listdir(split_dir) if os.path.isdir(os.path.join(split_dir, d))]
        classes = sorted(classes)
        if len(classes) == 0:
            raise ValueError(f"No class subfolders found in {split_dir}")

        self.classes = classes
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

        # Duyệt từng thư mục lớp và thu thập ảnh
        for cls_name in self.classes:
            cls_dir = os.path.join(split_dir, cls_name)
            for root, _, files in os.walk(cls_dir):
                for fname in files:
                    if fname.lower().endswith(self.extensions):
                        path = os.path.join(root, fname)
                        self.images.append(path)
                        self.labels.append(self.class_to_idx[cls_name])

        if len(self.images) == 0:
            raise ValueError(f"No images found in {split_dir} with extensions {self.extensions}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        """
        Trả về: (image, label_index)
        - image: PIL.Image (nếu transform None) hoặc transform(image)
        - label_index: int (index của lớp)
        """
        img_path = self.images[idx]
        label = self.labels[idx]

        # Mở ảnh an toàn
        with open(img_path, 'rb') as f:
            image = Image.open(f).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    # Tiện ích: trả về số ảnh / lớp
    def get_class_counts(self):
        """Trả về dict: {class_name: count}"""
        counts = Counter()
        for lbl in self.labels:
            counts[self.idx_to_class[lbl]] += 1
        return dict(counts)

    def print_stats(self):
        """In thông tin tóm tắt dataset"""
        total = len(self)
        counts = self.get_class_counts()
        print(f"Dataset split: {self.split}")
        print(f"Root dir: {self.root_dir}")
        print(f"Total images: {total}")
        print("Number of classes:", len(self.classes))
        print("Class -> index mapping:")
        for cls, idx in self.class_to_idx.items():
            print(f"  {cls:20s} -> {idx:3d} ({counts.get(cls,0)} images)")


In [15]:
from torchvision import transforms
from torch.utils.data import DataLoader

root = '/Users/braly/Desktop/lmvh/plant-identify/dataset'

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

train_ds = PlantDataset(root_dir=root, split='train', transform=transform)
val_ds   = PlantDataset(root_dir=root, split='valid', transform=transform)
test_ds  = PlantDataset(root_dir=root, split='test', transform=transform)

train_ds.print_stats()
print("Total train images:", len(train_ds))

# Lấy 1 mẫu
img, label = train_ds[0]
print(type(img), label)  # img là Tensor nếu transform -> ToTensor, label là int (index)

# Dùng DataLoader
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)


Dataset split: train
Root dir: /Users/braly/Desktop/lmvh/plant-identify/dataset
Total images: 20684
Number of classes: 47
Class -> index mapping:
  African Violet (Saintpaulia ionantha) ->   0 (478 images)
  Aloe Vera            ->   1 (366 images)
  Anthurium (Anthurium andraeanum) ->   2 (644 images)
  Areca Palm (Dypsis lutescens) ->   3 (258 images)
  Asparagus Fern (Asparagus setaceus) ->   4 (218 images)
  Begonia (Begonia spp.) ->   5 (312 images)
  Bird of Paradise (Strelitzia reginae) ->   6 (254 images)
  Birds Nest Fern (Asplenium nidus) ->   7 (402 images)
  Boston Fern (Nephrolepis exaltata) ->   8 (420 images)
  Calathea             ->   9 (448 images)
  Cast Iron Plant (Aspidistra elatior) ->  10 (366 images)
  Chinese Money Plant (Pilea peperomioides) ->  11 (530 images)
  Chinese evergreen (Aglaonema) ->  12 (734 images)
  Christmas Cactus (Schlumbergera bridgesii) ->  13 (418 images)
  Chrysanthemum        ->  14 (288 images)
  Ctenanthe            ->  15 (510 images)

In [16]:
class PlantDataModule(pl.LightningDataModule):
    def __init__(self,
                 root_dir: str,
                 image_size: int = 224,
                 batch_size: int = 32,
                 num_workers: int = 4,
                 pin_memory: bool = True):
        super().__init__()
        self.root_dir = root_dir
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # transforms
        self.train_transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.RandomHorizontalFlip(),
            T.RandomRotation(10),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])
        self.val_transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

        # placeholders set in setup()
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.num_classes = None

    def setup(self, stage: Optional[str] = None):
        # Called on every GPU in DDP — keep idempotent
        if stage in (None, 'fit'):
            self.train_dataset = PlantDataset(self.root_dir, split='train', transform=self.train_transform)
            self.val_dataset = PlantDataset(self.root_dir, split='valid', transform=self.val_transform)
            self.num_classes = len(self.train_dataset.classes)
        if stage in (None, 'test'):
            self.test_dataset = PlantDataset(self.root_dir, split='test', transform=self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=self.num_workers,
                          pin_memory=self.pin_memory)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers,
                          pin_memory=self.pin_memory)

    def test_dataloader(self):
        if self.test_dataset is None:
            return None
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers,
                          pin_memory=self.pin_memory)

In [17]:
class ViTLightning(pl.LightningModule):
    def __init__(self,
                 num_classes: int,
                 lr: float = 3e-4,
                 weight_decay: float = 1e-2,
                 backbone_name: str = "vit_base_patch16_224",
                 pretrained: bool = True,
                 freeze_backbone: bool = False):
        """
        If timm is available, use timm.create_model(backbone_name, pretrained=True, num_classes=num_classes).
        Otherwise try torchvision's vit_b_16 (if installed).
        """
        super().__init__()
        self.save_hyperparameters()

        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.backbone_name = backbone_name
        self.pretrained = pretrained
        self.freeze_backbone = freeze_backbone

        # Build model
        if _USE_TIMM:
            # timm handles classifier creation
            self.model = timm.create_model(self.backbone_name, pretrained=self.pretrained, num_classes=self.num_classes)
        else:
            # fallback to torchvision ViT if available
            try:
                from torchvision import models as tv_models
                vit_builder = getattr(tv_models, "vit_b_16", None)
                if vit_builder is None:
                    raise RuntimeError("torchvision ViT not available; please install timm.")
                # torchvision vit builder signatures vary; try to create without classifier then add head
                backbone = vit_builder(weights="IMAGENET1K_V1") if hasattr(vit_builder, '__call__') else vit_builder(pretrained=self.pretrained)
                # remove existing head if present
                if hasattr(backbone, 'heads'):
                    feat_dim = backbone.heads.head.in_features if hasattr(backbone.heads, 'head') else getattr(backbone, 'hidden_dim', 768)
                    backbone.heads = nn.Identity()
                elif hasattr(backbone, 'head'):
                    feat_dim = backbone.head.in_features
                    backbone.head = nn.Identity()
                else:
                    feat_dim = getattr(backbone, 'hidden_dim', 768)
                # create classifier
                head = nn.Linear(feat_dim, self.num_classes)
                self.model = nn.Sequential(backbone, head)
            except Exception as e:
                raise RuntimeError("No ViT backbone available. Install timm or use recent torchvision.") from e

        # optionally freeze backbone parameters for fine-tuning head only
        if self.freeze_backbone:
            for name, p in self.model.named_parameters():
                if "head" not in name and "heads" not in name and "classifier" not in name:
                    p.requires_grad = False

        # loss + metrics
        self.criterion = nn.CrossEntropyLoss()
        # use torchmetrics for metrics if available
        try:
            import torchmetrics
            self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes)
            self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes)
        except Exception:
            # fallback simple trackers
            self.train_acc = None
            self.val_acc = None

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        # log loss
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False)
        if self.train_acc is not None:
            acc = self.train_acc(preds, y)
            self.log("train/acc", acc, on_step=True, on_epoch=True, prog_bar=True)
        else:
            # rough acc
            acc = (preds == y).float().mean()
            self.log("train/acc", acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        if self.val_acc is not None:
            acc = self.val_acc(preds, y)
            self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        else:
            acc = (preds == y).float().mean()
            self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.log("test/loss", loss, on_step=False, on_epoch=True)
        if self.val_acc is not None:
            acc = self.val_acc(preds, y)
            self.log("test/acc", acc, on_step=False, on_epoch=True)
        else:
            acc = (preds == y).float().mean()
            self.log("test/acc", acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        # small scheduler example (cosine)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss"}}



In [18]:
import os
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt

# Callback lưu loss/acc mỗi epoch để sau đó vẽ
class LossHistory(pl.Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.val_losses = []
        self.val_accs = []

    def _get_metric(self, trainer, keys):
        """Trả về float hoặc None — keys là list các key thử (khác PL versions có key khác nhau)."""
        for k in keys:
            if k in trainer.callback_metrics:
                v = trainer.callback_metrics[k]
                try:
                    return float(v)
                except Exception:
                    return float(v.item())
        return None

    def on_validation_epoch_end(self, trainer, pl_module):
        # Thường trainer.callback_metrics sẽ chứa "train/loss" (on_epoch=True) và "val/loss"
        tr_loss = self._get_metric(trainer, ["train/loss", "train_loss", "train/loss_epoch", "train_loss_epoch"])
        v_loss = self._get_metric(trainer, ["val/loss", "val_loss", "val/loss_epoch", "val_loss_epoch"])
        v_acc  = self._get_metric(trainer, ["val/acc", "val_acc", "val/acc_epoch", "val_acc_epoch"])

        # Append only when available (safety)
        if tr_loss is not None:
            self.train_losses.append(tr_loss)
        if v_loss is not None:
            self.val_losses.append(v_loss)
        if v_acc is not None:
            self.val_accs.append(v_acc)


In [None]:
# Thay đường dẫn dataset của bạn ở đây
data_dir = "/Users/braly/Desktop/lmvh/plant-identify/dataset"

# hyperparams
image_size = 224
batch_size = 32
num_workers = 4
max_epochs = 10   # chỉnh tuỳ ý
precision = 16 if torch.cuda.is_available() else 32
gpus = 1 if torch.cuda.is_available() else 0

# Nếu bạn đã có PlantDataModule và ViTLightning trong notebook, dùng trực tiếp:
dm = PlantDataModule(root_dir=data_dir,
                     image_size=image_size,
                     batch_size=batch_size,
                     num_workers=num_workers)
dm.setup('fit')
print("Num classes:", dm.num_classes)

model = ViTLightning(num_classes=dm.num_classes,
                     lr=3e-4,
                     weight_decay=1e-2,
                     backbone_name="vit_base_patch16_224",
                     pretrained=True,
                     freeze_backbone=False)

# Logger + callbacks
logger = TensorBoardLogger(save_dir="lightning_logs", name="vit_plants_notebook")
checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints",
    monitor="val/acc",     # nếu bạn muốn monitor val/loss thay "val/acc"
    mode="max",
    save_top_k=3,
    filename="vit-{epoch:02d}-{val/acc:.4f}"
)
lr_monitor = LearningRateMonitor(logging_interval='step')
loss_hist_cb = LossHistory()
devices_for_trainer = gpus if gpus > 0 else 1
# Trainer (an toàn trong notebook: devices=None nếu no GPU)
trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="gpu" if gpus > 0 else "cpu",
    devices=gpus if gpus > 0 else 1,  # <-- sửa None thành 1
    precision=precision,
    callbacks=[checkpoint_cb, lr_monitor, loss_hist_cb],
    logger=logger,
    log_every_n_steps=50,
    enable_progress_bar=True,
    deterministic=True,
)

# Start training
trainer.fit(model, datamodule=dm)


Num classes: 47


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | VisionTransformer  | 85.8 M | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.339   Total estimated model params size (MB)
279       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=79, pipe_handle=96)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan

In [None]:
# Lưu checkpoint (trainer đã tự lưu theo callback ModelCheckpoint).
print("Best checkpoint path:", checkpoint_cb.best_model_path)

# Lưu state_dict final của LightningModule (dùng để load bằng model.load_state_dict)
os.makedirs("saved_models", exist_ok=True)
final_path = "saved_models/vit_plants_final_state_dict.pth"
torch.save(model.state_dict(), final_path)
print("Saved state_dict to:", final_path)

# Nếu muốn lưu toàn bộ checkpoint (bao gồm optimizer state) dùng:
trainer.save_checkpoint("saved_models/vit_plants_full_checkpoint.ckpt")
print("Saved full checkpoint to saved_models/vit_plants_full_checkpoint.ckpt")
