In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import rasterio
from tqdm import tqdm
import timm
import optuna
from sklearn.metrics import confusion_matrix
from torch.optim.lr_scheduler import StepLR
import lovasz_losses as L
from scipy.ndimage import uniform_filter, gaussian_filter
from scipy.ndimage import convolve
from skimage.restoration import denoise_bilateral
#from util_func.tversky_loss import TverskyLoss  # Make sure this exists

In [3]:
# -------------------------------
# Dataset
# -------------------------------
class ResizeTensor:
    def __init__(self, size):
        self.size = size

    def __call__(self, tensor):
        return F.interpolate(tensor.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False).squeeze(0)

class SARDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, image_size=(224, 224), speckle_filter=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_size = image_size
        self.speckle_filter = speckle_filter.lower() if speckle_filter else None
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    # def lee_sigma_filter(self, img, size=7, sigma=1.0):
    #     mean = uniform_filter(img, size)
    #     std = np.sqrt(uniform_filter(img**2, size) - mean**2)
    #     mask = np.abs(img - mean) <= sigma * std
    #     return np.where(mask, mean, img)
    
    def lee_sigma_filter(self, img, size=7, sigma=1.0):
        mean = uniform_filter(img, size)
        std = np.sqrt(uniform_filter(img**2, size) - mean**2)
        mask = np.abs(img - mean) <= sigma * std
        output = np.where(mask, mean, img)
        return output

    def refined_lee_filter(self, img, size=7):
        # Simplified refined Lee filter version (basic 2-step averaging)
        mean = uniform_filter(img, size)
        var = uniform_filter((img - mean)**2, size)
        cu = np.var(img)
        W = var / (var + cu + 1e-6)
        filtered = mean + W * (img - mean)
        return uniform_filter(filtered, 3)

    def apply_speckle_filter(self, img):
        return self.lee_sigma_filter(img) if self.speckle_filter == "lee_sigma" else img
    
    def gamma_filter(self, img, size=7):
        mean = uniform_filter(img, size)
        var = uniform_filter((img - mean) ** 2, size)
        b = var / (mean ** 2 + 1e-6)
        return mean / (1 + b)

    def frost_filter(self, img, size=5, damping_factor=2.0):
        padded = np.pad(img, size // 2, mode='reflect')
        out = np.zeros_like(img)
        for i in range(out.shape[0]):
            for j in range(out.shape[1]):
                window = padded[i:i+size, j:j+size]
                center = window[size // 2, size // 2]
                mean = np.mean(window)
                var = np.var(window)
                coeff_var = var / (mean**2 + 1e-6)
                weight = np.exp(-coeff_var * np.abs(window - center) / damping_factor)
                weight /= weight.sum()
                out[i, j] = np.sum(weight * window)
        return out
    
    def apply_speckle_filter(self, img):
        if self.speckle_filter == "lee":
            return self.lee_filter(img)
        elif self.speckle_filter == "gamma":
            return self.gamma_filter(img)
        elif self.speckle_filter == "frost":
            return self.frost_filter(img)
        elif self.speckle_filter == "refined_lee":
            return self.refined_lee_filter(img)
        elif self.speckle_filter == "lee_sigma":
            return self.lee_sigma_filter(img)
        else:
            return img  # No filter applied

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        with rasterio.open(image_path) as src:
            image = src.read(1).astype(np.float32)

        image = self.apply_speckle_filter(image)
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)
        image = torch.from_numpy(image).unsqueeze(0)

        if self.transform:
            image = self.transform(image)

        mask = Image.open(mask_path)
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))

        if self.image_size:
            mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), size=self.image_size, mode='nearest').squeeze().long()

        return image, mask


In [4]:



