# Segment Anything(SAM) Cleaner

## 仮想環境の作成とセットアップ

In [None]:
import os

# 仮想環境のパス設定
venv_path = "/tmp/sac_env"
model_cache_dir = "/tmp/models"

# モデル保存用のディレクトリ作成
os.makedirs(model_cache_dir, exist_ok=True)
os.environ['HUGGINGFACE_HUB_CACHE'] = model_cache_dir

# 仮想環境の作成とライブラリインストール
if not os.path.exists(venv_path):
    print("Creating virtual environment...")
    !python3 -m venv {venv_path}

print("Installing dependencies...")
!{venv_path}/bin/pip install --upgrade pip

# PaperspaceのCUDA環境に合わせたPyTorchを明示的にインストール
!{venv_path}/bin/pip install torch torchvision torchaudio
!{venv_path}/bin/pip install gradio==3.41.2 opencv-python numpy Pillow

# SAM2本体
!{venv_path}/bin/pip install git+https://github.com/facebookresearch/segment-anything-2.git


## アプリケーションの作成

In [None]:
%%writefile sac.py
import os
import shutil
import urllib.request
import torch
import torchvision
import torchaudio
import gradio as gr
from PIL import Image, ImageChops
import numpy as np
import cv2

# --- アプリケーション情報 ---
# バグ修正とUI微調整を反映したVer 1.1.0
VERSION = "1.1.0"

# --- 環境設定 ---
os.environ['MPLBACKEND'] = 'Agg'
WORKDIR = "/tmp/sac"
MODEL_DIR = "/tmp/sac/models"
for d in [WORKDIR, MODEL_DIR]:
    os.makedirs(d, exist_ok=True)

try:
    from sam2.build_sam import build_sam2
    from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
except ImportError:
    print("SAM2 library not found.")

# --- カスタムCSS ---
css = """
.padded-image { height: 600px !important; }
.padded-image .image-container { padding-top: 50px !important; height: calc(100% - 50px) !important; }
.padded-image img { max-height: 100% !important; object-fit: contain !important; }
.version-text { text-align: center; color: gray; font-size: 1.6em; margin-top: 20px; font-weight: bold; }
"""

