In [None]:
import numpy as np
import torch
import torch.nn as nn
import os
from torchvision.transforms import Compose,ToTensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pandas as pd
import seaborn as sns

!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

from torchvision import models
try:
    from torchsummary import summary
except:
    !pip install torchsummary > /dev/null
    from torchsummary import summary
    
import torch.nn.functional as F
from torchvision import models

# config and devices

In [None]:
CONFIG = dict(
    seed=42,
    DATA_ROOT = '/kaggle/input/google-research-identify-contrails-reduce-global-warming/',
    BATCH_SIZE = 16,
    IMG_SIZE = (256,256),
    NUM_TRAIN_SAMPLES = 1000,
    NUM_VAL_SAMPLES = 300,
    NUM_TEST_SAMPLES = 2)

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

device

In [None]:
_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)

def normalize_range(data, bounds):
    return (data - bounds[0]) / (bounds[1] - bounds[0])

def get_ash_img(band11, band14, band15):
    r = normalize_range(band15 - band14, _TDIFF_BOUNDS)
    g = normalize_range(band14 - band11, _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(band14, _T11_BOUNDS)
    false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
    return false_color

# Transfroms

In [None]:
train_transform = Compose([ToTensor()])
val_transform = Compose([ToTensor()])
test_transform = Compose([ToTensor()])


# Data set

In [None]:
data_dir: str = '/kaggle/input/google-research-identify-contrails-reduce-global-warming'
    
def get_band_images(idx: str, parrent_folder: str, band: str) -> np.array:
    return np.load(os.path.join(data_dir, parrent_folder, idx, f'band_{band}.npy'))


def get_ash_values(record_dir, allchannels):

    # get the bands of the record
    bands_data = []
        
    bands_data.append(np.load(os.path.join(record_dir, f'band_11.npy')))
    bands_data.append(np.load(os.path.join(record_dir, f'band_14.npy')))
    bands_data.append(np.load(os.path.join(record_dir, f'band_15.npy')))
        
    if allchannels: 
        images = get_ash_img(bands_data[0], bands_data[1], bands_data[2])
        ash = torch.tensor(np.reshape(images, (256, 256, 24))).to(torch.float32).permute(2, 0, 1)
    
    else:
        # Stack band data along the channel axis
        bands_data = np.stack(bands_data, axis=-1)
        
        #This is used to test the 3 channels data
        band11 = bands_data[:,:,4,0]
        band14 = bands_data[:,:,4,1]
        band15 = bands_data[:,:,4,2]
    
        #This is used to test the 9 channels data
        #band11 = np.stack([bands_data[:,:,5,0], bands_data[:,:,4,0], bands_data[:,:,6,0]], axis=-1)
        #band14 = np.stack([bands_data[:,:,5,1], bands_data[:,:,4,1], bands_data[:,:,6,1]], axis=-1)
        #band15 = np.stack([bands_data[:,:,5,2], bands_data[:,:,4,2], bands_data[:,:,6,2]], axis=-1)
        
        # get the false colour
        ash = get_ash_img(band11, band14, band15)
        #ash = torch.tensor(np.reshape(ash, (256, 256, 9))).to(torch.float32).permute(2, 0, 1)
        
    return ash

class contrailsDataset(Dataset):
    def __init__(self,base_dir:str, mode:str, num_samples:int, transform, allchannels):
        super().__init__()
        
        # init the attributes
        self.base_dir:str = base_dir
        self.mode:str = mode
        self.transform = transform
        self.records:list[str] = os.listdir(self.base_dir + self.mode)
        self.allchannels = allchannels
        
        # get in-ordered samples in the population
        self.records = self.records[:num_samples]
        
    def get_ash_img(self, bands):
        band11 = bands[:,:,4,0]
        band14 = bands[:,:,4,1]
        band15 = bands[:,:,4,2]
        return get_ash_img(band11,band14,band15)

    def __getitem__(self, idx):
        record_id = self.records[idx]
        record_dir = os.path.join(self.base_dir,self.mode,record_id)
        
        ash = get_ash_values(record_dir, self.allchannels)
        
        pixel_masks = None        
         # If the data type is 'train' or 'validation', load the masks
        if self.mode in ['train', 'validation']:
            pixel_masks_file = os.path.join(record_dir, 'human_pixel_masks.npy')
            pixel_masks = np.load(pixel_masks_file)
            
            
        if self.allchannels == False:
            ash = self.transform(ash)
            
        if self.mode != 'test':
            pixel_masks = self.transform(pixel_masks)
            sample = {'mask': pixel_masks, 'ash': ash}
        else:
            sample = {'ash': ash}
        return sample

    def __len__(self):
        return len(self.records)

In [None]:
trainData = contrailsDataset(CONFIG['DATA_ROOT'], mode='train',num_samples=CONFIG['NUM_TRAIN_SAMPLES'], transform=train_transform, allchannels=True)
valData = contrailsDataset(CONFIG['DATA_ROOT'], mode='validation',num_samples=CONFIG['NUM_VAL_SAMPLES'], transform=val_transform, allchannels=True)
testData = contrailsDataset(CONFIG['DATA_ROOT'], mode='test', num_samples=CONFIG['NUM_TEST_SAMPLES'],transform=test_transform, allchannels=True)

# allchannels binary value is True when we want to test 24 channels input, and false we want to have 3 channels input
# This takes in 3 channels
trainData2 = contrailsDataset(CONFIG['DATA_ROOT'], mode='train',num_samples=CONFIG['NUM_TRAIN_SAMPLES'], transform=train_transform, allchannels=False)
valData2 = contrailsDataset(CONFIG['DATA_ROOT'], mode='validation',num_samples=CONFIG['NUM_VAL_SAMPLES'], transform=val_transform, allchannels=False)
testData2 = contrailsDataset(CONFIG['DATA_ROOT'], mode='test', num_samples=CONFIG['NUM_TEST_SAMPLES'],transform=test_transform, allchannels=False)

In [None]:
len(trainData), len(valData), len(testData)

# DataLoader

In [None]:
train_dataloader = DataLoader(trainData, 
                              batch_size=CONFIG['BATCH_SIZE'], 
                              shuffle=True)

val_dataloader = DataLoader(valData, 
                            batch_size=CONFIG['BATCH_SIZE'],
                            shuffle=True)

test_dataloader = DataLoader(testData, 
                             batch_size=CONFIG['BATCH_SIZE'], 
                             shuffle=False)

train_dataloader2 = DataLoader(trainData2, 
                              batch_size=16, 
                              shuffle=True)

val_dataloader2 = DataLoader(valData2, 
                            batch_size=16,
                            shuffle=True)

test_dataloader2 = DataLoader(testData2, 
                             batch_size=CONFIG['BATCH_SIZE'], 
                             shuffle=False)

# **DATA EXPLORATION AND FORMATION**

In [None]:
def data_explotary(all_data_loader, type_data):
    contrail_image = 0
    no_contrail_image = 0


    for idx, batch in enumerate(all_data_loader): 
        mask = batch['mask']
        mask = torch.moveaxis(mask,1,-1)

        positive_count = 0
        negative_count = 0
        mask2 = mask[0].detach().numpy() 

        for row in range(256):
            for col in range(256):
                
                pixel_value = mask2[row, col]
                if pixel_value == 1:
                    positive_count += 1
                else:
                    negative_count += 1

        if positive_count > 0:
            contrail_image += 1
        else:
            no_contrail_image += 1
        

    print("Number of dataset:", len(all_train_data))
    print("Number of contrail images:", contrail_image)
    print("Number of no contrail images:", no_contrail_image)

    # Data for the pie chart
    data = [contrail_image, no_contrail_image]
    labels = ['Contrails Images', 'No Contrails Images']

    # Create a pie chart
    plt.figure(figsize=(6, 6))
    plt.pie(data, labels=labels, autopct='%1.1f%%', startangle=90)
    plt.title(f"Distribution of Contrails and Non-Contrails Images in the {type_data} Data")
    plt.show()

In [None]:
all_train_data = contrailsDataset(CONFIG['DATA_ROOT'], mode='train',num_samples=1000, transform=train_transform, allchannels=False)
all_train_data_loader = DataLoader(all_train_data, batch_size=1, shuffle=False)

all_validation_data = contrailsDataset(CONFIG['DATA_ROOT'], mode='validation',num_samples=300, transform=train_transform, allchannels=False)
all_validation_data_loader = DataLoader(all_validation_data, batch_size=1, shuffle=False)



# Visualise the composition of contrails dataset in both (1000) training and (300) validation datasets
# Training
data_explotary(all_train_data_loader, 'Training')
#Validation
data_explotary(all_validation_data_loader, 'Validation')


# Models Loading and Setup

In [None]:
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'

ACTIVATION = 'sigmoid' # could be None for logits 

# create segmentation model with pretrained encoder

deeplabv3 = smp.DeepLabV3Plus(
    encoder_name=ENCODER,
    in_channels=3,
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
).to(device)


pspnet = smp.PSPNet(
    encoder_name='mit_b3', 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
).to(device)

pspnet2 = smp.PSPNet(
    encoder_name=ENCODER, 
    in_channels=3,
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
).to(device)

# This is used to test the performances between the 24 transformed to 3 channels approach and other approaches 
pspnet = smp.PSPNet(
    encoder_name=ENCODER, 
    in_channels=3,
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
).to(device)

unet = smp.Unet(
    encoder_name=ENCODER,
    in_channels=3,
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
).to(device)

**FCN**

In [None]:
# Fully Connected NN

# This is a self coded FCN, which takes in 3 channels and increases it
# We then applies activation function and maxpooling function

# After 3 rounds, we upsamples the data by using ConvTranspose2d
# In the future, we can try bilinear interpolation 

class FCN(nn.Module):
    def __init__(self, num_classes=1):
        super(FCN, self).__init__()
        
        # Encoder
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=0)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1, padding=0)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=1, padding=0)
        self.relu3 = nn.ReLU()
        
        # Decoder
        self.upconv1 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(2, 2), stride=2)
        self.relu4 = nn.ReLU()
        
        self.upconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=num_classes, kernel_size=(2, 2), stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)
        x1 = self.relu1(x1)
        x1 = self.maxpool1(x1)
        
        x2 = self.conv2(x1)
        x2 = self.relu2(x2)
        x2 = self.maxpool2(x2)
        
        x3 = self.conv3(x2)
        x3 = self.relu3(x3)
        
        # Decoder
        x4 = self.upconv1(x3)
        x4 = self.relu4(x4)
        
        x5 = self.upconv2(x4)

        return x5




