# TerraFusion Property Condition Model Training

This notebook trains a model to predict property condition from images.
The model will output a score from 1-5, where:
- 1: Poor condition
- 2: Fair condition
- 3: Average condition
- 4: Good condition
- 5: Excellent condition

We'll use a pre-trained MobileNetV2 model and fine-tune it for our classification task.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import random
from pathlib import Path

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

## 1. Configuration

In [None]:
# Define paths for datasets
DATASET_DIR = "dataset"  # Main folder containing subfolders 1-5
MODEL_SAVE_PATH = "models/condition_model.pth"  # Where to save the trained model

# Set hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 5
LEARNING_RATE = 0.001
NUM_CLASSES = 5  # Condition grades 1-5

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Dataset Preparation

In [None]:
# Define image transformations for training
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define image transformations for validation
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Custom dataset class for property condition images
class PropertyConditionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # Collect all image paths and labels
        for condition in range(1, 6):  # Condition classes 1-5
            condition_dir = os.path.join(root_dir, str(condition))
            if os.path.exists(condition_dir):
                for img_name in os.listdir(condition_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                        img_path = os.path.join(condition_dir, img_name)
                        # Labels are 0-4 (for condition grades 1-5)
                        self.samples.append((img_path, condition - 1))
        
        # Check if we have samples
        if len(self.samples) == 0:
            print(f"WARNING: No image samples found in {root_dir}")
            print("Please make sure you have images in the dataset/1, dataset/2, etc. folders")
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            # Load and transform image
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
            
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            # Return a fallback image and the same label
            fallback = torch.zeros((3, 224, 224))
            return fallback, label

In [None]:
# Create datasets
dataset = PropertyConditionDataset(DATASET_DIR, transform=train_transforms)

# Calculate split sizes
total_size = len(dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # 20% for validation

# Split the dataset
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Update transforms for validation set
val_dataset.dataset.transform = val_transforms

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Total samples: {total_size}")
print(f"Training samples: {train_size}")
print(f"Validation samples: {val_size}")

# Count samples per class
class_counts = [0] * NUM_CLASSES
for _, label in dataset.samples:
    class_counts[label] += 1
    
for condition, count in enumerate(class_counts, 1):
    print(f"Condition {condition}: {count} samples")

## 3. Model Architecture

In [None]:
# Load pre-trained MobileNetV2
model = models.mobilenet_v2(pretrained=True)

# Modify the classifier to output 5 classes
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, NUM_CLASSES)

# Move model to device
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

## 4. Training Function

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    # History for plotting
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            # Backward + optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Forward
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
        
        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.double() / len(val_loader.dataset)
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc.item())
        
        # Print epoch stats
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
        print('-' * 40)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"Saved best model with accuracy: {val_acc:.4f}")
    
    return model, history

## 5. Train the Model

In [None]:
# Before training, check if we have samples
if len(dataset) > 0:
    # Train the model
    print("Starting training...")
    model, history = train_model(
        model, 
        train_loader, 
        val_loader, 
        criterion, 
        optimizer, 
        NUM_EPOCHS, 
        device
    )
    print("Training complete!")
else:
    print("No training samples found. Add images to dataset/1, dataset/2, etc. folders.")

## 6. Visualize Training Results

In [None]:
if 'history' in locals():
    # Plotting the training and validation loss
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss over Epochs')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Training Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy over Epochs')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

## 7. Test the Model on Sample Images

In [None]:
def predict_condition(model, img_path, device):
    # Load the image
    img = Image.open(img_path).convert('RGB')
    
    # Apply transformations
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        
        # Get class prediction
        _, predicted = torch.max(outputs, 1)
        condition_class = predicted.item() + 1  # Convert back to 1-5 scale
        
        # Calculate weighted average for a more precise score
        weighted_score = 0
        for i in range(5):
            weighted_score += (i + 1) * probs[i].item()
    
    return condition_class, weighted_score, probs.cpu().numpy()

In [None]:
# Test on a few validation samples if we have samples
if len(val_dataset) > 0 and os.path.exists(MODEL_SAVE_PATH):
    # Load the best saved model
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    model.eval()
    
    # Get a few samples
    sample_indices = np.random.choice(len(val_dataset), min(5, len(val_dataset)), replace=False)
    
    plt.figure(figsize=(15, 10))
    for i, idx in enumerate(sample_indices):
        img_path, true_label = val_dataset.dataset.samples[val_dataset.indices[idx]]
        true_label += 1  # Convert back to 1-5 scale
        
        # Predict
        pred_class, weighted_score, probs = predict_condition(model, img_path, device)
        
        # Display image and predictions
        img = Image.open(img_path).convert('RGB')
        plt.subplot(2, 3, i+1)
        plt.imshow(img)
        plt.title(f"True: {true_label}, Pred: {pred_class}\nScore: {weighted_score:.2f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("Model saved to:", os.path.abspath(MODEL_SAVE_PATH))
else:
    print("No model saved or no validation samples to test.")

## Next Steps

To use this trained model in production:

1. Gather a balanced dataset of property images labeled by condition (at least 20-30 images per class)
2. Train the model using this notebook
3. The model will be saved to `models/condition_model.pth`
4. The `model_loader.py` script will automatically use this model for inference