# -------------------------------
# Model
# -------------------------------
class ViTSegmentation(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=9):
        super().__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, in_chans=1)
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.n_patches = (image_size // patch_size) ** 2

        self.decoder = nn.Sequential(
            nn.Conv2d(self.vit.embed_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        B = x.shape[0]
        feats = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(B, -1, -1)
        feats = torch.cat((cls_token, feats), dim=1)
        feats = feats + self.vit.pos_embed
        feats = self.vit.pos_drop(feats)
        feats = self.vit.blocks(feats)
        feats = self.vit.norm(feats)

        feats = feats[:, 1:, :]
        h = w = self.image_size // self.patch_size
        feats = feats.permute(0, 2, 1).reshape(B, self.vit.embed_dim, h, w)
        feats = F.interpolate(feats, size=x.shape[2:], mode='bilinear', align_corners=False)
        return self.decoder(feats)

In [5]:


# -------------------------------
# Evaluation
# -------------------------------

import torch

# Example class frequencies (you can replace these)
class_freq = torch.tensor([
    0.0015, 0.0170, 0.2470, 0.2334,
    0.0750, 0.1391, 0.1063, 0.0679, 0.1128
])

# Inverse frequency weights (normalize if desired)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_weights = 1.0 / class_freq
class_weights = class_weights / class_weights.sum()  # normalize
class_weights = class_weights.to(device)  # ensure it's on GPU


def TverskyLoss(logits, targets, alpha=0.7, beta=0.3, gamma=1.0, eps=1e-6):
    num_classes = logits.shape[1]
    probs = F.softmax(logits, dim=1)
    targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()

    TP = (probs * targets_one_hot).sum(dim=(2, 3))
    FP = (probs * (1 - targets_one_hot)).sum(dim=(2, 3))
    FN = ((1 - probs) * targets_one_hot).sum(dim=(2, 3))

    Tversky = (TP + eps) / (TP + alpha * FP + beta * FN + eps)
    return (1 - Tversky).pow(gamma).mean()

# # combined_loss.py
# class CombinedLovaszFocalTverskyLoss(torch.nn.Module):
#     def __init__(self, alpha=0.7, beta=0.3, gamma=1.33, lovasz_weight=0.6, tversky_weight=0.4, ce_weight=1):
#         super().__init__()
#         self.alpha = alpha
#         self.beta = beta
#         self.gamma = gamma
#         self.lovasz_weight = lovasz_weight
#         self.tversky_weight = tversky_weight
#         self.ce_weight=ce_weight

#     def forward(self, logits, targets):
#         loss_lovasz = L.lovasz_softmax(F.softmax(logits, dim=1), targets)
#         ce_loss_clf = F.cross_entropy(logits, targets)
#         loss_tversky = focal_tversky_loss(logits, targets, self.alpha, self.beta, self.gamma)
#         return self.lovasz_weight * loss_lovasz + self.tversky_weight * loss_tversky+ self.ce_weight * ce_loss_clf


def evaluate(model, dataloader, num_classes):
    model.eval()
    total_correct = 0
    total_pixels = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.cuda(), y.cuda()
            out = model(x)
            preds = torch.argmax(out, dim=1)
            total_correct += (preds == y).sum().item()
            total_pixels += torch.numel(preds)
            all_preds.append(preds.view(-1))
            all_targets.append(y.view(-1))

    preds_cat = torch.cat(all_preds).cpu().numpy()
    targets_cat = torch.cat(all_targets).cpu().numpy()
    cm = confusion_matrix(targets_cat, preds_cat, labels=list(range(num_classes)))
    ious = np.diag(cm) / (cm.sum(1) + cm.sum(0) - np.diag(cm) + 1e-6)

    print("Class-wise IoUs:")
    for i, iou in enumerate(ious):
        print(f"Class {i}: IoU = {iou:.4f}")

    return total_correct / total_pixels, ious

# -------------------------------
# Optuna Objective
# -------------------------------
def objective(trial):
    # Fixed batch size
    batch_size = 16

    # Hyperparameters to tune
    lr = trial.suggest_float("lr", 1e-4, 3e-3, log=True)
    gamma_tversky = trial.suggest_float("gamma_tversky", 1.0, 3.0)
    alpha = trial.suggest_float("alpha", 0.1, 0.5)
    beta = 1-alpha


    ce_w = trial.suggest_float("ce_weight", 0.1, 1.0)
    tversky_w = trial.suggest_float("tversky_weight", 0.1, 1.0)
    lovasz_w = trial.suggest_float("lovasz_weight", 0.1, 1.0)

    # Dataset and loaders
    transform = ResizeTensor((224, 224))
    dataset = SARDataset(
        image_dir=r"D:\train_splitted\sar_images",
        mask_dir=r"D:\train_splitted\labels",
        transform=transform,
        image_size=(224, 224),
        speckle_filter="lee_sigma"
    )
    train_size = int(0.95 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTSegmentation(image_size=224, num_classes=9).to(device)

    # Loss function with weights and gamma_tversky
    def total_loss_fn(out, target, class_weights=None):
        probs = F.softmax(out, dim=1)
        loss_lovasz = L.lovasz_softmax(probs, target)
        #loss_tversky = TverskyLoss(out, target, gamma=gamma_tversky)
        loss_tversky = TverskyLoss(out, target,alpha=alpha, beta=beta, gamma=gamma_tversky)
        if class_weights is not None:
            loss_ce = F.cross_entropy(out, target, weight=class_weights)
        else:
            loss_ce = F.cross_entropy(out, target)
        return lovasz_w * loss_lovasz + tversky_w * loss_tversky + ce_w * loss_ce

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    # Scheduler - ReduceLROnPlateau on validation accuracy
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=False
    )

    best_val_acc = 0.0
    for epoch in range(30):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = total_loss_fn(out, y)
            loss.backward()
            optimizer.step()

        # Validation step
        model.eval()
        with torch.no_grad():
            val_acc, val_ious = evaluate(model, val_loader, num_classes=9)

        # Step the scheduler with validation accuracy
        scheduler.step(val_acc)

        trial.report(val_acc, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        best_val_acc = max(best_val_acc, val_acc)

    return best_val_acc




In [None]:
# -------------------------------
# Main
# -------------------------------
if __name__ == '__main__':
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=30, timeout=80000)

    print("Best trial:")
    trial = study.best_trial
    print(f"  Value: {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")


[I 2025-05-18 15:17:25,775] A new study created in memory with name: no-name-02d235f0-432b-4a54-be52-c86e63415e3d




In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import rasterio
from tqdm import tqdm
import timm
import optuna
from sklearn.metrics import confusion_matrix
from torch.optim.lr_scheduler import StepLR
import lovasz_losses as L
from scipy.ndimage import uniform_filter, gaussian_filter
from scipy.ndimage import convolve
from skimage.restoration import denoise_bilateral
#from util_func.tversky_loss import TverskyLoss  # Make sure this exists

In [None]:
# -------------------------------
# Dataset
# -------------------------------
class ResizeTensor:
    def __init__(self, size):
        self.size = size

    def __call__(self, tensor):
        return F.interpolate(tensor.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False).squeeze(0)

class SARDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, image_size=(224, 224), speckle_filter=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_size = image_size
        self.speckle_filter = speckle_filter.lower() if speckle_filter else None
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    # def lee_sigma_filter(self, img, size=7, sigma=1.0):
    #     mean = uniform_filter(img, size)
    #     std = np.sqrt(uniform_filter(img**2, size) - mean**2)
    #     mask = np.abs(img - mean) <= sigma * std
    #     return np.where(mask, mean, img)
    
    def lee_sigma_filter(self, img, size=7, sigma=1.0):
        mean = uniform_filter(img, size)
        std = np.sqrt(uniform_filter(img**2, size) - mean**2)
        mask = np.abs(img - mean) <= sigma * std
        output = np.where(mask, mean, img)
        return output

    def refined_lee_filter(self, img, size=7):
        # Simplified refined Lee filter version (basic 2-step averaging)
        mean = uniform_filter(img, size)
        var = uniform_filter((img - mean)**2, size)
        cu = np.var(img)
        W = var / (var + cu + 1e-6)
        filtered = mean + W * (img - mean)
        return uniform_filter(filtered, 3)

    def apply_speckle_filter(self, img):
        return self.lee_sigma_filter(img) if self.speckle_filter == "lee_sigma" else img
    
    def gamma_filter(self, img, size=7):
        mean = uniform_filter(img, size)
        var = uniform_filter((img - mean) ** 2, size)
        b = var / (mean ** 2 + 1e-6)
        return mean / (1 + b)

    def frost_filter(self, img, size=5, damping_factor=2.0):
        padded = np.pad(img, size // 2, mode='reflect')
        out = np.zeros_like(img)
        for i in range(out.shape[0]):
            for j in range(out.shape[1]):
                window = padded[i:i+size, j:j+size]
                center = window[size // 2, size // 2]
                mean = np.mean(window)
                var = np.var(window)
                coeff_var = var / (mean**2 + 1e-6)
                weight = np.exp(-coeff_var * np.abs(window - center) / damping_factor)
                weight /= weight.sum()
                out[i, j] = np.sum(weight * window)
        return out
    
    def apply_speckle_filter(self, img):
        if self.speckle_filter == "lee":
            return self.lee_filter(img)
        elif self.speckle_filter == "gamma":
            return self.gamma_filter(img)
        elif self.speckle_filter == "frost":
            return self.frost_filter(img)
        elif self.speckle_filter == "refined_lee":
            return self.refined_lee_filter(img)
        elif self.speckle_filter == "lee_sigma":
            return self.lee_sigma_filter(img)
        else:
            return img  # No filter applied

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])

        with rasterio.open(image_path) as src:
            image = src.read(1).astype(np.float32)

        image = self.apply_speckle_filter(image)
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)
        image = torch.from_numpy(image).unsqueeze(0)

        if self.transform:
            image = self.transform(image)

        mask = Image.open(mask_path)
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))

        if self.image_size:
            mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), size=self.image_size, mode='nearest').squeeze().long()

        return image, mask


