In [1]:
%load_ext autoreload
%autoreload 2

import warnings
from pathlib import Path

from leaf_disease.datasets.dataset import LeafImageDataModule, DataAugmentation
from leaf_disease.models import Resnet, Efficientnet, ResnetSSL
from leaf_disease.lit_models.lit_model import LitModel

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.utilities.warnings import PossibleUserWarning

warnings.filterwarnings("ignore", category=PossibleUserWarning)

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
# model transforms
from torchvision.models import resnet

resnet.ResNet50_Weights.IMAGENET1K_V1.transforms()

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [3]:
resnet.ResNet50_Weights.IMAGENET1K_V2.transforms()

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [4]:
# Albumentations library
image_size = 244

train_transform = A.Compose(
    [
        A.RandomResizedCrop(height=image_size, width=image_size),
        # Divide pixel values of an image by 255, so each pixel's value will 
        # lie in a range [0.0, 1.0]
        A.Normalize(max_pixel_value=255),
        ToTensorV2(),  # Reshape to [C, W, H]
    ]
)
valid_transform = A.Compose(
    [ 
        A.CenterCrop(height=image_size, width=image_size),
        A.Normalize(max_pixel_value=255),
        ToTensorV2(),  # Reshape to [C, W, H]
    ]
)

transforms = DataAugmentation()
transforms

<leaf_disease.datasets.augmentation.DataAugmentation at 0x28135bed0>

In [5]:
# image_size = 512

# train_transform = A.Compose(
#     [
#         A.RandomResizedCrop(image_size, image_size),
#         A.Transpose(p=0.5),
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
#         A.ShiftScaleRotate(p=0.5),
#         A.Normalize(
#             mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225],
#             max_pixel_value=255
#         ),
#         ToTensorV2(),   # Reshape to [C, W, H]
# ])

# valid_transform = A.Compose(
#     [ 
#         A.Resize(image_size, image_size),
#         A.Normalize(
#             mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225],
#             max_pixel_value=255
#         ),
#         ToTensorV2(),  # Reshape to [C, W, H]
#     ]
# )

In [6]:
# {
#     "resnet18": {"model": "resnet18", "weights": "ResNet18_Weights.IMAGENET1K_V1"},
#     "resnext50_32x4d_v1": {"model": "resnext50_32x4d", "weights": "ResNeXt50_32X4D_Weights.IMAGENET1K_V1"},
#     "resnext50_32x4d_v2": {"model": "resnext50_32x4d", "weights": "ResNeXt50_32X4D_Weights.IMAGENET1K_V2"}
# }

In [7]:
train_image_path = Path("input/train_images/")

# Data module
dm = LeafImageDataModule(
    image_path=train_image_path, 
    batch_size=32, 
    num_workers=4,
    transforms=transforms
    # train_transform=train_transform, 
    # test_transform=valid_transform
)

# download resnet18 from the hub
# The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. 
# You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights
# resnext50_32x4d_ssl
model_name = "resnext50_32x4d"
weights = "ResNeXt50_32X4D_Weights.IMAGENET1K_V2"

pytorch_model = Resnet(num_classes=5, model=model_name, weights="DEFAULT")
# pytorch_model

# model_name = "efficientnet_b0"
# weights = "EfficientNet_B0_Weights.IMAGENET1K_V1"

# model_name = "resnext18"
# weights = "ResNet18_Weights.IMAGENET1K_V1"

# pytorch_model = Efficientnet(num_classes=5, model=model_name, weights=weights)
# pytorch_model

Using cache found in /Users/cespeleta/.cache/torch/hub/pytorch_vision_main


In [8]:
# Initialize Lightning model.
lit_model = LitModel(pytorch_model=pytorch_model, learning_rate=1e-4, weight_decay=1e-6)

In [10]:
monitor = "val_acc"
mode = "max"

callbacks = [
    EarlyStopping(monitor=monitor, patience=3, verbose=True, mode=mode),
    ModelCheckpoint(filename="best", monitor=monitor, mode=mode, verbose=True, save_last=True),
    LearningRateMonitor(logging_interval="epoch", log_momentum=True)
]

In [11]:
L.seed_everything(123)

trainer = L.Trainer(
    callbacks=callbacks,
    max_epochs=10,
    accelerator="mps",
    devices=1,
    logger=TensorBoardLogger(save_dir="logs/", name=model_name, sub_dir="tf_logs"),
    deterministic=True,
    gradient_clip_val=0.1,  # from Kaggle Notebook
    num_sanity_val_steps=1
)

