# Athena Phase 2: Training the Micro Inspector

This notebook trains a U-Net model on the TTPLA dataset to detect Power Lines (Cables) and Vegetation.

In [None]:
import os
import json
import numpy as np
import cv2
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

## 1. Dataset Class

In [None]:
class TTPLADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]
        self.transform = transform
        
        # Class mapping
        self.classes = ['background', 'cable', 'tower', 'vegetation']
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        json_path = os.path.join(self.root_dir, img_name.replace('.jpg', '.json'))
        
        # Load Image
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        
        # Create Mask
        mask = np.zeros((h, w), dtype=np.uint8)
        
        if os.path.exists(json_path):
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            for shape in data['shapes']:
                label = shape['label'].lower()
                points = np.array(shape['points'], dtype=np.int32)
                
                class_id = 0
                if 'cable' in label or 'wire' in label:
                    class_id = 1
                elif 'tower' in label:
                    class_id = 2
                elif 'tree' in label or 'vegetation' in label:
                    class_id = 3
                    
                if class_id > 0:
                    cv2.fillPoly(mask, [points], class_id)
        
        # Simple Resize for testing (should be robust transform in prod)
        target_size = (512, 512)
        image = cv2.resize(image, target_size)
        mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
        
        # To Tensor
        image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        mask = torch.from_numpy(mask).long()
        
        return image, mask

In [None]:
DATA_DIR = '../data/TTPLA'
dataset = TTPLADataset(DATA_DIR)
print(f"Dataset Size: {len(dataset)}")

img, mask = dataset[0]
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(img.permute(1,2,0))
plt.title("Image")
plt.subplot(1,2,2)
plt.imshow(mask)
plt.title("Mask")
plt.show()

## 2. Model Setup

In [None]:
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=3,
    classes=4
)

## 3. Training Loop (Prototype)

In [None]:
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training on {device}")
model.to(device)

epochs = 1

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 10 == 0:
            print(f"Epoch {epoch+1}, Step {i}, Loss: {loss.item()}")
            
print("Training Finished")
torch.save(model.state_dict(), "../models/unet_ttpla_prototype.pth")