In [None]:
!rm -rf /kaggle/working/checkpoints
!mkdir /kaggle/working/checkpoints

 # TRAINING

In [None]:
# Channel Reduction

class ChanneReduction(nn.Module):
    def __init__(self, num_classes=1):
        super(ChanneReduction, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=24, out_channels=30, kernel_size=(1,1))
        self.conv2 = nn.Conv2d(in_channels=30, out_channels=20, kernel_size=(1,1))
        self.conv3 = nn.Conv2d(in_channels=20, out_channels=10, kernel_size=(1,1))
        self.conv4 = nn.Conv2d(in_channels=10, out_channels=3, kernel_size=(1,1))
        
    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.conv2(x1)
        x1 = self.conv3(x1)
        x1 = self.conv4(x1)
    
        return x1


In [None]:
f_e = ChanneReduction()


class ProcessedDataDataset(torch.utils.data.Dataset):
    def __init__(self, data_loader, f_e_model, transform):
        super().__init__()
        self.data_loader = data_loader
        self.f_e_model = f_e_model
        self.transform = transform
    def __len__(self):
        return len(self.data_loader)

    def __getitem__(self, idx):
        batch = self.data_loader.dataset[idx]
        mask, ash = batch['mask'], batch['ash']
        
        ash = self.f_e_model(ash)
        
        sample = {'mask': mask, 'ash': ash}
        return sample


