In [None]:
# Install required libs
!pip install -U segmentation-models-pytorch --user 

In [None]:
!python --version

In [None]:
# !pip uninstall -y segmentation-models-pytorch

## Loading data

For this example we will use **CamVid** dataset. It is a set of:
 - **train** images + segmentation masks
 - **validation** images + segmentation masks
 - **test** images + segmentation masks
 
All images have 320 pixels height and 480 pixels width.
For more inforamtion about dataset visit http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/.

In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import shutil
import sklearn
from glob import glob
from sklearn.model_selection import train_test_split

from distutils.dir_util import copy_tree

In [None]:
torch.cuda.empty_cache()

In [None]:
DATA_DIR = '../input/agrocodetransformed/transformed'
NB_CLASSES = 4
MAX_N_SAMPLES = 10 ** 9
ids = os.listdir(DATA_DIR)

In [None]:
train_ids, val_ids = train_test_split(ids, test_size=0.2)

In [None]:
len(train_ids), len(val_ids)

In [None]:
!mkdir train
!mkdir val
!mkdir train_labels
!mkdir val_labels

In [None]:
for sample_id in train_ids:
    img_pth = glob(os.path.join(DATA_DIR, sample_id, "images", "*"))[0]
    shutil.copy(img_pth, f"train")
    
    label_dir_pth = os.path.join(DATA_DIR, sample_id, "masks")
    copy_tree(label_dir_pth, f"train_labels/{sample_id}/")

In [None]:
for sample_id in val_ids:
    img_pth = glob(os.path.join(DATA_DIR, sample_id, "images", "*"))[0]
    shutil.copy(img_pth, f"val")
    
    label_dir_pth = os.path.join(DATA_DIR, sample_id, "masks")
    copy_tree(label_dir_pth, f"val_labels/{sample_id}/")

In [None]:
x_train_dir = "train"
y_train_dir = "train_labels"

x_val_dir = "val"
y_val_dir = "val_labels"

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
#     CLASSES = ['sky', 'building', 'pole', 'road', 'pavement', 
#                'tree', 'signsymbol', 'fence', 'car', 
#                'pedestrian', 'bicyclist', 'unlabelled']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir
    ):
        preprocessing_fn = get_preprocessing(smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS))
        self.images_files = glob(os.path.join(images_dir, "*"))[:MAX_N_SAMPLES]
        self.masks_dir = masks_dir
        self.augmentation = augmentation
    
        self.images = []
        self.masks = []
        
        for i in range(len(self)):
            print(f"enter dataset with index {i}")
            image_pth = self.images_files[i]
            obj_id = image_pth.split("/")[-1].split(".")[0]

            image = cv2.imread(image_pth)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, dsize=(512, 512))

    #         mask = cv2.imread(self.masks_fps[i], 0)
            mask = np.zeros((512, 512, NB_CLASSES), dtype=float)
            for class_idx in range(NB_CLASSES):
                # class_idx in annotations starts with 1, so do + 1
                mask_pthes = glob(os.path.join(self.masks_dir, obj_id, f"*_class_{class_idx + 1}.png"))
                for mask_pth in mask_pthes:
                    curr_msk = cv2.resize(cv2.imread(mask_pth), dsize=(512, 512)).mean(axis=2) > 0
                    mask[:, :, class_idx][curr_msk] = 1
                    
                sample = preprocessing_fn(image=image, mask=mask)
                image, mask = sample['image'], sample['mask']
                image = image.transpose(1, 2, 0).astype("float32")
                
            image = image.transpose(2, 0, 1).astype('float32')

            print("image shape", image.shape, "mask shape", mask.shape)
            self.images.append(image)
            self.masks.append(mask)
        
    
    def __getitem__(self, i: int):
        
        img, mask = self.images[i], self.masks[i]
        print("img shape", img.shape, "mask shape", mask.shape)
        return img, mask
        
    def __len__(self):
        return len(self.images_files)

import albumentations as albu

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Create model and train

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp

In [None]:
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ["berry", "leaf", "stem", "flower"]
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)


In [None]:
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    preprocessing=get_preprocessing(preprocessing_fn)
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
x, y = train_dataset[0]
plt.imshow(x.transpose((1, 2, 0)).astype("float"))

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = model.to(device)

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0

for i in range(0, 40):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model.state_dict(), './best_model.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

## Test best saved model

In [None]:
# load best saved checkpoint
model.load_state_dict(torch.load('./best_model.pth'))
best_model = model

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(valid_loader)

## Visualize predictions

In [None]:
from skimage import io
from skimage import color
from skimage import segmentation

In [None]:
label_idx = 3

In [None]:
for i in range(5):
    n = np.random.choice(len(valid_dataset))
    
#     image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = valid_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    
    image = np.transpose(image, (1, 2, 0))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    print(pr_mask.shape, gt_mask.shape)
    visualize(
        image=image, 
        ground_truth_mask=gt_mask[label_idx], 
        predicted_mask=pr_mask[label_idx],
    )

In [None]:
from torchvision.utils import draw_segmentation_masks

In [None]:
pr_mask.shape

In [None]:


image = cv2.imread("./val/37786E52-F57E-4777-A9E8-D94CBFE89EAD_1_105_c.jpeg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

preprocessing_fn = get_preprocessing(smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS))
x, _ = preproc_image(image)

x_tensor = torch.from_numpy(x).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())


In [None]:
img_tensor = torch.tensor(np.transpose(image, (2, 0, 1)), dtype=torch.uint8)

In [None]:
pr_masks = [cv2.resize(pr_mask[idx], dsize=(img_tensor.shape[2], img_tensor.shape[1])) for idx in range(4)]

In [None]:
pr_mask = np.stack(pr_masks)

In [None]:
print(img_tensor.shape, pr_mask.shape)

segm_img = draw_segmentation_masks(img_tensor, torch.tensor(pr_mask, dtype=torch.bool), alpha=0.3, colors=["blue", "red", "purple", "green"])

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(segm_img.numpy(), (1, 2, 0)))

In [None]:
class DeployedSegmentation:
    def __init__(self, weights_pth: str):
        ENCODER = 'se_resnext50_32x4d'
        ENCODER_WEIGHTS = 'imagenet'
        self.CLASSES = ["berry", "leaf", "stem", "flower"]
        ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
        DEVICE = 'cuda'

        # create segmentation model with pretrained encoder
        self.model = smp.FPN(
            encoder_name=ENCODER, 
            encoder_weights=ENCODER_WEIGHTS, 
            classes=len(self.CLASSES), 
            activation=ACTIVATION,
        )
        
    def get_preds(self, img_pth: str) -> torch.Tensor:
            


In [None]:
plt.imshow(pr_mask[1])

In [None]:
a = torch.zeros(10, dtype=torch.bool)

In [None]:
!ls val