In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
 # Install required libs
!pip install segmentation-models-pytorch
!pip install pytorch-lightning==1.9
!pip install albumentations
!pip install torch

In [None]:
!pip uninstall crcmod
!pip install --no-cache-dir -U crcmod

In [None]:
import os
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import albumentations as albu
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import PIL
import torch
import numpy as np
import segmentation_models_pytorch as smp
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [None]:
pl.seed_everything(78, workers=True)

In [None]:
!gsutil cp 'gs://res-id/cnn/training/prepped_gaip/mean_std_sentinel_v7.npy' .
mean_std = np.load('./mean_std_sentinel_v7.npy')

In [None]:
mean_std.shape

In [None]:
!gsutil cp gs://res-id/cnn/training/prepped_gaip/reservoirs_10band_nofp_ab_v7.zip .
!mkdir -p ./data
!unzip reservoirs_10band_nofp_ab_v7.zip -d ./data/

In [None]:
DATA_DIR = './data/reservoirs_10band/'

In [None]:
x_train_dir = os.path.join(DATA_DIR, 'img_dir/train')
y_train_dir = os.path.join(DATA_DIR, 'ann_dir/train')

x_valid_dir = os.path.join(DATA_DIR, 'img_dir/val')
y_valid_dir = os.path.join(DATA_DIR, 'ann_dir/val')

In [None]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
        albu.RandomRotate90(p=0.5),


        albu.ShiftScaleRotate(scale_limit=(0, 0.25), rotate_limit=0, shift_limit=0.0, p=0.5, border_mode=0),

    ]
    return albu.Compose(train_transform, is_check_shapes=False)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing():
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform, is_check_shapes=False)

def normalize_image(ar, mean_std):
    return (ar - mean_std[0])/mean_std[1]
    return ar

In [None]:
class Dataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. noralization, shape manipulation, etc.)

    """

    CLASSES = ['background', 'water']

    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
            mean_std = None
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id.replace('.tif', '.png')) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mean_std = mean_std

    def __getitem__(self, i):

        # read data
        image = io.imread(self.images_fps[i])
        mask = io.imread(self.masks_fps[i])

        # extract certain classes from mask
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        if self.mean_std is not None:
            image = normalize_image(image, self.mean_std)


        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=(image.astype(np.float32)), mask=mask)
            image, mask = sample['image'], sample['mask']
            # if np.random.randint(0,1):
            gauss_scale = np.random.uniform(0, 0.1)
            image = image + np.random.normal(scale=gauss_scale, size=image.shape)


        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']


        #Convert to PIL
        return {'image':image, 'mask':mask}

    def __len__(self):
        return len(self.ids)

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        if i ==0:
          plt.imshow(image[:,:, [7,9,8]])
          print('Final mean', image.mean())
        else:
          plt.imshow(image)

    plt.show()

In [None]:
# Lets look at data we have

dataset = Dataset(x_train_dir, y_train_dir, classes=['Water'],
                #   preprocessing=get_preprocessing(),
                  augmentation=get_training_augmentation(),
                  mean_std=mean_std,
)

batch = dataset[8] # get some sample
visualize(
    image=batch['image'],
    water_mask=batch['mask'].squeeze(),
)

In [None]:
class ResModel(pl.LightningModule):

    def __init__(self, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.MAnet(encoder_name=encoder_name, in_channels=in_channels, classes=out_classes,# encoder_weights=None, decoder_use_batchnorm=True,
                                      aux_params=dict(
                                          classes=1,
                                          dropout=0.5
                                          )
        )

        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

        self.crop_transform = torchvision.transforms.CenterCrop(500)

    def forward(self, image):
        # normalize image here
        mask = self.model(image)[0]
        return self.crop_transform(mask)

    def shared_step(self, batch, stage):

        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32,
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        # prob_mask = self.forward(image).sigmoid()
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")

        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset
        # with "empty" images (images without target class) a large gap could be observed.
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }

        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):

        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.0001) #, weight_decay=1e-5) # Adding weight decay = 1e-5 to start
        return [optim], [torch.optim.lr_scheduler.ExponentialLR(optim, 0.9)] # was 0.95

In [None]:
model = ResModel("resnet34", in_channels=10, out_classes=1)

In [None]:

early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(monitor="valid_dataset_iou", mode='max', min_delta=0.00, patience=5)
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath="/content/", save_top_k=1, monitor="valid_dataset_iou", mode='max')
logger = pl.loggers.CSVLogger("/content/logs/", name="manet_resnet_model25")

In [None]:
CLASSES = ['Water']
train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    preprocessing=get_preprocessing(),
    augmentation=get_training_augmentation(),
    classes=CLASSES,
    mean_std=mean_std,
)

valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    preprocessing=get_preprocessing(),
    classes=CLASSES,
    mean_std=mean_std,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2)

In [None]:


trainer = pl.Trainer(
    gpus=1,
    max_epochs=80,
    callbacks=[early_stop_callback, checkpoint_callback],
    deterministic=True,
    logger=logger
)


trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)

In [None]:

!cp /content/*.ckpt /content/drive/MyDrive/pytorch_training/
!cp -r /content/logs/* /content/drive/MyDrive/pytorch_training/

In [None]:
# from google.colab import runtime
# runtime.unassign()