In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image 
import timm 
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np
import os
from torchvision import models 

# --- Helper Function to Initialize EfficientNet-B0 (Used for Frozen Model) ---
# (Using the robust version provided)
def get_model(model_name, num_classes, pretrained=False, checkpoint_path=None, device="cuda"):
    """
    Initializes a torchvision EfficientNet, modifies classifier, and loads state_dict.
    Handles 'module.' prefix and potential dictionary wrapping in checkpoint.

    Args:
        model_name (str): Name of the EfficientNet model (e.g., 'efficientnet_b0').
        num_classes (int): Number of output classes for the final classifier.
        pretrained (bool): If True AND checkpoint_path is None, load torchvision pretrained weights.
                           If checkpoint_path is provided, this is ignored for loading state_dict.
        checkpoint_path (str, optional): Path to a .pth file to load state_dict.
        device (str or torch.device): Device to load the model and weights onto.

    Returns:
        torch.nn.Module: The initialized (and potentially loaded) model.
    """
    print(f"--- Initializing base model: {model_name} for {num_classes} classes ---")
    if model_name == "efficientnet_b0":
        # Use weights=None if loading a custom checkpoint,
        # or specify weights if using torchvision pretraining without a checkpoint
        weights = 'IMAGENET1K_V1' if pretrained and checkpoint_path is None else None
        model = models.efficientnet_b0(weights=weights)
        try:
            num_ftrs = model.classifier[1].in_features
            model.classifier[1] = nn.Linear(num_ftrs, num_classes)
            print(f"   Replaced classifier head for {num_classes} classes.")
        except Exception as e:
             raise ValueError(f"Could not modify torchvision efficientnet_b0 classifier[1]: {e}")
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    model = model.to(device) # Move model skeleton to device first

    if checkpoint_path:
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

        print(f"--- Loading specific checkpoint: {checkpoint_path}")
        # Load state_dict to the specified device directly
        state_dict_loaded = torch.load(checkpoint_path, map_location=device)

        # Handle dictionary wrapping (common practice)
        if isinstance(state_dict_loaded, dict) and 'state_dict' in state_dict_loaded:
            state_dict = state_dict_loaded['state_dict']
            print("   (Loaded from 'state_dict' key)")
        elif isinstance(state_dict_loaded, dict) and 'model_state_dict' in state_dict_loaded:
             state_dict = state_dict_loaded['model_state_dict']
             print("   (Loaded from 'model_state_dict' key)")
        elif isinstance(state_dict_loaded, dict) and not any(k in state_dict_loaded for k in ['state_dict', 'model_state_dict']):
             # If it's a dict but doesn't have the common keys, assume it *is* the state_dict
             state_dict = state_dict_loaded
             print("   (Loaded dictionary directly as state_dict)")
        elif isinstance(state_dict_loaded, nn.Module): # Less common, but possible
            state_dict = state_dict_loaded.state_dict()
            print("   (Loaded state_dict from a saved nn.Module object)")
        else:
             # Assume it's the state_dict directly if not a dict or known wrapper
             state_dict = state_dict_loaded
             print("   (Loaded unknown format directly, assuming state_dict)")


        # Handle 'module.' prefix (from DataParallel saving)
        new_state_dict = {}
        prefix_found = False
        for k, v in state_dict.items():
            if k.startswith('module.'):
                name = k[7:] # remove `module.`
                prefix_found = True
            else:
                name = k
            new_state_dict[name] = v
        if prefix_found: print("   (Handled 'module.' prefix in keys)")

        try:
            # Use strict=True when loading specific model weights
            missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=True)
            # It's expected that a checkpoint might only contain weights, not buffers etc.
            # However, for a full model checkpoint, both lists should ideally be empty.
            if missing_keys: print(f"   WARNING: Missing keys during load_state_dict: {missing_keys}")
            if unexpected_keys: print(f"   WARNING: Unexpected keys during load_state_dict: {unexpected_keys}")
            print(f"✅ Loaded checkpoint weights into model structure.")
        except Exception as e:
            print(f"❌ ERROR loading state_dict into model structure: {e}")
            # Print keys for debugging if loading fails
            print("   Model keys (first 5):", list(model.state_dict().keys())[:5], "...")
            print("   Checkpoint keys (first 5):", list(new_state_dict.keys())[:5], "...")
            raise RuntimeError(f"Failed to load checkpoint {checkpoint_path} into {model_name}") from e

    elif pretrained:
         # This case is hit if pretrained=True BUT no checkpoint_path is given.
         # The model was already initialized with weights='IMAGENET1K_V1' above.
         print(f"✅ Using torchvision ImageNet pretrained weights for {model_name}.")
    else:
        print(f"   Model {model_name} initialized with random weights (pretrained=False, no checkpoint).")


    return model