# This is used to transform 24 channels data to 3 channels data by applying the Channel Reduction function     
    
# Training Set
new_dataset = ProcessedDataDataset(train_dataloader, f_e, train_transform)
new_data_loader = DataLoader(new_dataset, batch_size=16, shuffle=True)

# Validation Set
new_dataset2 = ProcessedDataDataset(val_dataloader, f_e, train_transform)
new_data_loader2 = DataLoader(new_dataset2, batch_size=16, shuffle=False)

# **Visualisation of Data**

In [None]:
def images_visualisation(batch): 
    mask, image = batch['mask'], batch['ash']

    image = torch.moveaxis(image,1,-1)
    mask = torch.moveaxis(mask,1,-1)

    for i in range(1):
    
        rgb_img = image[i].detach().numpy()
        rgb_mask = mask[i].detach().numpy()

        plt.figure(figsize=(18, 6))
    
        ax = plt.subplot(1, 3, 1)
        ax.imshow(rgb_img)
        ax.set_title('False color image')
    

        ax = plt.subplot(1, 3, 2)
        ax.imshow(rgb_mask, interpolation='none')
        ax.set_title('Ground truth contrail mask')
        
        ax = plt.subplot(1, 3, 3)
        ax.imshow(rgb_img)
        ax.imshow(rgb_mask, alpha=.4, interpolation='none')
        ax.set_title('Contrail mask on false color image');


