# 链家租房数据爬虫与分析

本项目包含从链家网爬取租房数据、数据清洗、可视化分析以及价格预测建模的完整流程。

## 1. 环境准备与配置

In [13]:
# 安装依赖 (如果尚未安装，请取消注释并运行)
# !pip install catboost plotly beautifulsoup4 pandas numpy matplotlib scikit-learn fake-useragent requests seaborn

In [14]:
import logging
import random
import re
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
from bs4 import BeautifulSoup, Tag
from catboost import CatBoostRegressor, Pool
from fake_useragent import UserAgent
from plotly.subplots import make_subplots
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

# 配置 Pandas 显示选项
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s] %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)

In [15]:
# --- 全局配置 ---

# 项目目录
PROJECT_ROOT = Path('.').resolve()
DATA_DIR = PROJECT_ROOT / "data"
IMAGE_DIR = PROJECT_ROOT / "images"

DATA_DIR.mkdir(
    parents=True,
    exist_ok=True
)
IMAGE_DIR.mkdir(
    parents=True,
    exist_ok=True
)

# 爬虫参数
START_PAGE = 1
END_PAGE = 50
MIN_DELAY = 6
MAX_DELAY = 8
MAX_FAILURES = 3

# 城市映射
CITIES_MAP = {
    'bj': '北京',
    'sh': '上海',
    'gz': '广州',
    'sz': '深圳',
    'wh': '武汉',
    'cd': '成都',
    'hz': '杭州',
    'xm': '厦门'
}

# 分析目标城市
ANALYSIS_CITY = '厦门'

# Cookie (必须替换为有效的登录 Cookie)
COOKIE = """lianjia_uuid=21879120-1b21-49b8-ae51-bf462eeea672; crosSdkDT2019DeviceId=-n8sihi-e5a4d0-yh0lk4hhy11ahyr-49vxcihat; ftkrc_=775a8b9e-29c8-4701-9d81-299c3b5bf38c; lfrc_=0794bd25-2752-46e5-999f-f277ab4340ff; sensorsdata2015jssdkcross=%7B%22distinct_id%22%3A%2219b4b9589c6d51-049b59bc23a5d88-4c657b58-2073600-19b4b9589c715e0%22%2C%22%24device_id%22%3A%2219b4b9589c6d51-049b59bc23a5d88-4c657b58-2073600-19b4b9589c715e0%22%2C%22props%22%3A%7B%22%24latest_traffic_source_type%22%3A%22%E4%BB%98%E8%B4%B9%E5%B9%BF%E5%91%8A%E6%B5%81%E9%87%8F%22%2C%22%24latest_referrer%22%3A%22https%3A%2F%2Fwww.baidu.com%2Fother.php%22%2C%22%24latest_referrer_host%22%3A%22www.baidu.com%22%2C%22%24latest_search_keyword%22%3A%22%E9%93%BE%E5%AE%B6%22%2C%22%24latest_utm_source%22%3A%22baidu%22%2C%22%24latest_utm_medium%22%3A%22pinzhuan%22%2C%22%24latest_utm_campaign%22%3A%22wyfs%22%2C%22%24latest_utm_content%22%3A%22biaotimiaoxau%22%2C%22%24latest_utm_term%22%3A%22biaoti%22%7D%7D; select_city=350200; GUARANTEE_POPUP_SHOW=true; GUARANTEE_BANNER_SHOW=true; login_ucid=2000000512872949; lianjia_token=2.00146b5cf64f63e0ef05c675c7c01d6fe3; lianjia_token_secure=2.00146b5cf64f63e0ef05c675c7c01d6fe3; security_ticket=Km1prrtLg0Rg6VBikAI5CXZQJqbJqniEkzfZBIAFsRUKMwzf3ygXT4qIzHEY+m01fEyaWujvPlNDuIIDgRrUWuVJXsRKMYvgVz/tDNPdaYKQ/s6UC9VSeLRjuRnHL9TD2PsfzmQTbaJ77eOSZHdMY01opdG0VPzeFbw10olMk9Y=; lianjia_ssid=e07cfb39-a97c-47bc-a4c7-75311a9b540d; hip=DlrVud7_QcpqSJqmQ1pb9ZK_3mRILkADZ4p7nfsgd_ldNVKpWSNYljrzceIA75Hc6CsFSLFri3HcMvtqKAAcaLrZMsMtAox8c5KYH9fLGvaupn7XDDzy58bbmVb6gAlAn5XLoqufE9apWDqjDYRxlxp8e9GkP14Vd_v0iB-33gopZTwC53PUOJ8-94hvc7xSN6TZGfVg_SThBoKscBNh8n3qk97HF7QDsPKLXnoN07omTUQ_meyk-VXcFlcyT1ZZfNSeiIEA2X9BxK_Ovn_R_MAKYbFvDmQUhCXBGg%3D%3D; srcid=eyJ0Ijoie1wiZGF0YVwiOlwiYTZjZDRkMzVmZjE3NjcxOTg5YTFhMzE4NzQ4NGMzZDBmNDMzYzYwNWFlNGM4MjVkMmJhOTdlNmM3NWI3MWUyMjFmZmJkZDdlNTNkNWQzOGQ3MjA0ZmE5MjEwOTA3OGUwNWQ5NzVkMDYyNTJmYjgwYmNiYWQxMGMzNDZmZmFkZDA3OWU0N2JlMGVmMjRmOGI1Mjc4YmIzYWNhYzY2OTNiZGUxMmUwMzFjODBjMTY[... 259 chars omitted ...]"""

