# Lightning Workflow for Contrails

### Notebook to run models based on single bands / band combinations

In [1]:
!pip install lightning -q

In [2]:
# lightning library
import lightning as L
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger


# pytorch libraries
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchmetrics.classification import Dice
from torchmetrics import Precision, Recall, Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger


# other
import os
import subprocess
import pandas as pd
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
from IPython.core.display import display
from PIL import Image

# set base directory
BASE_DIR = "/kaggle/input/google-research-identify-contrails-reduce-global-warming"

# set batch size
BATCH_SIZE = 16 if torch.cuda.is_available() else 8
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']
  from IPython.core.display import display


In [3]:
# GPU available?
torch.cuda.is_available()

True

In [5]:
# List all examples from the train / val / test directories
train_list = os.listdir(BASE_DIR + '/train/')
val_list = os.listdir(BASE_DIR + '/validation/')
test_list = os.listdir(BASE_DIR + '/test/')

print(f"number of traning examples: {len(train_list)}")
print(f"number of validation examples: {len(val_list)}")
print(f"number of test examples: {len(test_list)}")

number of traning examples: 20529
number of validation examples: 1856
number of test examples: 2


In [6]:
# see files for an arbitrary training example
band_list = os.listdir(BASE_DIR + '/train/' + train_list[0])
#band_list = [band for band in band_list if band.startswith("band_")]
print(band_list)
print(np.load(BASE_DIR + '/train/' + train_list[0] + '/band_10.npy').shape)

['band_10.npy', 'band_14.npy', 'human_individual_masks.npy', 'band_15.npy', 'band_16.npy', 'band_08.npy', 'band_09.npy', 'band_13.npy', 'band_11.npy', 'human_pixel_masks.npy', 'band_12.npy']
(256, 256, 8)


### Preparing to standardize the data

"Each input channel is standardized by subtracting the global mean and dividing by the global variance of the channel before feeding it into the network" (Joe et al., 2023).

In [7]:
# creating a list of channels -i.e. "bands" (without targets)
band_list = [band for band in band_list if band.startswith("band_")]

# create empty dicts to store band-specific global mean and std
mean_list = {}
std_list = {}

# calculate global means and stds to normalize the data before training
for n_channel, channel in enumerate(band_list):
    channel_data = []
    
    for band in band_list:
        band_data = []
    
        for train_example in train_list[:300]:
            example_data = np.load(BASE_DIR + '/train/' + train_list[n_channel] + '/' + band).flatten()
        
        band_data.append(example_data)
    
    channel_data.append(np.concatenate(band_data))  
    mean_list[channel] = np.mean(channel_data)
    std_list[channel] = np.std(channel_data)

In [8]:
print([mean_list.get(x) for x in ["band_10.npy", "band_11.npy"]])
print(std_list.values())

[257.98895, 255.35904]
dict_values([11.159583, 15.506918, 10.20968, 3.221322, 6.7606144, 8.125894, 5.7324905, 5.8463006, 10.416265])


## Creating the Dataloader

In [9]:
class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, filepath, transform=None, bands=["band_08.npy"], timestep=4):
        self.filepath = filepath
        self.transform = transform
        self.bands = bands
        self.timestep = timestep

        # List of file names in the image and target directories
        self.examples = os.listdir(self.filepath)

    def __len__(self):
        return len(self.examples)  # assuming images and targets have a 1-1 correspondence

    def __getitem__(self, idx):
        #if torch.is_tensor(idx):
        #    idx = idx.tolist()

        img_name = os.path.join(self.filepath, self.examples[idx])
        target_name = os.path.join(self.filepath, self.examples[idx])
        
        #if len(self.bands) == 1:
        #    image = np.load(os.path.join(img_name, self.bands[0]))[:,:,self.timestep]
            
        #if len(self.bands) > 1:
        band_channels = []
        for _, b in enumerate(self.bands):
            image = np.load(os.path.join(img_name, b))[:,:,self.timestep]
            band_channels.append(image)
        image = np.dstack([b for b in band_channels])
        #image = image.transpose( 2, 0, 1)
        #image = torch.tensor(image)
        
        targets = np.load(os.path.join(target_name, "human_pixel_masks.npy"))
        targets = targets[..., -1]
        targets = torch.tensor(targets).to(torch.long)
        #image = torch.tensor(np.reshape(image, (256, 256, 1))).to(torch.float32).permute(2, 0, 1)

        if self.transform:
            image = self.transform(image)

        return image, targets


