In [None]:
import os
import random
from pathlib import Path
from typing import Optional, Union

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import timm

# --- Configuration ---
PROJECT_ROOT = Path.cwd()
LANDSLIDE_IMG_DIR = PROJECT_ROOT / 'landslide' / 'image'
LANDSLIDE_MASK_DIR = PROJECT_ROOT / 'landslide' / 'mask'
NON_LANDSLIDE_IMG_DIR = PROJECT_ROOT / 'non-landslide' / 'image'
IMG_EXTS = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Helper Functions ---
def load_image(path, target_size=(256, 256)):
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if img is None: raise RuntimeError(f"Failed to read image: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
    return img.astype(np.float32) / 255.0

def load_mask(path, target_size=(256, 256)):
    mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if mask is None: raise RuntimeError(f"Failed to read mask: {path}")
    mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
    return (mask > 0).astype(np.float32)

# --- Dataset Class ---
class LandslideBijieDataset(Dataset):
    def __init__(self, landslide_img_dir, landslide_mask_dir=None, 
                 non_landslide_img_dir=None, image_size=(256, 256), 
                 use_sam_guidance=True):
        self.samples = []
        self.image_size = image_size
        self.use_sam_guidance = use_sam_guidance

        if landslide_img_dir.exists():
            for img_name in os.listdir(landslide_img_dir):
                if Path(img_name).suffix.lower() in IMG_EXTS:
                    mask_p = landslide_mask_dir / img_name if landslide_mask_dir else None
                    self.samples.append({'img': landslide_img_dir/img_name, 'mask': mask_p, 'label': 1})

        if non_landslide_img_dir and non_landslide_img_dir.exists():
            for img_name in os.listdir(non_landslide_img_dir):
                if Path(img_name).suffix.lower() in IMG_EXTS:
                    self.samples.append({'img': non_landslide_img_dir/img_name, 'mask': None, 'label': 0})

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        img = load_image(item['img'], self.image_size)
        mask = load_mask(item['mask'], self.image_size) if item['mask'] and item['mask'].exists() else np.zeros(self.image_size, dtype=np.float32)
        
        img_t = torch.from_numpy(img).permute(2, 0, 1)
        mask_t = torch.from_numpy(mask).unsqueeze(0)
        
        return {
            'image': img_t,
            'mask': mask_t,
            'label': torch.tensor(item['label'], dtype=torch.float32),
            'sam_guidance': mask_t.clone() if self.use_sam_guidance else torch.zeros_like(mask_t)
        }

# --- Model Architecture ---
class SAMSwinHybrid(nn.Module):
    def __init__(self, backbone_name='swin_tiny_patch4_window7_224', img_channels=3):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True, in_chans=img_channels, features_only=True, out_indices=(1, 2, 3))
        feats = self.backbone.feature_info.channels()

        self.seg_decoder = nn.Sequential(
            nn.Conv2d(feats[-1], 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 2, stride=2), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 2, stride=2), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 1, 1)
        )

        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Linear(feats[-1], 128), nn.ReLU(True), nn.Linear(128, 1)
        )

    def forward(self, x, sam_guidance=None):
        feats = self.backbone(x)
        feat = feats[-1]
        if sam_guidance is not None:
            sam_resized = F.interpolate(sam_guidance, size=feat.shape[-2:], mode='bilinear', align_corners=False)
            feat = feat * (1.0 + sam_resized)
        return self.seg_decoder(feat), self.cls_head(feat).squeeze(1)

# --- Main Execution ---
if __name__ == "__main__":
    dataset = LandslideBijieDataset(LANDSLIDE_IMG_DIR, LANDSLIDE_MASK_DIR, NON_LANDSLIDE_IMG_DIR)
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    
    loader = DataLoader(train_ds, batch_size=4, shuffle=True)
    
    model = SAMSwinHybrid().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.BCEWithLogitsLoss()

    print(f"Starting training on {DEVICE}...")
    # Training loop would go here