In [None]:
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
import xgboost as xgb

# --- 1. Cargar imagen SAR (VV y VH) ---
ruta_imagen = "sentinel1_vv_vh.tif"
with rasterio.open(ruta_imagen) as src:
    sar_img = src.read()  # (bandas, filas, columnas)
    profile = src.profile
sar_img = np.transpose(sar_img, (1, 2, 0))  # (filas, columnas, bandas)

# --- 2. Cargar muestras de entrenamiento ---
ruta_muestras = "muestras.tif"
with rasterio.open(ruta_muestras) as src:
    muestras = src.read(1)  # (filas, columnas)

# --- 3. Preparar datos ---
mascara = muestras > 0
X = sar_img[mascara]
y = muestras[mascara]

# --- 4. Dividir en entrenamiento y prueba ---
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)

# --- 5. Normalizar características ---
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# --- 6. Entrenar modelo XGBoost ---
model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(np.unique(y)), eval_metric='mlogloss', use_label_encoder=False)
model.fit(X_train_scaled, y_train)

# --- 7. Evaluar modelo ---
y_pred = model.predict(X_test_scaled)
print("Matriz de Confusión:")
print(confusion_matrix(y_test, y_pred))
print("\nReporte de Clasificación:")
print(classification_report(y_test, y_pred))

# --- 8. Clasificar toda la imagen ---
filas, columnas, _ = sar_img.shape
X_total = sar_img.reshape(-1, sar_img.shape[2])
X_total_scaled = scaler.transform(X_total)
predicciones = model.predict(X_total_scaled)
clasificacion = predicciones.reshape(filas, columnas)

# --- 9. Guardar y visualizar resultado ---
salida = "clasificacion_xgboost.tif"
profile.update(dtype='uint8', count=1)
with rasterio.open(salida, 'w', **profile) as dst:
    dst.write(clasificacion.astype(rasterio.uint8), 1)

plt.figure(figsize=(8, 6))
plt.imshow(clasificacion, cmap='tab20')
plt.title("Clasificación XGBoost")
plt.axis('off')
plt.colorbar()
plt.show()
