In [4]:
import pandas as pd
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Geometry is in a geographic CRS.*")
warnings.simplefilter(action="ignore", category=pd.errors.SettingWithCopyWarning)

# Load in hybrid boundary stats

In [5]:
def merge_hybrid(shp_path, hybrid_stats_path):
    """
    Merges hybrid boundary shapefile with hybrid statistics.
    
    Parameters:
    shp_path (str): Path to the hybrid boundary shapefile.
    hybrid_stats_path (str): Path to the hybrid statistics CSV file.
    
    Returns:
    gpd.GeoDataFrame: Merged GeoDataFrame containing hybrid boundaries and statistics.
    """
    hybrid = gpd.read_file(shp_path)
    hybrid_stats = pd.read_csv(hybrid_stats_path)
    
    stats = hybrid_stats[['Admin 2', 'Year', 'Area Harvested: ha', 'Yield: MT/ha', 'Quantity Produced: MT', 'Source crop']]
    stats.columns = ["name_state", "year", "area_harvested", "yield", "production", "ag_hy_crop"]
    
    merged_hybrid = stats.merge(hybrid, how='left')
    merged_hybrid = gpd.GeoDataFrame(merged_hybrid, geometry='geometry')
    
    return hybrid, merged_hybrid

In [6]:
shp_path = "../shapefiles/2016_2023_hybrid_boundary_071925.shp" 
hybrid_stats_path = "../../data/temporally_harmonized_crop_data_072425.csv"

In [7]:
hybrid, merged_hybrid = merge_hybrid(shp_path, hybrid_stats_path)


# Load in ICRISAT product

In [8]:
# Load in the 1966 boundaries with names matched to ICRISAT data
icri_path = "../shapefiles/icrisat_apportioned/icrisat_boundary_match.shp"
icri = gpd.read_file(icri_path)

icri_apportioned_path = '../shapefiles/icrisat_apportioned/ICRISAT-District Level Data_Apportioned.csv'
icri_data = pd.read_csv(icri_apportioned_path)

In [9]:
icri.head(10).to_csv("icri_head.csv", index=False)

In [10]:
if hybrid.crs != icri.crs:
    icri = icri.to_crs(hybrid.crs)

colname1 = "name"
colname2 = "Dist_Name"

hybrid['name'] = hybrid[colname1]
icri['name'] = icri[colname2]

In [11]:
#Load in the ICRISAT apportioned data
icri = icri.rename(columns={'Name_12': 'Dist Name', 'NAME_1' : 'State Name'})
icri.loc[icri['State Name']== "Uttaranchal", "State Name"] = "Uttar Pradesh"
merged_icri = gpd.GeoDataFrame(icri.merge(icri_data, how="left", on=['Dist Name', 'State Name']))

# Constrained optimization approach

