# SAM2 + LaMa: Smart Furniture Restoration (Error-Free Version)

This notebook fixes the issue where SAM2 removes furniture parts (pillows/cushions).

**If you get any import errors:**
1. Go to Runtime → Restart runtime
2. Run all cells again

This usually fixes all dependency issues.

In [None]:
# @title 1️⃣ Install Dependencies { display-mode: "form" }
# @markdown This cell installs BigLaMa and SAM2 with all required dependencies

# Cell 1: Install with compatible versions
!pip install -q --upgrade pip

# CRITICAL: Install compatible NumPy version first (before anything else)
!pip uninstall -y numpy
!pip install -q numpy==1.24.3

# Clone LaMa repository
!git clone https://github.com/advimman/lama.git

# Install PyTorch with specific versions for compatibility
!pip uninstall -y torch torchvision torchaudio
!pip install -q torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

# Install core dependencies with specific versions
!pip install -q opencv-python==4.8.1.78 matplotlib==3.7.2 scipy==1.10.1
!pip install -q pyyaml==6.0.1 tqdm easydict scikit-image==0.21.0 scikit-learn==1.3.0
!pip install -q joblib pandas==2.0.3 packaging omegaconf==2.3.0

# Install albumentations 0.5.2 for LaMa compatibility
!pip uninstall -y albumentations albucore
!pip install -q albumentations==0.5.2

# Install SAM2 without dependencies
!pip install -q --no-deps git+https://github.com/facebookresearch/sam2.git
# Install SAM2 minimal dependencies
!pip install -q hydra-core==1.3.2 iopath==0.1.10 pillow==9.5.0 submitit==1.5.1

# Install LaMa dependencies
%cd lama
!pip install -q pytorch-lightning==1.2.9
!pip install -q kornia==0.6.7
!pip install -q webdataset
!pip install -q wldhx.yadisk-direct

# Download BigLaMa model
print("Downloading BigLaMa model from HuggingFace...")
!curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
!unzip -q big-lama.zip
!rm big-lama.zip

%cd ..

# Install UI components
!pip install -q ipywidgets ipycanvas

# Final check - ensure NumPy didn't get upgraded
!pip install -q numpy==1.24.3

# Verify installations
print("\n✅ Verifying installations:")
!python -c "import numpy; print(f'NumPy: {numpy.__version__}')"
!python -c "import torch; print(f'PyTorch: {torch.__version__}')"
!python -c "import torchvision; print(f'Torchvision: {torchvision.__version__}')"
!python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
!pip show albumentations | grep Version

print("\n✅ Installation complete!")
print("BigLaMa is ready to use")
print("\n⚠️ IMPORTANT: Restart runtime now!")
print("Runtime → Restart runtime")
print("Then run all cells in order")

In [None]:
# @title 1.5️⃣ Fix Compatibility Issues { display-mode: "form" }
# @markdown This cell patches the aug.py file to fix albumentations compatibility

# Cell 1.5: Fix DualIAATransform import issue
print("Fixing albumentations compatibility...")

import os

# Create directory if it doesn't exist
os.makedirs('./lama/saicinpainting/training/data', exist_ok=True)