In [None]:
# Comparision between two different datasets

# This is from the transformed 24-channels-3-channels dataset
batch = next(iter(new_data_loader))
images_visualisation(batch)

# This is from the 3 channels dataset
batch = next(iter(train_dataloader2))
images_visualisation(batch)

In [None]:
def data_visualisation(train):
    
    plt.figure(figsize=(18,6))
    ax = plt.subplot(1, 3, 1)
    df_data = pd.DataFrame({'Loss': train.epoch_losses})
    sns.lineplot(data=df_data)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Model Average Training Loss over Epochs')

    ax = plt.subplot(1, 3, 2)
    df_data = pd.DataFrame({'Batch Losses': train.batch_losses})
    sns.lineplot(data=df_data)
    ax.set_xlabel('Batch')
    ax.set_ylabel('Loss')
    ax.set_title('Batch Loss')
    
    ax = plt.subplot(1, 3, 3)
    df_data = pd.DataFrame({'Loss': train.validation_loss})
    sns.lineplot(data=df_data)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Model Validation Loss over Epochs')
    
    plt.show()

In [None]:
train_iou_list = []
valid_iou_list = []
valid_iou_sum = 0

# This code initialises, trains and validates model
class training:
    
    def __init__(self, model, optimizer, loss_fn):
        self.validation_loss = []
        self.batch_losses = []
        self.epoch_losses = []
        self.learning_rates = []
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
    
    
    train_iou_list, valid_iou_list = [], []
    
    def fit(self, new_data_loader, new_data_loader2, is_implemented, is_fcn):
        for epoch in range(10):
            
            # Perform training & validation steps
            print('\nEpoch: {}'.format(epoch+1))
            
            print("New learning rate: {}".format(self.optimizer.param_groups[0]['lr']))
            self.learning_rates.append(self.optimizer.param_groups[0]['lr'])
        
             # Stores data about the batch
            batch_losses = []
            sub_batch_losses = []
            
            self.model.train()
            # Training loop
            train_iou, train_loss = 0., 0.
        
            for idx, batch in enumerate(new_data_loader): 
            
                mask, ash = batch['mask'], batch['ash']
                ash = ash.to(device)                
                self.optimizer.zero_grad()
                
                # forward
                out = self.model(ash)
                
                if is_fcn:
                    target_size = (240, 240)
                    mask = mask.float()
                    mask = F.interpolate(mask, size=target_size, mode='bilinear', align_corners=True)
                
                # loss
                if is_implemented:
                    mask = mask.float()
                    mask = mask.cuda()

                    loss = self.loss_fn(out, mask)
 
                else:
                    loss = self.loss_fn(out, mask_oh)
            
                # backpropagate gradients
                loss.backward()
            
                # optimizer step
                optimizer.step()         

                train_loss += loss.detach().cpu().numpy()
            
                # Saves data
                self.batch_losses.append(loss.item())
                batch_losses.append(loss)
                sub_batch_losses.append(loss)
            
                train_iou_list.append(1. * train_iou / len(train_dataloader))
            
            # Reports on the path
            mean_epoch_loss = torch.Tensor(batch_losses).mean()
            self.epoch_losses.append(mean_epoch_loss.item())
            print('Train Epoch: {} Average Loss: {:.6f}'.format(idx, mean_epoch_loss))
            
            n_train_batches = len(train_dataloader)
            train_iou_list.append(1. * train_iou / n_train_batches)
        
            print("\nValidating")
    
            # Validation loop
            self.model.eval()
            with torch.inference_mode():
                valid = []
                valid_iou, valid_loss = 0., 0.
                valid_iou_sum, valid_loss_sum = 0., 0.
                for idx, batch in enumerate(new_data_loader2):
                    mask2, ash2 = batch['mask'], batch['ash']
                
                    ash2 = ash2.to(device)
                    out2 = self.model(ash2)
                    
                    if is_fcn:
                        target_size = (240, 240)
                        mask2 = mask2.float()
                        mask2 = F.interpolate(mask2, size=target_size, mode='bilinear', align_corners=True)
                    
                    # loss
                    if is_implemented:
                        mask2 = mask2.float()
                        mask2 = mask2.cuda()
                                                             
                        loss_val = self.loss_fn(out2, mask2)
                    else:
                        loss_val = self.loss_fn(out2, mask_oh2)

                    valid_loss += loss_val
                    valid.append(loss_val.item())
            
                n_val_batches = len(val_dataloader)
    
                #valid_iou_list.append(1. * valid_iou / n_val_batches)
                valid_iou_list.append(valid_iou_sum)
            
                avg_loss = torch.Tensor(valid).mean().item()
                print('Train Epoch: {} Average Loss: {:.6f}'.format(idx, avg_loss))
                self.validation_loss.append(avg_loss)
                
        

