In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from src.siamese_model import SiameseNetwork
from src.fetch_historical import fetch_landsat_tile
from src.fetch_current import fetch_google_tile
import cv2
import numpy as np

class TileDataset(Dataset):
    def __init__(self, coords_list):
        self.coords_list = coords_list

    def __len__(self):
        return len(self.coords_list)

    def __getitem__(self, idx):
        lat, lon = self.coords_list[idx]
        hist = fetch_landsat_tile(lat, lon)
        curr = fetch_google_tile(lat, lon)

        if hist is None or curr is None:
            # Fallback: zeros
            hist = np.zeros((256,256,3), dtype=np.float32)
            curr = np.zeros((256,256,3), dtype=np.float32)

        hist = cv2.resize(hist, (256,256))
        curr = cv2.resize(curr, (256,256))

        hist_tensor = torch.tensor(hist.transpose(2,0,1), dtype=torch.float32)
        curr_tensor = torch.tensor(curr.transpose(2,0,1), dtype=torch.float32)
        label = torch.tensor(0.0)  # placeholder

        return hist_tensor, curr_tensor, label

coords = [(37.7749, -122.4194), (37.7849, -122.4094)]
dataset = TileDataset(coords)
loader = DataLoader(dataset, batch_size=2)

model = SiameseNetwork()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop (simplified)
for epoch in range(2):
    for hist, curr, label in loader:
        optimizer.zero_grad()
        output = model(hist, curr).squeeze()
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

torch.save(model.state_dict(), "models/model_v1.pth")