# Create a patch for the aug.py file
patch_content = '''from albumentations.core.transforms_interface import DualTransform
import imgaug.augmenters as iaa
import numpy as np

# Create DualIAATransform if it doesn't exist
try:
    from albumentations import DualIAATransform, to_tuple
except ImportError:
    from albumentations.core.transforms_interface import to_tuple
    
    class DualIAATransform(DualTransform):
        """Base class for IAA transforms."""
        def __init__(self, always_apply=False, p=0.5):
            super(DualIAATransform, self).__init__(always_apply, p)

class IAAAffine2(DualIAATransform):
    """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
    via affine transformations.

    Note:
        This class introduce interpolation artifacts to mask if it has values other than {0;1}

    Args:
        scale (float, tuple of float): Scaling factor to use, where 1.0 represents no change and 0.5 is
            zoomed out to 50 percent of the original size.
            * If a single float, then that value will be used for all images.
            * If a tuple (a, b), then a value will be uniformly sampled per image from the interval [a, b].
            * If a list, then a random value will be sampled from that list per image.
            * If a StochasticParameter, then from that parameter per image.
        translate_percent (float, tuple of float dict-like):
            Translation as a fraction of the image height/width (x-translation, y-translation),
            where 0 denotes "no change" and 0.5 denotes "half of the axis size".
            * If None, then no translation will be performed.
            * If a single float, then that value will be used for all images.
            * If a tuple (a, b), then a value will be uniformly sampled per image from the interval [a, b].
            * If a list, then a random value will be sampled from that list per image.
            * If a StochasticParameter, then from that parameter per image.
            * If a dictionary, then it may contain the keys x and/or y. Each of these keys may have the
              same datatypes as described above. Using a dictionary allows to set different values for the
              two axis and sampling will then happen independently per axis, resulting in samples that differ
              between the axes.
        translate_px (int, tuple of int dict-like):
            Translation in pixels.
        rotate (float, tuple of float):
            Rotation in degrees (-360 to 360), where 0 denotes "no change" and 45 denotes a rotation of
            45 degrees in clockwise direction.
        shear (float, tuple of float):
            Shear in degrees (-360 to 360), where 0 denotes "no change".
        order (int, list of int, str, list of str, iap.ALL): Interpolation order to use. Same meaning as in skimage:
            * 0: Nearest-neighbor
            * 1: Bi-linear (default)
            * 2: Bi-quadratic (not recommended by skimage)
            * 3: Bi-cubic
            * 4: Bi-quartic
            * 5: Bi-quintic
        cval (float, tuple of float):
            The constant value to use when filling in newly created pixels.
        mode (string, list of string): Method to use when filling in newly created pixels.
            Same meaning as in skimage (and numpy.pad):
            * 'constant': Pads with a constant value
            * 'edge': Pads with the edge values of array
            * 'symmetric': Pads with the reflection of the vector mirrored along the edge of the array.
            * 'reflect': Pads with the reflection of the vector mirrored on the first and last values of
              the vector along each axis.
            * 'wrap': Pads with the wrap of the vector along the axis. The first values are used to pad
              the end and the end values are used to pad the beginning.
    Targets:
        image, mask
    """

    def __init__(
        self,
        scale=1.0,
        translate_percent=None,
        translate_px=None,
        rotate=0.0,
        shear=0.0,
        order=1,
        cval=0,
        mode="reflect",
        always_apply=False,
        p=0.5,
    ):
        super(IAAAffine2, self).__init__(always_apply, p)
        self.scale = scale
        self.translate_percent = translate_percent
        self.translate_px = translate_px
        self.rotate = rotate
        self.shear = shear
        self.order = order
        self.cval = cval
        self.mode = mode

    def apply(self, img, matrix, **params):
        import cv2
        import numpy as np
        return cv2.warpAffine(img, matrix[:2], (img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)

    def apply_to_mask(self, img, matrix, **params):
        import cv2
        import numpy as np
        return cv2.warpAffine(img, matrix[:2], (img.shape[1], img.shape[0]), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT_101)
    
    def get_params(self):
        return {"matrix": np.eye(3)}

    def get_transform_init_args_names(self):
        return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")


class IAAPerspective2(DualIAATransform):
    """Perform a random four point perspective transform of the input.

    Note:
        This class introduce interpolation artifacts to mask if it has values other than {0;1}

    Args:
        scale (float, tuple of float): Standard deviation of the normal distributions. These are used to sample
            the random distances of the subimage's corners from the full image's corners. The sampled values
            reflect percentage values (with respect to image height/width). Recommended values are in the
            range 0.0 to 0.1.
        keep_size (bool): Whether to resize image's back to their original size after applying the
            perspective transform. If set to False, the resulting images may end up having different shapes
            and will always be a list, never an array.
    Targets:
        image, mask
    """

    def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5):
        super(IAAPerspective2, self).__init__(always_apply, p)
        self.scale = scale
        self.keep_size = keep_size

    def apply(self, img, matrix, max_width, max_height, **params):
        import cv2
        return cv2.warpPerspective(img, matrix, (max_width, max_height))

    def apply_to_mask(self, img, matrix, max_width, max_height, **params):
        import cv2
        return cv2.warpPerspective(img, matrix, (max_width, max_height))

    def get_params(self):
        import numpy as np
        return {"matrix": np.eye(3), "max_width": 100, "max_height": 100}

    def get_transform_init_args_names(self):
        return ("scale", "keep_size")
'''

# Write the patched file
with open('./lama/saicinpainting/training/data/aug.py', 'w') as f:
    f.write(patch_content)

print("✅ Created compatibility patch for aug.py")
print("The import error should now be resolved")

In [None]:
# @title 2️⃣ Import Libraries { display-mode: "form" }
# @markdown This cell imports all required libraries for SAM2 and BigLaMa

# Cell 2: Import everything
import warnings
warnings.filterwarnings('ignore')