In [None]:
# define loss function
dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)

# define optimizer
optimizer = torch.optim.SGD([ 
    dict(params=deeplabv3.parameters(), lr=0.001),
])


train = training(deeplabv3, optimizer, dice_loss)
train.fit(train_dataloader2, val_dataloader2, True, False)
data_visualisation(train)


In [None]:


# define loss function
dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)

# define optimizer
optimizer = torch.optim.SGD([ 
    dict(params=pspnet2.parameters(), lr=0.001),
])

print('This is PSPNET')


train = training(pspnet2, optimizer, dice_loss)
train.fit(train_dataloader2, val_dataloader2, True, False)

data_visualisation(train)




In [None]:
fcn = FCN().to('cuda')

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(fcn.parameters(), lr=0.08, momentum=0.9)

train = training(fcn, optimizer, criterion)
train.fit(train_dataloader2, val_dataloader2, True, True)
data_visualisation(train)

In [None]:
print('This is a UNET')

unet = unet.to('cuda')
dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)


criterion = nn.BCEWithLogitsLoss()


# define optimizer
optimizer = torch.optim.SGD([ 
    dict(params=unet.parameters(), lr=0.001),
])

train = training(unet, optimizer, criterion)
train.fit(train_dataloader2, val_dataloader2, True, False)
data_visualisation(train)