# --- Combined Geo-Prediction Model ---
class CombinedGeoModel(nn.Module):
    def __init__(self,
                 frozen_region_model_path,
                 num_region_classes=15,
                 head_hidden_dims=[512, 128],
                 use_softmax_for_region=True, 
                 trainable_effnet_name='efficientnet_b0',
                 device='cuda'):
        """
        Initializes the combined model.

        Args:
            frozen_region_model_path (str): Path to the .pth file for the frozen Region_ID model.
            num_region_classes (int): Number of classes for the Region_ID model.
            head_hidden_dims (list): List of hidden layer dimensions for the regression head.
            use_softmax_for_region (bool): Whether to use softmax output (True) or logits (False)
                                           from the frozen region model.
            trainable_effnet_name (str): Name of the EfficientNet model to use for trainable features.
            device (str or torch.device): Device to run the model on.
        """
        super().__init__()
        self.num_region_classes = num_region_classes
        self.use_softmax_for_region = use_softmax_for_region
        self.device = device

        # 1. Load and Freeze the Region_ID Classifier Model
        print("--- Loading Frozen Region Model ---")
        self.frozen_region_model = get_model( # Uses the corrected get_model function
            model_name='efficientnet_b0', # Explicitly use EffNet-B0 for region model
            num_classes=num_region_classes,
            pretrained=False, # Must be False when loading specific checkpoint
            checkpoint_path=frozen_region_model_path, # Path to the specific frozen model weights
            device=self.device # Load directly to the target device
            )
        # Freeze all parameters in the region model
        for param in self.frozen_region_model.parameters():
            param.requires_grad = False
        self.frozen_region_model.eval() # Set to evaluation mode permanently
        print("Frozen Region Model loaded and parameters frozen.")

        # 2. Load the Trainable Image Embedding Model (using timm)
        print("\n--- Loading Trainable Image Embedder ---")
        self.trainable_image_embedder = timm.create_model(
            trainable_effnet_name,
            pretrained=True # Use ImageNet pretraining for the embedder
        )
        # Get the embedding dimension (before the classifier)
        self.embedding_dim = self.trainable_image_embedder.get_classifier().in_features
        # Replace the classifier with an Identity layer to get the features
        self.trainable_image_embedder.reset_classifier(0, '') # Efficient way in timm to remove classifier
        # self.trainable_image_embedder.classifier = nn.Identity() # Alternative way
        print(f"Trainable Image Embedder ({trainable_effnet_name}) loaded with output dim: {self.embedding_dim}")
        self.trainable_image_embedder = self.trainable_image_embedder.to(self.device) # Move embedder to device



        # ***** START FIX *****
        # Define pooling and flattening layers HERE in __init__
        self.pool = nn.AdaptiveAvgPool2d(output_size=1).to(self.device)
        self.flat = nn.Flatten(start_dim=1).to(self.device)
        # ***** END FIX *****

        # 3. Define the Regression Head
        print("\n--- Defining Regression Head ---")
        # Input dimension is embedding dim + region features (either logits or probabilities)
        input_head_dim = self.embedding_dim + self.num_region_classes
        print(f"Regression head input dimension: {input_head_dim} ({self.embedding_dim} from embedder + {self.num_region_classes} from region model)")
        layers = []
        current_dim = input_head_dim
        for hidden_dim in head_hidden_dims:
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim)) # Batch norm often helps in regression heads
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.3)) # Dropout for regularization
            current_dim = hidden_dim

        # Final output layer for 2 values (lat, lon)
        layers.append(nn.Linear(current_dim, 2))
        self.regression_head = nn.Sequential(*layers).to(self.device) # Move head to device
        print(f"Regression head structure: {self.regression_head}")
        print("\nModel Initialization Complete.")


    def forward(self, x):
        """
        Forward pass through the combined model.

        Args:
            x (torch.Tensor): Input image tensor (B, C, H, W) already on the correct device.

        Returns:
            torch.Tensor: Predicted scaled latitude and longitude (B, 2).
        """
        # 1. Get Region_ID features (logits or softmax) from the frozen model
        # Ensure no gradients are computed for this part
        with torch.no_grad():
            region_logits = self.frozen_region_model(x)
            if self.use_softmax_for_region:
                region_features = F.softmax(region_logits, dim=1)
            else:
                region_features = region_logits # Use logits directly

        # 2. Get image embeddings from the trainable model
        # Gradients will flow back through this part
        image_features = self.trainable_image_embedder(x) # timm model output is features directly now

        # ***** START FIX *****
        # Check dimensions and apply pooling/flattening if necessary
        if image_features.ndim == 4:
            # print("Applying pool and flatten to 4D embedder output.") # Optional debug print
            image_features = self.pool(image_features) # Shape: (B, C, 1, 1)
            image_features = self.flat(image_features) # Shape: (B, C), 2D
        elif image_features.ndim != 2:
             # If it's not 4D or 2D, something unexpected happened
             raise RuntimeError(f"Unexpected dimension for image_features from embedder: {image_features.ndim}. Expected 2D or 4D. Shape: {image_features.shape}")
        # Now image_features should reliably be 2D: (B, embedding_dim)
        # ***** END FIX *****


        # 3. Concatenate features
        # Ensure dimensions match: (B, embedding_dim), (B, num_region_classes)
        combined_features = torch.cat((image_features, region_features), dim=1)

        # 4. Pass through the regression head
        output = self.regression_head(combined_features)

        return output

