In [26]:
# Install necessary libraries (from first code, Cell 1)
try:
    import torch, torchvision, transformers, torchmetrics, PIL, exifread, requests, google
except ImportError:
    !pip install torch torchvision transformers torchmetrics pillow exifread requests

# Imports and custom dataset (from first code, Cell 2)
import os
os.environ['TK_SILENCE_DEPRECATION'] = '1'  # Suppress Tkinter warnings
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import ViTModel
from torch.utils.data import DataLoader, Dataset
import exifread
import requests
from google.colab import files

class SatelliteWildfireDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.images = []
        self.labels = []
        categories = ['Smoke', 'Seaside', 'Land', 'Haze', 'Dust', 'Cloud']
        for i in range(len(categories)):
            category_name = categories[i]
            folder = os.path.join(image_dir, category_name)
            if not os.path.exists(folder):
                continue
            files = os.listdir(folder)
            for file in files:
                if file.endswith('.tif'):
                    self.images.append(os.path.join(folder, file))
                    self.labels.append(i)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[index]
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)
        return image, label

# Custom ViT model (from first code, Cell 3)
class WildfireViTModel(nn.Module):
    def __init__(self):
        super(WildfireViTModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        for parameters in self.vit.parameters():
            parameters.requires_grad = False
        self.extra_layer = nn.Linear(768, 256)
        self.relu = nn.ReLU()
        self.final_layer = nn.Linear(256, 6)

    def forward(self, input_images):
        outputs = self.vit(pixel_values=input_images)
        cls_output = outputs.last_hidden_state[:, 0, :]
        hidden = self.extra_layer(cls_output)
        activated = self.relu(hidden)
        logits = self.final_layer(activated)
        return logits

# Utility functions: classification (first code, Cell 6), GPS, and fire station (second code)
def classify_image(image_path, model):
    try:
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        input_tensor = transform(image)
        input_batch = input_tensor.unsqueeze(0)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        input_batch = input_batch.to(device)
        model = model.to(device)
        with torch.no_grad():
            output = model(input_batch)
        _, predicted_class = torch.max(output, 1)
        class_labels = ['Smoke', 'Seaside', 'Land', 'Haze', 'Dust', 'Cloud']
        predicted_label = class_labels[predicted_class.item()]
        if predicted_label == 'Smoke':
            return f"Predicted class: {predicted_label}"
        return "No smoke detected"
    except Exception as e:
        return f"Error classifying image: {e}"

def get_gps_coordinates(image_path):
    try:
        with open(image_path, 'rb') as f:
            tags = exifread.process_file(f)
            if all(k in tags for k in ['GPS GPSLatitude', 'GPS GPSLatitudeRef', 'GPS GPSLongitude', 'GPS GPSLongitudeRef']):
                lat, lat_ref = tags['GPS GPSLatitude'].values, tags['GPS GPSLatitudeRef'].values
                lon, lon_ref = tags['GPS GPSLongitude'].values, tags['GPS GPSLongitudeRef'].values
                lat_deg = float(lat[0].num)/float(lat[0].den) + float(lat[1].num)/(float(lat[1].den)*60) + float(lat[2].num)/(float(lat[2].den)*3600)
                if lat_ref == 'S':
                    lat_deg *= -1
                lon_deg = float(lon[0].num)/float(lon[0].den) + float(lon[1].num)/(float(lon[1].den)*60) + float(lon[2].num)/(float(lon[2].den)*3600)
                if lon_ref == 'W':
                    lon_deg *= -1
                return lat_deg, lon_deg
    except Exception as e:
        print(f"Error reading EXIF data: {e}")
    return None

def get_fire_station_phone_number(lat, lon, api_key):
    try:
        url = f"https://places.googleapis.com/v1/places:searchNearby"
        payload = {
            "locationRestriction": {
                "circle": {
                    "center": {
                        "latitude": lat,
                        "longitude": lon
                    },
                    "radius": 5000.0
                }
            },
            "includedTypes": ["fire_station"]
        }
        headers = {
            "Content-Type": "application/json",
            "X-Goog-Api-Key": api_key,
            "X-Goog-FieldMask": "places.displayName,places.formattedAddress,places.internationalPhoneNumber"
        }
        response = requests.post(url, json=payload, headers=headers)
        data = response.json()
        if response.status_code != 200:
            return f"API Error: {response.status_code}"
        if "places" not in data or not data["places"]:
            return "No fire stations found nearby"
        place = data["places"][0]
        name = place.get("displayName", {}).get("text", "Unnamed Fire Station")
        phone = place.get("internationalPhoneNumber", "Phone number not available")
        address = place.get("formattedAddress", "Address unavailable")
        return {
            "name": name,
            "phone": phone,
            "address": address
        }
    except Exception as e:
        return f"Error retrieving phone number: {str(e)}"

# Colab-compatible inference (adapted from first code, Cell 6)
def run_inference(model):
    print("Please upload an image to test!")
    uploaded = files.upload()
    file_name = list(uploaded.keys())[0]
    print(f"Got your file: {file_name}")

    classification = classify_image(file_name, model)

    if classification == 'Predicted class: Smoke':
        location = get_gps_coordinates(file_name)
        if location:
            lat, lon = location
            api_key = "YOUR_API_KEY"  # Replace with your Google Places API key
            fire_station = get_fire_station_phone_number(lat, lon, api_key)
            if isinstance(fire_station, dict):
                alert_message = (
                    f"Wildfire detected at:\n\n"
                    f"{lat}, {lon}\n\n"
                    f"Nearest Fire Station:\n"
                    f"{fire_station['name']}\n"
                    f"{fire_station['phone']}\n"
                    f"{fire_station['address']}"
                )
            else:
                alert_message = (
                    f"Wildfire detected at:\n\n"
                    f"{lat}, {lon}\n\n"
                    f"Nearest Fire Station:\n"
                    f"{fire_station}"
                )
        else:
            alert_message = "Smoke detected but no location data found."
    else:
        alert_message = classification

    print("Analysis completed")
    print(alert_message)

    os.remove(file_name)
    print(f"Deleted {file_name} from Colab.")

# Main execution (from first code, Cell 5, adapted)
if __name__ == "__main__":
    model = WildfireViTModel()
    model_path = "/content/WEIGHTS/wildfire_model.pth"  # Weights will be uploaded
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
        print("Model loaded successfully!")
    else:
        print("Model weights not found. Please upload wildfire_model.pth to /content/WEIGHTS/ or train the model.")
        # Training (from first code, Cell 5, commented out)
        """
        from torchmetrics import Accuracy, F1Score
        def train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, epochs):
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model = model.to(device)
            class_count = 6
            train_accuracy_metric = Accuracy(task="multiclass", num_classes=class_count).to(device)
            val_accuracy_metric = Accuracy(task="multiclass", num_classes=class_count).to(device)
            f1_metric = F1Score(task="multiclass", num_classes=class_count, average='macro').to(device)
            for epoch in range(epochs):
                model.train()
                train_loss = 0
                train_accuracy_metric.reset()
                f1_metric.reset()
                for images, labels in train_loader:
                    images, labels = images.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    train_loss += loss.item()
                    preds = torch.argmax(outputs, dim=1)
                    train_accuracy_metric.update(preds, labels)
                    f1_metric.update(preds, labels)
                avg_train_loss = train_loss / len(train_loader)
                train_accuracy = train_accuracy_metric.compute().item()
                train_f1 = f1_metric.compute().item()
                model.eval()
                val_loss = 0
                val_accuracy_metric.reset()
                f1_metric.reset()
                with torch.no_grad():
                    for images, labels in test_loader:
                        images, labels = images.to(device), labels.to(device)
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                        val_loss += loss.item()
                        preds = torch.argmax(outputs, dim=1)
                        val_accuracy_metric.update(preds, labels)
                        f1_metric.update(preds, labels)
                avg_val_loss = val_loss / len(test_loader)
                val_accuracy = val_accuracy_metric.compute().item()
                val_f1 = f1_metric.compute().item()
                print(f"Epoch {epoch + 1}:")
                print(f"  Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train F1: {train_f1:.4f}")
                print(f"  Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")
            return model

        from google.colab import drive
        drive.mount('/content/drive')
        image_dir = "/content/drive/MyDrive/archive"  # Update with your dataset path
        dataset = SatelliteWildfireDataset(image_dir)
        train_size = int(0.8 * len(dataset))
        test_size = len(dataset) - train_size
        train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])
        train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
        test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        model = train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, epochs=1)
        torch.save(model.state_dict(), "/content/WEIGHTS/wildfire_model.pth")
        print("Model trained and saved!")
        """
    model.eval()
    run_inference(model)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model weights not found. Please upload wildfire_model.pth to /content/WEIGHTS/ or train the model.
Please upload an image to test!


Saving smoke_444.tif to smoke_444.tif
Got your file: smoke_444.tif
Analysis completed
No smoke detected
Deleted smoke_444.tif from Colab.
