In [152]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
import scipy
import torch
import torch.nn as nn
from PIL import Image
from scipy import ndimage
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
import pytorch_lightning as pl
import flash 
from flash.image import ImageClassifier
from flash.core.data.data_module import DataModule

%matplotlib inline
np.random.seed(1)

In [153]:
arr = np.random.randn(10, 5, 3)
arr = arr.transpose(2, 0, 1)
arr.shape

(3, 10, 5)

In [154]:
# Loading the data (signs)
def get_imgs_labels(h5_file_path):
    f = h5py.File(h5_file_path, "r")
    ds_keys = [key for key in f.keys()]
    imgs = np.array(f[ds_keys[1]])    
    labels = np.array(f[ds_keys[2]])
    list_classes = np.array(f[ds_keys[0]])
    imgs = np.transpose(imgs, (0, 3, 1, 2))
    return imgs, labels, list_classes

train_x, train_y, train_classes = get_imgs_labels("./datasets/train_signs.h5")
test_x, test_y, test_classes = get_imgs_labels("./datasets/test_signs.h5")
print(train_x.shape, train_y.shape)
print(test_x.shape, test_y.shape)

(1080, 3, 64, 64) (1080,)
(120, 3, 64, 64) (120,)


In [156]:
# CONSTANTS

NUM_FOLDS = 5
BATCH_SIZE = 64
NUM_WORKERS = 4

In [157]:
#img = Image.fromarray(np.uint8(test_x[0])).convert('RGB')
#img

In [159]:
from flash.core.data.data_source import DataSource, DefaultDataKeys
from torchvision.datasets.folder import make_dataset
from typing import Any, Dict, Iterable , Mapping, Sequence, Callable   

class SignsDataSource(DataSource):    
    def load_data(self, h5_file_path: str) -> Sequence[Mapping[str, Any]]:
        f = h5py.File(h5_file_path, "r")
        ds_keys = [key for key in f.keys()]
        img_arr = np.array(f[ds_keys[1]])    
        label_arr = np.array(f[ds_keys[2]])        
        return [
            {
                DefaultDataKeys.INPUT: img,
                DefaultDataKeys.TARGET: label
            } 
            for img, label in list(zip(img_arr, label_arr))]

    def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        sample[DefaultDataKeys.INPUT] = Image.fromarray(np.uint8(sample[DefaultDataKeys.INPUT])).convert('RGB')
        return sample


In [161]:
from typing import Optional
from flash.core.data.process import Preprocess
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources

class SignsPreprocess(Preprocess):
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
    ):
        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_sources={
                DefaultDataSources.FILES: SignsDataSource(),
            },
            default_data_source=DefaultDataSources.FILES,
        )

    def default_transforms(self) -> Dict[str, Callable]:
        return {
            "to_tensor_transform": transforms.ToTensor(),
            "post_tensor_transform": transforms.Normalize(
                torch.tensor([0.485, 0.456, 0.406]), 
                torch.tensor([0.229, 0.224, 0.225])
                ),
            "collate": torch.utils.data._utils.collate.default_collate
        }        

    def get_state_dict(self) -> Dict[str, Any]:
        return {**self.transforms}

    @classmethod
    def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
        return cls(**state_dict)

In [163]:
# for a training and label data in form of numpy arrays, return a fold_index array whose elements
# represent the fold index. The length of this fold_index array is same as length of input dataset
# and the items for which fold_index array value == cv iteration count are to be used for validation 
# in the corresponding cross validation iteration with rest of the items ( for which fold_index 
# array value != cv iteration count ) being used for training (typical ration being 80:20)
def get_skf_index(num_folds, X, y):
    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state = 42)
    train_fold_index = np.zeros(len(y))
    for fold, (train_index, val_index) in enumerate(skf.split(X=X, y=y)):
        train_fold_index[val_index] = [fold + 1] * len(val_index)
    return train_fold_index

k_folds = get_skf_index(num_folds=NUM_FOLDS, X=train_x, y=train_y)

In [165]:
def split_data(fold, kfolds, X, y):
    train_X = X[kfolds != fold+1]        
    train_y = y[kfolds != fold+1]    
    val_X = X[kfolds == fold+1]
    val_y = y[kfolds == fold+1]
    return train_X, train_y, val_X, val_y

In [166]:
from flash.image.classification.transforms import default_transforms
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.data_source import DefaultDataKeys
from flash.image import ImageClassificationData

train_X, train_y, val_X, val_y = split_data(0, k_folds, train_x, train_y)

signs_default_transform = {
    "to_tensor_transform": nn.Sequential(
            ApplyToKeys(DefaultDataKeys.INPUT, transforms.ToTensor()),
            ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
        ),
        "post_tensor_transform": ApplyToKeys(
            DefaultDataKeys.INPUT,
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ),
    "collate": torch.utils.data._utils.collate.default_collate
}

data_module = ImageClassificationData.from_numpy(
    train_data = train_X,
    train_targets= train_y,
    val_data = val_X,
    val_targets = val_y,
    #test_data = test_x,
    #test_targets = test_y,
    train_transform = signs_default_transform,
    val_transform = signs_default_transform,
    #test_transform = signs_default_transform,
    #predict_transform = signs_default_transform,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKERS
)

In [148]:
item = data_module.val_dataset[10]
print(item[DefaultDataKeys.TARGET])
type(item[DefaultDataKeys.INPUT])

3


PIL.Image.Image

In [177]:
test_y[10]

5

In [172]:
from flash.image import ImageClassifier
model = ImageClassifier(backbone="resnet18", num_classes=6 )

trainer = flash.Trainer(max_epochs=10, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=data_module, strategy="freeze")


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type           | Params
-------------------------------------------------
0 | train_metrics | ModuleDict     | 0     
1 | val_metrics   | ModuleDict     | 0     
2 | adapter       | DefaultAdapter | 11.2 M
-------------------------------------------------
12.7 K    Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.718    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [179]:
model.predict([test_x[110]], data_source=DefaultDataSources.NUMPY)

[0]