# EDA -CSIRO IMAGE2BIOMASS PREDICTION
---

In [1]:
# ===============
# libraries
# ===============

import os, gc, yaml, glob, pickle
import time
import random
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import polars as pl
import pandas as pd
from tqdm import tqdm

import warnings
# warnings.filterwarnings('ignore')

# original
import sys
sys.path.append("../src")
from utils.data import sep, show_df, glob_walk, set_seed, save_config_yaml, dict_to_namespace

from datetime import datetime
date = datetime.now().strftime("%Y%m%d")
print(f"TODAY is {date}")

TODAY is 20251212


In [2]:
# ===============
# Config
# ===============
config = SimpleNamespace(
    ver = 0,
    seed = 2025,
    ROOT_DIR = Path("/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/"),
    DATA_DIR = Path("/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/") / "data/raw",
)

In [3]:
# =================
# Load data 
# =================
train = pl.read_csv(config.DATA_DIR / "train.csv")
test = pl.read_csv(config.DATA_DIR / "test.csv")
sub = pl.read_csv(config.DATA_DIR / "sample_submission.csv")

sep("train"); show_df(train, 3, True); display(train.describe())
sep("train"); show_df(test, 3, False); 
sep("train"); show_df(sub, 3, False); 

train
(1785, 9)


sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
str,str,str,str,str,f64,f64,str,f64
"""ID1011485656__Dry_Clover_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Clover_g""",0.0
"""ID1011485656__Dry_Dead_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Dead_g""",31.9984
"""ID1011485656__Dry_Green_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Green_g""",16.2751


sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
str,str,str,str,str,f64,f64,str,f64
"""ID983582017__Dry_Green_g""","""train/ID983582017.jpg""","""2015/9/1""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Green_g""",40.94
"""ID983582017__Dry_Total_g""","""train/ID983582017.jpg""","""2015/9/1""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Total_g""",40.94
"""ID983582017__GDM_g""","""train/ID983582017.jpg""","""2015/9/1""","""WA""","""Ryegrass""",0.64,9.0,"""GDM_g""",40.94


statistic,sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
str,str,str,str,str,str,f64,f64,str,f64
"""count""","""1785""","""1785""","""1785""","""1785""","""1785""",1785.0,1785.0,"""1785""",1785.0
"""null_count""","""0""","""0""","""0""","""0""","""0""",0.0,0.0,"""0""",0.0
"""mean""",,,,,,0.657423,7.595985,,24.782295
"""std""",,,,,,0.151972,10.273725,,25.823738
"""min""","""ID1011485656__Dry_Clover_g""","""train/ID1011485656.jpg""","""2015/1/15""","""NSW""","""Clover""",0.16,1.0,"""Dry_Clover_g""",0.0
"""25%""",,,,,,0.56,3.0,,4.8182
"""50%""",,,,,,0.69,4.0,,18.2
"""75%""",,,,,,0.77,7.0,,35.9406
"""max""","""ID983582017__GDM_g""","""train/ID983582017.jpg""","""2015/9/4""","""WA""","""WhiteClover""",0.91,70.0,"""GDM_g""",185.7


train
(5, 3)


sample_id,image_path,target_name
str,str,str
"""ID1001187975__Dry_Clover_g""","""test/ID1001187975.jpg""","""Dry_Clover_g"""
"""ID1001187975__Dry_Dead_g""","""test/ID1001187975.jpg""","""Dry_Dead_g"""
"""ID1001187975__Dry_Green_g""","""test/ID1001187975.jpg""","""Dry_Green_g"""


train
(5, 2)


sample_id,target
str,f64
"""ID1001187975__Dry_Clover_g""",0.0
"""ID1001187975__Dry_Dead_g""",0.0
"""ID1001187975__Dry_Green_g""",0.0


In [4]:
# =================
# Load image paths 
# =================
train_imgs = glob_walk(config.DATA_DIR / "train", "*.jpg")
test_imgs = glob_walk(config.DATA_DIR / "test", "*.jpg")