In [12]:
def transform_all_crops_all_years(merged_hybrid_data, trained_optimizer, years=None):
    """
    Apply trained weights to ALL crops in your data for all years (2016-2022)
    Uses weights trained on Wheat/Soyabean/Barley and applies to all crops
    
    Parameters:
    -----------
    merged_hybrid_data : pd.DataFrame
        Your full dataset with all crops and years
    trained_optimizer : FastSpatialOptimizer
        Your trained optimizer (trained on 3 crops)
    years : list, optional
        Years to transform (defaults to 2016-2022)
    """
    
    if years is None:
        years = [2016, 2017, 2018, 2019, 2020, 2021, 2022]
    
    # Get all available crops from your data
    all_crops = sorted(merged_hybrid_data['ag_hy_crop'].unique())
    
    # Filter to crops that have ICRISAT mappings
    crops_with_mappings = []
    crops_without_mappings = []
    
    for crop in all_crops:
        if crop.upper() in trained_optimizer.crop_mapping:
            crops_with_mappings.append(crop)
        else:
            crops_without_mappings.append(crop)
    
    print("=== TRANSFORMING ALL CROPS TO ICRISAT FORMAT ===")
    print(f"Using weights trained on: Wheat, Soyabean, Barley")
    print(f"Applying to {len(crops_with_mappings)} crops with ICRISAT mappings")
    print(f"Years: {years}")
    
    if crops_without_mappings:
        print(f"\n⚠️  {len(crops_without_mappings)} crops without ICRISAT mappings (will be skipped):")
        for crop in crops_without_mappings[:10]:  # Show first 10
            print(f"    {crop}")
        if len(crops_without_mappings) > 10:
            print(f"    ... and {len(crops_without_mappings)-10} more")
    
    print(f"\n✅ Crops to be transformed:")
    for crop in crops_with_mappings:
        icrisat_crop = trained_optimizer.crop_mapping.get(crop.upper(), crop.upper())
        print(f"  {crop} -> {icrisat_crop}")
    
    # Create reverse mapping from target indices to district names
    target_idx_to_name = {idx: name for name, idx in trained_optimizer.target_name_to_idx.items()}
    
    all_results = []
    
    print(f"\n=== PROCESSING YEARS ===")
    for year in years:
        print(f"\nProcessing {year}...")
        
        year_totals = {'original_area': 0, 'original_prod': 0, 'transformed_area': 0, 'transformed_prod': 0}
        
        for crop in crops_with_mappings:
            # Prepare source data for this crop and year
            try:
                crop_data = trained_optimizer.prepare_crop_data_vectors_only(merged_hybrid_data, crop, year)
                
                # Transform using trained weights (from Wheat/Soyabean/Barley)
                transformed_area = trained_optimizer.weight_matrix_area @ crop_data['source_area']
                transformed_production = trained_optimizer.weight_matrix_production @ crop_data['source_production']
                
                # Get ICRISAT crop name
                icrisat_crop = trained_optimizer.crop_mapping.get(crop.upper(), crop.upper())
                
                # Track totals for conservation check
                orig_area = crop_data['source_area'].sum()
                orig_prod = crop_data['source_production'].sum()
                trans_area = transformed_area.sum()
                trans_prod = transformed_production.sum()
                
                year_totals['original_area'] += orig_area
                year_totals['original_prod'] += orig_prod
                year_totals['transformed_area'] += trans_area
                year_totals['transformed_prod'] += trans_prod
                
                # Create results for this crop
                for target_idx in range(trained_optimizer.n_target):
                    district_name = target_idx_to_name[target_idx]
                    
                    area_val = transformed_area[target_idx]
                    prod_val = transformed_production[target_idx]
                    
                    all_results.append({
                        'district_name': district_name,
                        'year': year,
                        'crop': crop,
                        'icrisat_crop': icrisat_crop,
                        'area_ha': area_val,
                        'production_tons': prod_val,
                        'yield_kg_per_ha': prod_val / area_val if area_val > 0 else 0
                    })
                
            except Exception as e:
                print(f"    ⚠️  Error processing {crop}: {e}")
        
        # Print year-level conservation summary
        area_conservation = abs(year_totals['transformed_area'] - year_totals['original_area']) / year_totals['original_area'] * 100 if year_totals['original_area'] > 0 else 0
        prod_conservation = abs(year_totals['transformed_prod'] - year_totals['original_prod']) / year_totals['original_prod'] * 100 if year_totals['original_prod'] > 0 else 0
        
        print(f"  Year {year} totals: Area conservation error: {area_conservation:.2f}%, Prod conservation error: {prod_conservation:.2f}%")
    
    # Convert to DataFrame
    result_df = pd.DataFrame(all_results)
    
    print(f"\n✅ ALL CROPS TRANSFORMATION COMPLETE!")
    print(f"Output shape: {result_df.shape}")
    print(f"Years: {sorted(result_df['year'].unique())}")
    print(f"Crops transformed: {len(result_df['crop'].unique())}")
    print(f"Districts: {len(result_df['district_name'].unique())}")
    
    # Summary by crop
    print(f"\nCrop summary:")
    crop_summary = result_df.groupby('crop').agg({
        'area_ha': 'sum',
        'production_tons': 'sum'
    }).round(0)
    
    for crop in crop_summary.index[:10]:  # Show top 10 crops by area
        area = crop_summary.loc[crop, 'area_ha']
        prod = crop_summary.loc[crop, 'production_tons']
        print(f"  {crop}: {area:,.0f} ha, {prod:,.0f} tons")
    
    if len(crop_summary) > 10:
        print(f"  ... and {len(crop_summary)-10} more crops")
    
    return result_df

def create_full_icrisat_format(transformed_long_df):
    """
    Convert the long format (all crops) to ICRISAT wide format
    
    Parameters:
    -----------
    transformed_long_df : pd.DataFrame
        Output from transform_all_crops_all_years()
    """
    
    print("=== CREATING FULL ICRISAT FORMAT ===")
    
    # Get unique combinations
    years = sorted(transformed_long_df['year'].unique())
    districts = sorted(transformed_long_df['district_name'].unique())
    crops = sorted(transformed_long_df['icrisat_crop'].unique())
    
    print(f"Years: {len(years)}")
    print(f"Districts: {len(districts)}")  
    print(f"Crops: {len(crops)}")
    
    all_rows = []
    
    for year in years:
        print(f"Processing {year}...")
        
        for district in districts:
            # Start row with district info
            row = {
                'name': district,
                'Year': year
            }
            
            # Get data for this district-year
            district_year_data = transformed_long_df[
                (transformed_long_df['district_name'] == district) & 
                (transformed_long_df['year'] == year)
            ]
            
            # Add each crop's data as columns
            for _, crop_row in district_year_data.iterrows():
                icrisat_crop = crop_row['icrisat_crop']
                
                # Create column names in ICRISAT format
                area_col = f'{icrisat_crop} AREA (1000 ha)'
                prod_col = f'{icrisat_crop} PRODUCTION (1000 tons)'
                yield_col = f'{icrisat_crop} YIELD (Kg per ha)'
                
                # Convert back to ICRISAT units (divide by 1000)
                row[area_col] = crop_row['area_ha'] / 1000
                row[prod_col] = crop_row['production_tons'] / 1000
                row[yield_col] = crop_row['yield_kg_per_ha']
            
            all_rows.append(row)
    
    result_df = pd.DataFrame(all_rows)
    
    print(f"✅ FULL ICRISAT FORMAT CREATED!")
    print(f"Shape: {result_df.shape}")
    print(f"Columns: {len(result_df.columns)} (name, Year + {len(result_df.columns)-2} crop columns)")
    
    return result_df

In [13]:
import pandas as pd
import numpy as np
import geopandas as gpd
import warnings
warnings.filterwarnings('ignore')