try:
    import torch
    import numpy as np
    import cv2
    from PIL import Image
    import matplotlib.pyplot as plt
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    # Import LaMa modules
    import sys
    sys.path.append('./lama')
    from saicinpainting.training.trainers import load_checkpoint
    from saicinpainting.evaluation.utils import move_to_device
    from saicinpainting.evaluation.refinement import refine_predict
    import yaml
    from omegaconf import OmegaConf
    from saicinpainting.evaluation.data import pad_img_to_modulo
    
    from scipy import ndimage
    from google.colab import files, output
    output.enable_custom_widget_manager()
    
    print("✅ All imports successful!")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
except ImportError as e:
    print("❌ Import error:", e)
    print("\nPlease:")
    print("1. Go to Runtime → Restart runtime")
    print("2. Run Cell 1 again")
    print("3. Then run this cell")

In [None]:
# @title 3️⃣ Load Models { display-mode: "form" }
# @markdown This cell loads SAM2 for segmentation and BigLaMa for inpainting

# Cell 3: Load models with error handling
print("Loading models...")

# Import required modules
import torch
import yaml
from omegaconf import OmegaConf
import sys
sys.path.append('./lama')
from saicinpainting.training.trainers import load_checkpoint

# Try to load SAM2 with error handling
try:
    # Disable JIT for compatibility
    torch.jit._state.disable()
    
    predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
    print("✅ SAM2 loaded")
except Exception as e:
    print(f"⚠️ SAM2 loading issue: {e}")
    print("Trying alternative loading method...")
    
    # Alternative: Load with different configuration
    try:
        from sam2.build_sam import build_sam2
        from sam2.sam2_image_predictor import SAM2ImagePredictor
        
        # Use the base model instead
        checkpoint = "facebook/sam2-hiera-base-plus"
        predictor = SAM2ImagePredictor.from_pretrained(checkpoint)
        print("✅ SAM2 loaded (base model)")
    except:
        print("❌ Could not load SAM2. You may need to restart runtime.")
        raise

# Load BigLaMa model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup config for BigLaMa
predict_config = OmegaConf.create({
    'model': {
        'path': './lama/big-lama/big-lama_places512_G.pth'
    },
    'refine': True,
    'refiner': {
        'gpu_ids': '0',
        'modulo': 8,
        'n_iters': 15,
        'lr': 0.002,
        'min_side': 512,
        'max_scales': 3,
        'px_budget': 1800000
    },
    'out_key': 'inpainted'
})

# Load model checkpoint
train_config_path = './lama/big-lama/config.yaml'
with open(train_config_path, 'r') as f:
    train_config = OmegaConf.create(yaml.safe_load(f))

train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'

# Initialize BigLaMa
checkpoint_path = predict_config.model.path
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.to(device)
model.eval()

print("✅ BigLaMa loaded")
print("   Model: big-lama_places512_G.pth")
print("   Refinement: Enabled (15 iterations)")

print("\nReady to process images!")

In [None]:
# @title 4️⃣ Upload and Display Image { display-mode: "form" }
# @markdown Upload your furniture image and prepare it for processing

# Cell 4: Upload and process image
print("Upload your furniture image:")
uploaded = files.upload()
filename = list(uploaded.keys())[0]

# Load image
image = Image.open(filename).convert("RGB")
image_np = np.array(image)

plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.title(f"Original Image ({image.size[0]}x{image.size[1]})")
plt.axis('off')
plt.show()

# Set image for SAM2
predictor.set_image(image_np)

In [None]:
# @title 5️⃣ Select Furniture with SAM2 { display-mode: "form" }
# @markdown Click on the furniture to create a segmentation mask

# Cell 5: Simple point selection
print("Enter coordinates for furniture selection:")
print("(You can click on the image above to see coordinates)")

# Get center point as default
center_x = image.size[0] // 2
center_y = image.size[1] // 2

x = int(input(f"X coordinate (default {center_x}): ") or center_x)
y = int(input(f"Y coordinate (default {center_y}): ") or center_y)

# Generate mask
coords = np.array([[x, y]])
labels = np.array([1])

masks, scores, _ = predictor.predict(
    point_coords=coords,
    point_labels=labels,
    multimask_output=False
)

# Get mask
mask = (masks[0] > 0.5).astype(np.uint8) * 255

# Show results
fig, axes = plt.subplots(1, 2, figsize=(15, 8))
axes[0].imshow(image)
axes[0].plot(x, y, 'go', markersize=15)
axes[0].set_title("Click Point")
axes[0].axis('off')

axes[1].imshow(mask, cmap='gray')
axes[1].set_title("SAM2 Mask")
axes[1].axis('off')
plt.show()

In [None]:
# @title 6️⃣ Smart Furniture Restoration Preparation { display-mode: "form" }
# @markdown Detect missing furniture parts and prepare for inpainting

# Cell 6: Smart furniture restoration
print("🔄 Running smart furniture restoration...")

