In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python torch torchvision
!pip install torch torchvision torchaudio


In [None]:
import requests

# URL for the model file
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"

# Local path to save the model
model_path = "sam_vit_b_01ec64.pth"

# Download the model
response = requests.get(url)
with open(model_path, "wb") as f:
    f.write(response.content)

print("Model downloaded and saved to", model_path)

In [None]:
!apt-get update
!apt-get install -y libgl1-mesa-glx


In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import SamPredictor, sam_model_registry

# Load the image
img_path = '/content/0002434fcecc427f805e7e8e4e63ad76.jpg'
image = cv2.imread(img_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load the SAM model (using a smaller model like ViT-B)
model_type = "vit_b"
sam_checkpoint = "sam_vit_b_01ec64.pth"  # Ensure you use the right path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the SAM model
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Create the predictor
predictor = SamPredictor(sam)

# Set the image in the predictor
predictor.set_image(image_rgb)

# Add more points to cover different regions of the sweater, especially near the edges
input_points = np.array([[230, 140], [230, 300], [170, 250], [290, 250], [230, 380]])
input_labels = np.array([1, 1, 1, 1, 1])  # All points indicate foreground (t-shirt)

input_labels = np.array([1, 1, 1, 1, 1])  # Indicating all points belong to the foreground

# Predict the mask for the sweater
with torch.no_grad():
    masks, _, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=False
    )

# Apply the mask to the image (focus on the sweater)
mask = masks[0]  # Single mask
segmented_image = cv2.bitwise_and(image, image, mask=mask.astype('uint8'))

# Save or display the result
cv2.imwrite("segmented_tshirt.png", segmented_image)

# Display the segmented image
plt.imshow(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
plt.title("Segmented T-shirt")
plt.axis('off')
plt.show()


In [None]:
import torch
import numpy as np
import cv2
import os
from segment_anything import SamPredictor, sam_model_registry
import matplotlib.pyplot as plt

# Load the SAM model
model_type = "vit_b"
sam_checkpoint = '/content/sam_vit_b_01ec64.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the SAM model
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Create the predictor # Moved this line outside the function
predictor = SamPredictor(sam)

# Directory containing the images to be segmented
input_directory = '/content/drive/MyDrive/AAA'
output_directory = '/content/drive/MyDrive/AAB'

# Ensure output directory exists
if not os.path.exists(output_directory):
    os.makedirs(output_directory)

# Function to calculate standardized points based on image size
def get_standardized_points(image_shape):
    height, width, _ = image_shape
    points = ([[227,98], [230,507], [10,500], [269,964], [553,431]])
    return np.array(points)

# Function to segment an image
def segment_image(image_path, output_path):
    # Load the image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Set the image in the predictor
    predictor.set_image(image_rgb)

    # Calculate standardized points based on image size
    input_points = get_standardized_points(image.shape)
    input_labels = np.array([1] * len(input_points))  # All points belong to the foreground

    # Predict the mask
    with torch.no_grad():
        masks, _, _ = predictor.predict(
            point_coords=input_points,
            point_labels=input_labels,
            multimask_output=False
        )

    # Apply the mask to the image
    mask = masks[0]  # Single mask
    segmented_image = cv2.bitwise_and(image, image, mask=mask.astype('uint8'))

    # Save the segmented image
    cv2.imwrite(output_path, segmented_image)

# List of images to process
image_files = [f for f in os.listdir(input_directory) if f.endswith(('.png', '.jpg', '.jpeg'))]

# Loop through each image and apply segmentation
for image_file in image_files:
    input_path = os.path.join(input_directory, image_file)
    output_path = os.path.join(output_directory, f'segmented_{image_file}')

    # Segment the image
    segment_image(input_path, output_path)

print("Segmentation completed for all images.")