### ...testing the dataloader:

In [10]:
bands = ["band_11.npy"]
# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([mean_list.get(x) for x in bands],
                         [std_list.get(x) for x in bands]), #using the precomputed values
])

In [11]:
# Create your datasets
train_dataset = ContrailsDataset(
    filepath= BASE_DIR + '/train/', 
    transform=transform,
    bands=list(["band_08.npy"])
)

val_dataset = ContrailsDataset(
    filepath= BASE_DIR + '/validation/', 
    transform=transform,
)

test_dataset = ContrailsDataset(
    filepath= BASE_DIR + '/test/', 
    transform=transform,
)

# Create your dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [11]:
# load one batch of images and labels
dataiter = iter(train_dataloader)
images, targets = next(dataiter)

# check shapes of the batch
images = torch.tensor(images).flatten(start_dim=2)
print(images.shape)

torch.Size([16, 1, 65536])


  images = torch.tensor(images).flatten(start_dim=2)


In [11]:
# load one batch of images and labels
dataiter = iter(train_dataloader)
images, targets = next(dataiter)

# check shapes of the batch
print(images.shape)
print(targets.shape)

torch.Size([16, 1, 256, 256])
torch.Size([16, 256, 256])


## Lightning Data Module

In [12]:
class ContrailsDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str = BASE_DIR,
        batch_size: int = BATCH_SIZE,
        bands: list = bands,
        timestep: int = 4,
    ):
    
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.bands = bands
        self.timestep = timestep
        
        # Transformations
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([mean_list.get(x) for x in self.bands],
                                 [std_list.get(x) for x in self.bands]), #using the precomputed values
        ])

        
        self.dims = (len(self.bands), 256, 256)
        
    def prepare_data(self):
        pass

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:

            self.train_dataset = ContrailsDataset(
                filepath= self.data_dir + '/train/', 
                transform=self.transform,
                bands=self.bands,
                timestep=self.timestep,
                )
            
            self.val_dataset = ContrailsDataset(
                filepath= self.data_dir + '/validation/', 
                transform=self.transform,
                bands=self.bands,
                )

        
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            
                self.test_dataset = ContrailsDataset(
                filepath= self.data_dir + '/test/', 
                transform=self.transform,
                bands=self.bands,
                )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size, 
                          shuffle=False,
                         )
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                         shuffle=False)

    #def test_dataloader(self):
    #    return DataLoader(self.mnist_test, batch_size=self.batch_size, bands=list(["band_12"])


## Implement the U-Net Model from https://paperswithcode.com/method/u-net

