# Geolocation Model Training - Google Colab

**Before running:**
1. Make sure your `archive.zip` is uploaded to Google Drive
2. Runtime → Change runtime type → **GPU** (T4 recommended)
3. Run cells in order from top to bottom

**Expected time:** 2-3 hours for 20 epochs

## Step 1: Check GPU and Install Dependencies

In [None]:
# Check GPU availability
import torch
print("GPU Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("GPU Memory:", torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
else:
    print("WARNING: No GPU detected! Go to Runtime → Change runtime type → GPU")

## Step 2: Upload Dataset

Upload your dataset to google colab

In [None]:
# Mount Drive for Saving Results
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted (for saving model)")

In [None]:
from google.colab import files
import os
uploaded = files.upload()  # Select your dataset zip file

if uploaded:
    zip_filename = list(uploaded.keys())[0]
    DATA_DIR = f"/content/{zip_filename}"
    
    # Check file size
    file_size = os.path.getsize(DATA_DIR) / (1024**3)
    print(f"\nUpload complete!")
    print(f"   File: {zip_filename}")
    print(f"   Size: {file_size:.2f} GB")
else:
    print(" No file uploaded!")

## Step 3: Extract Dataset

In [None]:
import zipfile
from tqdm import tqdm

EXTRACT_DIR = "/content/streetview_data"

print("Extracting dataset...")
print("This will take 2-5 minutes.\n")

with zipfile.ZipFile(DATA_DIR, 'r') as zip_ref:
    file_list = zip_ref.namelist()
    print(f"Total files: {len(file_list)}")
    
    for file in tqdm(file_list, desc="Extracting"):
        zip_ref.extract(file, EXTRACT_DIR)

print("\nExtraction complete!")

# Remove zip to save space
os.remove(DATA_DIR)
print("Removed zip file to save space")

## Step 4: Define Model and Training Code

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import json
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter

class StreetViewDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, self.labels[idx]
        except Exception as e:
            # Return black image if loading fails
            if self.transform:
                return self.transform(Image.new('RGB', (224, 224))), self.labels[idx]
            return Image.new('RGB', (224, 224)), self.labels[idx]


class GeoLocalizationModel(nn.Module):
    def __init__(self, num_classes, backbone='resnet50', pretrained=True):
        super().__init__()
        
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=pretrained)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            raise ValueError(f"Unknown backbone: {backbone}")
        
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


def get_transforms(training=True):
    if training:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])


def load_dataset(data_dir):
    print(f"Loading dataset from {data_dir}...")
    
    image_paths = []
    country_names = []
    
    data_path = Path(data_dir)
    country_dirs = [d for d in data_path.iterdir() if d.is_dir()]
    
    print(f"Found {len(country_dirs)} countries")
    
    for country_dir in tqdm(country_dirs, desc="Loading images"):
        country_name = country_dir.name
        jpg_files = list(country_dir.glob("*.jpg")) + list(country_dir.glob("*.JPG"))
        
        for img_path in jpg_files:
            image_paths.append(str(img_path))
            country_names.append(country_name)
    
    print(f"Total images: {len(image_paths)}")
    
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(country_names)
    
    # Print stats
    country_counts = Counter(country_names)
    print(f"\nTop 10 countries:")
    for country, count in country_counts.most_common(10):
        print(f"  {country}: {count} images")
    
    return image_paths, encoded_labels, label_encoder


