# Config class

In [1]:
class Arguments:
  def __init__(self):
    self.batch_size = 1
    self.gpu_id = 0
    self.train = False
    self.checkpoint_dir = "outputs"
    self.text_prompt = ''
    self.output_dir = "outputs"
    self.dataset_name = "sample"
    self.train_data_dir = None
    self.val_data_dir = None
    self.test_data_dir = None
    self.optimizer = "Adam"
    self.epochs = 200
    self.lr = 0.1
    self.patch_threshold = 100
    self.test_mask_size = 512
    self.save_test_predictions = True
    self.dice_coef = 10
    self.boundary_coef = 0.1
    self.focal_coef = 5

# Utility Functions

In [2]:
import numpy as np
import torch
import random
import colorsys
import cv2

def get_random_crop_coordinates(crop_scale_range, image_width, image_height):
    rand_number = random.random()
    rand_number *= crop_scale_range[1] - crop_scale_range[0]
    rand_number += crop_scale_range[0]
    patch_size = int(rand_number * min(image_width, image_height))
    if patch_size != min(image_width, image_height):
        x_start = random.randint(0, image_width - patch_size)
        y_start = random.randint(0, image_height - patch_size)
    else:
        x_start = 0
        y_start = 0
    return x_start, x_start + patch_size, y_start, y_start + patch_size

In [3]:
def get_crops_coords(image_size, patch_size, num_patchs_per_side):
    h, w = image_size
    if num_patchs_per_side == 1:
        x_step_size = y_step_size = 0
    else:
        x_step_size = (w - patch_size) // (num_patchs_per_side - 1)
        y_step_size = (h - patch_size) // (num_patchs_per_side - 1)
    crops_coords = []
    for i in range(num_patchs_per_side):
        for j in range(num_patchs_per_side):
            y_start, y_end, x_start, x_end = (
                i * y_step_size,
                i * y_step_size + patch_size,
                j * x_step_size,
                j * x_step_size + patch_size,
            )
            crops_coords.append([y_start, y_end, x_start, x_end])
    return crops_coords

In [4]:
def generate_distinct_colors(n):
    colors = []
    if n == 1:
        return [(255, 255, 255)]
    for i in range(n):
        hue = i / n
        saturation = 0.9
        value = 0.9
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        scaled_rgb = tuple(int(x * 255) for x in rgb)
        colors.append(scaled_rgb)
    return colors

In [5]:
def calculate_iou(prediction, mask):
    intersection = prediction * mask
    union = prediction + mask - intersection
    return intersection.sum() / (union.sum() + 1e-7)

In [6]:
def get_colored_segmentation(mask, boundry_mask, image, colors):
    boundry_mask_rgb = 0
    if boundry_mask is not None:
        boundry_mask_rgb = torch.repeat_interleave(boundry_mask[None, ...], 3, 0).type(
            torch.float
        )
        for j in range(3):
            for i in range(1, len(colors) + 1):
                boundry_mask_rgb[j] = torch.where(
                    boundry_mask_rgb[j] == i,
                    colors[i - 1][j] / 255,
                    boundry_mask_rgb[j],
                )
    mask_rgb = torch.repeat_interleave(mask[None, ...], 3, 0).type(torch.float)
    for j in range(3):
        for i in range(1, len(colors) + 1):
            mask_rgb[j] = torch.where(
                mask_rgb[j] == i, colors[i - 1][j] / 255, mask_rgb[j]
            )
    if boundry_mask is not None:
        return (boundry_mask_rgb * 0.6 + mask_rgb * 0.3 + image * 0.4).permute(1, 2, 0)
    else:
        return (mask_rgb * 0.6 + image * 0.4).permute(1, 2, 0)

# Datalaoder

In [7]:
import pytorch_lightning as pl
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from glob import glob
import os
import random


# def get_random_crop_coordinates(crop_ratio_range, width, height):
#     """Generate random crop coordinates based on the given ratio range."""
#     crop_ratio = random.uniform(crop_ratio_range[0], crop_ratio_range[1])
#     crop_width = int(width * crop_ratio)
#     crop_height = int(height * crop_ratio)
    
#     x_start = random.randint(0, width - crop_width)
#     x_end = x_start + crop_width
#     y_start = random.randint(0, height - crop_height)
#     y_end = y_start + crop_height
    
#     return x_start, x_end, y_start, y_end


