# Modules

### Imports

In [None]:
import torch 
import torchvision
import os 	
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader 
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm import tqdm

### head

In [None]:
class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.Upsample(scale_factor=upsampling, mode='bilinear', align_corners=False) if upsampling > 1 else nn.Identity()
        activation = Activation(activation)
        super().__init__(conv2d, upsampling, activation)

### base/module

In [None]:
class Conv2dReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True):
        conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not use_batchnorm)
        relu = nn.ReLU(inplace=True)
        bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        super(Conv2dReLU, self).__init__(conv, bn, relu)
class Activation(nn.Module):
    def __init__(self, name, **params):
        super().__init()
        if name is None or name == "identity":
            self.activation = nn.Identity(**params)
        elif name == "sigmoid":
            self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(x)

### Config

In [None]:
import yaml
cfg = yaml.safe_load("""

model:
  encoder: resnest26d.gluon_in1k  
  pretrained: True   
  decoder_channels: [256, 128, 64, 32, 16]

""")

## Unet

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, dropout=0):
        super().__init()
        conv_in_channels = in_channels + skip_channels
        self.conv1 = Conv2dReLU(conv_in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm)
        self.conv2 = Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm)
        self.dropout_skip = nn.Dropout(p=dropout)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        if skip is not None:
            skip = self.dropout_skip(skip)
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class UnetDecoder(nn.Module):
    def __init__(self, encoder_channels, decoder_channels, use_batchnorm=True, dropout=0):
        super().__init()
        encoder_channels = encoder_channels[::-1]
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels
        self.center = nn.Identity()
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, use_batchnorm=use_batchnorm, dropout=dropout)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, features):
        features = features[::-1]
        head = features[0]
        skips = features[1:]
        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)
        return x

## timm-resnest26d + unet

In [None]:
class Model(nn.Module):

    def __init__(self, cfg, pretrained, tta=None):
        super().__init__()
        name = cfg['model']['encoder']
        dropout = cfg['model']['dropout']
        pretrained = pretrained and cfg['model']['pretrained']
    

        self.encoder = timm.create_model(name, features_only=True, pretrained=pretrained)
        encoder_channels = self.encoder.feature_info.channels()



        decoder_channels = cfg['model']['decoder_channels'] 
        print('Encoder channels:', name, encoder_channels)
        print('Decoder channels:', decoder_channels)

        assert len(encoder_channels) == len(decoder_channels)

        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            dropout=dropout,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=1, activation="sigmoid", kernel_size=3,
        )

        initialize_decoder(self.decoder)


    def forward(self, x):

        features = self.encoder(x)
        decoder_output = self.decoder(features)
        y_pred = self.segmentation_head(decoder_output)



        return y_pred

# Dataset/DataLoader

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
#Helper function to create fc images
_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)

def normalize_range(data, bounds):
    """Maps data to the range [0, 1]."""
    return (data - bounds[0]) / (bounds[1] - bounds[0])



In [None]:
#Creating class for dataset
class ContrailsDataset(Dataset):
    def __init__(self, path, arr, img_size, train=True, batch_size=16):
        if train: 
            self.arr = arr
            self.path = path
            self.img_size = img_size
            self.train = train
            self.batch_size = batch_size
            self.files = self.arr[0:18_000]
        
        else:
            self.arr = arr
            self.path = path
            self.img_size = img_size
            self.train = train
            self.batch_size = batch_size
            self.files = self.arr[0:1624]

            
    def __len__(self):
        return len(self.files)
    
    
    
    
    def augment(self, path1):
        #Augmentation to create the false color images
        self.path1 = path1 
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        b11 = np.load(self.path1 + "/band_11.npy")
        b14 = np.load(self.path1 + "/band_14.npy")
        b15 = np.load(self.path1 + "/band_15.npy")
        ma = np.load(path1 + "/human_pixel_masks.npy")

        r = normalize_range(b15 - b14, _TDIFF_BOUNDS)
        g = normalize_range(b14 - b11, _CLOUD_TOP_TDIFF_BOUNDS)
        b = normalize_range(b14, _T11_BOUNDS)
        fc = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        im = fc[..., 4]

        return im, ma
        
        
    def __getitem__(self, index): 
        
        if self.train:
            _path = self.path + "train/" + self.files[index]
            img, mask = self.augment(_path)
            img = img.transpose((2, 0, 1))
            mask = mask.transpose((2, 0, 1))
        
        else:
            _path = self.path + "valid/" + self.files[index]
            img, mask = self.augment(_path)
            img = np.transpose(img, (2, 0, 1))
            mask = mask.transpose((2, 0, 1))
            
            
        return img, mask