Global seed set to 123
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model=lit_model, datamodule=dm)

len(train_image_paths)=17117
len(valid_image_paths)=4280



  | Name          | Type               | Params | In sizes       | Out sizes
----------------------------------------------------------------------------------
0 | pytorch_model | Resnet             | 23.0 M | [1, 3, 32, 32] | [1, 5]   
1 | train_acc     | MulticlassAccuracy | 0      | ?              | ?        
2 | valid_acc     | MulticlassAccuracy | 0      | ?              | ?        
3 | test_acc      | MulticlassAccuracy | 0      | ?              | ?        
----------------------------------------------------------------------------------
23.0 M    Trainable params
0         Non-trainable params
23.0 M    Total params
91.961    Total estimated model params size (MB)


                                                                           

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)


Epoch 0: 100%|██████████| 534/534 [08:10<00:00,  1.09it/s, v_num=3, val_loss=0.499, val_acc=0.830, train_acc=0.778]

Metric val_acc improved. New best score: 0.830
Epoch 0, global step 534: 'val_acc' reached 0.83014 (best 0.83014), saving model to 'logs/resnext50_32x4d/version_3/checkpoints/best.ckpt' as top 1


Epoch 1: 100%|██████████| 534/534 [08:06<00:00,  1.10it/s, v_num=3, val_loss=0.465, val_acc=0.840, train_acc=0.848]

Metric val_acc improved by 0.010 >= min_delta = 0.0. New best score: 0.840
Epoch 1, global step 1068: 'val_acc' reached 0.84019 (best 0.84019), saving model to 'logs/resnext50_32x4d/version_3/checkpoints/best.ckpt' as top 1


Epoch 2: 100%|██████████| 534/534 [08:00<00:00,  1.11it/s, v_num=3, val_loss=0.469, val_acc=0.849, train_acc=0.864]

Metric val_acc improved by 0.008 >= min_delta = 0.0. New best score: 0.849
Epoch 2, global step 1602: 'val_acc' reached 0.84860 (best 0.84860), saving model to 'logs/resnext50_32x4d/version_3/checkpoints/best.ckpt' as top 1


Epoch 3: 100%|██████████| 534/534 [07:59<00:00,  1.11it/s, v_num=3, val_loss=0.474, val_acc=0.846, train_acc=0.881]

Epoch 3, global step 2136: 'val_acc' was not in top 1


Epoch 4: 100%|██████████| 534/534 [08:00<00:00,  1.11it/s, v_num=3, val_loss=0.461, val_acc=0.849, train_acc=0.894]

Metric val_acc improved by 0.000 >= min_delta = 0.0. New best score: 0.849
Epoch 4, global step 2670: 'val_acc' reached 0.84907 (best 0.84907), saving model to 'logs/resnext50_32x4d/version_3/checkpoints/best.ckpt' as top 1


Epoch 5: 100%|██████████| 534/534 [08:08<00:00,  1.09it/s, v_num=3, val_loss=0.465, val_acc=0.851, train_acc=0.908]

Metric val_acc improved by 0.002 >= min_delta = 0.0. New best score: 0.851
Epoch 5, global step 3204: 'val_acc' reached 0.85093 (best 0.85093), saving model to 'logs/resnext50_32x4d/version_3/checkpoints/best.ckpt' as top 1


Epoch 6:  73%|███████▎  | 390/534 [36:40<13:32,  5.64s/it, v_num=3, val_loss=0.465, val_acc=0.851, train_acc=0.908]  

In [None]:
trainer.callbacks[0].best_score

# Inference

In [None]:
dm.setup()
test_dl = dm.predict_dataloader()

In [None]:
import numpy as np
import torch

# Load checkpoints
chk_path = trainer.callbacks[-1].best_model_path
# chk_path = "logs/efficientnet_b0/version_12/checkpoints/best.ckpt"
print(chk_path)
chk = torch.load(chk_path)

# Prepare weights dict
model_weights = {k.replace("model.", ""): v for k, v in chk["state_dict"].items()}

# Init model and set eval mode
model = Efficientnet()
model.load_state_dict(model_weights)
model.to("mps")
model.eval()

In [None]:
test_df = pd.read_csv("input/sample_submission.csv")

final_preds = []
with torch.no_grad():
    for xb, _ in test_dl:
        xb = xb.to("mps")
        batch_pred = model(xb)
        final_preds.append(batch_pred.to("cpu").detach())

final_pred_class = np.vstack(final_preds).argmax(axis=1)
test_df.label = final_pred_class
# test_df.to_csv("submission.csv", index=False)
test_df