## 2. 工具类定义

In [16]:
class CsvUtil:
    @staticmethod
    def save_data(
        data: Union[pd.DataFrame, List[dict]],
        prefix: str,
        city_name: Optional[str] = None
    ) -> Path:
        df = pd.DataFrame(data) if not isinstance(data, pd.DataFrame) else data
        filename = f"{prefix}_{city_name}.csv" if city_name else f"{prefix}.csv"
        file_path = DATA_DIR / filename
        df.to_csv(
            file_path,
            index=False,
            encoding='utf-8-sig'
        )
        return file_path

    @staticmethod
    def load_data(
        prefix: str,
        merge_pattern: bool = False
    ) -> Optional[pd.DataFrame]:
        if merge_pattern:
            files = list(DATA_DIR.glob(f"{prefix}_*.csv"))
        else:
            files = list(DATA_DIR.glob(f"{prefix}.csv"))
            
        if not files:
            return None
        try:
            df_list = [
                pd.read_csv(f, encoding='utf-8-sig') for f in files
            ]
            if df_list:
                return pd.concat(
                    df_list,
                    ignore_index=True
                )
            return None
        except Exception as e:
            logger.error(f"Load data failed: {e}")
            return None


class ImageUtil:
    @staticmethod
    def save_plot(
        fig: go.Figure,
        filename: str,
        fmt: str = 'png',
        scale: float = 2.0
    ) -> str:
        file_path = IMAGE_DIR / f"{filename}.{fmt}"
        try:
            if fmt == 'html':
                fig.write_html(str(file_path))
            else:
                # 静态图片保存需要 kaleido 或 orca，如果没有安装，可以注释掉这行
                # fig.write_image(str(file_path), scale=scale)
                pass
            return str(file_path)
        except Exception as e:
            logger.warning(f"Save plot failed: {e}")
            return ""

## 3. 阶段一：数据爬取 (Spider)

