In [56]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from enum import Enum
import json
import plotly.graph_objects as go

In [57]:
class Material(Enum):
    WATER = 0
    WATER_GRASS = 1
    STONE = 2
    DRIFTWOOD = 3
    PIER = 4
    GROUND = 5

class FishType(Enum):
    BUFFALO_FISH = 0
    YELLOW_PERCH = 1
    REDEAR_SUNFISH = 2

# 映射表头到枚举值
material_mapping = {
    'Water': Material.WATER,
    'Water Grass': Material.WATER_GRASS,
    'Stone': Material.STONE,
    'Driftwood': Material.DRIFTWOOD,
    'Pier': Material.PIER,
    'Ground': Material.GROUND
}

fish_mapping = {
    '水牛鱼': FishType.BUFFALO_FISH,
    '小冠太阳鱼': FishType.REDEAR_SUNFISH,
    '黄金鲈': FishType.YELLOW_PERCH
}

# 反向映射枚举值到表头
reverse_material_mapping = {v: k for k, v in material_mapping.items()}
reverse_fish_mapping = {v: k for k, v in fish_mapping.items()}

material_color_mapping = {
    Material.WATER.value: 'blue',
    Material.WATER_GRASS.value: 'green',
    Material.STONE.value: 'white',
    Material.DRIFTWOOD.value: 'brown',
    Material.PIER.value: 'black',
    Material.GROUND.value: 'rgba(0,0,0,0)'  # 使用 'rgba(0,0,0,0)' 表示透明
}

fishTypeCount = FishType.__len__()
noFishIndex = fishTypeCount
print(f'钓空/无鱼权重将被插入到第{noFishIndex}个位置')

material_df = pd.read_excel('exampleTables/exampleData.xlsx', sheet_name='material', index_col=0)
environment_df = pd.read_excel('exampleTables/exampleData.xlsx', sheet_name='environment', index_col=0)
pond_df = pd.read_excel('exampleTables/exampleData.xlsx', sheet_name='pond', index_col=0)
bait_df = pd.read_excel('exampleTables/exampleData.xlsx', sheet_name='bait')

# column name consts
N_FAV_TEMPERATURE = "喜好水温"
N_COEF_TEMPERATURE = "水温-权重衰减系数"
N_FAV_OXYGEN = "喜好含氧量"
N_COEF_OXYGEN = "含氧量-权重衰减系数"
N_VISUAL_LOW = "视觉低光阈值"
N_COEF_SMELL = "嗅觉衰减系数"
N_SMELL_MAX = "嗅觉最大距离"
N_PROB_SCORE_BASELINE = "基准权重"

钓空/无鱼权重将被插入到第3个位置


In [53]:
def plot_numpy_3d_colors(data: np.ndarray, colors: dict, sizeScale: float = 10.0):
    fig = go.Figure()
    
    # 修正 np.arange 的用法
    x_range = np.arange(data.shape[0])
    y_range = np.arange(data.shape[1])
    z_range = np.arange(data.shape[2])

    # 使用 np.meshgrid 创建网格点
    x_points, y_points, z_points = np.meshgrid(x_range, y_range, z_range, indexing='ij')
    x_points = x_points.flatten()
    y_points = y_points.flatten()
    z_points = z_points.flatten()
    
    # 映射颜色
    color_points = np.array([colors[data[x, y, z]] for x, y, z in zip(x_points, y_points, z_points)], dtype=str)
    color_points = color_points.flatten()

    # 添加点到图形中
    fig.add_trace(go.Scatter3d(
        x=x_points,
        y=y_points,
        z=z_points,
        mode='markers',
        marker=dict(
            size=5,
            color=color_points,  # 使用 color_points 来映射颜色
            opacity=0.5  # 设置全局透明度
        )
    ))

    # 设置布局
    fig.update_layout(
        scene=dict(
            xaxis=dict(nticks=10, range=[0, data.shape[0]]),
            yaxis=dict(nticks=10, range=[0, data.shape[1]]),
            zaxis=dict(nticks=5, range=[data.shape[2], 0]),
            aspectmode='manual',
            aspectratio=dict(x=data.shape[0]/sizeScale, y=data.shape[1]/sizeScale, z=data.shape[2]/sizeScale)
        ),
        width=800,
        height=800,
        margin=dict(r=20, l=10, b=10, t=10)
    )

    # 显示图形
    fig.show()

