# Config class

In [None]:
class Arguments:
  def __init__(self):
    self.batch_size = 1
    self.gpu_id = 0
    self.train = False
    self.checkpoint_dir = "outputs"
    self.text_prompt = ''
    self.output_dir = "outputs"
    self.dataset_name = "sample"
    self.train_data_dir = None
    self.val_data_dir = None
    self.test_data_dir = None
    self.optimizer = "Adam"
    self.epochs = 200
    self.lr = 0.1
    self.patch_threshold = 100
    self.test_mask_size = 512
    self.save_test_predictions = True
    self.dice_coef = 10
    self.boundary_coef = 0.1
    self.focal_coef = 5

# Utility Functions

In [None]:
import numpy as np
import torch
import random
import colorsys
import cv2

def get_random_crop_coordinates(crop_scale_range, image_width, image_height):
    rand_number = random.random()
    rand_number *= crop_scale_range[1] - crop_scale_range[0]
    rand_number += crop_scale_range[0]
    patch_size = int(rand_number * min(image_width, image_height))
    if patch_size != min(image_width, image_height):
        x_start = random.randint(0, image_width - patch_size)
        y_start = random.randint(0, image_height - patch_size)
    else:
        x_start = 0
        y_start = 0
    return x_start, x_start + patch_size, y_start, y_start + patch_size

In [None]:
def get_crops_coords(image_size, patch_size, num_patchs_per_side):
    h, w = image_size
    if num_patchs_per_side == 1:
        x_step_size = y_step_size = 0
    else:
        x_step_size = (w - patch_size) // (num_patchs_per_side - 1)
        y_step_size = (h - patch_size) // (num_patchs_per_side - 1)
    crops_coords = []
    for i in range(num_patchs_per_side):
        for j in range(num_patchs_per_side):
            y_start, y_end, x_start, x_end = (
                i * y_step_size,
                i * y_step_size + patch_size,
                j * x_step_size,
                j * x_step_size + patch_size,
            )
            crops_coords.append([y_start, y_end, x_start, x_end])
    return crops_coords

In [None]:
def generate_distinct_colors(n):
    colors = []
    if n == 1:
        return [(255, 255, 255)]
    for i in range(n):
        hue = i / n
        saturation = 0.9
        value = 0.9
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        scaled_rgb = tuple(int(x * 255) for x in rgb)
        colors.append(scaled_rgb)
    return colors

In [None]:
def calculate_iou(prediction, mask):
    intersection = prediction * mask
    union = prediction + mask - intersection
    return intersection.sum() / (union.sum() + 1e-7)

In [None]:
def get_colored_segmentation(mask, boundry_mask, image, colors):
    boundry_mask_rgb = 0
    if boundry_mask is not None:
        boundry_mask_rgb = torch.repeat_interleave(boundry_mask[None, ...], 3, 0).type(
            torch.float
        )
        for j in range(3):
            for i in range(1, len(colors) + 1):
                boundry_mask_rgb[j] = torch.where(
                    boundry_mask_rgb[j] == i,
                    colors[i - 1][j] / 255,
                    boundry_mask_rgb[j],
                )
    mask_rgb = torch.repeat_interleave(mask[None, ...], 3, 0).type(torch.float)
    for j in range(3):
        for i in range(1, len(colors) + 1):
            mask_rgb[j] = torch.where(
                mask_rgb[j] == i, colors[i - 1][j] / 255, mask_rgb[j]
            )
    if boundry_mask is not None:
        return (boundry_mask_rgb * 0.6 + mask_rgb * 0.3 + image * 0.4).permute(1, 2, 0)
    else:
        return (mask_rgb * 0.6 + image * 0.4).permute(1, 2, 0)

# Dataloader

In [None]:
import pytorch_lightning as pl
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from glob import glob
import os
import random