In [17]:
class LianJiaSpider:
    _PATTERNS = {
        'layout': re.compile(r'(\d+)室(\d+)厅(\d+)卫'),
        'area': re.compile(r'(\d+\.?\d*)㎡'),
        'orientation': re.compile(r'/\s*([\u4e00-\u9fa5\s]+)\s*/'),
        'floor_level': re.compile(r'([\u4e00-\u9fa5]+)楼层'),
        'total_floors': re.compile(r'（(\d+)层）')
    }

    def __init__(self):
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': UserAgent().random,
            'Cookie': COOKIE
        })

    def run(
        self,
        specific_cities: Dict[str, str]
    ):
        for city_code, city_name in specific_cities.items():
            data = self._crawl_city(
                city_name,
                city_code
            )
            CsvUtil.save_data(
                data,
                prefix='raw',
                city_name=city_name
            )

    def _crawl_city(
        self,
        city_name: str,
        city_code: str
    ) -> List[Dict]:
        logger.info(f">>> 启动抓取: {city_name} ({city_code})")
        items = []
        consecutive_failures = 0

        for page in range(START_PAGE, END_PAGE + 1):
            page_data = self._fetch_page(
                city_code,
                page,
                city_name
            )

            if page_data:
                items.extend(page_data)
                consecutive_failures = 0
                if page % 5 == 0:
                    logger.info(f"    √ 已抓取 {len(items)} 条数据 (当前页: {page})")
            else:
                consecutive_failures += 1
                logger.info(f"    × 第 {page} 页无数据")

            if consecutive_failures >= MAX_FAILURES:
                logger.error(f"连续失败触发熔断: {city_name}")
                break

            time.sleep(random.uniform(MIN_DELAY, MAX_DELAY))
        
        logger.info(f"<<< {city_name} 抓取完成，共 {len(items)} 条")
        return items

    def _fetch_page(
        self,
        city_code: str,
        page: int,
        city_name: str
    ) -> List[Dict]:
        url = f"https://{city_code}.lianjia.com/zufang/pg{page}/"
        try:
            resp = self.session.get(
                url,
                timeout=10
            )
            resp.raise_for_status()
            soup = BeautifulSoup(
                resp.content,
                'html.parser'
            )
            cards = soup.select('div.content__list--item')
            return [
                self._parse_card(card, city_name) for card in cards
            ]
        except Exception as e:
            logger.error(f"请求异常 (Page {page}): {e}")
            return []

    def _parse_card(
        self,
        card: Tag,
        city_name: str
    ) -> Dict[str, Any]:
        def _text(selector: str) -> str:
            el = card.select_one(selector)
            return el.get_text(strip=True) if el else ""

        def _regex(
            text: str,
            key: str,
            dtype: type = str,
            default: Any = None
        ) -> Any:
            match = self._PATTERNS[key].search(text)
            if match:
                try:
                    if match.lastindex == 1:
                        return dtype(match.group(1))
                    return match.groups()
                except (ValueError, IndexError):
                    pass
            return default

        desc = _text('p.content__list--item--des')
        floor_txt = _text('p.content__list--item--des span.hide')
        title = _text('p.content__list--item--title')
        
        locs = [
            a.get_text(strip=True) for a in card.select('p.content__list--item--des a')[:3]
        ]
        locs += [''] * (3 - len(locs))

        layout = self._PATTERNS['layout'].search(desc)
        rooms = [int(x) for x in layout.groups()] if layout else [0, 0, 0]

        tag_texts = [
            i.get_text(strip=True) for i in card.select('p.content__list--item--bottom i')
        ]

        return {
            'city': city_name,
            'title': title,
            'rent_type': '整租' if '整租' in title else ('合租' if '合租' in title else '独栋'),
            'district': locs[0],
            'sub_district': locs[1],
            'community': locs[2],
            'area_sqm': _regex(desc, 'area', float, 0.0),
            'bedrooms': rooms[0],
            'living_rooms': rooms[1],
            'bathrooms': rooms[2],
            'orientation': _regex(desc, 'orientation', str, '').strip(),
            'floor_level': _regex(floor_txt, 'floor_level', str, ''),
            'total_floors': _regex(floor_txt, 'total_floors', int, 0),
            'tags': '|'.join(tag_texts),
            'platform': _text('p.content__list--item--brand span.brand'),
            'update_time': _text('p.content__list--item--brand span.content__list--item--time'),
            'price_rmb': _text('span.content__list--item-price em')
        }

