# 50States10K - State Classification Training

This notebook runs the training pipeline for the state classification model using PyTorch.

In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

if IN_COLAB:
    # Install dependencies
    !pip install wandb
    !pip install tqdm
    
    # Download and run the setup script
    !wget -O colab_setup.py https://raw.githubusercontent.com/yourusername/state-classifier/main/colab_setup.py
    from colab_setup import setup_environment
    
    # Set up environment
    repo_path = setup_environment(
        github_repo="yourusername/state-classifier",
        branch="main",
        data_drive_path="/content/drive/MyDrive/50States10K"
    )

In [None]:
# Import modules from the package
from state_classifier.config.config_utils import load_config
from state_classifier.experiment import Experiment

# Check available GPUs
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Login to Weights & Biases
import wandb
wandb.login()

In [None]:
# Load config and update paths
import os

# Define the path to the config file
if IN_COLAB:
    config_path = os.path.join(repo_path, "config.yaml")
else:
    config_path = "config.yaml"

# Load and update config
config = load_config(config_path)

# Update paths if necessary
if IN_COLAB and "STATE_CLASSIFIER_DATA" in os.environ:
    data_path = os.environ["STATE_CLASSIFIER_DATA"]
    config.dataset_root = os.path.join(data_path, "train")
    config.test_dataset_root = os.path.join(data_path, "test")
    
print(f"Dataset path: {config.dataset_root}")
print(f"Test dataset path: {config.test_dataset_root}")

In [None]:
# Create and run the experiment
experiment = Experiment(config)

# Initialize wandb
wandb.init(
    project=config.wandb.project,
    name=f"resnet101_{wandb.util.generate_id()}"
)

# Train the model
experiment.train()

# Test the model
experiment.test()

# Close wandb
wandb.finish()