In [None]:
# !pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
# !pip install -q monai

In [None]:
from glob import glob
import torch
import cv2
import numpy as np
new_images = glob("../data/dataset/New_dataset/Images/*.jpg")
new_masks = [i.replace("Images" , "Masks") for i in new_images]
new_masks = [j.replace("jpg","tif") for j in new_masks]

split_idx = int(len(new_images) * 0.8)

train_images, train_masks = new_images[:split_idx], new_masks[:split_idx]
val_images, val_masks = new_images[split_idx:], new_masks[split_idx:]

len(train_images), len(train_masks), len(val_images), len(val_masks)

In [None]:
import torch
import numpy as np
from torchvision import transforms

def augment_image_with_mask(image, mask):
    # Convert numpy arrays to torch Tensors
    image = torch.from_numpy(image.transpose(2, 0, 1))  # Shape to C x H x W
    mask = torch.from_numpy(mask).unsqueeze(0)          # Shape to 1 x H x W
    seed = np.random.randint(2147483647)

    # Define the transformation pipeline (no ColorJitter for the mask)
    spatial_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.2)),  # Randomly crop and resize
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
    ])

    photometric_transforms = transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjust brightness, contrast, etc.
    ])

    torch.manual_seed(seed)
    image = spatial_transforms(image)
    torch.manual_seed(seed)
    mask = spatial_transforms(mask)
    image = photometric_transforms(image)

    # Convert back to numpy arrays
    image_augmented = image.permute(1, 2, 0).numpy()  # Shape to H x W x C
    mask_augmented = mask.squeeze(0).numpy()          # Shape to H x W

    return image_augmented, mask_augmented


In [None]:
from torch.utils.data import Dataset

class SAMDataset(Dataset):
  def __init__(self, images , masks, processor ,augment_factor = 1 , building_threshold=0.15, transform = False):
    self.images = images
    self.masks = masks
    self.processor = processor
    self.augment_factor = augment_factor
    self.building_threshold = building_threshold
    self.valid_indices = self._filter_images()
    self.transform = transform

  def _filter_images(self):
      valid_indices = []
      for idx, (image_path, mask_path) in enumerate(zip(self.images, self.masks)):
          mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
          building_area = (mask > 0).sum() / (mask.size)  # Normalized by total number of pixels

          if building_area >= self.building_threshold:
              valid_indices.append(idx)

      return valid_indices
  
  def __len__(self):
      return len(self.valid_indices) * 16 * self.augment_factor

  def __getitem__(self, idx):
      effective_index = (idx // self.augment_factor) % (len(self.valid_indices) * 16)  # Ensure we stay within bounds
      image_idx = self.valid_indices[effective_index // 16]
      
      image_path = self.images[image_idx]
      mask_path = self.masks[image_idx]
      image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
      mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

      patch_idx = effective_index % 16  # Get the patch index
      x_offset = (patch_idx % 4) * 256
      y_offset = (patch_idx // 4) * 256
      image = image[y_offset:y_offset + 256, x_offset:x_offset + 256]
      ground_truth_mask = mask[y_offset:y_offset + 256, x_offset:x_offset + 256]

      image  = image/255
      image = image.astype(np.float32)
      if self.transform:
            image, ground_truth_mask = augment_image_with_mask(image, ground_truth_mask)
          
      # prepare image and prompt for the model
      inputs = self.processor(image, input_boxes=[[[0,0,0,0]]], return_tensors="pt")

      # remove batch dimension which the processor adds by default
      inputs = {k:v.squeeze(0) for k,v in inputs.items()}

      # add ground truth segmentation
      inputs["ground_truth_mask"] = (ground_truth_mask>0).astype(int)
      inputs["image_patch"] = image
      return inputs


In [None]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
train_dataset = SAMDataset(train_images , train_masks, processor=processor ,transform = False )
val_dataset = SAMDataset(val_images , val_masks, processor=processor, transform = False )
len(train_dataset) , len(val_dataset)

In [None]:
train_dataset[1]['pixel_values'].shape, train_dataset[1]['ground_truth_mask'].shape, train_dataset[1]['image_patch'].shape 

In [None]:
import matplotlib.pyplot as plt
idx = 555
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].imshow(train_dataset[idx]['image_patch'])
axs[1].imshow(train_dataset[idx]['ground_truth_mask'])

In [None]:
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)

In [None]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)
len(train_dataloader) , len(val_dataloader)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

In [None]:
from transformers import SamModel 
from torch.optim import AdamW
import monai

model = SamModel.from_pretrained("facebook/sam-vit-base")
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
  elif name.startswith("mask_decoder"):
        param.requires_grad = True

optimizer = AdamW(model.mask_decoder.parameters(), lr=0.0001)
seg_loss = torch.nn.BCEWithLogitsLoss()
seg_loss1 = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
import torch
import gc
gc.collect()  # Collect garbage
torch.cuda.empty_cache()  # Free up unused GPU memory

In [None]:
for batch in train_dataloader:
    break

In [None]:
for k,v in batch.items():
  print(k,v.shape)

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 20

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

for epoch in range(num_epochs):
    # Training phase
    model.train()
    epoch_losses = []
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
        # Forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # Compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record loss
        epoch_losses.append(loss.item())
        
    # Calculate and print average training loss for the epoch
    avg_train_loss = mean(epoch_losses)
    print(f'EPOCH {epoch + 1}:')
    print(f'Mean Training Loss: {avg_train_loss}')
    
    # Validation phase
    model.eval()
    val_losses = []
    with torch.no_grad():  # No gradients needed for validation
        for batch in tqdm(val_dataloader, desc=f"Validation Epoch {epoch + 1}/{num_epochs}"):
            # Forward pass
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                            input_boxes=batch["input_boxes"].to(device),
                            multimask_output=False)

            # Compute loss
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            val_loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
            
            # Record validation loss
            val_losses.append(val_loss.item())

    # Calculate and print average validation loss for the epoch
    avg_val_loss = mean(val_losses)
    print(f'Mean Validation Loss: {avg_val_loss}')


In [None]:
outputs

In [None]:
ground_truth_masks