class FastSpatialOptimizer:
    """
    Fast spatial weight optimizer using iterative proportional fitting (IPF)
    Much faster and more stable than constrained optimization
    """
    
    def __init__(self):
        self.intersection_mask = None
        self.initial_weights = None
        self.weight_matrix_area = None
        self.weight_matrix_production = None
        self.crop_mapping = None
        
    def load_crop_mapping(self, crop_key_df):
        """Load crop mapping from your cropkey.csv"""
        self.crop_mapping = dict(zip(
            crop_key_df['crop'].str.upper(), 
            crop_key_df['icrisat_crop'].str.upper()
        ))
        print(f"Loaded {len(self.crop_mapping)} crop mappings")
        
    def compute_intersections(self, source_boundaries, target_boundaries):
        """Compute spatial intersections - same as before but faster"""
        print("Computing spatial intersections...")
        
        # Ensure same CRS
        if source_boundaries.crs != target_boundaries.crs:
            target_boundaries = target_boundaries.to_crs(source_boundaries.crs)
        
        n_target = len(target_boundaries)
        n_source = len(source_boundaries)
        
        print(f"Source districts: {n_source}")
        print(f"Target districts: {n_target}")
        
        # Initialize sparse intersection matrix
        intersection_weights = {}  # (i,j): weight
        
        # Create mappings
        source_name_to_idx = {row['name_state']: idx for idx, row in source_boundaries.iterrows()}
        target_name_to_idx = {row['name']: idx for idx, row in target_boundaries.iterrows()}
        
        # Store the mappings
        self.source_name_to_idx = source_name_to_idx
        self.target_name_to_idx = target_name_to_idx
        self.n_source = n_source
        self.n_target = n_target
        
        # Compute intersections - only store non-zero weights
        intersection_count = 0
        for i, target_row in target_boundaries.iterrows():
            target_geom = target_row.geometry
            target_idx = target_name_to_idx[target_row['name']]
            
            for j, source_row in source_boundaries.iterrows():
                source_geom = source_row.geometry
                source_idx = source_name_to_idx[source_row['name_state']]
                
                if target_geom.intersects(source_geom):
                    intersect_area = target_geom.intersection(source_geom).area
                    source_area = source_geom.area
                    if intersect_area > 0:
                        weight = intersect_area / source_area
                        intersection_weights[(target_idx, source_idx)] = weight
                        intersection_count += 1
        
        self.intersection_weights = intersection_weights
        print(f"Found {intersection_count} spatial intersections")
        return intersection_weights
    
    def prepare_crop_data(self, source_data, target_data, crop_name, year):
        """Prepare data for one crop and year - same as before"""
        
        # Get ICRISAT crop name
        if self.crop_mapping is None:
            raise ValueError("Must load crop mapping first")
            
        icrisat_crop = self.crop_mapping.get(crop_name.upper(), crop_name.upper())
        
        # Filter source data
        source_subset = source_data[
            (source_data['year'] == year) & 
            (source_data['ag_hy_crop'] == crop_name)
        ].copy()
        
        # Filter target data  
        target_subset = target_data[target_data['Year'] == year].copy()
        
        # Create vectors
        source_area = np.zeros(self.n_source)
        source_production = np.zeros(self.n_source)
        target_area = np.zeros(self.n_target)
        target_production = np.zeros(self.n_target)
        
        # Fill source vectors
        for _, row in source_subset.iterrows():
            if row['name_state'] in self.source_name_to_idx:
                idx = self.source_name_to_idx[row['name_state']]
                source_area[idx] = row['area_harvested'] if pd.notna(row['area_harvested']) else 0
                source_production[idx] = row['production'] if pd.notna(row['production']) else 0
        
        # ICRISAT column names
        area_col = f'{icrisat_crop} AREA (1000 ha)'
        prod_col = f'{icrisat_crop} PRODUCTION (1000 tons)'
        
        # Fill target vectors and convert units
        for _, row in target_subset.iterrows():
            if row['name'] in self.target_name_to_idx:
                idx = self.target_name_to_idx[row['name']]
                if area_col in target_subset.columns and pd.notna(row[area_col]):
                    target_area[idx] = row[area_col] * 1000  # Convert to ha
                if prod_col in target_subset.columns and pd.notna(row[prod_col]):
                    target_production[idx] = row[prod_col] * 1000  # Convert to tons
        
        print(f"  {crop_name} -> {icrisat_crop}")
        print(f"    Source area: {source_area.sum():.0f}, target area: {target_area.sum():.0f}")
        print(f"    Source prod: {source_production.sum():.0f}, target prod: {target_production.sum():.0f}")
        
        return {
            'source_area': source_area,
            'source_production': source_production,
            'target_area': target_area, 
            'target_production': target_production,
            'icrisat_crop': icrisat_crop
        }
    
    def prepare_crop_data_vectors_only(self, source_data, crop_name, year):
        """
        Helper function to prepare just the source vectors for transformation
        (doesn't need target data since we're just transforming)
        """
        
        # Filter source data
        source_subset = source_data[
            (source_data['year'] == year) & 
            (source_data['ag_hy_crop'] == crop_name)
        ].copy()
        
        # Create source vectors
        source_area = np.zeros(self.n_source)
        source_production = np.zeros(self.n_source)
        
        # Fill source vectors
        for _, row in source_subset.iterrows():
            if row['name_state'] in self.source_name_to_idx:
                idx = self.source_name_to_idx[row['name_state']]
                source_area[idx] = row['area_harvested'] if pd.notna(row['area_harvested']) else 0
                source_production[idx] = row['production'] if pd.notna(row['production']) else 0
        
        return {
            'source_area': source_area,
            'source_production': source_production
        }
    
    def optimize_weights_ipf(self, crop_data_dict, data_type='area', max_iterations=50):
        """
        Optimize weights using Iterative Proportional Fitting (IPF)
        Much faster and more stable than constrained optimization
        """
        
        crops = list(crop_data_dict.keys())
        n_crops = len(crops)
        
        print(f"Optimizing {data_type} weights using IPF...")
        print(f"  Crops: {crops}")
        
        # Prepare data matrices
        source_matrix = np.zeros((self.n_source, n_crops))
        target_matrix = np.zeros((self.n_target, n_crops))
        
        for i, crop in enumerate(crops):
            if data_type == 'area':
                source_matrix[:, i] = crop_data_dict[crop]['source_area']
                target_matrix[:, i] = crop_data_dict[crop]['target_area']
            else:
                source_matrix[:, i] = crop_data_dict[crop]['source_production']
                target_matrix[:, i] = crop_data_dict[crop]['target_production']
        
        # Check for valid crops
        valid_crops = []
        valid_cols = []
        for i, crop in enumerate(crops):
            if source_matrix[:, i].sum() > 0 and target_matrix[:, i].sum() > 0:
                valid_crops.append(crop)
                valid_cols.append(i)
        
        if len(valid_crops) == 0:
            print("❌ No valid crops")
            return np.zeros((self.n_target, self.n_source))
        
        print(f"  Valid crops: {valid_crops}")
        
        # Use only valid crops
        source_matrix = source_matrix[:, valid_cols]
        target_matrix = target_matrix[:, valid_cols]
        
        # Initialize weight matrix with area proportions
        W = np.zeros((self.n_target, self.n_source))
        for (i, j), weight in self.intersection_weights.items():
            W[i, j] = weight
        
        # Normalize so each source column sums to 1
        col_sums = W.sum(axis=0)
        for j in range(self.n_source):
            if col_sums[j] > 0:
                W[:, j] = W[:, j] / col_sums[j]
        
        print(f"  Starting IPF with {len(self.intersection_weights)} non-zero weights...")
        
        # IPF iterations
        for iteration in range(max_iterations):
            W_old = W.copy()
            
            # For each crop, adjust weights to match target totals
            for crop_idx in range(len(valid_crops)):
                source_vec = source_matrix[:, crop_idx]
                target_vec = target_matrix[:, crop_idx]
                
                # Skip if no data
                source_total = source_vec.sum()
                target_total = target_vec.sum()
                if source_total == 0 or target_total == 0:
                    continue
                
                # Current prediction
                predicted = W @ source_vec
                predicted_total = predicted.sum()
                
                # Adjust weights proportionally to match target total
                if predicted_total > 0:
                    adjustment_factor = target_total / predicted_total
                    
                    # Apply adjustment only to non-zero weights
                    for (i, j) in self.intersection_weights.keys():
                        if source_vec[j] > 0:  # Only adjust if source has data
                            W[i, j] *= adjustment_factor
            
            # Check convergence
            weight_change = np.max(np.abs(W - W_old))
            if weight_change < 1e-6:
                print(f"  ✅ IPF converged after {iteration + 1} iterations")
                break
            
            if (iteration + 1) % 10 == 0:
                print(f"    Iteration {iteration + 1}, max weight change: {weight_change:.6f}")
        
        else:
            print(f"  ⚠️  IPF reached max iterations ({max_iterations})")
        
        # Store results
        if data_type == 'area':
            self.weight_matrix_area = W
        else:
            self.weight_matrix_production = W
        
        # Calculate and print final results
        predicted_matrix = W @ source_matrix
        
        print(f"  Final Results:")
        total_rmse = 0
        for i, crop in enumerate(valid_crops):
            predicted = predicted_matrix[:, i]
            actual = target_matrix[:, i]
            
            rmse = np.sqrt(np.mean((predicted - actual)**2))
            total_rmse += rmse
            
            source_total = source_matrix[:, i].sum()
            target_total = actual.sum()
            predicted_total = predicted.sum()
            conservation_error = abs(predicted_total - target_total)
            conservation_pct = conservation_error / target_total * 100 if target_total > 0 else 0
            
            print(f"    {crop}: RMSE={rmse:.0f}, Conservation error={conservation_error:.0f} ({conservation_pct:.2f}%)")
            print(f"      Source: {source_total:.0f}, Target: {target_total:.0f}, Predicted: {predicted_total:.0f}")
        
        print(f"  Average RMSE: {total_rmse/len(valid_crops):.0f}")
        return W
    
    def validate_on_test_year(self, source_data, target_data, crops, test_year):
        """Test the trained weights on a different year"""
        
        print(f"\n=== TESTING ON YEAR {test_year} ===")
        
        # Prepare test data
        test_crop_data = {}
        for crop in crops:
            test_crop_data[crop] = self.prepare_crop_data(source_data, target_data, crop, test_year)
        
        # Test area weights
        if self.weight_matrix_area is not None:
            print("\nTesting AREA weights:")
            self._test_weights(test_crop_data, 'area')
        
        # Test production weights  
        if self.weight_matrix_production is not None:
            print("\nTesting PRODUCTION weights:")
            self._test_weights(test_crop_data, 'production')
    
    def _test_weights(self, crop_data_dict, data_type):
        """Helper function to test weights"""
        
        crops = list(crop_data_dict.keys())
        
        # Get weight matrix
        W = self.weight_matrix_area if data_type == 'area' else self.weight_matrix_production
        
        # Prepare test data
        source_matrix = np.zeros((self.n_source, len(crops)))
        target_matrix = np.zeros((self.n_target, len(crops)))
        
        for i, crop in enumerate(crops):
            if data_type == 'area':
                source_matrix[:, i] = crop_data_dict[crop]['source_area']
                target_matrix[:, i] = crop_data_dict[crop]['target_area']
            else:
                source_matrix[:, i] = crop_data_dict[crop]['source_production']
                target_matrix[:, i] = crop_data_dict[crop]['target_production']
        
        # Predict
        predicted_matrix = W @ source_matrix
        
        # Calculate metrics
        print(f"  Test Results:")
        total_rmse = 0
        valid_crops = 0
        
        for i, crop in enumerate(crops):
            predicted = predicted_matrix[:, i]
            actual = target_matrix[:, i]
            
            if actual.sum() > 0:
                rmse = np.sqrt(np.mean((predicted - actual)**2))
                total_rmse += rmse
                valid_crops += 1
                conservation_error = abs(predicted.sum() - actual.sum())
                conservation_pct = conservation_error / actual.sum() * 100
                
                print(f"    {crop}: RMSE={rmse:.0f}, Conservation error={conservation_error:.0f} ({conservation_pct:.2f}%)")
        
        if valid_crops > 0:
            print(f"  Average Test RMSE: {total_rmse/valid_crops:.0f}")