# --- Custom Dataset with Exclusion Logic ---
class GeoImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, scaler_lat=None, scaler_lon=None, exclude_filenames=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            scaler_lat (StandardScaler): Fitted scaler for latitude. Can be None if data is pre-scaled.
            scaler_lon (StandardScaler): Fitted scaler for longitude. Can be None if data is pre-scaled.
            exclude_filenames (list, optional): A list of filenames (e.g., 'img_0095.jpg') to exclude.
        """
        print(f"--- Loading dataset from: {csv_file}")
        try:
            self.data_frame = pd.read_csv(csv_file)
        except FileNotFoundError:
            print(f"ERROR: CSV file not found at {csv_file}")
            raise
        self.img_dir = img_dir
        self.transform = transform
        self.scaler_lat = scaler_lat
        self.scaler_lon = scaler_lon

        # Ensure lat/lon columns exist
        if 'latitude' not in self.data_frame.columns or 'longitude' not in self.data_frame.columns:
             raise ValueError(f"CSV '{csv_file}' must contain 'latitude' and 'longitude' columns.")
        if 'filename' not in self.data_frame.columns:
             raise ValueError(f"CSV '{csv_file}' must contain a 'filename' column.")

        # --- Exclusion Logic ---
        if exclude_filenames:
            initial_len = len(self.data_frame)
            self.data_frame = self.data_frame[~self.data_frame['filename'].isin(exclude_filenames)].reset_index(drop=True)
            print(f"   Excluded {initial_len - len(self.data_frame)} rows based on exclude_filenames list.")
            if len(self.data_frame) == 0:
                 print(f"WARNING: All rows were excluded from {csv_file}. Check your exclude_filenames list.")

        # Pre-scale labels if scalers are provided
        if self.scaler_lat and self.scaler_lon:
             print("   Scaling latitude and longitude using provided scalers.")
             try:
                self.scaled_lat = self.scaler_lat.transform(self.data_frame[['latitude']]).flatten()
                self.scaled_lon = self.scaler_lon.transform(self.data_frame[['longitude']]).flatten()
             except Exception as e:
                 print(f"Error during scaling: {e}")
                 print("Ensure scalers were fitted correctly on the training data.")
                 raise
        else:
             # If scalers not provided, assume columns 'latitude', 'longitude' are ALREADY scaled
             print("   WARNING: Scalers not provided. Assuming 'latitude' and 'longitude' columns in CSV are already scaled.")
             self.scaled_lat = self.data_frame['latitude'].values
             self.scaled_lon = self.data_frame['longitude'].values


    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_filename = self.data_frame.loc[idx, 'filename']
        img_name = os.path.join(self.img_dir, img_filename)
        try:
            # Use PIL Image open for broader compatibility, then convert
            image = Image.open(img_name).convert('RGB')
        except FileNotFoundError:
             print(f"Error: Image not found at {img_name} (referenced in CSV row {idx})")
             # You might want to return None and handle in collate_fn, or raise error
             raise FileNotFoundError(f"Image not found: {img_name}")
        except Exception as e:
            print(f"Error opening or processing image {img_name}: {e}")
            raise # Re-raise other image loading errors

        # Apply transforms if specified
        if self.transform:
            image = self.transform(image)

        # Get scaled labels
        # Ensure index is valid after potential filtering
        if idx >= len(self.scaled_lat):
             raise IndexError(f"Index {idx} out of bounds after filtering/scaling. Dataset length: {len(self)}. Scaled lat length: {len(self.scaled_lat)}")

        scaled_lat = self.scaled_lat[idx]
        scaled_lon = self.scaled_lon[idx]
        labels = torch.tensor([scaled_lat, scaled_lon], dtype=torch.float32)

        # Return image filename along with data for potential debugging
        # return image, labels, img_filename
        return image, labels


# --- Weighted Loss Function ---
class WeightedLatLonLoss(nn.Module):
    def __init__(self, lat_weight=0.4, lon_weight=0.6):
        super().__init__()
        if abs((lat_weight + lon_weight) - 1.0) > 1e-6:
             print(f"WARNING: Weights ({lat_weight}, {lon_weight}) do not sum close to 1. Normalizing.")
             total = lat_weight + lon_weight
             lat_weight = lat_weight / total
             lon_weight = lon_weight / total

        self.lat_weight = lat_weight
        self.lon_weight = lon_weight
        self.mse = nn.MSELoss() # Use MSE as the base loss for each component
        # self.huber = nn.HuberLoss() # Alternatively, use Huber loss
        print(f"Initialized WeightedLatLonLoss with lat_weight={self.lat_weight:.2f}, lon_weight={self.lon_weight:.2f}")

    def forward(self, preds, targets):
        # Ensure predictions and targets have shape (Batch, 2)
        if preds.shape != targets.shape or preds.ndim != 2 or preds.shape[1] != 2:
             raise ValueError(f"Shape mismatch or incorrect dimensions for WeightedLatLonLoss. "
                              f"Preds shape: {preds.shape}, Targets shape: {targets.shape}. Expected (B, 2).")

        lat_loss = self.mse(preds[:, 0], targets[:, 0])
        lon_loss = self.mse(preds[:, 1], targets[:, 1])

        # Use Huber loss if defined:
        # lat_loss = self.huber(preds[:, 0], targets[:, 0])
        # lon_loss = self.huber(preds[:, 1], targets[:, 1])

        total_loss = self.lat_weight * lat_loss + self.lon_weight * lon_loss
        return total_loss


# --- Setup & Example Usage ---

# 1. Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"=========================================")
print(f"Using device: {DEVICE}")
print(f"=========================================\n")

# --- USER-DEFINED PATHS --- >>> PLEASE CHANGE THESE <<<
TRAIN_IMG_DIR = "/kaggle/input/latlong-dataset/images_train_combine"
TRAIN_CSV_PATH = "/kaggle/input/latlong-dataset/train_combine.csv"
VAL_IMG_DIR = "/kaggle/input/val-dataset/images_val"
VAL_CSV_PATH = "/kaggle/input/val-dataset/labels_val.csv"
FROZEN_MODEL_PATH = "/kaggle/input/saved-models/kaggle/working/saved_models/best_efficientnet_b0.pt" # Checkpoint for Region Classification model
SCALER_SAVE_DIR = "/kaggle/working/" # Directory to save fitted scalers
MODEL_SAVE_DIR = "/kaggle/working/combined_model_checkpoints"
# --- END USER-DEFINED PATHS ---

NUM_REGION_CLASSES = 15 # Should match the output classes of the frozen model
BATCH_SIZE = 32
LEARNING_RATE = 1e-4 # Starting learning rate
NUM_EPOCHS = 40 # Set to a lower number for quick testing, increase for real training
IMG_SIZE = 256 # Input image size for the models
HEAD_HIDDEN_DIMS = [512, 256, 128] # Hidden dimensions for the regression head
LAT_LOSS_WEIGHT = 0.4 # Weight for latitude component in the loss
LON_LOSS_WEIGHT = 0.6 # Weight for longitude component in the loss
TRAINABLE_EMBEDDER_NAME = 'efficientnet_b0' # Base model for trainable embedder
USE_SOFTMAX_REGION = True # Use softmax output from frozen model (True) or logits (False)

# Create model save directory if it doesn't exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(SCALER_SAVE_DIR, exist_ok=True)

# --- Filenames to exclude from validation set ---
# Based on IDs: [95, 145, 146, 158, 159, 160, 161]
# Assuming filenames follow the pattern 'img_XXXX.jpg' (4 digits with padding)
ids_to_exclude = [95, 145, 146, 158, 159, 160, 161]
validation_exclude_filenames = [f"img_{id:04d}.jpg" for id in ids_to_exclude]
print(f"\n--- Files to exclude from Validation Set: {validation_exclude_filenames} ---")


# 2. Data Preprocessing & Scaling
# IMPORTANT: Fit scalers ONLY on the TRAINING data!
scaler_lat = StandardScaler()
scaler_lon = StandardScaler()
scaler_lat_path = os.path.join(SCALER_SAVE_DIR, 'scaler_lat.joblib')
scaler_lon_path = os.path.join(SCALER_SAVE_DIR, 'scaler_lon.joblib')

try:
    import joblib
    # Try to load existing scalers
    if os.path.exists(scaler_lat_path) and os.path.exists(scaler_lon_path):
        scaler_lat = joblib.load(scaler_lat_path)
        scaler_lon = joblib.load(scaler_lon_path)
        print(f"--- Loaded existing scalers from {SCALER_SAVE_DIR} ---")
    else:
        print(f"\n--- Fitting scalers on TRAINING data: {TRAIN_CSV_PATH} ---")
        # Load only the training data for fitting
        train_df_for_scaling = pd.read_csv(TRAIN_CSV_PATH)
        if 'latitude' not in train_df_for_scaling.columns or 'longitude' not in train_df_for_scaling.columns:
            raise KeyError("Training CSV must contain 'latitude' and 'longitude' for scaler fitting.")

        # Fit scalers (use .values.reshape(-1, 1) for single feature)
        scaler_lat.fit(train_df_for_scaling[['latitude']].values)
        scaler_lon.fit(train_df_for_scaling[['longitude']].values)
        print("   Scalers fitted successfully on training data.")

        # Save the fitted scalers for inference later
        joblib.dump(scaler_lat, scaler_lat_path)
        joblib.dump(scaler_lon, scaler_lon_path)
        print(f"   Scalers saved to {SCALER_SAVE_DIR}")

except FileNotFoundError:
    print(f"ERROR: Training CSV file not found at {TRAIN_CSV_PATH}. Cannot fit scalers.")
    exit()
except KeyError as e:
    print(f"ERROR: {e} Column missing in {TRAIN_CSV_PATH}. Cannot fit scalers.")
    exit()
except ImportError:
     print("WARNING: joblib not found. Cannot save/load scalers. Will fit scalers every run.")
     # Re-fit scalers if joblib is not available (less efficient)
     print(f"\n--- Fitting scalers on TRAINING data: {TRAIN_CSV_PATH} (joblib not found) ---")
     train_df_for_scaling = pd.read_csv(TRAIN_CSV_PATH)
     if 'latitude' not in train_df_for_scaling.columns or 'longitude' not in train_df_for_scaling.columns:
            raise KeyError("Training CSV must contain 'latitude' and 'longitude' for scaler fitting.")
     scaler_lat.fit(train_df_for_scaling[['latitude']].values)
     scaler_lon.fit(train_df_for_scaling[['longitude']].values)
     print("   Scalers fitted successfully on training data (not saved).")
except Exception as e:
    print(f"An unexpected error occurred during scaler handling: {e}")
    exit()


# 3. Transforms
# Use normalization values standard for ImageNet pre-trained models
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    normalize,
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    normalize,
])

# 4. Datasets and DataLoaders
print("\n--- Creating Datasets and DataLoaders ---")
try:
    train_dataset = GeoImageDataset(csv_file=TRAIN_CSV_PATH,
                                    img_dir=TRAIN_IMG_DIR,
                                    transform=train_transform,
                                    scaler_lat=scaler_lat, # Pass fitted scalers
                                    scaler_lon=scaler_lon)

    val_dataset = GeoImageDataset(csv_file=VAL_CSV_PATH,
                                  img_dir=VAL_IMG_DIR,
                                  transform=val_transform,
                                  scaler_lat=scaler_lat, # Use the SAME scalers fitted on train data
                                  scaler_lon=scaler_lon,
                                  exclude_filenames=validation_exclude_filenames) # Pass the list of files to exclude


    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count()//2, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count()//2, pin_memory=True)

    print(f"\nTrain dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Train DataLoader steps per epoch: {len(train_loader)}")
    print(f"Validation DataLoader steps per epoch: {len(val_loader)}")
    dataloaders_ready = True

except FileNotFoundError as e:
    print(f"\nERROR Creating Datasets: Required file/directory not found.")
    print(f"Details: {e}")
    print(f"Please check paths: TRAIN_IMG_DIR, TRAIN_CSV_PATH, VAL_IMG_DIR, VAL_CSV_PATH")
    dataloaders_ready = False
except ValueError as e: # Catch errors from Dataset init (e.g., missing columns)
     print(f"\nERROR Creating Datasets: {e}")
     dataloaders_ready = False
except Exception as e:
     print(f"\nAn unexpected error occurred creating Dataset/DataLoader: {e}")
     dataloaders_ready = False


# 5. Initialize Model
if dataloaders_ready: # Only proceed if datasets were loaded correctly
    print("\n--- Initializing Combined Geo Model ---")
    try:
        model = CombinedGeoModel(
            frozen_region_model_path=FROZEN_MODEL_PATH,
            num_region_classes=NUM_REGION_CLASSES,
            head_hidden_dims=HEAD_HIDDEN_DIMS,
            use_softmax_for_region=USE_SOFTMAX_REGION,
            trainable_effnet_name=TRAINABLE_EMBEDDER_NAME,
            device=DEVICE # Pass the device to the model
        )
        model_ready = True
    except FileNotFoundError as e:
         print(f"\nERROR Initializing Model: Checkpoint file not found.")
         print(f"Details: {e}")
         print(f"Please check path: FROZEN_MODEL_PATH='{FROZEN_MODEL_PATH}'")
         model_ready = False
    except (ValueError, RuntimeError) as e: # Catch errors from model init (e.g., unsupported model, load_state_dict issues)
         print(f"\nERROR Initializing Model: {e}")
         model_ready = False
    except Exception as e:
         print(f"\nAn unexpected error occurred during model initialization: {e}")
         model_ready = False
else:
    model_ready = False

# 6. Loss Function and Optimizer (only if model is ready)
if model_ready:
    print("\n--- Setting up Loss and Optimizer ---")
    criterion = WeightedLatLonLoss(lat_weight=LAT_LOSS_WEIGHT, lon_weight=LON_LOSS_WEIGHT)
    # Only optimize parameters that require gradients
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=LEARNING_RATE,
        weight_decay=1e-5 # L2 regularization
    )
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',     # Reduce LR when validation loss stops decreasing
                                                           factor=0.1,     # Reduce LR by a factor of 0.1
                                                           patience=5,     # Wait 5 epochs of no improvement before reducing
                                                           verbose=True)
    optimizer_ready = True
else:
    optimizer_ready = False


if dataloaders_ready and model_ready and optimizer_ready:
    print("\n--- Starting Training Loop ---")
    best_val_loss = float('inf')

    for epoch in range(NUM_EPOCHS):
        print("-" * 30)
        print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
        print("-" * 30)

        # --- Training Phase ---
        model.train() # Set model to training mode
        running_train_loss = 0.0
        processed_samples_train = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * inputs.size(0) # Loss per batch * batch size
            processed_samples_train += inputs.size(0)

            if (i + 1) % 100 == 0: # Print every 100 mini-batches
                avg_batch_loss = running_train_loss / processed_samples_train
                print(f'  Train Step [{i+1}/{len(train_loader)}], Avg Loss: {avg_batch_loss:.6f}')

        epoch_train_loss = running_train_loss / len(train_dataset)
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Average Training Loss: {epoch_train_loss:.6f}')

        # --- Validation Phase ---
        model.eval() # Set model to evaluation mode
        running_val_loss = 0.0
        processed_samples_val = 0

        with torch.no_grad():
            for inputs_val, labels_val in val_loader:
                inputs_val, labels_val = inputs_val.to(DEVICE), labels_val.to(DEVICE)
                outputs_val = model(inputs_val)
                loss_val = criterion(outputs_val, labels_val)
                running_val_loss += loss_val.item() * inputs_val.size(0)
                processed_samples_val += inputs_val.size(0)

        # Check if validation set was non-empty after exclusion
        if len(val_dataset) > 0:
            epoch_val_loss = running_val_loss / len(val_dataset)
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Validation Loss: {epoch_val_loss:.6f}')

            # Step the scheduler based on validation loss
            scheduler.step(epoch_val_loss)

            # Save the best model based on validation loss
            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss
                best_model_path = os.path.join(MODEL_SAVE_DIR, f'best_combined_model.pth')
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': best_val_loss,
                    'scaler_lat_path': scaler_lat_path, # Store path to scalers used
                    'scaler_lon_path': scaler_lon_path,
                }, best_model_path)
                print(f"   *** New best model saved to {best_model_path} (Val Loss: {best_val_loss:.6f}) ***")
        else:
            print("   Validation set is empty, skipping validation loss calculation and model saving.")
            # Optionally, save based on training loss if validation is not possible
            # scheduler.step(epoch_train_loss) # Or don't step scheduler

        # Save checkpoint periodically (e.g., every 5 epochs)
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(MODEL_SAVE_DIR, f'combined_model_epoch_{epoch+1}.pth')
            torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': epoch_train_loss,
                    'val_loss': epoch_val_loss if len(val_dataset) > 0 else None,
                    'scaler_lat_path': scaler_lat_path, # Store path to scalers used
                    'scaler_lon_path': scaler_lon_path,
                }, checkpoint_path)
            print(f"   Checkpoint saved to {checkpoint_path}")


    print('\nFinished Training')

else:
    print("\n--- Training Skipped ---")
    if not dataloaders_ready:
        print("Reason: Dataloaders failed to initialize.")
    elif not model_ready:
        print("Reason: Model failed to initialize.")
    elif not optimizer_ready:
        print("Reason: Optimizer setup failed.")

