In [1]:
# Mount/connect my google drive
#from google.colab import drive
#drive.mount('/content/drive')



#path_to_images = '/content/drive/MyDrive/Data/UAVVaste/test_train_val/images/train'
path_to_images = './sample_data/images/'
#path_to_masks = '/content/drive/MyDrive/Data/UAVVaste/test_train_val/masks/train'
#path_to_json = ' /content/drive/MyDrive/Data/UAVVaste/train_reduced_annotations.json'
path_to_json = './sample_data/train_reduced_annotations.json'


In [38]:
#!pip install segment-anything
! pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
#! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth &> /dev/null

#model_type = 'vit_b'
#checkpoint = './sam_vit_b_01ec64.pth'

model_type = 'vit_h'
checkpoint = './sam_vit_h_4b8939.pth'

In [40]:
from segment_anything import SamPredictor, sam_model_registry, utils
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
import os
import sys

# Loading the model
#sam = sam_model_registry[model_type](checkpoint=checkpoint)
#predictor = SamPredictor(sam)
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)



In [41]:
import os
import json
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision.transforms import ToTensor, Normalize, Compose, Resize

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, json_path, transform=None):
        self.root_dir = root_dir
        self.json_path = json_path
        self.transform = transform
        self.data = self.load_json()

    def load_json(self):
        with open(self.json_path) as json_file:
            data = json.load(json_file)
        return data

    def __len__(self):
        return len(self.data['images'])

    def preprocess(self, x):
        x = Resize((224, 224))(x)
        x = ToTensor()(x)
        x = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
        return x

    def extract_masks_bboxes_from_json(self, annotation, image_info):
        image_id = image_info["id"]
        masks = []
        bboxes = []

        # Get the 'width' and 'height' from image_info
        width = image_info["width"]
        height = image_info["height"]

        segmentation = annotation.get("segmentation")
        bbox = annotation.get("bbox")

        if segmentation is not None:
            mask = Image.new('L', (width, height), 0)
            draw = ImageDraw.Draw(mask)

            for segment in segmentation:
                flattened_segment = [int(coord) for coord in segment]
                coordinates = [(flattened_segment[i + 1], flattened_segment[i]) for i in range(0, len(flattened_segment), 2)]
                draw.polygon(coordinates, outline=255, fill=255)

            masks.append(mask)

        if bbox is not None and len(bbox) == 4:
            x, y, w, h = bbox
            x1, y1, x2, y2 = int(x), int(y), int(x + w), int(y + h)
            bboxes.append((x1, y1, x2, y2))

        return masks, bboxes

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        while idx < len(self.data['images']):
            image_info = self.data['images'][idx]
            image_filename = image_info["file_name"]
            img_name = os.path.join(self.root_dir, image_filename)

            print("Processing item:", idx)
            print("Image filename:", img_name)

            if not os.path.exists(img_name):
                # If the image file is missing, skip this image and move to the next one
                print("Image file not found. Skipping...")
                idx += 1
                continue

            image = Image.open(img_name)

            if self.transform:
                image = self.transform(image)

            # Preprocess the image
            image = self.preprocess(image)

            annotations_for_image = [ann for ann in self.data['annotations'] if ann['image_id'] == image_info['id']]

            masks = []
            bboxes = []

            for annotation in annotations_for_image:
                mask, bbox = self.extract_masks_bboxes_from_json(annotation, image_info)
                masks.extend(mask)
                bboxes.extend(bbox)

            # Convert masks to binary mask (optional)
            binary_masks = [torch.tensor(np.array(mask) / 255, dtype=torch.float32) for mask in masks]

            if image is not None and binary_masks and bboxes:
                return {'image': image, 'masks': binary_masks, 'bboxes': bboxes}

            # If any of the data is missing or invalid, move to the next image
            print("Invalid data. Moving to the next image...")
            idx += 1

        # If we reach the end of the dataset, return an empty dictionary to signal the end of the epoch
        return {}

    @staticmethod
    def collate_fn(batch):
        # Remove None entries from the batch (if any)
        batch = [item for item in batch if item]

        images = [item['image'] for item in batch]
        masks = [item['masks'] for item in batch]
        bboxes = [item['bboxes'] for item in batch]

        return {'images': images, 'masks': masks, 'bboxes': bboxes}


