In [1]:

# Imports here are through Anaconda (Conda) for the primary PyTorch, Lightning, and TorchMetrics
# libraries. PIL, and NumPy are also used. Matplotlib is extraneous

import os
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint, BasePredictionWriter
from lightning.pytorch.loggers import MLFlowLogger
from torch.utils.data import Dataset, random_split, DataLoader
import torchvision.transforms.v2 as v2

import lightning as L
import torchmetrics.classification as TM

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# I set this to "high" precision for much faster compute time, it's possible for better
# performance not using this, but it does result in about a 3-4 times faster compute
torch.set_float32_matmul_precision('high')

In [3]:
# This is the core double convolution of U-Net. It is reused in the up and down
# parts of the U-Net architecture. The Mish function is used as I found it works
# best, but ReLU, SiLU, etc. also work. Mish was what I found works best personally.

# We have: [2D convolution, batch normalization (instance works less well), Mish] x2
class DubConv(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.dub_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.Mish(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.Mish(inplace=True)   
        )
        
    def forward(self, x):
        return self.dub_conv(x)

In [4]:
# Downward limb of our U-Net. I had read studies which suggest the Average Pooling
# performs better with SAR imagery than Max Pooling. I found this to be the case as
# well and used this here.

# We have an [Average Pooling, Double Convolution] Here.
class Down(nn.Module):
    
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.avgpool_conv = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            DubConv(in_ch, out_ch)
        )
        
    def forward(self, x):
        return self.avgpool_conv(x)

In [5]:
# Upward Limb of our U-Net. We upsample using nearest neighbor algorithm. I haven't
# tested with this to find it is the best, however it was the algorithm use by my advisor
# and I remained with it. Bilinear *may* work better, but I cannot speak to this.

# We have [UpSample, Double Convolution]
class Up(nn.Module):

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_ch, in_ch // 2, kernel_size=1)
        )
        self.conv = DubConv(in_ch, out_ch)

# Here we have padding to "restore" or maintain the input image size as well.
    
    def forward(self, x1, x2):
        x1 = self.upsample(x1)

        # Pad x1 to the size of x2
        diff_h = x2.shape[2] - x1.shape[2]
        diff_w = x2.shape[3] - x1.shape[3]

        x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

        # Concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [6]:
# Here is our core U-Net architecture. I generally used "default" or established values
# for the number of layers and features as found in other works. This works well, though

class UNet(nn.Module):

    def __init__(self, num_classes: int = 1, num_layers: int = 5, features_start: int = 64):

        super().__init__()
        self.num_layers = num_layers

        layers = [DubConv(1, features_start)]

        feats = features_start
        for _ in range(num_layers - 1):
            layers.append(Down(feats, feats * 2))
            feats *= 2

        for _ in range(num_layers - 1):
            layers.append(Up(feats, feats // 2))
            feats //= 2
# We append the previous convolution layers to the new layer, as to U-Net architecture
        
        layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        xi = [self.layers[0](x)]

        for layer in self.layers[1: self.num_layers]:
            xi.append(layer(xi[-1]))

        for i, layer in enumerate(self.layers[self.num_layers: -1]):
            xi[-1] = layer(xi[-1], xi[-2 - i])
            
        logits = self.layers[-1](xi[-1])
        
        return logits

In [7]:
#Transformations applied to input data, to enrich training data

transform = v2.RandomApply(transforms=[
    v2.RandomAffine((-90, 90)),
    v2.RandomAdjustSharpness(sharpness_factor=3),
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    v2.RandomRotation(90)],
    p=0.25
)

In [8]:
# PyTorch Dataset Path (Adjust as needed to your directories)
class SARTrainData(Dataset):
    IMAGE_PATH = "images"
    MASK_PATH = "labels"
    data_path = "/home/michael/sar_crater/train_Sig0db"

    def __init__(
            self,
            data_path: str,
            img_size: tuple = (256, 256),
    ):
        self.transform = transform
        self.img_size = img_size

        self.data_path = data_path
        self.img_path = os.path.join(self.data_path, self.IMAGE_PATH)
        self.mask_path = os.path.join(self.data_path, self.MASK_PATH)
        self.img_list = self.get_filenames(self.img_path)
        self.mask_list = self.get_filenames(self.mask_path)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        
        process = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Normalize((0.5, ), (0.5, ))
        ])
        
        image = Image.open(self.img_list[idx])
        mask = Image.open(self.mask_list[idx])
        #image = np.asarray([image], dtype=np.float32)
        mask = np.asarray([mask], dtype=np.float32)
        image = process(image)
        mask = torch.from_numpy(mask)
        
        img, mask = self.transform(image, mask)

        return img, mask


    def get_filenames(self, path):
        files_list = []
        for root, dirs, files in os.walk(path):
            for filename in files:
                if filename.endswith(".tif"):
                    files_list.append(os.path.join(path, filename))
        return files_list

