In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, ConcatDataset
import os
import numpy as np
!pip install -q segmentation_models_pytorch
!pip download segmentation_models_pytorch -d
!pip install -q torchmetrics
from torchmetrics.classification import Dice
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.checkpoint as C
import torchvision.transforms.functional as fn
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision import models
!pip install -q torchsummary
from torchsummary import summary
from sklearn.model_selection import KFold

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
device

[0m
Usage:   
  pip download [options] <requirement specifier> [package-index-options] ...
  pip download [options] -r <requirements file> [package-index-options] ...
  pip download [options] <vcs project url> ...
  pip download [options] <local project path> ...
  pip download [options] <archive url/path> ...

-d option requires 1 argument
[0m

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


[0m

device(type='cuda')

In [2]:
class CFG:
    
    GLOBAL_PATH = '/kaggle/input/google-research-identify-contrails-preprocessing'
        
    #Model Choice    
    model = 'UNET'
    encoder = 'tu-tf_efficientnetv2_s.in21k_ft_in1k'
    weights = 'imagenet'
    
    resize = 256
    final = True
    
    #Model Parameters
    batch_size = 32
    optimizer='Adam'
    lr = 0.0005
    epochs = 50
    
    k_folds = 5

In [3]:
class ContrailDataset(Dataset):
    
    def __init__(self, base_dir, data_type='train', transform=None):
        
        self.base_dir = base_dir
        self.data_type = data_type
        self.record = os.listdir(self.base_dir +'/'+ self.data_type)
        
        self.transform = transform
   
    def __len__(self):
        return len(self.record)

    def __getitem__(self, idx):
        
        record_id = self.record[idx]
        record_dir = os.path.join(self.base_dir, self.data_type, record_id)
        
        false_color = np.load(os.path.join(record_dir,'image.npy'))
        human_pixel_mask = np.load(os.path.join(record_dir,'human_pixel_masks.npy')) 
        
        if self.transform is not None:
            transformed = self.transform(image=false_color, mask=human_pixel_mask)
            false_color = transformed['image']
            human_pixel_mask = transformed['mask']
        
        false_color = torch.from_numpy(false_color)
        human_pixel_mask = torch.from_numpy(human_pixel_mask)
        
        false_color = torch.moveaxis(false_color,-1,0)
        human_pixel_mask = torch.moveaxis(human_pixel_mask,-1,0)
            
        
        return false_color, human_pixel_mask.float()

In [4]:

albu_transform = A.Compose(
        [A.HorizontalFlip(p=0.5),
         A.VerticalFlip(p=0.5),
         A.RandomRotate90(),
         A.RandomResizedCrop(CFG.resize, CFG.resize, scale=[.75,1.0])])


In [5]:
training_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='train_images', transform=albu_transform)
validation_data = ContrailDataset(base_dir=CFG.GLOBAL_PATH, data_type='validate_images', transform=albu_transform)

full_data = ConcatDataset([training_data, validation_data])

kfold = KFold(n_splits=CFG.k_folds, shuffle=True)

In [6]:
# Average dice score for the examples in a batch
# def dice_avg(y_p, y_t,smooth=1e-3):
#     i = torch.sum(y_p * y_t, dim=(2, 3))
#     u = torch.sum(y_p, dim=(2, 3)) + torch.sum(y_t, dim=(2, 3))
#     score = (2 * i + smooth)/(u + smooth)
#     return torch.mean(score)


# def dice_loss_avg(y_p,y_t):
#     return 1-dice_score_jan(y_p,y_t)

def dice_global(y_p,y_t,smooth=1e-3):

    intersection = torch.sum(y_p * y_t)
    union = torch.sum(y_p) + torch.sum(y_t)
    dice = (2.0 * intersection + smooth) / (union + smooth)

    return dice

def dice_loss_global(y_p,y_t):
    return 1-dice_global(y_p,y_t)

In [7]:
@torch.no_grad()
def validate_data(val_loader):
    #model.train(False)
    torch.cuda.empty_cache()
    bar = tqdm(val_loader)
    tot_dice_global = 0
    count = 0
    dice = Dice().to(device)
    for image, mask in bar:

        image,mask = image.to(device), mask.to(device)
        pred_mask = model(image)

        tot_dice_global += dice(pred_mask, mask.int()).item()
        count += 1
        bar.set_postfix(ValidDiceGlobal=f'{tot_dice_global/count:.4f}')
    return tot_dice_global/count

if CFG.final:
    model = smp.Unet(
     encoder_name =CFG.encoder,
     encoder_weights=CFG.weights,    
     in_channels=3,                  
     classes=1,       
     activation="sigmoid")
    model.to(device)
    
    train_loader = torch.utils.data.DataLoader(
                      full_data, 
                      batch_size=CFG.batch_size)
    optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=CFG.epochs)
    print(f'learning rate: {optimizer.param_groups[0]["lr"]}')
    
    train_dice_global = []
    eval_dice_global = []
    bst_dice = 0
    bst_epoch = 1
    
    #criterion = smp.losses.DiceLoss(mode = 'binary', classes=None, log_loss=False, from_logits=True, smooth=1.0, ignore_index=None, eps=1e-07)

    for epoch in range(1,CFG.epochs+1):

        print(f'________epoch: {epoch}________')

        model.train()
        bar = tqdm(train_loader)
        tot_loss_global = 0
        tot_dice_global = 0
        tot_dice_avg = 0
        count = 0
        for image, mask in bar:

            # Transfer to Device
            image,mask = image.to(device), mask.to(device)

            # Set optimizer gradients to zero
            optimizer.zero_grad()

            #Perform Inference
            pred_mask = model(image)

            # Calculate the loss and do a backward pass
            loss = dice_loss_global(pred_mask, mask)
            loss.backward()

            # Adjust the weights
            optimizer.step()

            tot_loss_global += loss.item()
            count += 1
            scheduler.step()
            bar.set_postfix(TrainLossGlobal=f'{tot_loss_global/count:.4f}', 
                            TrainLossBatch=f'{loss.item()}')

        train_dice_global.append(np.array(tot_dice_global/count))
        print(tot_loss_global/count)
        if epoch%5 == 0:
            torch.save(model, f'epoch_{epoch}_loss_{tot_loss_global/count:.4f}.pt')
    torch.save(model, f'epoch_{CFG.epochs}_loss_{tot_loss_global/count:.4f}.pt')
    
else:
    for fold, (train_ids, val_ids) in enumerate(kfold.split(full_data)):
        print(f'Fold: {fold}')
        model = smp.Unet(
         encoder_name =CFG.encoder,
         encoder_weights=CFG.weights,    
         in_channels=3,                  
         classes=1,       
         activation="sigmoid")
        model.to(device)

        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

        train_loader = torch.utils.data.DataLoader(
                          full_data, 
                          batch_size=CFG.batch_size, sampler=train_subsampler)
        val_loader = torch.utils.data.DataLoader(
                          full_data,
                          batch_size=CFG.batch_size, sampler=val_subsampler)

        optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=CFG.epochs)
        print(f'learning rate: {optimizer.param_groups[0]["lr"]}')

        train_dice_global = []
        eval_dice_global = []
        bst_dice = 0
        bst_epoch = 1

        criterion = nn.BCELoss().to(device)

        for epoch in range(1,CFG.epochs+1):

            print(f'________epoch: {epoch}________')

            # Early stopping
            #if epoch-bst_epoch >=10:
            #    print(f'early stopping in epoch {epoch}')
            #    break

            model.train()
            bar = tqdm(train_loader)
            tot_loss_global = 0
            tot_dice_global = 0
            tot_dice_avg = 0
            count = 0
            for image, mask in bar:

                # Transfer to Device
                image,mask = image.to(device), mask.to(device)

                # Set optimizer gradients to zero
                optimizer.zero_grad()

                #Perform Inference
                pred_mask = model(image)

                # Calculate the loss and do a backward pass
                loss = criterion(pred_mask, mask)
                loss.backward()

                # Adjust the weights
                optimizer.step()

                tot_loss_global += loss.item()
                count += 1
                scheduler.step()
                bar.set_postfix(TrainLossGlobal=f'{tot_loss_global/count:.4f}', 
                                TrainLossBatch=f'{loss.item()}')

            train_dice_global.append(np.array(tot_dice_global/count))

            valid_dice = validate_data(val_loader)

            eval_dice_global.append(np.array(tot_dice_global/count))
            print(f'learning rate: {optimizer.param_groups[0]["lr"]}')
            print(f'Valid dice: {valid_dice}')

            if valid_dice > bst_dice:
                bst_dice = valid_dice
                bst_epoch = epoch
                torch.save(model.state_dict(), f'model_fold_{fold}_state_dict_epoch_{epoch}_dice_{bst_dice:.4f}.pth')
                torch.save(model, f'model_fold_{fold}_epoch_{epoch}_dice_{bst_dice:.4f}.pt')
                print(f"current model saved! Epoch: {epoch} global dice: {bst_dice}") 