# Ensemble Deep Learning Method


In [None]:

# This is the implementation of the Ensemble Model, which takes in the inputs of deeplabv3, pspnet, and unet
# It then concatenates all the inputs and passes them through a Convol(1x1) and applies the activiation function

class Ensemble_Stacking(nn.Module):
    def __init__(self, deeplabv3, pspnet2, unet, num_classes):
        super(Ensemble_Stacking, self).__init__()
    
        # Initialize models as attributes
        self.deeplabv3 = deeplabv3
        self.pspnet = pspnet2
        self.unet = unet
        
        # Compute the aggregated results
        self.stacking_model = nn.Sequential(
            nn.Conv2d(3, num_classes, kernel_size=1),
            nn.ReLU() 
        )

    def forward(self, ash):

        deeplab_prediction = self.deeplabv3(ash).squeeze(0)
        pspnet_prediction = self.pspnet(ash).squeeze(0)
        unet_prediction = self.unet(ash).squeeze(0)
            
        # Stack the predictions along the channel dimension (num_classes = 1)
        stacked_predictions = torch.cat([deeplab_prediction, pspnet_prediction, unet_prediction], dim=1)
            
        # Pass the stacked predictions through the stacking model
        final_prediction = self.stacking_model(stacked_predictions)

        return final_prediction


ensemble_model = Ensemble_Stacking(deeplabv3, pspnet2, unet, 1)
ensemble_model.to(device)

dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)
criterion = nn.BCEWithLogitsLoss()

# define optimizer
optimizer = torch.optim.SGD([ 
    dict(params=ensemble_model.parameters(), lr=0.001),
])

train = training(ensemble_model, optimizer, criterion)
train.fit(train_dataloader2, val_dataloader2, True, False)
data_visualisation(train)

# 24 channels transformed to 3 channels approach

In [None]:
# To reproduce the performance of 24 channels transformed to 3 channels, uncomment the code below

# From here, 
# define loss function
#dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False)

# define optimizer
#optimizer = torch.optim.SGD([ 
#    dict(params=pspnet.parameters(), lr=0.001),
#])

#train = training(pspnet, optimizer, dice_loss)
#train.fit(new_data_loader, new_data_loader2, True, False)

#data_visualisation(train)

# End here. 

# Prediction

Predict and Display the image

In [None]:
# The prediction image 
with torch.no_grad():
    batch = next(iter(val_dataloader2))

    ash = batch['ash']
    mask = batch['mask']

    ash = ash.to("cpu")
    mask = mask.to("cpu")
    
    
    # The current model is deeplabv3

    # To visualise others model's inputs,
    # replace deeplabv3 with the defined names of the models such as fcn, unet, pspnet2   
    model = deeplabv3.to("cpu")
    prediction = model(ash)
    pred_mask = prediction


    image = torch.moveaxis(ash,1,-1)
    mask = torch.moveaxis(mask,1,-1)
    pred_mask = torch.moveaxis(prediction,1,-1)


    image, mask, pred_mask = image.cpu(), mask.cpu(), pred_mask.detach().cpu()

    for i in range(16):
    
        img = image[i].detach().numpy()
        mask_np  = mask[i].detach().numpy()
        pred = pred_mask[i].detach().numpy()
    
        plt.figure(figsize=(18, 6))
    
        ax = plt.subplot(1, 3, 1)
        ax.imshow(img)
        ax.set_title('False color image')
    

        ax = plt.subplot(1, 3, 2)
        ax.imshow(mask_np , interpolation='none')
        ax.set_title('Ground truth contrail mask')
    
        ax = plt.subplot(1, 3, 3)
        ax.imshow(pred, interpolation='none')
        ax.set_title('Predicted_Mask')
    
    