def run_fast_optimization(merged_hybrid_data, merged_icri_data, hybrid_boundaries, icri_boundaries, crop_key_df):
    """
    Fast workflow using IPF for Wheat, Soyabean, and Barley
    """
    
    crops = ['Wheat', 'Soyabean', 'Barley']
    
    print("=== FAST SPATIAL WEIGHT OPTIMIZATION ===")
    print(f"Target crops: {crops}")
    print("Using Iterative Proportional Fitting (IPF) - much faster!")
    
    # Initialize
    optimizer = FastSpatialOptimizer()
    optimizer.load_crop_mapping(crop_key_df)
    
    # Compute intersections
    optimizer.compute_intersections(hybrid_boundaries, icri_boundaries)
    
    # Prepare training data (2016)
    print(f"\n=== PREPARING TRAINING DATA (2016) ===")
    train_crop_data = {}
    for crop in crops:
        train_crop_data[crop] = optimizer.prepare_crop_data(merged_hybrid_data, merged_icri_data, crop, 2016)
    
    # Train using IPF - much faster!
    print(f"\n=== TRAINING AREA WEIGHTS ===")
    area_weights = optimizer.optimize_weights_ipf(train_crop_data, 'area')
    
    print(f"\n=== TRAINING PRODUCTION WEIGHTS ===") 
    prod_weights = optimizer.optimize_weights_ipf(train_crop_data, 'production')
    
    # Test on 2017
    optimizer.validate_on_test_year(merged_hybrid_data, merged_icri_data, crops, 2017)
    
    print(f"\n=== OPTIMIZATION COMPLETE ===")
    
    return {
        'optimizer': optimizer,
        'area_weights': area_weights,
        'production_weights': prod_weights,
        'crops': crops
    }