Downloading model.safetensors:   0%|          | 0.00/86.5M [00:00<?, ?B/s]

learning rate: 3.9999999999999996e-05
________epoch: 1________


  0%|          | 0/700 [00:00<?, ?it/s]

0.9643108475208283
________epoch: 2________


  0%|          | 0/700 [00:00<?, ?it/s]

0.7316208530323846
________epoch: 3________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4761033661024911
________epoch: 4________


  0%|          | 0/700 [00:00<?, ?it/s]

0.447090664931706
________epoch: 5________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4367298596245902
________epoch: 6________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4360324386187962
________epoch: 7________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43780123753207073
________epoch: 8________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4412498084987913
________epoch: 9________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4356612040315356
________epoch: 10________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4382559654542378
________epoch: 11________


  0%|          | 0/700 [00:00<?, ?it/s]

0.439342075245721
________epoch: 12________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43539730438164304
________epoch: 13________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43536198249885016
________epoch: 14________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43396732696465085
________epoch: 15________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43015779120581493
________epoch: 16________


  0%|          | 0/700 [00:00<?, ?it/s]

0.43105138736111775
________epoch: 17________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4281018306527819
________epoch: 18________


  0%|          | 0/700 [00:00<?, ?it/s]

0.42191984423569273
________epoch: 19________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4210760826723916
________epoch: 20________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4172872535671507
________epoch: 21________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4159701496362686
________epoch: 22________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4136221992118018
________epoch: 23________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4092823051554816
________epoch: 24________


  0%|          | 0/700 [00:00<?, ?it/s]