class Dataset(TorchDataset):
    def __init__(
        self,
        data_dir,
        train=True,
        mask_size=352,
        num_parts=1,
        min_crop_ratio=0.5,
        dataset_name: str = "sample",
    ):
        self.image_paths = sorted(glob(os.path.join(data_dir, "*.png")))
        self.mask_paths = sorted(glob(os.path.join(data_dir, "*.npy")))
        self.train = train
        self.mask_size = mask_size
        self.num_parts = num_parts
        self.min_crop_ratio = min_crop_ratio
        self.current_part_idx = 1

        
        # Normalize transform (similar to CLIP normalization)
        self.normalize_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # For rotation angle range
        if dataset_name == "celeba":
            rotation_range = (-10, 10)
        else:
            rotation_range = (-30, 30)
            
        # Train transforms
        self.train_transform_1 = A.Compose([
            A.Resize(352, 352),
            A.HorizontalFlip(),
            A.GaussianBlur(blur_limit=(3, 5))
        ])
        
        self.train_transform_2 = A.Compose([
            A.Rotate(
                rotation_range,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                mask_value=0
            )
        ])
        
        # Test transform
        self.tensorize = A.Compose([
            ToTensorV2()
        ])

        self.resize_transform = A.Compose([
            A.Resize(352,352)
        ])

    def __getitem__(self, idx):
        # Load the image
        image = np.array(Image.open(self.image_paths[idx]))
        name = self.image_paths[idx].split('/')[-1]
        name = name.replace('.png', '')
        
        if self.train:
            # Load mask
            mask = np.load(self.mask_paths[idx])
            
            # Apply first set of transforms (resize, flip, blur)
            result = self.train_transform_1(image=image, mask=mask)
            image, mask = result["image"], result["mask"]
                        
            
            # Apply second set of transforms (rotation)
            result = self.train_transform_2(image=image, mask=mask)
            image, mask = result["image"], result["mask"]
            
            # Convert to tensor and normalize
            result = self.normalize_transform(image=image)
            result = self.tensorize(image = result['image'])
            image = result["image"]
            
            # Convert mask to tensor
            mask = torch.from_numpy(mask).float()
            
            return image, mask, name
        else:
            # Test mode
            mask = np.load(self.mask_paths[idx])
            result = self.resize_transform(image = image, mask = mask)
            image, mask = result["image"], result["mask"]
            result = self.normalize_transform(image=image)
            image = result['image']
            # should I turn to tensor?
            result = self.tensorize(image=image)
            image = result["image"]
            mask = torch.from_numpy(mask).float()
                
            return image, mask, name

    def __len__(self):
        return len(self.image_paths)


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_data_dir: str = "./data",
        val_data_dir: str = "./data",
        test_data_dir: str = "./data",
        batch_size: int = 1,
        train_mask_size: int = 352,
        test_mask_size: int = 352,
        num_parts: int = 2,
        min_crop_ratio: float = 0.5,
        dataset_name: str = "sample",
    ):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.test_data_dir = test_data_dir
        self.batch_size = batch_size
        self.train_mask_size = train_mask_size
        self.test_mask_size = test_mask_size
        self.num_parts = num_parts
        self.min_crop_ratio = min_crop_ratio
        self.dataset_name = dataset_name

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = Dataset(
                data_dir=self.train_data_dir,
                train=True,
                mask_size=self.train_mask_size,
                num_parts=self.num_parts,
                min_crop_ratio=self.min_crop_ratio,
                dataset_name=self.dataset_name,
            )
            self.val_dataset = Dataset(
                data_dir=self.val_data_dir,
                train=False,
                mask_size=self.test_mask_size,
            )
        elif stage == "test":
            self.test_dataset = Dataset(
                data_dir=self.test_data_dir,
                train=False,
                mask_size=self.test_mask_size,
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, num_workers=0, shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False
        )

# custom loss function

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # Ensure shapes match
        if predictions.shape != targets.shape:
            targets = targets.view(predictions.shape)
            
        # Flatten prediction and target tensors
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth
        )
        
        return 1 - dice

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.contrib import distance_transform