def transform_all_crops_all_years(merged_hybrid_data, trained_optimizer, years=None):
    """
    Apply trained weights to ALL crops in your data for all years (2016-2022)
    Uses weights trained on Wheat/Soyabean/Barley and applies to all crops
    """
    
    if years is None:
        years = [2016, 2017, 2018, 2019, 2020, 2021, 2022]
    
    # Get all available crops from your data
    all_crops = sorted(merged_hybrid_data['ag_hy_crop'].unique())
    
    # Filter to crops that have ICRISAT mappings
    crops_with_mappings = []
    crops_without_mappings = []
    
    for crop in all_crops:
        if crop.upper() in trained_optimizer.crop_mapping:
            crops_with_mappings.append(crop)
        else:
            crops_without_mappings.append(crop)
    
    print("=== TRANSFORMING ALL CROPS TO ICRISAT FORMAT ===")
    print(f"Using weights trained on: Wheat, Soyabean, Barley")
    print(f"Applying to {len(crops_with_mappings)} crops with ICRISAT mappings")
    print(f"Years: {years}")
    
    if crops_without_mappings:
        print(f"\n⚠️  {len(crops_without_mappings)} crops without ICRISAT mappings (will be skipped):")
        for crop in crops_without_mappings[:10]:  # Show first 10
            print(f"    {crop}")
        if len(crops_without_mappings) > 10:
            print(f"    ... and {len(crops_without_mappings)-10} more")
    
    print(f"\n✅ Crops to be transformed:")
    for crop in crops_with_mappings:
        icrisat_crop = trained_optimizer.crop_mapping.get(crop.upper(), crop.upper())
        print(f"  {crop} -> {icrisat_crop}")
    
    # Create reverse mapping from target indices to district names
    target_idx_to_name = {idx: name for name, idx in trained_optimizer.target_name_to_idx.items()}
    
    all_results = []
    
    print(f"\n=== PROCESSING YEARS ===")
    for year in years:
        print(f"\nProcessing {year}...")
        
        year_totals = {'original_area': 0, 'original_prod': 0, 'transformed_area': 0, 'transformed_prod': 0}
        
        for crop in crops_with_mappings:
            # Prepare source data for this crop and year
            try:
                crop_data = trained_optimizer.prepare_crop_data_vectors_only(merged_hybrid_data, crop, year)
                
                # Transform using trained weights (from Wheat/Soyabean/Barley)
                transformed_area = trained_optimizer.weight_matrix_area @ crop_data['source_area']
                transformed_production = trained_optimizer.weight_matrix_production @ crop_data['source_production']
                
                # Get ICRISAT crop name
                icrisat_crop = trained_optimizer.crop_mapping.get(crop.upper(), crop.upper())
                
                # Track totals for conservation check
                orig_area = crop_data['source_area'].sum()
                orig_prod = crop_data['source_production'].sum()
                trans_area = transformed_area.sum()
                trans_prod = transformed_production.sum()
                
                year_totals['original_area'] += orig_area
                year_totals['original_prod'] += orig_prod
                year_totals['transformed_area'] += trans_area
                year_totals['transformed_prod'] += trans_prod
                
                # Create results for this crop
                for target_idx in range(trained_optimizer.n_target):
                    district_name = target_idx_to_name[target_idx]
                    
                    area_val = transformed_area[target_idx]
                    prod_val = transformed_production[target_idx]
                    
                    all_results.append({
                        'district_name': district_name,
                        'year': year,
                        'crop': crop,
                        'icrisat_crop': icrisat_crop,
                        'area_ha': area_val,
                        'production_tons': prod_val,
                        'yield_kg_per_ha': prod_val / area_val if area_val > 0 else 0
                    })
                
            except Exception as e:
                print(f"    ⚠️  Error processing {crop}: {e}")
        
        # Print year-level conservation summary
        area_conservation = abs(year_totals['transformed_area'] - year_totals['original_area']) / year_totals['original_area'] * 100 if year_totals['original_area'] > 0 else 0
        prod_conservation = abs(year_totals['transformed_prod'] - year_totals['original_prod']) / year_totals['original_prod'] * 100 if year_totals['original_prod'] > 0 else 0
        
        print(f"  Year {year} totals: Area conservation error: {area_conservation:.2f}%, Prod conservation error: {prod_conservation:.2f}%")
    
    # Convert to DataFrame
    result_df = pd.DataFrame(all_results)
    
    print(f"\n✅ ALL CROPS TRANSFORMATION COMPLETE!")
    print(f"Output shape: {result_df.shape}")
    print(f"Years: {sorted(result_df['year'].unique())}")
    print(f"Crops transformed: {len(result_df['crop'].unique())}")
    print(f"Districts: {len(result_df['district_name'].unique())}")
    
    # Summary by crop
    print(f"\nCrop summary:")
    crop_summary = result_df.groupby('crop').agg({
        'area_ha': 'sum',
        'production_tons': 'sum'
    }).round(0)
    
    for crop in crop_summary.index[:10]:  # Show top 10 crops by area
        area = crop_summary.loc[crop, 'area_ha']
        prod = crop_summary.loc[crop, 'production_tons']
        print(f"  {crop}: {area:,.0f} ha, {prod:,.0f} tons")
    
    if len(crop_summary) > 10:
        print(f"  ... and {len(crop_summary)-10} more crops")
    
    return result_df

