# Imports

In [1]:
!pip install -q torchsummary
!pip install -q segmentation_models_pytorch

[0m

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

from matplotlib import animation
from IPython import display

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.models.detection import maskrcnn_resnet50_fpn
import torchvision.models as models

from torchsummary import summary
import segmentation_models_pytorch as smp

# Utilites

In [3]:
data_dir: str = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
def get_band_images(idx: str, parrent_folder: str, band: str) -> np.array:
    return np.load(os.path.join(data_dir, parrent_folder, idx, f'band_{band}.npy'))

In [6]:
_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)

def normalize_range(data, bounds):
    """Maps data to the range [0, 1]."""
    return (data - bounds[0]) / (bounds[1] - bounds[0])


def get_ash_color_images(idx: str, parrent_folder: str, get_mask_frame_only=False) -> np.array:
    band11 = get_band_images(idx, parrent_folder, '11')
    band14 = get_band_images(idx, parrent_folder, '14')
    band15 = get_band_images(idx, parrent_folder, '15')
    
    if get_mask_frame_only:
        band11 = band11[:,:,4]
        band14 = band14[:,:,4]
        band15 = band15[:,:,4]

    r = normalize_range(band15 - band14, _TDIFF_BOUNDS)
    g = normalize_range(band14 - band11, _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(band14, _T11_BOUNDS)
    false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
    return false_color

In [7]:
def get_mask_image(idx: str, parrent_folder: str) -> np.array:
    return np.load(os.path.join(data_dir, parrent_folder, idx, 'human_pixel_masks.npy')) 

# Model

In [8]:
class Upsampled(nn.Module):
    def __init__(self, back_bone: nn.Module):
        super(Upsampled, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.back_bone = back_bone
        self.down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)

    def forward(self, x):
        # Forward pass through the layers
        x = self.up(x)
        x = self.back_bone(x)
        x = self.down(x)
        return x
    
    def forward_no_downsample_back(self, x):
        # Forward pass through the layers
        x = self.up(x)
        x = self.back_bone(x)
        return x 

In [9]:
core = smp.Unet(
    encoder_name ='efficientnet-b0',
    encoder_weights=None,
    in_channels=3,   
    classes=1,
    activation=None#"sigmoid",
    )
model = Upsampled(core)
model.to(device)

Upsampled(
  (up): Upsample(scale_factor=2.0, mode='bilinear')
  (back_bone): Unet(
    (encoder): EfficientNetEncoder(
      (_conv_stem): Conv2dStaticSamePadding(
        3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1))
      )
      (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_blocks): ModuleList(
        (0): MBConvBlock(
          (_depthwise_conv): Conv2dStaticSamePadding(
            32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
            (static_padding): ZeroPad2d((1, 1, 1, 1))
          )
          (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (_se_reduce): Conv2dStaticSamePadding(
            32, 8, kernel_size=(1, 1), stride=(1, 1)
            (static_padding): Identity()
          )
          (_se_expand): Conv2dStaticSamePadding(
            8, 32, kernel_size=(1,

In [10]:
#model.load_state_dict(torch.load('/kaggle/working/model_checkpoint_e1.pt'))
#model.eval()
#model.to(device)

In [11]:
summary(model, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1          [-1, 3, 512, 512]               0
         ZeroPad2d-2          [-1, 3, 513, 513]               0
Conv2dStaticSamePadding-3         [-1, 32, 256, 256]             864
       BatchNorm2d-4         [-1, 32, 256, 256]              64
MemoryEfficientSwish-5         [-1, 32, 256, 256]               0
         ZeroPad2d-6         [-1, 32, 258, 258]               0
Conv2dStaticSamePadding-7         [-1, 32, 256, 256]             288
       BatchNorm2d-8         [-1, 32, 256, 256]              64
MemoryEfficientSwish-9         [-1, 32, 256, 256]               0
         Identity-10             [-1, 32, 1, 1]               0
Conv2dStaticSamePadding-11              [-1, 8, 1, 1]             264
MemoryEfficientSwish-12              [-1, 8, 1, 1]               0
         Identity-13              [-1, 8, 1, 1]               0
Conv2dStaticSame

# Trainer

In [12]:
class Dice(nn.Module):
    def __init__(self, use_sigmoid=True):
        super(Dice, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.use_sigmoid = use_sigmoid

    def forward(self, inputs, targets, smooth=1):
        if self.use_sigmoid:
            inputs = self.sigmoid(inputs)       
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        dice = (2.0 *intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return dice
    
dice = Dice()

In [13]:
class MyTrainer:
    def __init__(self, model, optimizer, loss_fn, lr_scheduler):
        self.validation_losses = []
        self.batch_losses = []
        self.epoch_losses = []
        self.learning_rates = []
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.lr_scheduler = lr_scheduler
        self._check_optim_net_aligned()

    # Ensures that the given optimizer points to the given model
    def _check_optim_net_aligned(self):
        assert self.optimizer.param_groups[0]['params'] == list(self.model.parameters())

    # Trains the model
    def fit(self,
            train_dataloader: DataLoader,
            test_dataloader: DataLoader,
            epochs: int = 10,
            eval_every: int = 1,
            ):
  
        for e in range(epochs):
            print("New learning rate: {}".format(self.lr_scheduler.get_last_lr()))
            self.learning_rates.append(self.lr_scheduler.get_last_lr()[0])

            # Stores data about the batch
            batch_losses = []
            sub_batch_losses = []

            for i, data in enumerate(train_dataloader):
                self.model.train()
                if i % 100 == 0:
                    print(f'epotch: {e} batch: {i}/{len(train_dataloader)} loss: {torch.Tensor(sub_batch_losses).mean()}')
                    sub_batch_losses.clear()
                # Every data instance is an input + label pair
                images, mask = data
                
                images = images.to(device)
                mask = mask.to(device)

                # Zero your gradients for every batch!
                self.optimizer.zero_grad()
                # Make predictions for this batch
                outputs = self.model(images)#['out']
                # Compute the loss and its gradients
                loss = self.loss_fn(outputs, mask)
                loss.backward()
                # Adjust learning weights
                self.optimizer.step()

                # Saves data
                self.batch_losses.append(loss.item())
                batch_losses.append(loss)
                sub_batch_losses.append(loss)

            # Adjusts learning rate
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            # Reports on the path
            mean_epoch_loss = torch.Tensor(batch_losses).mean()
            self.epoch_losses.append(mean_epoch_loss.item())
            print('Train Epoch: {} Average Loss: {:.6f}'.format(e, mean_epoch_loss))

            # Reports on the training progress
            if (e + 1) % eval_every == 0:
                torch.save(self.model.state_dict(), "model_checkpoint_e" + str(e) + ".pt")
                with torch.no_grad():
                    self.model.eval()
                    losses = []
                    for i, data in enumerate(test_dataloader):
                        # Every data instance is an input + label pair
                        images, mask = data

                        images = images.to(device)
                        mask = mask.to(device)

                        output = self.model(images)#['out']
                        loss = self.loss_fn(output, mask)
                        losses.append(loss.item())
                        
                    avg_loss = torch.Tensor(losses).mean().item()
                    self.validation_losses.append(avg_loss)
                    print("Validation loss after", (e + 1), "epochs was", round(avg_loss, 4))

# Dataset

In [14]:
class ContrailsAshDataset(torch.utils.data.Dataset):
    def __init__(self, parrent_folder: str):
        self.df_idx: pd.DataFrame = pd.DataFrame({'idx': os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/{parrent_folder}')})
        self.parrent_folder: str = parrent_folder

    def __len__(self):
        return len(self.df_idx)

    def __getitem__(self, idx):
        image_id: str = str(self.df_idx.iloc[idx]['idx'])
        images = torch.tensor(np.reshape(get_ash_color_images(image_id, self.parrent_folder, get_mask_frame_only=True), (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
        mask = torch.tensor(get_mask_image(image_id, self.parrent_folder)).to(torch.float32).permute(2, 0, 1)
        return images, mask

In [15]:
dataset_train = ContrailsAshDataset('train')
dataset_validation = ContrailsAshDataset('validation')

data_loader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
data_loader_validation = DataLoader(dataset_validation, batch_size=16, shuffle=True, num_workers=2, drop_last=True)

In [16]:
len(os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/train'))

20529

# Train

In [17]:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10))# smp.losses.DiceLoss(mode='binary') 
optimizer = optim.Adam(model.parameters(), lr=5e-4)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)
model.train()

num_epochs = 7

trainer = MyTrainer(model, optimizer, criterion, lr_scheduler)
trainer.fit(data_loader_train, data_loader_validation, epochs=num_epochs)

New learning rate: [0.0005]
epotch: 0 batch: 0/1283 loss: nan
epotch: 0 batch: 10/1283 loss: 0.5986009240150452
epotch: 0 batch: 20/1283 loss: 0.40431323647499084
epotch: 0 batch: 30/1283 loss: 0.30684125423431396
epotch: 0 batch: 40/1283 loss: 0.2558574676513672
epotch: 0 batch: 50/1283 loss: 0.23649191856384277
epotch: 0 batch: 60/1283 loss: 0.23285862803459167
epotch: 0 batch: 70/1283 loss: 0.21275082230567932
epotch: 0 batch: 80/1283 loss: 0.20246577262878418
epotch: 0 batch: 90/1283 loss: 0.1882757693529129
epotch: 0 batch: 100/1283 loss: 0.18319836258888245
epotch: 0 batch: 110/1283 loss: 0.2171717882156372
epotch: 0 batch: 120/1283 loss: 0.16159948706626892
epotch: 0 batch: 130/1283 loss: 0.15907007455825806
epotch: 0 batch: 140/1283 loss: 0.14879387617111206
epotch: 0 batch: 150/1283 loss: 0.16033130884170532
epotch: 0 batch: 160/1283 loss: 0.15710127353668213
epotch: 0 batch: 170/1283 loss: 0.17864295840263367


KeyboardInterrupt: 

# Find Optimal Threshold

In [None]:
class DiceThresholdTester:
    
    def __init__(self, model: nn.Module, data_loader: torch.utils.data.DataLoader):
        self.model = model
        self.data_loader = data_loader
        self.cumulative_mask_pred = []
        self.cumulative_mask_true = []
        
    def precalculate_prediction(self) -> None:
        self.model.eval()
        with torch.no_grad():
            sigmoid = nn.Sigmoid()

            for images, mask_true in self.data_loader:
                if torch.cuda.is_available():
                    images = images.cuda()

                mask_pred = sigmoid(model.forward(images))

                self.cumulative_mask_pred.append(mask_pred.cpu().detach().numpy())
                self.cumulative_mask_true.append(mask_true.cpu().detach().numpy())

            self.cumulative_mask_pred = np.concatenate(self.cumulative_mask_pred, axis=0)
            self.cumulative_mask_true = np.concatenate(self.cumulative_mask_true, axis=0)

            self.cumulative_mask_pred = torch.flatten(torch.from_numpy(self.cumulative_mask_pred))
            self.cumulative_mask_true = torch.flatten(torch.from_numpy(self.cumulative_mask_true))
    
    def test_threshold(self, threshold: float) -> float:
        _dice = Dice(use_sigmoid=False)
        after_threshold = np.zeros(self.cumulative_mask_pred.shape)
        after_threshold[self.cumulative_mask_pred[:] > threshold] = 1
        after_threshold[self.cumulative_mask_pred[:] < threshold] = 0
        after_threshold = torch.flatten(torch.from_numpy(after_threshold))
        return _dice(self.cumulative_mask_true, after_threshold).item()

In [None]:
dice_threshold_tester = DiceThresholdTester(model, data_loader_validation)
dice_threshold_tester.precalculate_prediction()

In [None]:
thresholds_to_test = [round(x * 0.01, 2) for x in range(101)]

optim_threshold = 0.975
best_dice_score = -1

thresholds = []
dice_scores = []

for t in thresholds_to_test:
    dice_score = dice_threshold_tester.test_threshold(t)
    if dice_score > best_dice_score:
        best_dice_score = dice_score
        optim_threshold = t
    
    thresholds.append(t)
    dice_scores.append(dice_score)
    
print(f'Best Threshold: {optim_threshold} with dice: {best_dice_score}')
df_threshold_data = pd.DataFrame({'Threshold': thresholds, 'Dice Score': dice_scores})

In [None]:
sns.lineplot(data=df_threshold_data, x='Threshold', y='Dice Score')
plt.axhline(y=best_dice_score, color='green')
plt.axvline(x=optim_threshold, color='green')
plt.text(-0.02, best_dice_score * 0.96, f'{best_dice_score:.3f}', va='center', ha='left', color='green')
plt.text(optim_threshold - 0.01, 0.02, f'{optim_threshold}', va='center', ha='right', color='green')
plt.ylim(bottom=0)
plt.title('Threshold vs Dice Score')
plt.show()

# Evaluate

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

batches_to_show = 4
model.eval()

for i, data in enumerate(data_loader_validation):
    images, mask = data
    
    # Predict mask for this instance
    if torch.cuda.is_available():
        images = images.cuda()
    predicated_mask = sigmoid(model.forward(images[:, :, :, :]).cpu().detach().numpy())
    
    # Apply threshold
    predicated_mask_with_threshold = np.zeros((images.shape[0], 256, 256))
    predicated_mask_with_threshold[predicated_mask[:, 0, :, :] < optim_threshold] = 0
    predicated_mask_with_threshold[predicated_mask[:, 0, :, :] > optim_threshold] = 1
    
    images = images.cpu()
        
    for img_num in range(0, images.shape[0]):
        fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20,10))
        axes = axes.flatten()
        
        # Show groud trought 
        axes[0].imshow(mask[img_num, 0, :, :])
        axes[0].axis('off')
        axes[0].set_title('Ground Truth')
        
        # Show ash color scheme input image
        axes[1].imshow( np.concatenate(
            (
            np.expand_dims(images[img_num, 4, :, :], axis=2),
            np.expand_dims(images[img_num, 12, :, :], axis=2),
            np.expand_dims(images[img_num, 20, :, :], axis=2)
        ), axis=2))
        axes[1].axis('off')
        axes[1].set_title('Ash color scheeme input - Frame 4')

        # Show predicted mask
        axes[2].imshow(predicated_mask[img_num, 0, :, :], vmin=0, vmax=1)
        axes[2].axis('off')
        axes[2].set_title('Predicted probability mask')

        # Show predicted mask after threshold
        axes[3].imshow(predicated_mask_with_threshold[img_num, :, :])
        axes[3].axis('off')
        axes[3].set_title('Predicted mask with threshold')
        plt.show()
    
    if i + 1 >= batches_to_show:
        break