In [1]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image

In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [3]:
def load_images_from_folder(folder):
    """Load all images from the specified folder"""
    images = []
    for filename in sorted(os.listdir(folder)):  # Ensure images are read in order
        if filename.endswith((".jpg", ".png", ".jpeg")):  # Only read image files
            image = Image.open(os.path.join(folder, filename))  # Open the image file
            image = np.array(image.convert("RGB"))  # Convert to RGB format and turn into an array
            if image is not None:
                images.append(image)  # Add image to the list
    return images


In [4]:
images = load_images_from_folder("rubiks_cube")
len(images)

6

In [5]:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

# mask_generator = SAM2AutomaticMaskGenerator(sam2)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=32,
    points_per_batch=64,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92 ,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.6,
    crop_n_points_downscale_factor=3,
    min_mask_region_area=25.0,
    use_m2m=False,
)

In [6]:
def select_masks(masks):
    # Set some filtering parameters
    min_area = 20000  # Minimum area of a Rubik's cube color block; may need adjustment
    max_area = 60000  # Maximum area of a Rubik's cube color block; may need adjustment
    aspect_ratio_threshold = 0.2  # Aspect ratio difference threshold (e.g., 0.2 means aspect ratio should be between 0.8 and 1.2)
    
    selected_masks = []
    
    # Iterate through all masks and select those matching Rubik's cube color block characteristics
    for mask in masks:
        area = mask['area']
        bbox = mask['bbox']
        width, height = bbox[2], bbox[3]
        aspect_ratio = width / height
    
        # Select masks with area in the reasonable range and close to a 1:1 aspect ratio
        if min_area < area < max_area and (1 - aspect_ratio_threshold) < aspect_ratio < (1 + aspect_ratio_threshold):
            selected_masks.append(mask)
    
    # print(f"Number of masks initially selected as matching Rubik's cube color blocks: {len(selected_masks)}")
    return selected_masks


In [7]:
def calculate_intersection_area(mask1, mask2):
    """Calculate the intersection area between two masks"""
    intersection = cv2.bitwise_and(mask1['segmentation'].astype(np.uint8), mask2['segmentation'].astype(np.uint8))
    return np.sum(intersection > 0)


In [8]:
def filter_masks(selected_masks, image):
    final_selected_masks = []
    # Iterate over the selected masks and use area comparison for filtering
    for i, mask1 in enumerate(selected_masks):
        is_nested = False  # Flag to track if this mask is nested within a larger mask
        for j, mask2 in enumerate(selected_masks):
            if i == j:
                continue  # Skip comparison with the same mask
            
            # Check if mask1 is nested within mask2
            if mask1['area'] < mask2['area']:  # Only consider smaller masks
                intersection_area = calculate_intersection_area(mask1, mask2)
                
                # If the intersection area is large enough and close to mask1's area, mark it as nested
                if intersection_area > 0.9 * mask1['area']:
                    is_nested = True  # Mark as nested
                    break
        
        # Retain only those smaller masks that are not nested
        if not is_nested:
            final_selected_masks.append(mask1)
    # print(f"Number of masks after filtering that match Rubik's cube color blocks: {len(final_selected_masks)}")
    return final_selected_masks


In [9]:
def sort_masks_by_position(masks):
    """Sort masks from left to right, top to bottom"""
    if not masks:
        return []  # Return an empty list if no masks are provided

    # First, sort by y-coordinate (vertical position), then by x-coordinate (horizontal position) within each row
    masks = sorted(masks, key=lambda mask: mask['bbox'][1])  # Sort by y-coordinate (vertical)

    sorted_rows = []
    current_row = []
    last_y = masks[0]['bbox'][1]  # Initial y-coordinate
    row_height = masks[0]['bbox'][3]  # Height of the first row, used to determine row change

    for mask in masks:
        current_y = mask['bbox'][1]
        current_height = mask['bbox'][3]
        
        # Dynamically determine if a new row has started based on mask height
        if abs(current_y - last_y) > row_height / 2 and current_row:
            # For a new row, sort the current row by x-coordinate (horizontal) first
            current_row = sorted(current_row, key=lambda m: m['bbox'][0])
            sorted_rows.extend(current_row)
            current_row = []
            row_height = current_height  # Update height for the new row

        current_row.append(mask)
        last_y = current_y

    # Sort the last row
    if current_row:
        current_row = sorted(current_row, key=lambda m: m['bbox'][0])
        sorted_rows.extend(current_row)

    return sorted_rows


In [10]:
def get_dominant_color_rgb(image, mask, index):
    """Obtain the dominant color of the entire mask area in RGB color space and display pre- and post-processing results"""
    mask_uint8 = mask['segmentation'].astype(np.uint8)

    # Apply mask to the image
    masked_image_rgb = cv2.bitwise_and(image, image, mask=mask_uint8)

    # Extract RGB channels within the masked area
    r_channel = masked_image_rgb[:, :, 0][mask_uint8 > 0]
    g_channel = masked_image_rgb[:, :, 1][mask_uint8 > 0]
    b_channel = masked_image_rgb[:, :, 2][mask_uint8 > 0]

    # Calculate the mean of each channel
    r_mean = np.mean(r_channel)
    g_mean = np.mean(g_channel)
    b_mean = np.mean(b_channel)

    return (r_mean, g_mean, b_mean)