# Example usage
if __name__ == "__main__":
    print("Complete FastSpatialOptimizer class for spatial weight optimization")
    print()
    print("Main functions:")
    print("1. run_fast_optimization() - Train weights on 3 crops")
    print("2. transform_all_crops_all_years() - Apply to all crops")
    print()
    print("Usage:")
    print("# Train")
    print("results = run_fast_optimization(merged_hybrid, merged_icri, hybrid_bounds, icri_bounds, crop_key)")
    print("# Transform all")
    print("all_data = transform_all_crops_all_years(merged_hybrid, results['optimizer'])")

Complete FastSpatialOptimizer class for spatial weight optimization

Main functions:
1. run_fast_optimization() - Train weights on 3 crops
2. transform_all_crops_all_years() - Apply to all crops

Usage:
# Train
results = run_fast_optimization(merged_hybrid, merged_icri, hybrid_bounds, icri_bounds, crop_key)
# Transform all
all_data = transform_all_crops_all_years(merged_hybrid, results['optimizer'])


In [14]:
crop_key = pd.read_csv('../../data/cropkey.csv')  # Load your cropkey.csv with crop mappings

In [15]:
# Run the complete workflow
results = run_fast_optimization(
    merged_hybrid, merged_icri, hybrid, icri, crop_key
)

# Get your trained weight matrices
area_weights = results['area_weights']        # For transforming area data
production_weights = results['production_weights']  # For transforming production data

=== FAST SPATIAL WEIGHT OPTIMIZATION ===
Target crops: ['Wheat', 'Soyabean', 'Barley']
Using Iterative Proportional Fitting (IPF) - much faster!
Loaded 48 crop mappings
Computing spatial intersections...
Source districts: 676
Target districts: 310
Found 2614 spatial intersections

=== PREPARING TRAINING DATA (2016) ===
  Wheat -> WHEAT
    Source area: 31763623, target area: 31486790
    Source prod: 113458953, target prod: 112962820
  Soyabean -> SOYABEAN
    Source area: 11173537, target area: 11134020
    Source prod: 13182453, target prod: 13133130
  Barley -> BARLEY
    Source area: 654021, target area: 645530
    Source prod: 1931442, target prod: 1923680

=== TRAINING AREA WEIGHTS ===
Optimizing area weights using IPF...
  Crops: ['Wheat', 'Soyabean', 'Barley']
  Valid crops: ['Wheat', 'Soyabean', 'Barley']
  Starting IPF with 2614 non-zero weights...
    Iteration 10, max weight change: 0.000702
    Iteration 20, max weight change: 0.000034
    Iteration 30, max weight change: 

In [16]:
# Now apply to ALL crops:
all_crops_data = transform_all_crops_all_years(
    merged_hybrid,           # Your full dataset with all crops
    results['optimizer'] # Your trained optimizer
)

=== TRANSFORMING ALL CROPS TO ICRISAT FORMAT ===
Using weights trained on: Wheat, Soyabean, Barley
Applying to 47 crops with ICRISAT mappings
Years: [2016, 2017, 2018, 2019, 2020, 2021, 2022]

