In [1]:
# Mount/connect my google drive
from google.colab import drive
drive.mount('/content/drive')

path_to_images = '/content/drive/MyDrive/Data/UAVVaste/images/'
#jpg_image_names = sorted([(path_to_images + filename) for filename in os.listdir(path_to_images) if filename.endswith(".jpg")])

##change path
path_to_masks = '/content/drive/MyDrive/Data/UAVVaste/masks/pixel_masks_rgb/'
#npz_mask_names = sorted([(path_to_npz_mask + filename) for filename in os.listdir(path_to_npz_mask) if filename.endswith(".npz")])


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
#!pip install segment-anything
! pip install git+https://github.com/facebookresearch/segment-anything.git &> /dev/null
#! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth &> /dev/null
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth &> /dev/null

#model_type = 'vit_b'
#checkpoint = './sam_vit_b_01ec64.pth'

model_type = 'vit_h'
checkpoint = './sam_vit_h_4b8939.pth'

In [5]:
from segment_anything import SamPredictor, sam_model_registry
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
import os
import sys

# Loading the model
#sam = sam_model_registry[model_type](checkpoint=checkpoint)
#predictor = SamPredictor(sam)
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)

# Define dataset with masks
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mask_dir, transform=None):
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.images[idx])
        mask_name = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_name)
        mask = Image.open(mask_name)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask




# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)), # Resize to the size a model expects
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalization values for pre-trained PyTorch models
])

# Load custom dataset
dataset = CustomDataset(root_dir=path_to_images, mask_dir=path_to_masks, transform=transform)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Fine-tuning the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#predictor.sam.to(device)
#predictor.sam.train()
#predictor.to(device)
#predictor.train()
sam_model.to(device)
sam_model.train();

# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(sam_model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        #outputs = predictor.sam(inputs)
        outputs = sam_model(inputs, multimask_output=None)

        #outputs = sam_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

IndexError: ignored

In [3]:
from PIL import Image
import os

image_dir = '/content/drive/MyDrive/Data/UAVVaste/images/'
mask_dir = '/content/drive/MyDrive/Data/UAVVaste/masks/pixel_masks_rgb/'

image_files = os.listdir(image_dir)
mask_files = os.listdir(mask_dir)

for image_file, mask_file in zip(image_files, mask_files):
    image_path = os.path.join(image_dir, image_file)
    mask_path = os.path.join(mask_dir, mask_file)

    image = Image.open(image_path)
    mask = Image.open(mask_path)

    image_width, image_height = image.size
    mask_width, mask_height = mask.size

    if image_width != mask_width or image_height != mask_height:
        print(f"Image: {image_file} - Dimensions: {image_width} x {image_height}")
        print(f"Mask: {mask_file} - Dimensions: {mask_width} x {mask_height}")
        print("Image and mask dimensions do not match.")
    else:
        print(f"Image: {image_file} - Dimensions: {image_width} x {image_height}")
        print(f"Mask: {mask_file} - Dimensions: {mask_width} x {mask_height}")
        print("Image and mask dimensions match.")

    image_dtype = image.getbands()[0]
    mask_dtype = mask.getbands()[0]

    if image_dtype != mask_dtype:
        print(f"Image: {image_file} - Data Type: {image_dtype}")
        print(f"Mask: {mask_file} - Data Type: {mask_dtype}")
        print("Image and mask data types do not match.")
    else:
        print(f"Image: {image_file} - Data Type: {image_dtype}")
        print(f"Mask: {mask_file} - Data Type: {mask_dtype}")
        print("Image and mask data types match.")


Image: BATCH_d07_img_6400.jpg - Dimensions: 3840 x 2160
Mask: BATCH_d07_img_2380.jpg - Dimensions: 3840 x 2160
Image and mask dimensions match.
Image: BATCH_d07_img_6400.jpg - Data Type: R
Mask: BATCH_d07_img_2380.jpg - Data Type: R
Image and mask data types match.
Image: BATCH_d07_img_580.jpg - Dimensions: 3840 x 2160
Mask: BATCH_d07_img_580.jpg - Dimensions: 3840 x 2160
Image and mask dimensions match.
Image: BATCH_d07_img_580.jpg - Data Type: R
Mask: BATCH_d07_img_580.jpg - Data Type: R
Image and mask data types match.
Image: BATCH_d07_img_1250.jpg - Dimensions: 3840 x 2160
Mask: BATCH_d07_img_880.jpg - Dimensions: 3840 x 2160
Image and mask dimensions match.
Image: BATCH_d07_img_1250.jpg - Data Type: R
Mask: BATCH_d07_img_880.jpg - Data Type: R
Image and mask data types match.
Image: BATCH_d07_img_2760.jpg - Dimensions: 3840 x 2160
Mask: BATCH_d07_img_4550.jpg - Dimensions: 3840 x 2160
Image and mask dimensions match.
Image: BATCH_d07_img_2760.jpg - Data Type: R
Mask: BATCH_d07_img

UnidentifiedImageError: ignored