In [9]:
# Prediction dataset directory
class PredData(Dataset):

    data_path = "/home/michael/sar_crater/pred/pred_Sig0db"

    def __init__(
            self,
            data_path: str,
    ):

        self.data_path = data_path
        self.img_path = self.data_path
        self.img_list = self.get_filenames(self.img_path)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        
        process = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize((0.5, ), (0.5, ))
        ])

        image = Image.open(self.img_list[idx])
        img = process([image])


        return img


    def get_filenames(self, path):
        files_list = []
        for filename in os.listdir(path):
            files_list.append(os.path.join(path, filename))
        return files_list

In [10]:
# Lightning Datamodule
class SARDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "/home/michael/sar_crater/train_Sig0db", batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size


    def setup(self, stage: str):

        if stage == "fit":
            SAR_full = SARTrainData(self.data_dir)
            self.SAR_train, self.SAR_val = random_split(
                SAR_full, [0.8, 0.2], generator=torch.Generator().manual_seed(37)
            )
            
        if stage == "predict":
            self.PredData = PredData("/home/michael/sar_crater/pred_Sig0db")

    def train_dataloader(self):
        return DataLoader(self.SAR_train, batch_size=self.batch_size, num_workers=18, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.SAR_val, batch_size=self.batch_size, num_workers=8, persistent_workers=True)
    
    def predict_dataloader(self):
        return DataLoader(self.PredData, batch_size=1, num_workers=4)

In [11]:
# Lightning module combining the previous elements, makes it easier to make modifications in specific values
class LitUNet(L.LightningModule):

    def __init__(
            self,
            data_path: str = "/home/michael/sar_crater/train_Sig0db",
            batch_size: int = 16,
            lr: float = 0.0137,
            num_layers: int = 5,
            features_start: int = 64,
    ):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        self.num_layers = num_layers
        self.features_start = features_start
        self.save_hyperparameters()
        
        self.net = UNet()
        
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([20]))

        self.t_acc = TM.BinaryAccuracy(ignore_index=0)
        self.v_acc = TM.BinaryAccuracy(ignore_index=0)
        self.v_recall = TM.BinaryRecall()
        self.v_precision = TM.BinaryPrecision()
        self.v_f1 = TM.BinaryF1Score()


    def forward(self, batch):
        return self.net(batch)

    def training_step(self, batch, batch_idx):
        img, mask = batch
        out = self(img)
        loss = self.loss_fn(out, mask)
        self.t_acc(out, mask.short())
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("tacc", self.t_acc, prog_bar=True, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        img, mask = batch
        out = self(img)
        val_loss = self.loss_fn(out, mask)
        self.v_acc(out, mask.short())
        self.v_recall(out, mask.short())
        self.v_precision(out, mask.short())
        self.v_f1(out, mask.short())
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("vacc", self.v_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("vrecall", self.v_recall, on_step=False, on_epoch=True, prog_bar=True)
        self.log("vprec", self.v_precision, on_step=False, on_epoch=True, prog_bar=True)
        self.log("vf1", self.v_f1, on_step=False, on_epoch=True, prog_bar=True)
        return val_loss
    
    def test_step(self, batch, batch_idx):
        img, mask = batch
        out = self.net(img)
        recall = self.recall(out, mask.short())
        precision = self.precision(out, mask.short())
        f1 = self.f1(out, mask.short())
        log_dict = {"recall": recall, "precision": precision, "f1": f1}

        self.log_dict(log_dict, logger=True, on_step=True)
        
    def predict_step(self, batch, batch_idx):
        preds = self(batch)
        return preds
    
    def configure_optimizers(self):
        opt = torch.optim.SGD(self.net.parameters(), lr=self.lr)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=24)

        return [opt], [sch]

In [12]:
# Prediction writer class, makes output images
class PredWriter(BasePredictionWriter):
    
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir
        
    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))