class Dataset(TorchDataset):
    def __init__(
        self,
        data_dir,
        train=True,
        mask_size=352,
        num_parts=1,
        min_crop_ratio=0.5,
        dataset_name: str = "sample",
    ):
        self.image_paths = sorted(glob(os.path.join(data_dir, "*.png")))
        self.mask_paths = sorted(glob(os.path.join(data_dir, "*.npy")))
        self.train = train
        self.mask_size = mask_size
        self.num_parts = num_parts
        self.min_crop_ratio = min_crop_ratio
        self.current_part_idx = 0
        
        # Normalize transform (similar to CLIP normalization)
        self.normalize_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        # For rotation angle range
        if dataset_name == "celeba":
            rotation_range = (-10, 10)
        else:
            rotation_range = (-30, 30)
            
        # Train transforms
        self.train_transform_1 = A.Compose([
            A.Resize(352, 352),
            A.HorizontalFlip(),
            A.GaussianBlur(blur_limit=(3, 5))
        ])
        
        self.train_transform_2 = A.Compose([
            A.Resize(352, 352),
            A.Rotate(
                rotation_range,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                mask_value=0
            )
        ])
        
        # Test transform
        self.test_transform = A.Compose([
            A.Resize(352, 352),
            ToTensorV2()
        ])

    def __getitem__(self, idx):
        # Load the image
        image = np.array(Image.open(self.image_paths[idx]))
        
        if self.train:
            # Load mask
            mask = np.load(self.mask_paths[idx])
            
            # Apply first set of transforms (resize, flip, blur)
            result = self.train_transform_1(image=image, mask=mask)
            image, mask = result["image"], result["mask"]
            
            # Calculate original mask size for current part
            original_mask_size = np.where(mask == self.current_part_idx, 1, 0).sum()
            
            # Random crop ensuring mask is included
            mask_is_included = False
            while not mask_is_included:
                x_start, x_end, y_start, y_end = get_random_crop_coordinates(
                    (self.min_crop_ratio, 1), 352, 352
                )
                aux_mask = mask[y_start:y_end, x_start:x_end]
                if (
                    original_mask_size == 0
                    or np.where(aux_mask == self.current_part_idx, 1, 0).sum() / original_mask_size > 0.3
                ):
                    mask_is_included = True
            
            image = image[y_start:y_end, x_start:x_end]
            mask = aux_mask
            
            # Apply second set of transforms (rotation)
            result = self.train_transform_2(image=image, mask=mask)
            image, mask = result["image"], result["mask"]
            
            # Convert to tensor and normalize
            result = self.normalize_transform(image=image)
            image = result["image"]
            
            # Convert mask to tensor
            mask = torch.from_numpy(mask).float()
            
            # Update current part index
            self.current_part_idx += 1
            self.current_part_idx = self.current_part_idx % self.num_parts
            
            return image, mask
        else:
            # Test mode
            if len(self.mask_paths) > 0:
                mask = np.load(self.mask_paths[idx])
                result = self.test_transform(image=image, mask=mask)
                image, mask = result["image"], result["mask"]
                mask = torch.from_numpy(mask).float()
            else:
                result = self.test_transform(image=image)
                image = result["image"]
                mask = torch.zeros((352, 352))
                
            return image, mask

    def __len__(self):
        return len(self.image_paths)


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_data_dir: str = "./data",
        val_data_dir: str = "./data",
        test_data_dir: str = "./data",
        batch_size: int = 1,
        train_mask_size: int = 352,
        test_mask_size: int = 352,
        num_parts: int = 2,
        min_crop_ratio: float = 0.5,
        dataset_name: str = "sample",
    ):
        super().__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.test_data_dir = test_data_dir
        self.batch_size = batch_size
        self.train_mask_size = train_mask_size
        self.test_mask_size = test_mask_size
        self.num_parts = num_parts
        self.min_crop_ratio = min_crop_ratio
        self.dataset_name = dataset_name

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = Dataset(
                data_dir=self.train_data_dir,
                train=True,
                mask_size=self.train_mask_size,
                num_parts=self.num_parts,
                min_crop_ratio=self.min_crop_ratio,
                dataset_name=self.dataset_name,
            )
            self.val_dataset = Dataset(
                data_dir=self.val_data_dir,
                train=False,
                mask_size=self.test_mask_size,
            )
        elif stage == "test":
            self.test_dataset = Dataset(
                data_dir=self.test_data_dir,
                train=False,
                mask_size=self.test_mask_size,
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, num_workers=0, shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=0, shuffle=False
        )

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations


# custom loss function

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# class DiceLoss(nn.Module):
#     def __init__(self, smooth=1.0):
#         super(DiceLoss, self).__init__()
#         self.smooth = smooth
        
#     def forward(self, predictions, targets):
#         """
#         Compute Dice loss between predictions and target masks
        
#         Args:
#             predictions: Tensor of shape (B, H, W) or (H, W), predicted probabilities
#             targets: Binary mask of shape (B, H, W) or (H, W)
            
#         Returns:
#             Dice loss (1 - Dice coefficient)
#         """
#         # Flatten the tensors
#         predictions = predictions.view(-1)
#         targets = targets.view(-1)
        
#         # Calculate intersection and union
#         intersection = (predictions * targets).sum()
#         union = predictions.sum() + targets.sum()
        