print(f"Number of train imgs: {len(train_imgs)}"); print(f"Sample: \n{train_imgs[:3]}")
print(f"Number of test imgs: {len(test_imgs)}"); print(f"Sample: \n{test_imgs[:3]}")

Number of train imgs: 357
Sample: 
[PosixPath('/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/data/raw/train/ID1011485656.jpg'), PosixPath('/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/data/raw/train/ID1012260530.jpg'), PosixPath('/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/data/raw/train/ID1025234388.jpg')]
Number of test imgs: 1
Sample: 
[PosixPath('/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/data/raw/test/ID1001187975.jpg')]


In [5]:
train = (
    train.with_columns(
        pl.col("sample_id").str.split("__").list.get(0).alias("id"),
        pl.col("Sampling_Date").str.to_datetime(format="%Y/%m/%d").alias("Sampling_Date")
        )
        .with_columns(
          pl.col("Sampling_Date").dt.ordinal_day().alias("Sampling_dayofyear"), # 1年の初めから何日。
          pl.col("Sampling_Date").dt.day().alias("Sampling_day"),
          pl.col("Sampling_Date").dt.weekday().alias("Sampling_weekday"),
          pl.col("Sampling_Date").dt.month().alias("Sampling_month"),
          pl.col("Sampling_Date").dt.year().alias("Sampling_year"),
          )
          )
sep("process datetime"); show_df(train)

process datetime
(1785, 15)


sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target,id,Sampling_dayofyear,Sampling_day,Sampling_weekday,Sampling_month,Sampling_year
str,str,datetime[μs],str,str,f64,f64,str,f64,str,i16,i8,i8,i8,i32
"""ID1011485656__Dry_Clover_g""","""train/ID1011485656.jpg""",2015-09-04 00:00:00,"""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Clover_g""",0.0,"""ID1011485656""",247,4,5,9,2015
"""ID1011485656__Dry_Dead_g""","""train/ID1011485656.jpg""",2015-09-04 00:00:00,"""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Dead_g""",31.9984,"""ID1011485656""",247,4,5,9,2015
"""ID1011485656__Dry_Green_g""","""train/ID1011485656.jpg""",2015-09-04 00:00:00,"""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Green_g""",16.2751,"""ID1011485656""",247,4,5,9,2015


In [10]:
df_pivot = train.pivot(
    values="target",
    index="id",
    columns="target_name",
    aggregate_function="first",  # もし重複があるなら "mean" などに変更
)
meta = train.group_by("id").agg([
    pl.first("Pre_GSHH_NDVI").alias("Pre_GSHH_NDVI"),
    pl.first("Height_Ave_cm").alias("Height_Ave_cm"),
])

df_pivot = (df_pivot.join(meta, on="id", how="left"))

# 列順を整えたい場合（id, meta, target列…の順）
target_cols = [c for c in df_pivot.columns if c != "id"]
df_pivot = df_pivot.select(["id"] + target_cols)
show_df(df_pivot)