In [18]:
# 执行爬虫
def get_missing_cities() -> dict:
    existing_files = list(DATA_DIR.glob("raw_*.csv"))
    existing_names = {
        p.stem.split('_')[1] for p in existing_files
    }
    return {
        c: n for c, n in CITIES_MAP.items() if n not in existing_names
    }

missing_cities = get_missing_cities()
if missing_cities:
    print(f"需要爬取的城市: {list(missing_cities.values())}")
    # LianJiaSpider().run(missing_cities)  # 取消注释以运行爬虫
else:
    print("所有目标城市数据已存在，跳过爬取。")

所有目标城市数据已存在，跳过爬取。


## 4. 阶段二：数据清洗 (Cleaning)

In [19]:
class DataCleaner:
    COLS_TEXT = [
        'city',
        'title',
        'rent_type',
        'district',
        'sub_district',
        'community',
        'orientation',
        'floor_level',
        'tags',
        'platform',
        'update_time',
        'price_rmb'
    ]
    COLS_NUM = [
        'area_sqm',
        'bedrooms',
        'living_rooms',
        'bathrooms',
        'total_floors'
    ]
    _PATTERN_PRICE = re.compile(r'(\d+\.?\d*)')
    _PATTERN_REL_DATE = re.compile(r'(\d+)\s*(天|周|个月|月|年)前')
    _DATE_OFFSET_MAP = {
        '天': 1,
        '周': 7,
        '月': 30,
        '年': 365
    }

    def __init__(
        self,
        df: pd.DataFrame
    ):
        self.df = df.copy()
        self.now = datetime.now()

    @classmethod
    def execute_task(cls) -> pd.DataFrame:
        raw_df = CsvUtil.load_data(
            "raw",
            merge_pattern=True
        )
        if raw_df is None or raw_df.empty:
            logger.warning("无原始数据")
            return pd.DataFrame()
        
        cleaned_df = cls(raw_df).process()
        CsvUtil.save_data(
            cleaned_df,
            prefix='cleaned'
        )
        return cleaned_df

    def process(self) -> pd.DataFrame:
        initial_len = len(self.df)
        self.df = (
            self.df
            .pipe(self._clean_text)
            .pipe(self._clean_numeric)
            .pipe(self._parse_prices)
            .pipe(self._standardize_dates)
            .pipe(self._remove_outliers)
        )
        logger.info(f"清洗完成: {initial_len} -> {len(self.df)} (剔除 {initial_len - len(self.df)} 条)")
        return self.df

    def _clean_text(
        self,
        df: pd.DataFrame
    ) -> pd.DataFrame:
        cols = df.columns.intersection(self.COLS_TEXT)
        df[cols] = df[cols].fillna('未知').astype(str).apply(lambda x: x.str.strip())
        df[cols] = df[cols].replace(
            r'^(nan|NaN|None|NULL|)$',
            '未知',
            regex=True
        )
        return df

    def _clean_numeric(
        self,
        df: pd.DataFrame
    ) -> pd.DataFrame:
        for col in df.columns.intersection(self.COLS_NUM):
            df[col] = pd.to_numeric(
                df[col],
                errors='coerce'
            ).fillna(0)
        return df

    def _parse_prices(
        self,
        df: pd.DataFrame
    ) -> pd.DataFrame:
        def _calc_stats(val: str) -> pd.Series:
            nums = [
                float(x) for x in self._PATTERN_PRICE.findall(str(val))
            ]
            if nums:
                return pd.Series([
                    min(nums),
                    max(nums),
                    sum(nums) / len(nums)
                ])
            return pd.Series([0.0, 0.0, 0.0])

        stats = df['price_rmb'].apply(_calc_stats)
        stats.columns = [
            'price_min',
            'price_max',
            'price_avg'
        ]
        return pd.concat(
            [df, stats],
            axis=1
        )

    def _standardize_dates(
        self,
        df: pd.DataFrame
    ) -> pd.DataFrame:
        def _parse(date_str: str) -> str:
            if '今天' in date_str:
                return self.now.strftime('%Y-%m-%d')
            match = self._PATTERN_REL_DATE.search(date_str)
            if match:
                val, unit = match.groups()
                days = int(val) * self._DATE_OFFSET_MAP.get(unit.replace('个', ''), 0)
                return (self.now - timedelta(days=days)).strftime('%Y-%m-%d')
            return date_str
            
        df['clean_date'] = df['update_time'].apply(_parse)
        return df

    def _remove_outliers(
        self,
        df: pd.DataFrame
    ) -> pd.DataFrame:
        df = df[df['price_avg'] > 0]
        if df.empty:
            return df
        low, high = df['price_avg'].quantile(0.03), df['price_avg'].quantile(0.97)
        return df[(df['price_avg'] >= low) & (df['price_avg'] <= high)]