# 1. Detect complete furniture boundary
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (50, 50))
closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
filled_mask = ndimage.binary_fill_holes(closed_mask).astype(np.uint8) * 255

# 2. Find areas to inpaint (missing parts)
missing_parts = cv2.bitwise_and(filled_mask, cv2.bitwise_not(mask))

# 3. Detect white artifacts
bg_removed = image.copy()
bg_array = np.array(bg_removed)
bg_array[mask == 0] = 255  # White background
bg_removed = Image.fromarray(bg_array)

gray = cv2.cvtColor(bg_array, cv2.COLOR_RGB2GRAY)
white_artifacts = ((gray > 240) & (filled_mask > 0)).astype(np.uint8) * 255

# 4. Combine all areas to inpaint
inpaint_mask = cv2.bitwise_or(missing_parts, white_artifacts)
kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
inpaint_mask = cv2.dilate(inpaint_mask, kernel_small, iterations=1)

# Show what we'll inpaint
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
axes[0].imshow(bg_removed)
axes[0].set_title("Background Removed (with holes)")
axes[0].axis('off')

axes[1].imshow(inpaint_mask, cmap='gray')
axes[1].set_title("Areas to Restore")
axes[1].axis('off')

# Overlay
overlay = bg_array.copy()
overlay[inpaint_mask > 0] = [255, 0, 0]
axes[2].imshow(overlay)
axes[2].set_title("Areas to Restore (Red)")
axes[2].axis('off')
plt.show()

print("✅ Smart mask created!")

In [None]:
# @title 7️⃣ Run BigLaMa Inpainting { display-mode: "form" }
# @markdown Restore furniture using BigLaMa with refinement

# Cell 7: Run BigLaMa inpainting
print("🎨 Running BigLaMa inpainting with refinement...")

# Prepare tensors
img_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0) / 255.0
mask_tensor = torch.from_numpy(inpaint_mask).float().unsqueeze(0).unsqueeze(0) / 255.0

# Pad to multiple of 8 for the model
img_tensor, mask_tensor = pad_img_to_modulo(img_tensor, mask_tensor, 8)

# Move to device
batch = {'image': img_tensor.to(device), 'mask': mask_tensor.to(device)}

# Run BigLaMa with refinement
print("Running inference with refinement...")
with torch.no_grad():
    batch = move_to_device(batch, device)
    batch['unpad_to_size'] = [image.size[1], image.size[0]]
    
    # First pass
    batch = model(batch)
    
    # Refinement for best quality
    if predict_config.refine:
        batch = refine_predict(batch, model, **predict_config.refiner)
        cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
    else:
        cur_res = batch['inpainted'][0].permute(1, 2, 0).detach().cpu().numpy()

# Convert back to image
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cur_res[:image.size[1], :image.size[0]]  # Crop to original size
result = Image.fromarray(cur_res)

# Apply furniture mask to remove background
final = Image.new('RGB', result.size, (255, 255, 255))
result_array = np.array(result)
final_array = np.array(final)
final_array[filled_mask > 0] = result_array[filled_mask > 0]
final = Image.fromarray(final_array)

print("✅ BigLaMa inpainting complete!")
print("   - Used 15 refinement iterations")
print("   - Maximum quality output")

# Show results
fig, axes = plt.subplots(1, 3, figsize=(20, 8))

axes[0].imshow(image)
axes[0].set_title("Original", fontsize=16)
axes[0].axis('off')

axes[1].imshow(bg_removed)
axes[1].set_title("SAM2 Result (with holes)", fontsize=16)
axes[1].axis('off')

axes[2].imshow(final)
axes[2].set_title("BigLaMa Restored (Maximum Quality)", fontsize=16)
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# @title 8️⃣ Save and Download Results { display-mode: "form" }
# @markdown Save the restored image and comparison

# Cell 8: Save results
base_name = filename.rsplit('.', 1)[0]

# Save restored image
restored_name = f"{base_name}_restored.png"
final.save(restored_name)
print(f"✅ Saved: {restored_name}")

# Save comparison
comparison = Image.new('RGB', (image.width * 3, image.height))
comparison.paste(image, (0, 0))
comparison.paste(bg_removed, (image.width, 0))
comparison.paste(final, (image.width * 2, 0))

comparison_name = f"{base_name}_comparison.jpg"
comparison.save(comparison_name)
print(f"✅ Saved: {comparison_name}")

# Download
files.download(restored_name)
files.download(comparison_name)

print("\n✅ Smart furniture restoration complete!")
print("   - Restored missing pillows/cushions")
print("   - Removed white artifacts")
print("   - Preserved furniture structure")