(357, 8)


  df_pivot = train.pivot(


id,Dry_Clover_g,Dry_Dead_g,Dry_Green_g,Dry_Total_g,GDM_g,Pre_GSHH_NDVI,Height_Ave_cm
str,f64,f64,f64,f64,f64,f64,f64
"""ID1011485656""",0.0,31.9984,16.2751,48.2735,16.275,0.62,4.6667
"""ID1012260530""",0.0,0.0,7.6,7.6,7.6,0.55,16.0
"""ID1025234388""",6.05,0.0,0.0,6.05,6.05,0.38,1.0


In [None]:
# Identify numerical and categorical columns
numetric_cols = train.select(pl.selectors.numeric()).columns
print(f"Numerical columns: {numetric_cols}")
categorical_cols = [col for col in train.columns if col not in numetric_cols+["sample_id", "image_path", "Sampling_Date"]]
print(f"Categorical columns: {categorical_cols}")

In [None]:
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
%matplotlib inline

def plot_hist(df, ncols=4, features=None, title="show_features", show_output=True):
    """
    ヒストグラムを描画（NaN対応版）
    
    Parameters:
    -----------
    df : pl.DataFrame
        入力データフレーム
    ncols : int
        列数
    features : list
        描画するカラムのリスト
    title : str
        タイトル
    show_output : bool
        表示するかどうか
    """
    if features is None:
        features = df.columns
    
    # NaNのみのカラムを除外
    valid_features = []
    for col_name in features:
        col_data = df[col_name].to_numpy()
        # 有効な値が1つ以上存在するかチェック
        if np.isfinite(col_data).sum() > 0:
            valid_features.append(col_name)
        else:
            print(f"警告: '{col_name}' は全てNaNのためスキップします")
    
    if not valid_features:
        print("エラー: 有効なカラムがありません")
        return
    
    # グリッドサイズ計算
    nrows = int(np.ceil(len(valid_features) / ncols))
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 3))
    axes = axes.flatten() if nrows * ncols > 1 else [axes]
    
    for idx, col_name in enumerate(valid_features):
        ax = axes[idx]
        
        # データ取得とNaN除去
        col_data = df[col_name].to_numpy()
        col_data_clean = col_data[np.isfinite(col_data)]
        
        # ヒストグラム描画
        if len(col_data_clean) > 0:
            ax.hist(col_data_clean, bins=60, alpha=0.7, density=True,
                   histtype="stepfilled", label="data", log=True
                   )
            ax.set_title(f"{col_name}\n(valid: {len(col_data_clean)}/{len(col_data)})")
            ax.legend()
        else:
            ax.text(0.5, 0.5, "No valid data", ha='center', va='center')
            ax.set_title(col_name)
    
    # 余った軸を非表示
    for idx in range(len(valid_features), len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    
    if show_output:
        plt.show()

In [None]:
from pylab import rcParams
def show_category_data(df:pd.DataFrame, col:str, title:str)->None:
    """show categorical data

    Args:
        df (pd.DataFrame): input data
        col (str): columns 
        title (str): title of figure
    """
    print(f'Unique number of {col}: {len(df[col].unique())}')
    print("-"*80); print(f"Frequent appearances of {col}"); print("-"*80); print(f"{df[col].value_counts()[:20]}")
    rcParams['figure.figsize'] = 10, 5
    topn = df[col].value_counts().head(50)

    # ラベルを20文字で切り詰め
    labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in topn.index]
    topn.plot(kind="bar")
    plt.xticks(range(len(labels)), labels, rotation=90)  # 文字列を切り詰めたものを使用
    plt.title(title)
    plt.xlabel(col)
    plt.tight_layout()
    plt.show()

In [None]:
plot_hist(train, ncols=4, features=numetric_cols, title="show_features", show_output=True)

In [None]:
for col in categorical_cols:
    sep(col)
    show_category_data(train.to_pandas(), col, col)

# IMAGES

In [None]:
import polars as pl
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import math