0.41060241673673903
________epoch: 25________


  0%|          | 0/700 [00:00<?, ?it/s]

0.40838723949023653
________epoch: 26________


  0%|          | 0/700 [00:00<?, ?it/s]

0.4056344676869256
________epoch: 27________


  0%|          | 0/700 [00:00<?, ?it/s]

0.40244755242552077
________epoch: 28________


  0%|          | 0/700 [00:00<?, ?it/s]

0.39913658729621343
________epoch: 29________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3963793475287301
________epoch: 30________


  0%|          | 0/700 [00:00<?, ?it/s]

0.39409872821399144
________epoch: 31________


  0%|          | 0/700 [00:00<?, ?it/s]

0.39146067091396874
________epoch: 32________


  0%|          | 0/700 [00:00<?, ?it/s]

0.38945904544421606
________epoch: 33________


  0%|          | 0/700 [00:00<?, ?it/s]

0.38426465834890094
________epoch: 34________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3832127640928541
________epoch: 35________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3797973151717867
________epoch: 36________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3773034170695714
________epoch: 37________


  0%|          | 0/700 [00:00<?, ?it/s]

0.37475249350070955
________epoch: 38________


  0%|          | 0/700 [00:00<?, ?it/s]

0.37270886983190266
________epoch: 39________


  0%|          | 0/700 [00:00<?, ?it/s]

0.37082892911774773
________epoch: 40________


  0%|          | 0/700 [00:00<?, ?it/s]

0.36845471245901923
________epoch: 41________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3668031392778669
________epoch: 42________


  0%|          | 0/700 [00:00<?, ?it/s]

0.364614394562585
________epoch: 43________


  0%|          | 0/700 [00:00<?, ?it/s]

0.36233170832906453
________epoch: 44________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3605370286532811
________epoch: 45________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3612325250250953
________epoch: 46________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3602426024845668
________epoch: 47________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3593974049602236
________epoch: 48________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3598492746693747
________epoch: 49________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3584819539104189
________epoch: 50________


  0%|          | 0/700 [00:00<?, ?it/s]

0.3582445244278227