⚠️  1 crops without ICRISAT mappings (will be skipped):
    Turmeric

✅ Crops to be transformed:
  Arecanut -> nan
  Arhar/Tur -> PIGEONPEA
  Bajra -> PEARL MILLET
  Banana -> nan
  Barley -> BARLEY
  Black pepper -> nan
  Cardamom -> nan
  Cashewnut -> nan
  Castor seed -> CASTOR
  Coconut -> nan
  Coriander -> nan
  Cotton(lint) -> COTTON
  Cowpea(Lobia) -> nan
  Dry chillies -> nan
  Garlic -> nan
  Ginger -> nan
  Groundnut -> GROUNDNUT
  Guar seed -> nan
  Horse-gram -> nan
  Jowar -> SORGHUM
  Jute -> nan
  Khesari -> nan
  Linseed -> nan
  Maize -> MAIZE
  Masoor -> nan
  Mesta -> nan
  Moong(Green Gram) -> nan
  Moth -> nan
  Niger seed -> nan
  Onion -> ONION
  Other Cereals -> nan
  Peas & beans (Pulses) -> nan
  Potato -> POTATOES
  Ragi -> FINGER MILLET
  Rapeseed &Mustard -> RAPESEED

In [17]:
import pandas as pd
import numpy as np

def clean_and_enhance_output(transformed_df, crop_key_df):
    """
    Post-processing: Keep only crops with ICRISAT mappings and add ICRISAT crop names
    
    Parameters:
    -----------
    transformed_df : pd.DataFrame
        Output from transform_all_crops_all_years() 
    crop_key_df : pd.DataFrame
        Your crop mapping (cropkey.csv)
        
    Returns:
    --------
    pd.DataFrame: Cleaned data with only mapped crops and ICRISAT names
    """
    
    print("=== CLEANING AND ENHANCING OUTPUT ===")
    
    # Create mapping dictionary
    crop_mapping = dict(zip(
        crop_key_df['crop'].str.strip().str.upper(), 
        crop_key_df['icrisat_crop'].str.strip().str.upper()
    ))
    
    print(f"Original data: {len(transformed_df):,} rows, {len(transformed_df['crop'].unique())} crops")
    
    # Filter to only crops that have mappings
    crops_with_mappings = set(crop_mapping.keys())
    available_crops = set(transformed_df['crop'].str.upper())
    
    # Find intersection
    valid_crops = crops_with_mappings.intersection(available_crops)
    crops_to_keep = [crop for crop in transformed_df['crop'].unique() 
                     if crop.upper() in valid_crops]
    
    print(f"Crops with ICRISAT mappings: {len(crops_to_keep)}")
    
    # Filter dataframe
    filtered_df = transformed_df[transformed_df['crop'].isin(crops_to_keep)].copy()
    
    # Add ICRISAT crop names as a separate column
    filtered_df['icrisat_crop_name'] = filtered_df['crop'].str.upper().map(crop_mapping)
    
    # Reorder columns to put ICRISAT crop name next to original crop name
    cols = ['district_name', 'year', 'crop', 'icrisat_crop_name', 'area_ha', 'production_tons', 'yield_kg_per_ha']
    if 'icrisat_crop' in filtered_df.columns:
        # Remove the old icrisat_crop column if it exists
        filtered_df = filtered_df.drop('icrisat_crop', axis=1)
    
    filtered_df = filtered_df[cols]
    
    print(f"Final data: {len(filtered_df):,} rows, {len(filtered_df['crop'].unique())} crops")
    
    # Show crop mapping summary
    print(f"\nCrop mappings applied:")
    crop_summary = filtered_df.groupby(['crop', 'icrisat_crop_name']).size().reset_index(name='observations')
    for _, row in crop_summary.iterrows():
        print(f"  {row['crop']} -> {row['icrisat_crop_name']} ({row['observations']:,} obs)")
    
    # Check for any missing mappings
    missing_mappings = filtered_df[filtered_df['icrisat_crop_name'].isna()]
    if len(missing_mappings) > 0:
        print(f"\n⚠️  {len(missing_mappings)} rows with missing ICRISAT mappings")
        missing_crops = missing_mappings['crop'].unique()
        print(f"Missing crops: {list(missing_crops)}")
    
    return filtered_df

def create_clean_icrisat_format(cleaned_long_df):
    """
    Create ICRISAT wide format from cleaned long format data
    
    Parameters:
    -----------
    cleaned_long_df : pd.DataFrame
        Output from clean_and_enhance_output()
    """
    
    print("=== CREATING CLEAN ICRISAT FORMAT ===")
    
    # Get unique combinations
    years = sorted(cleaned_long_df['year'].unique())
    districts = sorted(cleaned_long_df['district_name'].unique())
    icrisat_crops = sorted(cleaned_long_df['icrisat_crop_name'].unique())
    
    print(f"Years: {len(years)}")
    print(f"Districts: {len(districts)}")  
    print(f"ICRISAT crops: {len(icrisat_crops)}")
    
    all_rows = []
    
    for year in years:
        for district in districts:
            # Start row with district info
            row = {
                'name': district,
                'Year': year
            }
            
            # Get data for this district-year
            district_year_data = cleaned_long_df[
                (cleaned_long_df['district_name'] == district) & 
                (cleaned_long_df['year'] == year)
            ]
            
            # Add each crop's data as columns using ICRISAT names
            for _, crop_row in district_year_data.iterrows():
                icrisat_crop = crop_row['icrisat_crop_name']
                
                # Create column names in ICRISAT format
                area_col = f'{icrisat_crop} AREA (1000 ha)'
                prod_col = f'{icrisat_crop} PRODUCTION (1000 tons)'
                yield_col = f'{icrisat_crop} YIELD (Kg per ha)'
                
                # Convert back to ICRISAT units (divide by 1000)
                row[area_col] = crop_row['area_ha'] / 1000
                row[prod_col] = crop_row['production_tons'] / 1000
                row[yield_col] = crop_row['yield_kg_per_ha']
            
            all_rows.append(row)
    
    result_df = pd.DataFrame(all_rows)
    
    print(f"✅ CLEAN ICRISAT FORMAT CREATED!")
    print(f"Shape: {result_df.shape}")
    print(f"Columns: {len(result_df.columns)} (name, Year + {len(result_df.columns)-2} crop columns)")
    
    # Show sample of column names
    crop_cols = [col for col in result_df.columns if col not in ['name', 'Year']]
    print(f"Sample crop columns: {crop_cols[:6]}")
    if len(crop_cols) > 6:
        print(f"... and {len(crop_cols)-6} more")
    
    return result_df