In [None]:



# -------------------------------
# Model
# -------------------------------
class ViTSegmentation(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=9):
        super().__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, in_chans=1)
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.n_patches = (image_size // patch_size) ** 2

        self.decoder = nn.Sequential(
            nn.Conv2d(self.vit.embed_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        B = x.shape[0]
        feats = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(B, -1, -1)
        feats = torch.cat((cls_token, feats), dim=1)
        feats = feats + self.vit.pos_embed
        feats = self.vit.pos_drop(feats)
        feats = self.vit.blocks(feats)
        feats = self.vit.norm(feats)

        feats = feats[:, 1:, :]
        h = w = self.image_size // self.patch_size
        feats = feats.permute(0, 2, 1).reshape(B, self.vit.embed_dim, h, w)
        feats = F.interpolate(feats, size=x.shape[2:], mode='bilinear', align_corners=False)
        return self.decoder(feats)

In [None]:


# -------------------------------
# Evaluation
# -------------------------------



def TverskyLoss(logits, targets, alpha=0.7, beta=0.3, gamma=1.0, eps=1e-6):
    num_classes = logits.shape[1]
    probs = F.softmax(logits, dim=1)
    targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()

    TP = (probs * targets_one_hot).sum(dim=(2, 3))
    FP = (probs * (1 - targets_one_hot)).sum(dim=(2, 3))
    FN = ((1 - probs) * targets_one_hot).sum(dim=(2, 3))

    Tversky = (TP + eps) / (TP + alpha * FP + beta * FN + eps)
    return (1 - Tversky).pow(gamma).mean()

# # combined_loss.py
# class CombinedLovaszFocalTverskyLoss(torch.nn.Module):
#     def __init__(self, alpha=0.7, beta=0.3, gamma=1.33, lovasz_weight=0.6, tversky_weight=0.4, ce_weight=1):
#         super().__init__()
#         self.alpha = alpha
#         self.beta = beta
#         self.gamma = gamma
#         self.lovasz_weight = lovasz_weight
#         self.tversky_weight = tversky_weight
#         self.ce_weight=ce_weight

#     def forward(self, logits, targets):
#         loss_lovasz = L.lovasz_softmax(F.softmax(logits, dim=1), targets)
#         ce_loss_clf = F.cross_entropy(logits, targets)
#         loss_tversky = focal_tversky_loss(logits, targets, self.alpha, self.beta, self.gamma)
#         return self.lovasz_weight * loss_lovasz + self.tversky_weight * loss_tversky+ self.ce_weight * ce_loss_clf


def evaluate(model, dataloader, num_classes):
    model.eval()
    total_correct = 0
    total_pixels = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.cuda(), y.cuda()
            out = model(x)
            preds = torch.argmax(out, dim=1)
            total_correct += (preds == y).sum().item()
            total_pixels += torch.numel(preds)
            all_preds.append(preds.view(-1))
            all_targets.append(y.view(-1))

    preds_cat = torch.cat(all_preds).cpu().numpy()
    targets_cat = torch.cat(all_targets).cpu().numpy()
    cm = confusion_matrix(targets_cat, preds_cat, labels=list(range(num_classes)))
    ious = np.diag(cm) / (cm.sum(1) + cm.sum(0) - np.diag(cm) + 1e-6)

    print("Class-wise IoUs:")
    for i, iou in enumerate(ious):
        print(f"Class {i}: IoU = {iou:.4f}")

    return total_correct / total_pixels, ious

# -------------------------------
# Optuna Objective
# -------------------------------
def objective(trial):
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16])
    gamma_scheduler = trial.suggest_float("gamma_scheduler", 0.1, 0.9)
    gamma_tversky = trial.suggest_float("gamma_tversky", 1.0, 3.0)
    

    #gamma = trial.suggest_float("gamma", 0.1, 0.9)
    ce_w = trial.suggest_float("ce_weight", 0.1, 1.0)
    tversky_w = trial.suggest_float("tversky_weight", 0.1, 1.0)
    lovasz_w = trial.suggest_float("lovasz_weight", 0.1, 1.0)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD", "AdamW"])
    # image_dir="C:/Users/SSINGH21/Documents/IEEE/train/sar_images",
    #     mask_dir="C:/Users/SSINGH21/Documents/IEEE/train/labels/" 
    transform = ResizeTensor((224, 224))
    dataset = SARDataset(
        image_dir=r"D:\train_splitted\sar_images",
        mask_dir=r"D:\train_splitted\labels",
        transform=transform,
        image_size=(224, 224),
        speckle_filter="lee_sigma"
    )

    train_size = int(0.95 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTSegmentation(image_size=224, num_classes=9).to(device)

    

    def total_loss_fn(out, target):
       probs = F.softmax(out, dim=1)  # Apply softmax before passing to lovasz
       loss_lovasz = L.lovasz_softmax(probs, target)
       loss_tversky = TverskyLoss(out, target,gamma=gamma_tversky)
       loss_ce = F.cross_entropy(out, target)
       return lovasz_w * loss_lovasz + tversky_w * loss_tversky + ce_w * loss_ce

    
    
    


    



    optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=10, gamma=gamma_scheduler)

    best_val_acc = 0
    for epoch in range(30):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = total_loss_fn(out, y)
            loss.backward()
            optimizer.step()

        scheduler.step()
        val_acc, val_ious = evaluate(model, val_loader, num_classes=9)
        trial.report(val_acc, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        best_val_acc = max(best_val_acc, val_acc)

    return best_val_acc