def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({
            'loss': f'{running_loss/(pbar.n+1):.3f}',
            'acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total


def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Validation")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{running_loss/(pbar.n+1):.3f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
    
    return running_loss / len(val_loader), 100. * correct / total

print("Model and training functions defined!")

## Step 5: Load Data and Prepare Training

In [None]:
# Find data directory
data_dir = EXTRACT_DIR
subdirs = [d for d in Path(EXTRACT_DIR).iterdir() if d.is_dir()]
if len(subdirs) == 1:
    data_dir = str(subdirs[0])

print(f"Data directory: {data_dir}")

# Load dataset
image_paths, labels, label_encoder = load_dataset(data_dir)

# Save label mapping to Google Drive
output_dir = "/content/drive/MyDrive/geolocation_model"
os.makedirs(output_dir, exist_ok=True)

label_mapping = {i: label for i, label in enumerate(label_encoder.classes_)}
with open(f"{output_dir}/label_mapping.json", 'w') as f:
    json.dump(label_mapping, f, indent=2)

print(f"\nLabel mapping saved to Google Drive!")

In [None]:
# Filter out classes with too few samples
from collections import Counter

print("Filtering dataset...")

# Count samples per class
class_counts = Counter(labels)

# Find classes with at least 2 samples (needed for stratified split)
min_samples = 2
valid_indices = [i for i, label in enumerate(labels) if class_counts[label] >= min_samples]

# Filter dataset
image_paths_filtered = [image_paths[i] for i in valid_indices]
labels_filtered = [labels[i] for i in valid_indices]

# Recount
remaining_classes = len(set(labels_filtered))
removed_samples = len(image_paths) - len(image_paths_filtered)

print(f"Removed {removed_samples} images from classes with < {min_samples} samples")
print(f"Remaining: {len(image_paths_filtered)} images across {remaining_classes} countries")

# Now split
X_train, X_val, y_train, y_val = train_test_split(
    image_paths_filtered, labels_filtered, 
    test_size=0.2, 
    stratify=labels_filtered,  # Now safe to use stratify
    random_state=42
)

print(f"\nTrain set: {len(X_train)} images")
print(f"Val set: {len(X_val)} images")

# Create datasets
train_dataset = StreetViewDataset(X_train, y_train, transform=get_transforms(training=True))
val_dataset = StreetViewDataset(X_val, y_val, transform=get_transforms(training=False))

# Create data loaders (optimized for GPU)
BATCH_SIZE = 64  # Good for GPU

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)

print("\nData loaders created!")

## Step 6: Train the Model!

This will take approximately **2-3 hours** on a T4 GPU.

Progress will be shown with:
- Loss (lower is better)
- Accuracy (higher is better)

The model automatically saves after each epoch to your Google Drive

In [None]:
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

num_classes = len(label_encoder.classes_)
EPOCHS = 20
LEARNING_RATE = 0.001

# Create model
print(f"\nCreating model with {num_classes} classes...")
model = GeoLocalizationModel(num_classes, backbone='resnet50', pretrained=True)
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

print("\nStarting training...\n")
print("=" * 60)

history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Save best model to Google Drive
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'num_classes': num_classes,
            'backbone': 'resnet50'
        }, f"{output_dir}/best_model.pth")
        print(f"  ✓ New best model saved to Google Drive! (Val Acc: {val_acc:.2f}%)")

print("\n" + "=" * 60)
print("Training complete!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Model saved to: {output_dir}/best_model.pth")
print("=" * 60)

## Step 7: Plot Training Results

In [None]:
# Save history
with open(f"{output_dir}/training_history.json", 'w') as f:
    json.dump(history, f, indent=2)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Accuracy', marker='o')
ax2.plot(history['val_acc'], label='Val Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(f"{output_dir}/training_history.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTraining plots saved to Google Drive!")

## Step 8: Download Model to Your Computer

Your trained model is now in your Google Drive at:
`MyDrive/geolocation_model/`

Files saved:
- `best_model.pth` - Your trained model
- `label_mapping.json` - Country name mappings
- `training_history.json` - Training metrics
- `training_history.png` - Accuracy/loss graphs

You can download these from Google Drive and use them with the prediction script on your local machine!

In [None]:
# List saved files
import os
print("Files saved to Google Drive:")
print("=" * 60)
for file in os.listdir(output_dir):
    filepath = os.path.join(output_dir, file)
    size_mb = os.path.getsize(filepath) / (1024**2)
    print(f"  {file:30s} - {size_mb:.2f} MB")
print("=" * 60)
print(f"\nAccess them at: MyDrive/geolocation_model/")