In [20]:
# 执行清洗并预览数据
df_clean = DataCleaner.execute_task()
if not df_clean.empty:
    display(df_clean.head())
    print(f"Data Shape: {df_clean.shape}")
else:
    print("没有数据可供清洗。")

[01:27:35] 清洗完成: 7680 -> 7218 (剔除 462 条)


Unnamed: 0,city,title,rent_type,district,sub_district,community,area_sqm,bedrooms,living_rooms,bathrooms,orientation,floor_level,total_floors,tags,platform,update_time,price_rmb,price_min,price_max,price_avg,clean_date
0,上海,整租·杨家浜小区 1室1厅 南,整租,杨浦,控江路,杨家浜小区,36.0,1,1,1,南,低,3,自营|新上|近地铁|精装|押一付一|随时看房|首次出租,贝壳优选,今天维护,3200,3200.0,3200.0,3200.0,2026-01-08
1,上海,整租·金杨六街坊 1室1厅 南,整租,浦东,金杨,金杨六街坊,41.07,1,1,1,南,中,6,官方核验|近地铁|随时看房,链家,1天前维护,3400,3400.0,3400.0,3400.0,2026-01-07
2,上海,整租·东方明珠大宁公寓 2室2厅 东南/南,整租,静安,大宁,东方明珠大宁公寓,95.18,2,2,2,东南 南,低,25,自营|新上|押一付一|双卫生间|随时看房|首次出租,贝壳优选,今天维护,7500,7500.0,7500.0,7500.0,2026-01-08
3,上海,独栋·方隅服务公寓 上海会展旗舰店 【康复医院陪护】近华山西院康复医院 可做饭 陪护优选拎包...,独栋,未知,未知,未知,50.0,1,1,1,未知,未知,0,独栋公寓|月租|精装|开放厨房|押一付一,方隅服务公寓,1天前维护,4670,4670.0,4670.0,4670.0,2026-01-07
4,上海,整租·绿地海域笙晖(公寓) 3室1厅 南,整租,宝山,杨行,绿地海域笙晖(公寓),96.32,3,1,1,南,低,16,自营|新上|精装|押一付一|随时看房|首次出租,贝壳优选,今天维护,6100,6100.0,6100.0,6100.0,2026-01-08


Data Shape: (7218, 21)


## 5. 阶段三：数据可视化 (Visualization)