In [13]:
# Lightning module command center

model = LitUNet()
mlf_logger = MLFlowLogger(
    experiment_name="litUnetLogs", 
    tracking_uri="http://localhost:3737",
    log_model=True
)

checkpoint_callback = ModelCheckpoint(
    monitor="vf1",
    dirpath="/home/michael/sar_crater/ckpt",
    filename="sar-{epoch:02d}-{vf1:.2f}",
    save_top_k=3,
    mode="max"
)

pred_writer = PredWriter(output_dir="/home/michael/sar_crater/inference", write_interval="epoch")

trainer = L.Trainer(
    accelerator="gpu", devices=1,
    log_every_n_steps=10,
    profiler="simple",
    precision="16-mixed",
    default_root_dir="/home/michael/sar_crater",
    enable_checkpointing=True,
    logger=mlf_logger,
    max_epochs=1000,
    callbacks=[checkpoint_callback, pred_writer]
)

dm = SARDataModule(data_dir="/home/michael/sar_crater/train_Sig0db")

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Run module for fit
trainer.fit(model, datamodule=dm)

/home/michael/sar_crater/py_env/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/michael/sar_crater/ckpt exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type              | Params | Mode 
----------------------------------------------------------
0 | net         | UNet              | 28.9 M | train
1 | loss_fn     | BCEWithLogitsLoss | 0      | train
2 | t_acc       | BinaryAccuracy    | 0      | train
3 | v_acc       | BinaryAccuracy    | 0      | train
4 | v_recall    | BinaryRecall      | 0      | train
5 | v_precision | BinaryPrecision   | 0      | train
6 | v_f1        | BinaryF1Score     | 0      | train
----------------------------------------------------------
28.9 M    Trainable params
0         Non-trainable params
28.9 M    Total params
115.790   Total estimated model params size (MB)
109       Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 236/236 [00:42<00:00,  5.50it/s, v_num=1e9c, train_loss_step=0.425, tacc_step=0.324]
[Aidation: |          | 0/? [00:00<?, ?it/s]
[Aidation:   0%|          | 0/59 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|          | 0/59 [00:00<?, ?it/s]
[Aidation DataLoader 0:   2%|▏         | 1/59 [00:00<00:03, 15.42it/s]
[Aidation DataLoader 0:   3%|▎         | 2/59 [00:00<00:03, 15.31it/s]
[Aidation DataLoader 0:   5%|▌         | 3/59 [00:00<00:03, 15.33it/s]
[Aidation DataLoader 0:   7%|▋         | 4/59 [00:00<00:03, 15.28it/s]
[Aidation DataLoader 0:   8%|▊         | 5/59 [00:00<00:03, 15.25it/s]
[Aidation DataLoader 0:  10%|█         | 6/59 [00:00<00:03, 15.28it/s]
[Aidation DataLoader 0:  12%|█▏        | 7/59 [00:00<00:03, 15.31it/s]
[Aidation DataLoader 0:  14%|█▎        | 8/59 [00:00<00:03, 15.33it/s]
[Aidation DataLoader 0:  15%|█▌        | 9/59 [00:00<00:03, 15.29it/s]
[Aidation DataLoader 0:  17%|█▋        | 10/59 [00:00<00:03, 15.31it/s]
[Aidat

In [None]:
# Run module to predict
trainer.predict(datamodule=dm)