#         # Calculate Dice coefficient
#         dice = (2. * intersection + self.smooth) / (union + self.smooth)
        
#         # Return Dice loss
#         return 1 - dice



class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # Ensure shapes match
        if predictions.shape != targets.shape:
            targets = targets.view(predictions.shape)
            
        # Flatten prediction and target tensors
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth
        )
        
        return 1 - dice

# ClipSeg@@@@

In [9]:
import os
from IPython.display import clear_output


import pytorch_lightning as pl
import torch
from torch import optim
import torch.nn.functional as F


import gc
from PIL import Image

import torchvision
from torchvision import transforms


# load the model

import torch
import requests

from models.clipseg import CLIPDensePredT

from matplotlib import pyplot as plt
from PIL import Image




class CLIPSeg22(pl.LightningModule):
    def __init__(self, config, learning_rate=0.001):
        super().__init__()

        self.config = config
        # self.save_hyperparameters(config.__dict__)
        self.max_val_iou = 0
        self.val_ious = []

        # self.device1 = 'cpu'
        self.device1 = 'cuda' if torch.cuda.is_available() else 'cpu'

        # load model
        self.model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
        self.model.train()
        for p in self.model.parameters():
            p.requires_grad = False

        # non-strict, because we only stored decoder weights (not CLIP weights)
        self.model.load_state_dict(torch.load('clipseg_weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False)

        self.model = self.model.to(self.device1)

        base_prompt = ''

        import clip

        # text_tokens = clip.tokenize(base_prompt).to(self.device1)
        # emb = self.model.clip_model.encode_text(text_tokens)
        # self.emb_to_learn = emb.detach().clone().to(self.device1)  # Detach and clone to make it a leaf tensor
        # self.emb_to_learn.requires_grad_(True)  # Correct way to set requires_grad

        # Create embedding properly as a model parameter
        with torch.no_grad():
            text_tokens = clip.tokenize(base_prompt).to(self.device1)
            emb = self.model.clip_model.encode_text(text_tokens)
            self.emb_to_learn = torch.nn.Parameter(emb.clone())



    def on_fit_start(self) -> None: #pl: called at the beginning of the fit()
        # move model to gpu
        pass

    def training_step(self, batch, batch_idx):
        image, mask = batch
        image = image.to(self.device1)
        mask = mask.to(self.device1)
        
        # print('training step was called..........')
        # print(image.shape)
        
        # Get model predictions
        # print('self emb to learn is : ', self.emb_to_learn.shape)
        preds = self.model(image, self.emb_to_learn)[0]
        # print('preds grad: ', preds.requires_grad)
        
        # Apply sigmoid but keep the computational graph intact
        prediction = torch.sigmoid(preds)
        
        # print('prediction shape: ', prediction.shape)
        # print('mask shape :', mask.shape)
        
        # Make sure mask and prediction have compatible shapes
        # Assuming mask is [B, H, W] and prediction is [B, 1, H, W]
        if len(mask.shape) < len(prediction.shape):
            mask = mask.unsqueeze(1)
        
        # Use DiceLoss
        dice = DiceLoss().to(self.device1)
        loss = dice(prediction, mask)
        # loss = F.cross_entropy(
        #     prediction,
        #     mask.type(torch.float16),
        # )
        # print('loss is : ', loss)

        # if self.emb_to_learn.grad is not None:
        #     print("Embedding grad in training step:", 
        #         self.emb_to_learn.grad.min().item(),
        #         self.emb_to_learn.grad.mean().item(), 
        #         self.emb_to_learn.grad.max().item())
        # else:
        #     print("Embedding grad is None in training step")
        
        # Log the loss if needed
        # self.log("train_loss", loss, on_step=True, prog_bar=True)
        # print('mean is : lkals;dafkj: ',self.emb_to_learn.mean())
        
        
        clear_output(wait=True)
        
        return loss




    def on_validation_start(self):
        pass

    def on_validation_epoch_start(self):
        self.val_ious = []

    def validation_step(self, batch, batch_idx):
        image, mask = batch
        #calculate mask / iou

        # self.log("val mean iou", mean_iou.cpu(), on_step=True, sync_dist=True)
        return torch.tensor(0.0)

    def on_validation_epoch_end(self):
        # epoch_mean_iou = sum(self.val_ious) / len(self.val_ious)
        # if epoch_mean_iou >= self.max_val_iou:
        #     self.max_val_iou = epoch_mean_iou
        #     for i, embedding in enumerate(self.emb_to_learn): #save embeddings
        #         torch.save(
        #             embedding,
        #             os.path.join(self.checkpoint_dir, f"embedding_{i}.pth"),
        #         )
        gc.collect() #python garbage collector

    def on_test_start(self) -> None:
        # make directory for saving results
        # load the embeddings if needed
        pass

    def test_step(self, batch, batch_idx):
        # Saving just the emb_to_learn parameter
        torch.save(self.emb_to_learn, 'emb_to_learn.pt')


        # do your test visualize/iou/save results/log results
        test_pth = 'datasets/dataset/test'
        print('--------------------------------------------------------------')
        img_list = [x for x in os.listdir(test_pth) if '.png' in x]
        print(len(img_list))

        print('--------------------------------------------------------------')
        for img_name in img_list:
            input_image = Image.open(f'{test_pth}/{img_name}')
            mask = np.load(f"{test_pth}/{img_name.replace('.png', '.npy')}")


            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                transforms.Resize((352, 352)),
            ])

            img = transform(input_image).unsqueeze(0)


            # predict
            with torch.no_grad():
                preds = self.model(img, self.emb_to_learn, is_training=False)[0]

            #mask generation
            # filename = f"mask.png"
            # here we save the second mask
            m = torch.sigmoid(preds[0][0]).cpu()
            plt.imsave(f'results/{img_name}',m)


            # Create binary mask
            img2 = cv2.imread(f'results/{img_name}')
            gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
            (thresh, bw_image) = cv2.threshold(gray_image, 150, 255, cv2.THRESH_BINARY)

            # Save the binary mask using cv2.imwrite instead of plt.imsave
            bmask_filename = f"results/{img_name.replace('.png', '')}_bmask.png"
            cv2.imwrite(bmask_filename, bw_image)






            # img2 = cv2.imread(f'results/{img_name}')
            # gray_image = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)

            # (thresh, bw_image) = cv2.threshold(gray_image, 150, 255, cv2.THRESH_BINARY)

            # # fix color format
            # cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)

            # # Image.fromarray(bw_image)

            # # visualize prediction
            # _, ax = plt.subplots(1, 4, figsize=(10, 4))
            # [a.axis('off') for a in ax.flatten()]
            # ax[0].imshow(input_image)
            # ax[1].imshow(m)
            # ax[1].text(0, -15, 'polyp')
            # ax[2].imshow(bw_image)
            # ax[3].imshow(mask, cmap='gray')


        return torch.tensor(0.0)

    def on_test_end(self) -> None:
        print("max val mean iou: ", self.max_val_iou)

    def configure_optimizers(self):
        optimizer = getattr(optim, self.config.optimizer)(
            [self.emb_to_learn],  # Pass the parameter directly
            lr=self.config.lr,
        )
        return optimizer

    def on_before_optimizer_step(self, optimizer):
        if self.trainer.global_step % 1 == 0:  # Avoid excessive logging
            for param in self.emb_to_learn:
                if param.grad is None:
                    print( "is None, requires_grad : ", param.requires_grad)
                else:
                    print ( "grad", param.grad.data.max(), param.grad.data.min())