In [21]:
class RentalDataVisualizer:
    LABELS_MAP = {
        'city': '城市',
        'price_avg': '月租均价 (元)',
        'price_per_sqm': '每平米单价 (元/㎡)',
        'is_subway': '房源类型',
        'district': '行政区',
        'count': '房源数量',
        'range': '价格区间'
    }
    COLOR_SEQ = px.colors.qualitative.Plotly

    def __init__(
        self,
        df: pd.DataFrame
    ):
        self.df = df[df['district'] != '未知'].copy()
        self.df['price_per_sqm'] = self.df['price_avg'] / self.df['area_sqm']
        self.df['is_subway'] = np.where(
            self.df['tags'].str.contains('近地铁'),
            '地铁房',
            '普通房'
        )

    def _apply_style(
        self,
        fig,
        title: str
    ):
        fig.update_layout(
            title={
                'text': title,
                'x': 0.5,
                'y': 0.95,
                'xanchor': 'center',
                'yanchor': 'top'
            },
            template="plotly_white",
            font=dict(
                family="Microsoft YaHei",
                size=12
            ),
            margin=dict(
                l=40,
                r=40,
                t=80,
                b=40
            ),
            showlegend=True
        )
        return fig

    def plot_city_comparison(self):
        metrics = self.df.groupby('city')[['price_avg', 'price_per_sqm']].mean().reset_index()
        fig = make_subplots(specs=[[{"secondary_y": True}]])
        
        fig.add_trace(
            go.Bar(
                x=metrics['city'],
                y=metrics['price_avg'],
                name="月租均价",
                marker_color=self.COLOR_SEQ[0]
            ),
            secondary_y=False
        )
        
        fig.add_trace(
            go.Scatter(
                x=metrics['city'],
                y=metrics['price_per_sqm'],
                name="每平米单价",
                mode='lines+markers',
                line=dict(
                    color=self.COLOR_SEQ[1],
                    width=3
                )
            ),
            secondary_y=True
        )
        
        fig.update_yaxes(
            title_text="月租均价 (元)",
            secondary_y=False
        )
        fig.update_yaxes(
            title_text="每平米单价 (元/㎡)",
            secondary_y=True,
            showgrid=False
        )
        self._apply_style(
            fig,
            "各城市租金水平对比"
        )
        ImageUtil.save_plot(
            fig,
            "各城市租金水平对比"
        )
        return fig

    def plot_subway_premium(self):
        metrics = self.df.groupby('is_subway')['price_avg'].mean().reset_index()
        fig = px.bar(
            metrics,
            x='is_subway',
            y='price_avg',
            labels=self.LABELS_MAP,
            color='is_subway',
            color_discrete_sequence=self.COLOR_SEQ,
            text_auto='.2f'
        )
        fig.update_traces(textposition='outside')
        self._apply_style(
            fig,
            "地铁房 vs 普通房：月租均价对比"
        )
        ImageUtil.save_plot(
            fig,
            "地铁房 vs 普通房：月租均价对比"
        )
        return fig

    def plot_city_analysis(
        self,
        city: str
    ):
        city_data = self.df[self.df['city'] == city]
        if city_data.empty:
            logger.warning(f"未找到城市 {city} 的数据")
            return None
        
        # 1. Box Plot
        fig_box = px.box(
            city_data,
            x='district',
            y='price_avg',
            labels=self.LABELS_MAP,
            color='district',
            color_discrete_sequence=self.COLOR_SEQ
        )
        fig_box.update_xaxes(categoryorder='median ascending')
        self._apply_style(
            fig_box,
            f"{city} - 各区域租金分布"
        )
        ImageUtil.save_plot(
            fig_box,
            f"{city} - 各区域租金分布"
        )
        fig_box.show()

        # 2. Top 10
        top10 = city_data.groupby('district')['price_avg'].mean().nlargest(10).reset_index()
        fig_bar = px.bar(
            top10,
            x='price_avg',
            y='district',
            orientation='h',
            labels=self.LABELS_MAP,
            color='price_avg',
            color_continuous_scale='Blues'
        )
        fig_bar.update_yaxes(autorange="reversed")
        self._apply_style(
            fig_bar,
            f"{city} - 租金最贵行政区 Top 10"
        )
        ImageUtil.save_plot(
            fig_bar,
            f"{city} - 租金最贵行政区 Top 10"
        )
        fig_bar.show()

        # 3. Pie Chart
        bins = [0, 2000, 4000, 6000, 8000, 10000, float('inf')]
        labels = ['2k以下', '2k-4k', '4k-6k', '6k-8k', '8k-1w', '1w以上']
        counts = pd.cut(
            city_data['price_avg'],
            bins=bins,
            labels=labels
        ).value_counts().reset_index()
        counts.columns = ['range', 'count']
        
        fig_pie = px.pie(
            counts,
            values='count',
            names='range',
            labels=self.LABELS_MAP,
            color_discrete_sequence=self.COLOR_SEQ,
            hole=0.4
        )
        fig_pie.update_traces(
            textposition='inside',
            textinfo='percent+label'
        )
        self._apply_style(
            fig_pie,
            f"{city} - 房源价格区间占比"
        )
        ImageUtil.save_plot(
            fig_pie,
            f"{city} - 房源价格区间占比"
        )
        fig_pie.show()