def visualize_species_images(
    df: pl.DataFrame,
    config,
    species_col: str = "Species",
    image_path_col: str = "Image_path",
    max_images: int = 8,
    ncols: int = 4,
    sample_step: int = 5  # 新しいパラメータ: サンプリング間隔
):
    """
    Speciesごとに画像を可視化（改良版）
    
    Parameters:
    -----------
    df : pl.DataFrame
        データフレーム
    config : Config
        設定オブジェクト
    species_col : str
        Species列名
    image_path_col : str
        画像パス列名
    max_images : int
        表示する最大画像数
    ncols : int
        列数
    sample_step : int
        サンプリング間隔（例: 5なら5行ごとに1枚取得）
    """
    
    # 列名の確認
    if species_col not in df.columns:
        available_cols = [col for col in df.columns if 'species' in col.lower()]
        if available_cols:
            print(f"'{species_col}' が見つかりません。代わりに '{available_cols[0]}' を使用します")
            species_col = available_cols[0]
        else:
            print(f"エラー: Species列が見つかりません。利用可能な列: {df.columns}")
            return
    
    if image_path_col not in df.columns:
        available_cols = [col for col in df.columns if 'image' in col.lower() or 'path' in col.lower()]
        if available_cols:
            print(f"'{image_path_col}' が見つかりません。代わりに '{available_cols[0]}' を使用します")
            image_path_col = available_cols[0]
        else:
            print(f"エラー: 画像パス列が見つかりません。利用可能な列: {df.columns}")
            return
    
    print(f"使用する列: Species='{species_col}', ImagePath='{image_path_col}'")
    print(f"サンプリング間隔: {sample_step}行ごと\n")
    
    unique_species = df[species_col].unique().sort()
    print(f"ユニークなSpecies数: {len(unique_species)}\n")
    
    for species in unique_species:
        # 該当Speciesのデータをフィルタ
        species_data = df.filter(pl.col(species_col) == species)
        
        if len(species_data) == 0:
            continue
        
        # sample_step間隔でサンプリング
        # 方法1: gather_every を使用（Polars 0.19.0以降）
        if hasattr(species_data, 'gather_every'):
            tmp = species_data.gather_every(sample_step).head(max_images)
        else:
            # 方法2: slice を使用（古いバージョン対応）
            indices = list(range(0, len(species_data), sample_step))[:max_images]
            tmp = species_data[indices]
        
        print(f"{species}: 総データ数={len(species_data)}, サンプリング後={len(tmp)}")
        
        if len(tmp) == 0:
            continue
        
        # 辞書のリストに変換
        rows = tmp.to_dicts()
        
        # グリッドサイズ計算
        n_images = len(rows)
        nrows = math.ceil(n_images / ncols)
        
        # サブプロット作成
        fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
        
        # axesを1次元化
        if nrows == 1 and ncols == 1:
            axes = [axes]
        else:
            axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
        
        # 各画像を表示
        for idx, row in enumerate(rows):
            ax = axes[idx]
            
            # 画像パスを取得
            image_rel_path = row[image_path_col]
            image_full_path = Path(config.DATA_DIR) / image_rel_path
            
            try:
                # 画像を読み込み
                img = Image.open(image_full_path)
                
                # 画像を表示
                ax.imshow(img)
                
                # タイトル作成
                image_id = Path(image_rel_path).stem
                title = f"ID: {image_id}"
                
                # 日付情報があれば追加
                if "Sampling_Date" in row:
                    title += f"\n{row['Sampling_Date']}"
                
                ax.set_title(title, fontsize=9)
                ax.axis('off')
                
            except FileNotFoundError:
                ax.text(0.5, 0.5, f"画像なし\n{Path(image_rel_path).name}", 
                       ha='center', va='center', fontsize=8, color='red')
                ax.set_title(f"ID: {Path(image_rel_path).stem}", fontsize=9)
                ax.axis('off')
            except Exception as e:
                ax.text(0.5, 0.5, f"エラー\n{type(e).__name__}", 
                       ha='center', va='center', fontsize=8, color='red')
                ax.set_title(f"Error: {str(e)[:20]}", fontsize=8)
                ax.axis('off')
        
        # 余った軸を非表示
        for idx in range(n_images, len(axes)):
            axes[idx].axis('off')
        
        plt.suptitle(f"Species: {species} (表示: {n_images}/{len(species_data)})", 
                    fontsize=16, fontweight='bold', y=0.98)
        plt.tight_layout()
        plt.show()
        
        print(f"✓ {species}: {n_images}枚の画像を表示\n")

# 使用例: 5つおきにサンプリング
visualize_species_images(train, config, max_images=8, ncols=4, sample_step=5)