# Training script

In [10]:
import pytorch_lightning as pl

def main():
    config = Arguments()
    config.dataset_name = "pascal"
    config.train_data_dir = "datasets/dataset/train"
    config.val_data_dir = ""
    config.test_data_dir = ""
    config.train = True
    config.epochs = 500


    dm = DataModule(
        train_data_dir=config.train_data_dir,
        val_data_dir=config.val_data_dir,
        test_data_dir=config.test_data_dir,
        batch_size=config.batch_size,
        test_mask_size=config.test_mask_size,
        dataset_name=config.dataset_name,
    )
    model = CLIPSeg22(config=config)
    if isinstance(config.gpu_id, int):
        gpu_id = [config.gpu_id]
    else:
        gpu_id = config.gpu_id
    trainer = pl.Trainer(
        # accelerator="gpu",
        default_root_dir=config.output_dir,
        max_epochs=config.epochs,
        devices=gpu_id,
        log_every_n_steps=1,
        enable_checkpointing=False,
        num_sanity_val_steps=0,
    )
    if config.train:
        trainer.fit(model=model, datamodule=dm)
        trainer.test(model=model, datamodule=dm)
    else:
        trainer.test(model=model, datamodule=dm)

main()

is None, requires_grad :  True
Epoch 499: 100%|██████████| 11/11 [00:00<00:00, 30.17it/s, v_num=79]

`Trainer.fit` stopped: `max_epochs=500` reached.


Epoch 499: 100%|██████████| 11/11 [00:00<00:00, 30.17it/s, v_num=79]


You are using a CUDA device ('NVIDIA GeForce RTX 4070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]



--------------------------------------------------------------
681
--------------------------------------------------------------
Testing DataLoader 0: 100%|██████████| 1/1 [00:56<00:00, 56.29s/it]
max val mean iou:  0


In [None]:
# # Later, to load it:
# loaded_emb = torch.load('emb_to_learn.pt')
# self.emb_to_learn = torch.nn.Parameter(loaded_emb)