In [None]:
path_to_data = "/kaggle/input/google-research-identify-contrails-reduce-global-warming/"
img_size = (256, 256)
batch_size = 16
batch_size1 = 8
train_id_list = os.listdir("/kaggle/input/google-research-identify-contrails-reduce-global-warming/train")
valid_id_list = os.listdir("/kaggle/input/google-research-identify-contrails-reduce-global-warming/valid")
train_ds = ContrailsDataset(path_to_data, train_id_list, img_size, train=True, batch_size=batch_size)
valid_ds = ContrailsDataset(path_to_data, valid_id_list, img_size, train=False, batch_size=batch_size1)

In [None]:
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=cfg['train']['num_workers'], pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=batch_size1, shuffle=True, num_workers=cfg['train']['num_workers'], pin_memory=True)

### Check Train/Valid Images

In [None]:
img, mask  = next(iter(train_loader))
img1, mask1 =  next(iter(valid_loader))

In [None]:
def show(img):
    im = img[0]
    im = im.numpy()
    im = im.transpose((1,2,0))

prep_img = show(img)
plt.imshow(prep_img, interpolation='none')
plt.show()

# Loss/Accuracy -->  Dice

In [None]:
#dice coef/loss
def dice_coeff(y_true, y_pred):
    y_true=y_true.to(device)
    y_pred = y_pred.to(device)
    smooth = 1e-5
    intersection = torch.sum(y_true * y_pred)
    union = torch.sum(y_true) + torch.sum(y_pred)
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return dice

class DiceLoss(torch.nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, y_true, y_pred):
        return 1.0 - dice_coeff(y_true, y_pred)

# Training

In [None]:
#Setting up the model params// can be run on parallel gpu if needed
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_val=0

In [None]:
criterion = DiceLoss()
num_epochs = 60

for epoch in range(num_epochs): 

    model.to(device)
       
    model.train()
    epoch_loss = 0
    dice_score = 0
    total_batches = len(train_loader)
    
    loop = tqdm(enumerate(train_loader), total=total_batches, leave=False)
    
    for batch_idx, (img, mask) in loop: 

        optimizer.zero_grad()
        img = img.to(device)
        mask = mask.to(device)
        
        masks_pred = model(img)
        
        loss = criterion(mask, masks_pred)
        dice = dice_coeff(mask, masks_pred)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        dice_score += dice.item()
        
        loop.set_description(f"Epoch[{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(), dice_coeff=dice.item())
    

    #Calculate and print average loss and dice coefficient for the epoch
    avg_epoch_loss = epoch_loss / total_batches
    avg_dice_score = dice_score / total_batches
    print(f"Epoch [{epoch+1}/{num_epochs}] - Average Loss: {avg_epoch_loss:.4f} - Average Dice Coefficient: {avg_dice_score:.4f}")
    torch.save(model.state_dict(), 'cont_best_model.pth')
    model.eval()
    val_dice=0

    loop2 = tqdm(enumerate(valid_loader), total=len(valid_loader), leave=False)
    
    with torch.no_grad():
        for batch_idx, (v_img,v_mask) in loop2:
            v_img, v_mask = v_img.to(device), v_mask.to(device)        
            v_out = model(v_img)
            dice = dice_coeff(v_mask, v_out)
            val_dice += dice.item()
            
    average_dice = val_dice / len(valid_loader)
## Code for saving models with the best validation score
#     if best_val<average_dice:
#         best_val = average_dice
#         checkpoint = {
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict()
#         }
#         torch.save(model.state_dict(), 'cont_best_model.pth')
#         print("Model has been saved")
    print(f"Validation Dice Coefficient: {average_dice}")
        