class BoundaryLoss(nn.Module):
    def __init__(self):
        super(BoundaryLoss, self).__init__()

    def compute_target_sdf(self, target_mask):
        # Ensure target mask is properly shaped for distance_transform
        # If target_mask is (B, H, W), we need to add a channel dimension
        if target_mask.dim() == 3:
            binary_mask = (target_mask > 0).float().unsqueeze(1)  # (B, 1, H, W)
        else:
            binary_mask = (target_mask > 0).float()

        # Compute distance transforms for foreground and background
        pos_dist = distance_transform(binary_mask)
        neg_dist = distance_transform(1 - binary_mask)

        # Calculate signed distance function
        sdf = neg_dist - pos_dist  # Note: switched order for correct boundary emphasis
        return sdf

    def forward(self, pred_mask, target_mask):
        """
        pred_mask: Predicted probabilities after sigmoid (batch_size, 1, height, width)
        target_mask: Ground truth binary mask (batch_size, height, width)
        """
        batch_size = pred_mask.size(0)
        
        # Skip softmax since input is already probability after sigmoid
        pred_probs = pred_mask  # Already probabilities, no need for softmax
        
        # Ensure target_mask is on the correct device
        target_mask = target_mask.to(pred_mask.device)
        
        # Compute SDF for the target mask
        target_sdf = self.compute_target_sdf(target_mask)
        
        # Normalize SDF
        target_sdf = torch.tanh(target_sdf / 10.0)
        
        # For binary case with pre-applied sigmoid, use probabilities directly
        weighted_probs = pred_probs.squeeze(1) * torch.abs(target_sdf).squeeze(1)
        loss = weighted_probs.sum() / (weighted_probs.numel() / batch_size)
        
        # Adjust regularization term for binary case
        reg_term = 0.1 * torch.mean((pred_probs.squeeze(1) - target_mask.float()) ** 2)
        
        return loss + reg_term

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, prediction, target):
        # Apply sigmoid if needed
        prediction = torch.sigmoid(prediction)
        
        # Flatten the tensors
        prediction = prediction.view(-1)
        target = target.view(-1).float()
        
        # Calculate BCE
        bce = F.binary_cross_entropy(prediction, target, reduction='none')
        
        # Calculate focal weight
        pt = target * prediction + (1 - target) * (1 - prediction)
        focal_weight = (1 - pt) ** self.gamma
        
        # Calculate focal loss
        focal_loss = self.alpha * focal_weight * bce
        
        return focal_loss.mean()

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        
    def forward(self, prediction, target):
        target = target.float()
        if target.dim() == 3:
            target = target.unsqueeze(1)  # Add channel dimension
            
        bce_loss = self.bce(prediction, target)
        dice_loss = self.dice(torch.sigmoid(prediction), target)
        
        return self.bce_weight * bce_loss + self.dice_weight * dice_loss

# few shot clipseg

In [None]:
import os
from IPython.display import clear_output


import pytorch_lightning as pl
import torch
from torch import optim
import torch.nn.functional as F


import gc
from PIL import Image

import torchvision
from torchvision import transforms


# load the model

import torch
import requests

from models.clipseg import CLIPDensePredT

from matplotlib import pyplot as plt
from PIL import Image
import time