In [13]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    
class UNet2(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet2, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [14]:
class U_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True, lr=0.001):
        super(UNet2, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.learning_rate = lr
        
        # counter and accumulator for progress
        self.counter = 0
        self.progress = []
        
        # loss
        self.loss_function = nn.CrossEntropyLoss()
        
        # optimizer
        self.optimiser = torch.optim.Adam(self.parameters(), lr=learning_rate)

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
    def train(self, input, target):
        pred = self(input)
        loss = self.loss(pred, target)
        
        # increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)
            pass
        
        # zero gradients, perform a backward pass, and update the weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
        
    def plot_progress(self):
        df = pd.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

## Creating the Lightning Module for training

In [15]:
class Contrails_U_Net(L.LightningModule):
    def __init__(self, learning_rate=1e-5, n_channels = 1, n_classes = 2):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.model = UNet2(self.n_channels, self.n_classes)
        self.CEL = nn.CrossEntropyLoss(weight = torch.Tensor([0.57, 4.17]))
        self.learning_rate = learning_rate
        
        self.precision = Precision(task="binary", average='macro', num_classes=2) # binary classification 
        self.recall = Recall(task="binary", average='macro', num_classes=2)  
        self.accuracy = Accuracy(task="binary", average='macro', num_classes=2)
        #self.dice = Dice(num_classes=2, average='macro') 
        
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.CEL(logits, y)
        
        # Compute metrics
        train_precision = self.precision(logits.softmax(dim=1)[:, 1, ...], y)
        train_recall = self.recall(logits.softmax(dim=1)[:, 1, ...], y)
        train_accuracy = self.accuracy(logits.softmax(dim=1)[:, 1, ...], y)
        #train_dice = self.dice(torch.squeeze(logits.softmax(dim=1)[:, 1, ...], dim=1), y)

        # Logging metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_precision', train_precision, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_recall', train_recall, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_accuracy', train_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        #self.log('train_dice', train_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.CEL(logits, y)

        # Compute metrics
        val_precision = self.precision(logits.softmax(dim=1)[:, 1, ...], y)
        val_recall = self.recall(logits.softmax(dim=1)[:, 1, ...], y)
        val_accuracy = self.accuracy(logits.softmax(dim=1)[:, 1, ...], y)
        #val_dice = self.dice(logits.softmax(dim=1)[:, 1, ...], y)

        # Logging metrics
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_precision', val_precision, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_recall', val_recall, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_accuracy', val_accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        #self.log('val_dice', val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.CEL(logits, y)

        # Compute metrics
        test_precision = self.precision(logits.softmax(dim=1)[:, 1, ...], y)
        test_recall = self.recall(logits.softmax(dim=1)[:, 1, ...], y)
        test_accuracy = self.accuracy(logits.softmax(dim=1)[:, 1, ...], y)
        #test_dice = self.dice(logits.softmax(dim=1)[:, 1, ...], y) # Compute Dice for test set


        # Logging metrics
        self.log('test_loss', loss, prog_bar=True, logger=True)
        self.log('test_precision', test_precision, prog_bar=True, logger=True)
        self.log('test_recall', test_recall, prog_bar=True, logger=True)
        self.log('test_accuracy', test_accuracy, prog_bar=True, logger=True)
        #self.log('test_dice', test_dice, prog_bar=True, logger=True)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
        

### Create early-stopping and checkpoint callback function 

In [18]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='model/',
    filename='best-checkpoint',
    save_top_k=1,
    mode='min')

## Training..!

In [17]:
model = Contrails_U_Net(learning_rate=5e-4)
dm = ContrailsDataModule(batch_size=16, bands=["band_08.npy"], timestep=4)

# create the csv logger for results
# Instantiate CSV logger
csv_logger = CSVLogger('logs/', name='csv_log')


trainer = Trainer(
    accelerator="auto",
    devices = "auto",
    max_epochs=8,
    callbacks=[TQDMProgressBar(refresh_rate=20)], # optional: early_stop_callback, checkpoint_callback
    logger=csv_logger,
    log_every_n_steps =100,
)
trainer.fit(model, dm)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name      | Type             | Params
-----------------------------------------------
0 | model     | UNet2            | 17.3 M
1 | CEL       | CrossEntropyLoss | 0     
2 | precision | BinaryPrecision  | 0     
3 | recall    | BinaryRecall     | 0     
4 | accuracy  | BinaryAccuracy   | 0     
-----------------------------------------------
17.3 M    Trainable params
0         Non-trainable params
17.3 M    Total params
69.065    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Plot training progress

In [None]:
# Read data from CSV file
log_data = pd.read_csv('logs/csv_log/version_0/metrics.csv')
log_data.columns

In [None]:
log_data[log_data['train_precision'].notna()]['train_precision']

In [None]:
# Read data from CSV file
log_data = pd.read_csv('logs/csv_log/version_0/metrics.csv')

# Plot training loss
plt.figure(figsize=(10,6))
plt.plot(log_data[log_data['train_precision'].notna()]['step'], log_data[log_data['train_precision'].notna()]['train_precision'], label='Training Loss')
#plt.plot(log_data['step'], log_data['val_loss_step'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.clf()

In [None]:
# Plot training precision
plt.figure(figsize=(10,6))
plt.plot(log_data['step'], log_data['train_precision'], label='Training Precision')
plt.plot(log_data['step'], log_data['val_precision'], label='Validation Precision')
plt.title('Training and Validation Precision')
plt.xlabel('Training Steps')
plt.ylabel('Precision')
plt.legend()
plt.show()

## Testset Prediction

In [None]:
# After training
trainer.test()

# Or load a pre-trained model and test
model = Contrails_U_Net.load_from_checkpoint('path_to_checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=dm.test_dataloader())