In [36]:
# Load custom dataset
custom_dataset = CustomDataset(root_dir=path_to_images, json_path=path_to_json)

# Create a DataLoader
batch_size = 4  # Set your desired batch size
dataloader = DataLoader(custom_dataset, batch_size=batch_size, collate_fn=CustomDataset.collate_fn)

# Set device, better to use A100
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Training loop
num_epochs = 2  # Set the desired number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        input_images, gt_binary_masks, bboxes = batch['images'], batch['masks'], batch['bboxes']

        # Move each input image tensor to the desired device
        input_images = [image.to(device) for image in input_images]

        # Move each binary mask tensor to the desired device
        for image_masks in gt_binary_masks:
            image_masks = [mask.to(device) for mask in image_masks]

        # Here, input_images is a list of tensors, each corresponding to a batch of images.
        # gt_binary_masks is a list of lists, where each inner list contains the masks for each image in the batch.

        # You can then use these tensors for further processing in your training loop.
        # For example:
        #for image_tensor in input_images:
        #    print("Input Image Shape:", image_tensor.shape)

        #for image_masks in gt_binary_masks:
        #    for mask_tensor in image_masks:
        #        print("Binary Mask Shape:", mask_tensor.shape)

        # Rest of your training code goes here...




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask Shape: torch.Size([1628, 3668])
Binary Mask

In [None]:
# Load the SAM model
#sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')


# Set up optimizer
#optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())

# Set up loss function
#loss_fn = torch.nn.MSELoss()

# Load custom dataset
custom_dataset = CustomDataset(root_dir=path_to_images, json_path=path_to_json)

# Create a DataLoader
batch_size = 4  # Set your desired batch size
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

# Set device, better to use A100
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


# Training loop
num_epochs = 2  # Set the desired number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        input_image, gt_binary_masks, bboxes = batch['image'], batch['masks'], batch['bboxes']
        input_image, gt_binary_masks = input_image.to(device), gt_binary_masks.to(device)
        print (input_image)