def plotViridis(data: np.ndarray, materials:np.ndarray, upperLimit:float ,sizeScale: float = 10.0):
    fig = go.Figure()
    
    # 修正 np.arange 的用法
    x_range = np.arange(data.shape[0])
    y_range = np.arange(data.shape[1])
    z_range = np.arange(data.shape[2])

    # 使用 np.meshgrid 创建网格点
    x_grid, y_grid, z_grid = np.meshgrid(x_range, y_range, z_range, indexing='ij')
    mask = materials != Material.GROUND.value

    x_points = x_grid[mask]
    y_points = y_grid[mask]
    z_points = z_grid[mask]
    color_points = data[mask]/upperLimit
    
    plt.add_trace(go.Scatter3d(
        x=x_points,
        y=y_points,
        z=z_points,
        mode='markers',
        marker=dict(
            size=5,
            color=color_points,  # 使用 color_points 来映射颜色
            colorscale='Viridis',
            opacity=0.5  # 设置全局透明度
        )
    ))

    # 设置布局
    fig.update_layout(
        scene=dict(
            xaxis=dict(nticks=10, range=[0, data.shape[0]]),
            yaxis=dict(nticks=10, range=[0, data.shape[1]]),
            zaxis=dict(nticks=5, range=[data.shape[2], 0]),
            aspectmode='manual',
            aspectratio=dict(x=data.shape[0]/sizeScale, y=data.shape[1]/sizeScale, z=data.shape[2]/sizeScale)
        ),
        width=800,
        height=800,
        margin=dict(r=20, l=10, b=10, t=10)
    )
    plt.show()

def plotGreys(data: np.ndarray, materials:np.ndarray, upperLimit:float ,sizeScale: float = 10.0):
    fig = go.Figure()
    
    # 修正 np.arange 的用法
    x_range = np.arange(data.shape[0])
    y_range = np.arange(data.shape[1])
    z_range = np.arange(data.shape[2])

    # 使用 np.meshgrid 创建网格点
    x_grid, y_grid, z_grid = np.meshgrid(x_range, y_range, z_range, indexing='ij')
    mask = materials != Material.GROUND.value

    x_points = x_grid[mask]
    y_points = y_grid[mask]
    z_points = z_grid[mask]
    color_points = data[mask]/upperLimit
    
    plt.add_trace(go.Scatter3d(
        x=x_points,
        y=y_points,
        z=z_points,
        mode='markers',
        marker=dict(
            size=5,
            color=color_points,  # 使用 color_points 来映射颜色
            colorscale='Greys',
            opacity=0.5  # 设置全局透明度
        )
    ))

    # 设置布局
    fig.update_layout(
        scene=dict(
            xaxis=dict(nticks=10, range=[0, data.shape[0]]),
            yaxis=dict(nticks=10, range=[0, data.shape[1]]),
            zaxis=dict(nticks=5, range=[data.shape[2], 0]),
            aspectmode='manual',
            aspectratio=dict(x=data.shape[0]/sizeScale, y=data.shape[1]/sizeScale, z=data.shape[2]/sizeScale)
        ),
        width=800,
        height=800,
        margin=dict(r=20, l=10, b=10, t=10)
    )
    plt.show()
    

In [50]:
with open('map_OL2.0_v1_159x64.json', 'r') as f:
    map_data = json.load(f)

mapDatas = map_data['mapDatas']

# 探寻最大水深
maxDepth = 0
maxX = 0
maxY = 0

for rows in mapDatas:
    for item in rows:
        maxDepth = max(maxDepth, item['waterDepth'])
        maxX = max(maxX, item['x'])
        maxY = max(maxY, item['y'])

# 定义湖的尺寸
map_length = maxX + 1  # x方向的长度
map_width = maxY + 1  # y方向的宽度
map_depth = maxDepth    # z方向的深度

print(f'湖的尺寸为 {map_length} x {map_width} x {map_depth}')

湖的尺寸为 159 x 64 x 5


In [51]:
material_np = np.full((map_length, map_width, map_depth), Material.GROUND.value, dtype=int)

for rows in mapDatas:
    for item in rows:
        x = item['x']
        y = item['y']
        depth = item['waterDepth']
        # 先初始化成2d编辑器编出来的material
        for i in range(depth):
            material_np[x][y][i] = Material.WATER.value
        # 如果有表面物体，就直通到水底
        if item['surfaceType'] != 0:
            for i in range(0, depth):
                material_np[x][y][i] = item['surfaceType']
        # bottom barriers
        if item['underWaterType'] != 0:
            material_np[x][y][depth-1] = item['underWaterType']

# plot_numpy_3d_colors(material_np, material_color_mapping, 10.0)

In [61]:
fish_prob_weights = [pond_df.loc[reverse_fish_mapping[fishType],N_PROB_SCORE_BASELINE] for fishType in FishType]
# print(fish_prob_weights)
fish_prob_weights_np = np.array(fish_prob_weights)
print(fish_prob_weights_np.shape)

(3,)