class CLIPSeg22(pl.LightningModule):
    def __init__(self, config, learning_rate=0.001):
        super().__init__()

        self.config = config
        # self.save_hyperparameters(config.__dict__)
        self.max_val_iou = 0
        self.val_ious = []

        # self.device1 = 'cpu'
        self.device1 = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(self.device1)

        # load model
        self.model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
        self.model.train()
        for p in self.model.parameters():
            p.requires_grad = False

        # non-strict, because we only stored decoder weights (not CLIP weights)
        self.model.load_state_dict(torch.load('clipseg_weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False)

        self.model = self.model.to(self.device1)

        base_prompt = 'nose'

        import clip

        # text_tokens = clip.tokenize(base_prompt).to(self.device1)
        # emb = self.model.clip_model.encode_text(text_tokens)
        # self.emb_to_learn = emb.detach().clone().to(self.device1)  # Detach and clone to make it a leaf tensor
        # self.emb_to_learn.requires_grad_(True)  # Correct way to set requires_grad

        # Create embedding properly as a model parameter
        with torch.no_grad():
            text_tokens = clip.tokenize(base_prompt).to(self.device1)
            emb = self.model.clip_model.encode_text(text_tokens)
            self.emb_to_learn = torch.nn.Parameter(emb.clone())

        self.i = 0



    def on_fit_start(self) -> None: #pl: called at the beginning of the fit()
        # move model to gpu
        pass

    def training_step(self, batch, batch_idx):
        image, mask, _ = batch
        image = image.to(self.device1)
        mask = mask.to(self.device1)
        
        # print('training step was called..........')
        
        # Get model predictions

        preds = self.model(image, self.emb_to_learn)[0]

        
        # Apply sigmoid but keep the computational graph intact
        prediction = torch.sigmoid(preds)
        
        # Make sure mask and prediction have compatible shapes
        # print('pred shape',prediction.shape) # pred shape torch.Size([1, 1, 352, 352])
        # print('gt shape', mask.shape)# gt shape torch.Size([1, 352, 352])

        # # # Use Boundary loss
        # bloss = BoundaryLoss().to(self.device1)
        # loss = bloss(prediction, mask)


        
        # Use DiceLoss
        # Assuming mask is [B, H, W] and prediction is [B, 1, H, W]
        if len(mask.shape) < len(prediction.shape):
            mask = mask.unsqueeze(1)
        dice = DiceLoss().to(self.device1)
        loss = dice(prediction, mask)

        # loss = loss1 + loss2


        # criterion = nn.BCELoss()
        # loss = criterion(prediction, mask.unsqueeze(1).float())

        ##combined loss
        # combined_loss = CombinedLoss()
        # loss = combined_loss(prediction, mask)
        
        
        # if self.i % 10 == 0:
        #     clear_output(wait=True)
        
        self.i = self.i + 1
        return loss




    def on_validation_start(self):
        pass

    def on_validation_epoch_start(self):
        os.makedirs('validations', exist_ok=True)
        self.val_ious = []

    def validation_step(self, batch, batch_idx):
        #calculate mask / iou
        image, mask, file_name = batch
        file_name = file_name[0]
        image = image.to(self.device1)
        mask = mask.to(self.device1)



        # predict
        with torch.no_grad():
            preds = self.model(image, self.emb_to_learn)[0]
            # Apply sigmoid but keep the computational graph intact
            preds = torch.sigmoid(preds)

        #mask generation
        # filename = f"mask.png"
        # here we save the second mask
        m = torch.sigmoid(preds[0][0]).cpu()
        plt.imsave(f'validations/{file_name}.png',m)


        # Create binary mask
        img2 = cv2.imread(f'validations/{file_name}.png')
        gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        (thresh, bw_image) = cv2.threshold(gray_image, 125, 255, cv2.THRESH_BINARY)

        # Save the binary mask using cv2.imwrite instead of plt.imsave
        bmask_filename = f"validations/{file_name}_bmask.png"
        cv2.imwrite(bmask_filename, bw_image)

        iou = calculate_iou(bw_image, mask.cpu().numpy())

        self.val_ious.append(iou)

        # self.log("val mean iou", mean_iou.cpu(), on_step=True, sync_dist=True)
        return torch.tensor(0.0)

    def on_validation_epoch_end(self):
        mean_iou = sum(self.val_ious) / len(self.val_ious)
        print('mean iou is ', mean_iou)
        self.log("val mean iou", mean_iou, sync_dist=True)
        gc.collect() #python garbage collector

    def on_test_start(self) -> None:
        # make directory for saving results
        os.makedirs('results', exist_ok=True)
        # load the embeddings if needed
        self.start_time = time.time()
        pass

    def test_step(self, batch, batch_idx):
        # Saving just the emb_to_learn parameter
        torch.save(self.emb_to_learn, 'emb_to_learn.pt')


        image, mask, file_name = batch
        file_name = file_name[0]
        image = image.to(self.device1)
        mask = mask.to(self.device1)



        # predict
        with torch.no_grad():
            preds = self.model(image, self.emb_to_learn)[0]
            # Apply sigmoid but keep the computational graph intact
            preds = torch.sigmoid(preds)

        #mask generation
        # filename = f"mask.png"
        # here we save the second mask
        m = torch.sigmoid(preds[0][0]).cpu()
        plt.imsave(f'results/{file_name}.png',m)


        # Create binary mask
        img2 = cv2.imread(f'results/{file_name}.png')
        gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        (thresh, bw_image) = cv2.threshold(gray_image, 125, 255, cv2.THRESH_BINARY)

        # Save the binary mask using cv2.imwrite instead of plt.imsave
        bmask_filename = f"results/{file_name}_bmask.png"
        cv2.imwrite(bmask_filename, bw_image)





        return torch.tensor(0.0)

    def on_test_end(self) -> None:
        self.end_time = time.time()

        print(f"Time taken for {len(self.trainer.datamodule.test_dataset)} images: {self.end_time - self.start_time} seconds")
        print(f"Average time per image: {(self.end_time - self.start_time) / len(self.trainer.datamodule.test_dataset)} seconds")
        print("max val mean iou: ", self.max_val_iou)

    def configure_optimizers(self):
        optimizer = getattr(optim, self.config.optimizer)(
            [self.emb_to_learn],  # Pass the parameter directly
            lr=0.1,
        )
        return optimizer

    def on_before_optimizer_step(self, optimizer):
        if self.trainer.global_step % 1 == 0:  # Avoid excessive logging
          if self.emb_to_learn.grad is not None:
              print("Embedding grad in training step:",
                  self.emb_to_learn.grad.min().item(),
                  self.emb_to_learn.grad.mean().item(),
                  self.emb_to_learn.grad.max().item())
          else:
              print("Embedding grad is None in training step")


# Training script

In [None]:
import pytorch_lightning as pl

def main():
    config = Arguments()
    config.dataset_name = "pascal"
    config.train_data_dir = "datasets/dataset/train"
    config.val_data_dir = "datasets/dataset/val"
    config.test_data_dir = "datasets/dataset/test"
    config.train = True
    config.epochs = 2


    dm = DataModule(
        train_data_dir=config.train_data_dir,
        val_data_dir=config.val_data_dir,
        test_data_dir=config.test_data_dir,
        batch_size=config.batch_size,
        test_mask_size=config.test_mask_size,
        dataset_name=config.dataset_name,
    )
    model = CLIPSeg22(config=config)
    if isinstance(config.gpu_id, int):
        gpu_id = [config.gpu_id]
    else:
        gpu_id = config.gpu_id
    trainer = pl.Trainer(
        # accelerator="gpu",
        default_root_dir=config.output_dir,
        max_epochs=config.epochs,
        devices=gpu_id,
        log_every_n_steps=1,
        enable_checkpointing=False,
        num_sanity_val_steps=0,
        val_check_interval= 1.0,
    )
    if config.train:
        trainer.fit(model=model, datamodule=dm)
        trainer.test(model=model, datamodule=dm)
    else:
        trainer.test(model=model, datamodule=dm)

main()

In [None]:
# # Later, to load it:
# loaded_emb = torch.load('emb_to_learn.pt')
# self.emb_to_learn = torch.nn.Parameter(loaded_emb)

In [None]:
import torch
torch.cuda.is_available()