'''
                # Image encoding
        with torch.no_grad():
            image_embedding = sam_model.image_encoder(input_image)

        # Prompt encoding
        with torch.no_grad():
            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None, boxes=bboxes, masks=None)

        # Mask decoding
        low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        # Postprocessing (adjust input_size and original_image_size accordingly)
        upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

        # Generate binary mask
        binary_mask = F.normalize(F.threshold(upscaled_masks, 0.0, 0)).to(device)

        # Update running_loss for monitoring
        running_loss += loss.item()

    # Print the average loss for the epoch
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")

# Save the fine-tuned model
torch.save(sam_model.state_dict(), 'fine_tuned_sam.pth')

''


SyntaxError: ignored

In [23]:
path_to_json = './sample_data/train_reduced_annotations.json'
import json

# Replace 'path_to_json' with the actual path to your JSON file
with open(path_to_json, 'r') as json_file:
    data = json.load(json_file)

# Print the names of the image files
for image_info in data['images']:
    image_filename = image_info["file_name"]
    print(image_filename)



BATCH_d07_img_880.jpg
BATCH_d07_img_4550.jpg
BATCH_d07_img_410.jpg
BATCH_d07_img_6030.jpg
BATCH_d07_img_3470.jpg
BATCH_d07_img_6970.jpg
BATCH_d07_img_5110.jpg
BATCH_d07_img_1250.jpg
BATCH_d07_img_450.jpg
BATCH_d07_img_2760.jpg
BATCH_d07_img_2340.jpg
BATCH_d07_img_610.jpg
BATCH_d07_img_7260.jpg
BATCH_d07_img_4370.jpg
BATCH_d07_img_1760.jpg
BATCH_d07_img_4730.jpg
BATCH_d07_img_1700.jpg
BATCH_d07_img_4670.jpg
BATCH_d07_img_1390.jpg
BATCH_d07_img_360.jpg
BATCH_d07_img_600.jpg
BATCH_d07_img_1120.jpg
BATCH_d07_img_510.jpg
BATCH_d07_img_6520.jpg
BATCH_d07_img_2670.jpg
BATCH_d07_img_2500.jpg
BATCH_d07_img_2820.jpg
BATCH_d07_img_3330.jpg
BATCH_d07_img_6230.jpg
BATCH_d07_img_5020.jpg
BATCH_d07_img_4430.jpg
BATCH_d07_img_4720.jpg
BATCH_d07_img_2220.jpg
BATCH_d07_img_1010.jpg
BATCH_d07_img_4880.jpg
BATCH_d07_img_5380.jpg
BATCH_d07_img_3860.jpg
BATCH_d07_img_5470.jpg
BATCH_d07_img_5780.jpg
BATCH_d07_img_5960.jpg
BATCH_d07_img_320.jpg
BATCH_d07_img_5400.jpg
BATCH_d07_img_1450.jpg
BATCH_d07_img_1830.

In [20]:
json_data = custom_dataset.load_json()
print(json_data['images'][:5])  # Print the first 5 images

[{'id': 3, 'width': 3840, 'height': 2160, 'file_name': 'BATCH_d07_img_880.jpg', 'license': None, 'flickr_url': 'https://live.staticflickr.com/65535/50678072893_5eb0e82526_o.jpg', 'coco_url': None, 'date_captured': None, 'flickr_640_url': None}, {'id': 4, 'width': 3840, 'height': 2160, 'file_name': 'BATCH_d07_img_4550.jpg', 'license': None, 'flickr_url': 'https://live.staticflickr.com/65535/50678066013_e7dcd2b895_o.jpg', 'coco_url': None, 'date_captured': None, 'flickr_640_url': None}, {'id': 5, 'width': 3840, 'height': 2160, 'file_name': 'BATCH_d07_img_410.jpg', 'license': None, 'flickr_url': 'https://live.staticflickr.com/65535/50678074913_5e7abe940a_o.jpg', 'coco_url': None, 'date_captured': None, 'flickr_640_url': None}, {'id': 6, 'width': 3840, 'height': 2160, 'file_name': 'BATCH_d07_img_6030.jpg', 'license': None, 'flickr_url': 'https://live.staticflickr.com/65535/50678894982_9e8f25972d_o.jpg', 'coco_url': None, 'date_captured': None, 'flickr_640_url': None}, {'id': 7, 'width': 38

In [27]:
ls ./sample_data/images

batch_01_frame_24.jpg   batch_s02_img_102.jpg  batch_s02_img_35.jpg
batch_01_frame_35.jpg   batch_s02_img_107.jpg  batch_s02_img_36.jpg
batch_01_frame_9.jpg    batch_s02_img_108.jpg  batch_s02_img_41.jpg
batch_02_img_3160.jpg   batch_s02_img_10.jpg   batch_s02_img_50.jpg
batch_02_img_720.jpg    batch_s02_img_111.jpg  batch_s02_img_56.jpg
batch_03_img_2160.jpg   batch_s02_img_113.jpg  batch_s02_img_60.jpg
batch_03_img_2180.jpg   batch_s02_img_116.jpg  batch_s02_img_64.jpg
batch_03_img_2720.jpg   batch_s02_img_118.jpg  batch_s02_img_66.jpg
batch_03_img_2920.jpg   batch_s02_img_11.jpg   batch_s02_img_85.jpg
batch_03_img_3120.jpg   batch_s02_img_125.jpg  batch_s02_img_8.jpg
batch_04_img_1400.jpg   batch_s02_img_127.jpg  batch_s02_img_91.jpg
batch_04_img_1620.jpg   batch_s02_img_132.jpg  batch_s02_img_93.jpg
batch_04_img_1780.jpg   batch_s02_img_136.jpg  batch_s02_img_99.jpg
batch_04_img_2780.jpg   batch_s02_img_138.jpg  BATCH_s04_img_1160.jpg
batch_04_img_4140.jpg   batch_s02_img_139.jpg  