def get_color_letter_from_rgb(rgb):
    """Map RGB values to corresponding color initials"""
    r, g, b = rgb

    # Define specific ranges to determine color
    if 90 <= r <= 170 and 0 <= g <= 60 and 0 <= b <= 60:
        return "r"  # Red
    elif 0 <= r <= 100 and 70 <= g <= 255 and 0 <= b <= 100:
        return "g"  # Green
    elif 0 <= r <= 100 and 0 <= g <= 100 and 90 <= b <= 255:
        return "b"  # Blue
    elif 120 <= r <= 255 and 120 <= g <= 255 and 0 <= b <= 80:
        return "y"  # Yellow
    elif 120 <= r <= 255 and 125 <= g <= 255 and 140 <= b <= 255:
        return "w"  # White
    elif 100 <= r <= 255 and 50 <= g <= 200 and 20 <= b <= 100:
        return "o"  # Orange
    else:
        return "?"  # Unknow

def read_cube_face_rgb(image, masks):
    """Read the color configuration of one face of a Rubik's cube using RGB color space and return the initials of each color"""
    # Assume that 9 color block masks of the cube face have been retained after filtering
    sorted_masks = sort_masks_by_position(masks)  # Assume this function is available for sorting

    face_config = ""

    # Iterate over the sorted masks and read the color of each block
    for i, mask in enumerate(sorted_masks[:9]):  # Only process the first 9 masks
        dominant_rgb = get_dominant_color_rgb(image, mask, i)
        color_letter = get_color_letter_from_rgb(dominant_rgb)
        face_config += color_letter

    return face_config

In [11]:
def process_image(image):
    """Process a single image and return the color block configuration"""
    # Here, call your color and mask filtering logic
    masks = mask_generator.generate(image)
    # selected_masks = segment_image(image)  # Call your function for segmentation
    selected_masks = select_masks(masks)
    final_selected_masks = filter_masks(selected_masks, image)  # Filter valid masks
    sorted_masks = sort_masks_by_position(final_selected_masks)  # Sort masks by position

    # face_config = ""
    face_config = read_cube_face_rgb(image, final_selected_masks)
    return face_config


In [12]:
def generate_cube_configuration(images):
    """Generate the configuration for the six faces of a Rubik's cube"""
    cube_config = []
    for i, image in enumerate(images):
        face_config = process_image(image)
        print(f"Configuration for Face {i+1}: {face_config}")
        cube_config.append(face_config)
    return cube_config


In [13]:
folder = "rubiks_cube"  # Specify the folder path
images = load_images_from_folder(folder)

if len(images) == 6:
    cube_configuration = generate_cube_configuration(images)
    print("The configuration for the six faces of the Rubik's cube:")
    for i, config in enumerate(cube_configuration):
        print(f"Face {i+1}: {config}")
else:
    print(f"Error: Found {len(images)} images, but 6 are required.")


Configuration for Face 1: rrryyoggo
Configuration for Face 2: wwwggwyyg
Configuration for Face 3: rogrryybb
Configuration for Face 4: bwrbwowyo
Configuration for Face 5: wbybbwboo
Configuration for Face 6: ggbgoryro
The configuration for the six faces of the Rubik's cube:
Face 1: rrryyoggo
Face 2: wwwggwyyg
Face 3: rogrryybb
Face 4: bwrbwowyo
Face 5: wbybbwboo
Face 6: ggbgoryro


In [14]:
cube_configuration

['rrryyoggo', 'wwwggwyyg', 'rogrryybb', 'bwrbwowyo', 'wbybbwboo', 'ggbgoryro']

In [15]:
color_order = ['y', 'g', 'r', 'w', 'b', 'o']  # Define the order of colors

# Extract the center color block (5th character) of each face and pair it with the full face configuration
config_with_centers = [(face[4], face) for face in cube_configuration]

# Sort faces based on the order of center colors
sorted_config = sorted(config_with_centers, key=lambda x: color_order.index(x[0]))

# Extract the sorted face configurations
sorted_faces = [face for _, face in sorted_config]

# Concatenate the sorted faces into a final string
config_str = ''.join(sorted_faces)

# Output the final result
print(f"Concatenated string after sorting: {config_str}")


Concatenated string after sorting: rrryyoggowwwggwyygrogrryybbbwrbwowyowbybbwbooggbgoryro


In [16]:
center_color_map = {
        'w': 'D',  
        'r': 'F',  
        'g': 'R',  
        'y': 'U',  
        'o': 'B',  
        'b': 'L'   
    }

In [17]:
new_config_str = ""
for c in config_str:
    if c in center_color_map:
        new_config_str += center_color_map[c]
print(f"Converted Rubik's cube state: {new_config_str}")


Converted Rubik's cube state: FFFUUBRRBDDDRRDUURFBRFFUULLLDFLDBDUBDLULLDLBBRRLRBFUFB


In [18]:
import kociemba

kociemba.solve(new_config_str)

"B' R D F B' R2 D2 B2 U2 D2 L2 U2 F2 R2"