In [22]:
# 执行可视化
if not df_clean.empty:
    viz = RentalDataVisualizer(df_clean)
    
    # 1. 城市对比
    fig1 = viz.plot_city_comparison()
    fig1.show()
    
    # 2. 地铁房溢价
    fig2 = viz.plot_subway_premium()
    fig2.show()
    
    # 3. 特定城市深度分析
    viz.plot_city_analysis(ANALYSIS_CITY)

## 6. 阶段四：价格预测建模 (Modeling)

In [23]:
class ModelPipeline:
    CATEGORICAL_FEATURES = [
        'city',
        'rent_type',
        'district',
        'sub_district',
        'orientation',
        'floor_level'
    ]
    NUMERIC_FEATURES = [
        'area_sqm',
        'bedrooms',
        'living_rooms',
        'bathrooms',
        'total_floors',
        'is_subway'
    ]
    TARGET_COLUMN = 'price_avg'
    FEATURE_DISPLAY_MAP = {
        'city': '城市',
        'rent_type': '租赁方式',
        'district': '行政区',
        'sub_district': '商圈',
        'orientation': '朝向',
        'floor_level': '楼层等级',
        'area_sqm': '面积(㎡)',
        'bedrooms': '室数',
        'living_rooms': '厅数',
        'bathrooms': '卫数',
        'total_floors': '总楼层',
        'is_subway': '是否近地铁'
    }
    COLOR_SEQ = px.colors.qualitative.Plotly

    def __init__(
        self,
        raw_df: pd.DataFrame
    ):
        self.raw_df = raw_df
        self.performance_metrics: List[Dict] = []

    def run(self):
        features, target = self._prepare_data()
        self.all_feature_names = features.columns.tolist()
        
        train_x, test_x, train_y, test_y = train_test_split(
            features,
            target,
            test_size=0.2,
            random_state=42
        )

        print("Training CatBoost...")
        catboost_model = self._train_catboost(
            train_x,
            train_y,
            test_x,
            test_y
        )
        self._evaluate(
            catboost_model,
            test_x,
            test_y,
            "CatBoost"
        )

        print("Training RandomForest...")
        rf_model, test_x_enc = self._train_random_forest(
            train_x,
            train_y,
            test_x
        )
        self._evaluate(
            rf_model,
            test_x_enc,
            test_y,
            "RandomForest"
        )

        self._plot_metrics_comparison()
        self._plot_importance(catboost_model)

    def _prepare_data(self) -> Tuple[pd.DataFrame, pd.Series]:
        df = self.raw_df.copy()
        df['is_subway'] = df['tags'].str.contains(
            '近地铁',
            na=False
        ).astype(int)
        
        return (
            df[self.CATEGORICAL_FEATURES + self.NUMERIC_FEATURES],
            df[self.TARGET_COLUMN]
        )

    def _train_catboost(
        self,
        train_x,
        train_y,
        test_x,
        test_y
    ):
        model = CatBoostRegressor(
            iterations=1000,
            learning_rate=0.05,
            depth=5,
            loss_function='RMSE',
            verbose=0,
            early_stopping_rounds=50,
            allow_writing_files=False
        )
        model.fit(
            Pool(
                train_x,
                train_y,
                cat_features=self.CATEGORICAL_FEATURES
            ),
            eval_set=Pool(
                test_x,
                test_y,
                cat_features=self.CATEGORICAL_FEATURES
            ),
            use_best_model=True
        )
        return model

    def _train_random_forest(
        self,
        train_x,
        train_y,
        test_x
    ):
        encoder = OrdinalEncoder(
            handle_unknown='use_encoded_value',
            unknown_value=-1
        )
        train_x_enc = train_x.copy()
        test_x_enc = test_x.copy()
        
        train_x_enc[self.CATEGORICAL_FEATURES] = encoder.fit_transform(
            train_x[self.CATEGORICAL_FEATURES]
        )
        test_x_enc[self.CATEGORICAL_FEATURES] = encoder.transform(
            test_x[self.CATEGORICAL_FEATURES]
        )
        
        model = RandomForestRegressor(
            n_estimators=100,
            max_depth=20,
            n_jobs=-1,
            random_state=42
        )
        model.fit(
            train_x_enc,
            train_y
        )
        return model, test_x_enc

    def _evaluate(
        self,
        model,
        features,
        target,
        model_name: str
    ):
        pred = model.predict(features)
        rmse = np.sqrt(mean_squared_error(target, pred))
        mae = mean_absolute_error(target, pred)
        r2 = r2_score(target, pred)
        
        print(f"{model_name}: RMSE={rmse:.2f}, MAE={mae:.2f}, R2={r2:.3f}")
        
        self.performance_metrics.extend([
            {'Model': model_name, 'Metric': 'RMSE', 'Value': rmse},
            {'Model': model_name, 'Metric': 'MAE', 'Value': mae},
            {'Model': model_name, 'Metric': 'R2', 'Value': r2}
        ])

    def _plot_metrics_comparison(self):
        metrics_df = pd.DataFrame(self.performance_metrics)
        fig = make_subplots(
            rows=1,
            cols=2,
            subplot_titles=("模型拟合度 (R2)", "误差指标 (RMSE/MAE)")
        )
        
        r2_df = metrics_df[metrics_df['Metric'] == 'R2']
        fig.add_trace(
            go.Bar(
                x=r2_df['Model'],
                y=r2_df['Value'],
                name='R2',
                text=r2_df['Value'].round(3),
                textposition='auto',
                marker_color=self.COLOR_SEQ[0]
            ),
            row=1,
            col=1
        )
        
        err_df = metrics_df[metrics_df['Metric'].isin(['RMSE', 'MAE'])]
        for i, metric in enumerate(['RMSE', 'MAE']):
            d = err_df[err_df['Metric'] == metric]
            fig.add_trace(
                go.Bar(
                    x=d['Model'],
                    y=d['Value'],
                    name=metric,
                    marker_color=self.COLOR_SEQ[i+1]
                ),
                row=1,
                col=2
            )
            
        fig.update_layout(
            barmode='group',
            title_text="模型性能评估对比"
        )
        ImageUtil.save_plot(
            fig,
            "模型性能对比图"
        )
        fig.show()

    def _plot_importance(
        self,
        model
    ):
        importance_df = pd.DataFrame({
            'feature': self.all_feature_names,
            'score': model.get_feature_importance()
        })
        importance_df['feature_cn'] = importance_df['feature'].map(self.FEATURE_DISPLAY_MAP)
        importance_df = importance_df.sort_values(
            'score',
            ascending=True
        ).tail(10)
        
        fig = px.bar(
            importance_df,
            x='score',
            y='feature_cn',
            orientation='h',
            labels={'score': '贡献度', 'feature_cn': '特征'},
            color='score',
            color_continuous_scale='Viridis'
        )
        fig.update_layout(title_text="CatBoost 核心特征贡献排行")
        ImageUtil.save_plot(
            fig,
            "CatBoost 核心特征贡献排行"
        )
        fig.show()

In [24]:
# 执行建模
if not df_clean.empty:
    ModelPipeline(df_clean).run()
else:
    print("没有数据可供建模。")

Training CatBoost...
CatBoost: RMSE=1614.60, MAE=985.77, R2=0.804
Training RandomForest...
RandomForest: RMSE=1653.61, MAE=977.83, R2=0.794