def verify_district_alignment(cleaned_data, icrisat_boundaries):
    """
    Verify that district names in your transformed data match ICRISAT boundaries
    
    Parameters:
    -----------
    cleaned_data : pd.DataFrame
        Output from clean_and_enhance_output()
    icrisat_boundaries : geopandas.GeoDataFrame
        Your ICRISAT boundaries with 'name' column
    """
    
    print("=== VERIFYING DISTRICT ALIGNMENT ===")
    
    # Get district names from both
    final_districts = set(cleaned_data['district_name'].unique())
    icrisat_districts = set(icrisat_boundaries['name'].unique())
    
    print(f"Districts in transformed data: {len(final_districts)}")
    print(f"Districts in ICRISAT boundaries: {len(icrisat_districts)}")
    
    # Check alignment
    perfect_match = final_districts == icrisat_districts
    print(f"Perfect match: {perfect_match}")
    
    if not perfect_match:
        missing_in_boundaries = final_districts - icrisat_districts
        missing_in_data = icrisat_districts - final_districts
        
        if missing_in_boundaries:
            print(f"\nDistricts in data but not in boundaries: {len(missing_in_boundaries)}")
            for district in list(missing_in_boundaries)[:5]:
                print(f"  {district}")
            if len(missing_in_boundaries) > 5:
                print(f"  ... and {len(missing_in_boundaries)-5} more")
        
        if missing_in_data:
            print(f"\nDistricts in boundaries but not in data: {len(missing_in_data)}")
            for district in list(missing_in_data)[:5]:
                print(f"  {district}")
            if len(missing_in_data) > 5:
                print(f"  ... and {len(missing_in_data)-5} more")
    
    overlap = len(final_districts.intersection(icrisat_districts))
    overlap_pct = overlap / len(icrisat_districts) * 100
    print(f"\nOverlap: {overlap}/{len(icrisat_districts)} districts ({overlap_pct:.1f}%)")
    
    if overlap_pct > 95:
        print("✅ Excellent alignment - ready for mapping!")
    elif overlap_pct > 90:
        print("⚠️  Good alignment - minor mismatches")
    else:
        print("❌ Poor alignment - check district name matching")
    
    return {
        'perfect_match': perfect_match,
        'overlap_count': overlap,
        'overlap_percentage': overlap_pct,
        'missing_in_boundaries': missing_in_boundaries if not perfect_match else set(),
        'missing_in_data': missing_in_data if not perfect_match else set()
    }

# Example usage
if __name__ == "__main__":
    print("Post-processing functions for crop data transformation")
    print()
    print("Usage:")
    print("1. clean_and_enhance_output(transformed_df, crop_key_df)")
    print("2. create_clean_icrisat_format(cleaned_long_df)")
    print("3. verify_district_alignment(cleaned_data, icrisat_boundaries)")
    print()
    print("Example:")
    print("# Clean the data")
    print("cleaned = clean_and_enhance_output(all_crops_transformed, crop_key)")
    print("# Create ICRISAT format") 
    print("icrisat_wide = create_clean_icrisat_format(cleaned)")
    print("# Verify alignment")
    print("alignment = verify_district_alignment(cleaned, icrisat_boundaries)")

Post-processing functions for crop data transformation

Usage:
1. clean_and_enhance_output(transformed_df, crop_key_df)
2. create_clean_icrisat_format(cleaned_long_df)
3. verify_district_alignment(cleaned_data, icrisat_boundaries)

Example:
# Clean the data
cleaned = clean_and_enhance_output(all_crops_transformed, crop_key)
# Create ICRISAT format
icrisat_wide = create_clean_icrisat_format(cleaned)
# Verify alignment
alignment = verify_district_alignment(cleaned, icrisat_boundaries)


In [None]:
# After you've run transform_all_crops_all_years():
# Clean up the results (filter to mapped crops + add ICRISAT names)
cleaned_data = clean_and_enhance_output(
    all_crops_data,  # Your raw output 
    crop_key            # Your crop mapping
)

# Save the cleaned long format
cleaned_data.to_csv('all_crops_icrisat_clean_long.csv', index=False)

# Create clean ICRISAT wide format
clean_icrisat_wide = create_clean_icrisat_format(cleaned_data)
clean_icrisat_wide.to_csv('all_crops_icrisat_clean_wide.csv', index=False)

NameError: name 'all_crops_transformed' is not defined