class SAM2Handler:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator = None
        self.current_model = None
        self.active_image_path = os.path.join(WORKDIR, "original.png")
        self.current_filename = "result.png"

    def load_model(self, model_name, p_side, iou_th, stab_th, min_area):
        target_path = os.path.join(MODEL_DIR, model_name)
        params = (p_side, iou_th, stab_th, min_area)
        if self.current_model == model_name and hasattr(self, 'current_params') and self.current_params == params: return
        urls = {
            "sam2_hiera_tiny.pt": ("sam2_hiera_t.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
            "sam2_hiera_small.pt": ("sam2_hiera_s.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
            "sam2_hiera_base_plus.pt": ("sam2_hiera_b+.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
            "sam2_hiera_large.pt": ("sam2_hiera_l.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt")
        }
        cfg, url = urls[model_name]
        if not os.path.exists(target_path): urllib.request.urlretrieve(url, target_path)
        model = build_sam2(cfg, target_path, device=self.device)
        self.generator = SAM2AutomaticMaskGenerator(model=model, points_per_side=int(p_side), pred_iou_thresh=float(iou_th), stability_score_thresh=float(stab_th), crop_n_layers=1, min_mask_region_area=int(min_area))
        self.current_model = model_name
        self.current_params = params

    def handle_upload(self, file_obj):
        if not file_obj: return None
        # 拡張子を強制的に .png に変更
        base_name = os.path.splitext(os.path.basename(file_obj.name))[0]
        self.current_filename = f"{base_name}.png"
        shutil.copy(file_obj.name, self.active_image_path)
        return Image.open(self.active_image_path)

    def run_segment(self, model_name, p_side, iou_th, stab_th, min_area):
        if not os.path.exists(self.active_image_path): return None
        self.load_model(model_name, p_side, iou_th, stab_th, min_area)
        orig_pil = Image.open(self.active_image_path).convert("RGB")
        img_np = np.array(orig_pil)
        
        try:
            masks = self.generator.generate(img_np)
            if not masks:
                print("Warning: No masks detected.")
                return None
        except Exception as e:
            print(f"SAM2 Segmentation Error: {e}")
            return None

        overlay = np.zeros_like(img_np)
        for i, m in enumerate(masks):
            color = [(i*47)%230+20, (i*97)%230+20, (i*149)%230+20]
            overlay[m['segmentation']] = color
        seg_pil = Image.fromarray(overlay)
        seg_pil.save(os.path.join(WORKDIR, "segments.png"))
        blended = Image.blend(seg_pil, orig_pil, 0.4)
        return blended

    def make_mask(self, sketch_data, invert, smooth, expand):
        if sketch_data is None: return None
        draw_img = sketch_data.get("mask") if isinstance(sketch_data, dict) else sketch_data
        seg_file = os.path.join(WORKDIR, "segments.png")
        if draw_img is None or not os.path.exists(seg_file): return None
        seg_np = np.array(Image.open(seg_file))
        draw_np = np.array(draw_img.convert("L"))
        y, x = np.where(draw_np > 0)
        if len(y) == 0: return None
        target_colors = set(tuple(seg_np[y_i, x_i]) for y_i, x_i in zip(y, x))
        target_colors.discard((0,0,0))
        mask = np.zeros(seg_np.shape[:2], dtype=np.uint8)
        for c in target_colors: mask[np.all(seg_np == c, axis=-1)] = 255
        
        if smooth > 0:
            s_val = int(smooth) * 2 + 1
            mask = cv2.GaussianBlur(mask, (s_val, s_val), 0); _, mask = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY)
        if expand != 0:
            e_val = int(abs(expand))
            k = np.ones((e_val, e_val), np.uint8)
            mask = cv2.dilate(mask, k) if expand > 0 else cv2.erode(mask, k)
            
        if invert: mask = cv2.bitwise_not(mask)
        mask_pil = Image.fromarray(mask).convert("RGB")
        mask_pil.save(os.path.join(WORKDIR, "mask_generated.png"))
        orig_img = Image.open(self.active_image_path).convert("RGB")
        blended = Image.blend(mask_pil, orig_img, 0.5)
        return blended

    def run_crop(self, edited_mask_data, final_invert, crop_blur):
        if not edited_mask_data or not os.path.exists(self.active_image_path): return None, None
        gen_mask_path = os.path.join(WORKDIR, "mask_generated.png")
        if not os.path.exists(gen_mask_path): return None, None
        gen_mask_np = np.array(Image.open(gen_mask_path).convert("L"))
        if isinstance(edited_mask_data, dict) and "mask" in edited_mask_data:
            stroke_np = np.array(edited_mask_data["mask"].convert("L"))
            final_mask_np = cv2.subtract(gen_mask_np, stroke_np)
        else:
            final_mask_np = gen_mask_np
        if final_invert:
            final_mask_np = cv2.bitwise_not(final_mask_np)

        if crop_blur > 0:
            b_val = int(crop_blur) * 2 + 1
            final_mask_np = cv2.GaussianBlur(final_mask_np, (b_val, b_val), 0)

        Image.fromarray(final_mask_np).save(os.path.join(WORKDIR, "mask_edited.png"))
        orig = np.array(Image.open(self.active_image_path).convert("RGB"))
        if orig.shape[:2] != final_mask_np.shape:
            final_mask_np = cv2.resize(final_mask_np, (orig.shape[1], orig.shape[0]))
        
        # RGBAで透過処理
        res_rgba = cv2.cvtColor(orig, cv2.COLOR_RGB2RGBA)
        res_rgba[:, :, 3] = final_mask_np 

        out_path = os.path.join(WORKDIR, self.current_filename)
        # 保存形式をPNGに強制（RGBA維持のため）
        Image.fromarray(res_rgba).save(out_path, format="PNG")
        return Image.fromarray(res_rgba), out_path

handler = SAM2Handler()
version_info = f"torch: {torch.__version__} | torchvision: {torchvision.__version__} | torchaudio: {torchaudio.__version__} | gradio: {gr.__version__}"

with gr.Blocks(css=css) as demo:
    gr.Markdown(f"# SAM2 画像切り抜きツール（Ver{VERSION}）")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 1. 解析設定")
            model_drop = gr.Dropdown(["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"], label="モデル選択", value="sam2_hiera_tiny.pt")
            preset = gr.Radio(["標準", "詳細", "軽量"], label="プリセット", value="標準")
            reset_btn = gr.Button("全てリセット", variant="stop")
            with gr.Accordion("パラメーター", open=False):
                p_side = gr.Slider(8, 128, 32, step=8, label="ポイント密度"); iou_th = gr.Slider(0.0, 1.0, 0.8, label="IOUしきい値"); stab_th = gr.Slider(0.0, 1.0, 0.92, label="安定しきい値"); min_area = gr.Slider(0, 1000, 300, label="最小面積（ノイズ除去）")
            input_file = gr.File(label="画像をアップロード")
            input_preview = gr.Image(label="アップロード確認（プレビュー）", type="pil", interactive=False)
            run_btn = gr.Button("解析実行", variant="primary")
        with gr.Column():
            gr.Markdown("### 2. 切り抜き作業")
            seg_view = gr.Image(label="解析画像（切り抜き部分をなぞる）", tool="sketch", type="pil", interactive=True, elem_classes="padded-image")
            with gr.Accordion("マスク調整", open=False):
                inv = gr.Checkbox(label="反転")
                sm = gr.Slider(0, 20, 0, step=1, label="滑らか")
                ex = gr.Slider(-20, 20, 0, step=1, label="拡張")
            mask_btn = gr.Button("マスク生成", variant="primary")
            mask_view = gr.Image(label="マスク生成結果（消しゴム編集可能）", tool="sketch", type="pil", interactive=True, elem_classes="padded-image")
            
            # デフォルトで閉じる設定 (open=False)
            with gr.Accordion("切り抜き調整", open=False):
                final_inv_chk = gr.Checkbox(label="切り抜き時にマスクを反転", value=False)
                crop_blur = gr.Slider(0, 20, 0, step=1, label="ぼかし (px)")
            
            crop_btn = gr.Button("切り抜き実行", variant="primary")
            result_view = gr.Image(label="結果画像", type="pil", interactive=False)
            result_file = gr.File(label="保存")

    gr.HTML(f"<div class='version-text'>{version_info}</div>")

    def set_p(p): return {"標準":(32, 0.8, 0.92, 300), "詳細":(64, 0.5, 0.5, 100), "軽量":(16, 0.8, 0.95, 400)}[p]
    preset.change(set_p, preset, [p_side, iou_th, stab_th, min_area])
    input_file.change(handler.handle_upload, input_file, input_preview)
    
    run_btn.click(lambda: None, None, seg_view).then(
        handler.run_segment, [model_drop, p_side, iou_th, stab_th, min_area], seg_view, show_progress=True
    )
    mask_btn.click(lambda: None, None, mask_view).then(
        handler.make_mask, [seg_view, inv, sm, ex], mask_view, show_progress=True
    )
    crop_btn.click(
        handler.run_crop, [mask_view, final_inv_chk, crop_blur], [result_view, result_file], show_progress=True
    )
    reset_btn.click(lambda: [None] * 6 + [False, 0], None, [input_file, input_preview, seg_view, mask_view, result_view, result_file, final_inv_chk, crop_blur])

demo.launch(server_name="0.0.0.0", share=True)

## アプリケーションの起動

In [None]:
# Notebook側の環境変数も Agg に固定して実行
%env MPLBACKEND=Agg
!{venv_path}/bin/python sac.py
