In [None]:
import os
from copy import deepcopy

from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

import numpy as np


In [None]:
class ImagePositionPredictor(nn.Module):
    def __init__(self):
        super(ImagePositionPredictor, self).__init__()
        
        # Convolutional layers for each input image
        self.conv1 = nn.Conv2d(3, 16, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        
        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layers
        self.fc1 = nn.Linear(512, 256)  # 512 comes from 256 * 2 (two images)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 3)  # Output: x, y, confidence
        
        # Batch normalization layers
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(128)
        self.bn5 = nn.BatchNorm2d(256)
        
    def forward(self, img1, img2):
        # Process first image
        x1 = self._process_single_image(img1)
        
        # Process second image
        x2 = self._process_single_image(img2)
        
        # Concatenate features from both images
        x = torch.cat((x1, x2), dim=1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        # Apply tanh to first two outputs (x, y) to constrain them between -1 and 1
        # Apply sigmoid to the third output (confidence) to constrain it between 0 and 1
        return torch.cat((torch.tanh(x[:, :2]), torch.sigmoid(x[:, 2].unsqueeze(1))), dim=1)
    
    def _process_single_image(self, img):
        x = F.relu(self.bn1(self.conv1(img)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.gap(x)
        return torch.flatten(x, 1)

In [None]:
# Calculate total parameters
model = ImagePositionPredictor()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

In [None]:
# Example usage
model = ImagePositionPredictor()
img1 = torch.randn(1, 3, 224, 224)  # Batch size 1, 3 color channels, 224x224 resolution
img2 = torch.randn(1, 3, 224, 224)
output = model(img1, img2)
print(output)  

In [None]:
folder_paths=[r"D:\Unity\AITX_PanLoc\Assets\data\21.34938_-12.25565_folder"]

all_data=[]
for folder_path in folder_paths:
    # Get all files in folder
    image_paths_and_locations={}
    file_paths=os.listdir(folder_path)

    # Get the Base Image
    base_image=None
    base_position=None
    base_text_path=None
    for file_path in file_paths:
        if "_1." in file_path:
            if file_path.endswith(".jpg"):
                base_image = file_path
            
            # if txt set the location
            if file_path.endswith(".txt"):
                base_text_path=file_path
                text_data=open(os.path.join(folder_path, file_path)).read().split("\n")
                base_position=[float(x) for x in text_data[0].split("(")[1].split(")")[0].split(",")]

    if base_image and base_position:
        print("Found base image and position")
    else:
        print("Base image and position not found")


    # Get the input images
    files_paths_input_data=deepcopy(file_paths)
    files_paths_input_data.remove(base_image)
    files_paths_input_data.remove(base_text_path)

    image_paths_and_locations={}
    for file_path in files_paths_input_data:
        if "_1." in file_path:
            continue

        # Get file name and initialize the data
        file_name=file_path.split(".")[0]
        if image_paths_and_locations.get(file_name, None)==None:
                image_paths_and_locations[file_name]={}

        # if image set the image path
        if file_path.endswith(".jpg"):
            image_paths_and_locations[file_name]["image"]=os.path.join(folder_path, file_path)

        # if txt set the location
        if file_path.endswith(".txt"):
            text_data=open(os.path.join(folder_path, file_path)).read().split("\n")
            image_paths_and_locations[file_name]["location"]=[float(x) for x in text_data[0].split("(")[1].split(")")[0].split(",")]

    all_data+=list(image_paths_and_locations.values())

In [None]:
all_data

In [None]:
image_paths_and_locations

In [None]:
class ImagePairDataset(Dataset):
    def __init__(self, folder_paths, transform=None):
        self.folder_paths = folder_paths
        self.transform = transform
        self.image_pairs = self._load_image_pairs()

    def _load_image_pairs(self):
        all_data = []
        for folder_path in self.folder_paths:
            file_paths = os.listdir(folder_path)

            # Get the Base Image
            base_image = None
            base_position = None
            for file_path in file_paths:
                if "_1." in file_path:
                    if file_path.endswith(".jpg"):
                        base_image = os.path.join(folder_path, file_path)
                    if file_path.endswith(".txt"):
                        base_text_path = os.path.join(folder_path, file_path)
                        with open(base_text_path, 'r') as f:
                            text_data = f.read().split("\n")
                        base_position = [float(x) for x in text_data[0].split("(")[1].split(")")[0].split(",")]

            if not (base_image and base_position):
                print(f"Base image and position not found in {folder_path}")
                continue

            # Get the input images
            image_pairs = []
            for file_path in file_paths:
                if "_1." in file_path or not file_path.endswith(".jpg"):
                    continue

                file_name = file_path.split(".")[0]
                image_path = os.path.join(folder_path, file_path)
                txt_path = os.path.join(folder_path, f"{file_name}.txt")

                if os.path.exists(txt_path):
                    with open(txt_path, 'r') as f:
                        text_data = f.read().split("\n")
                    location = [float(x) for x in text_data[0].split("(")[1].split(")")[0].split(",")]
                    
                    # Calculate relative position
                    relative_x = location[0] - base_position[0]
                    relative_y = location[2] - base_position[2]
                    
                    # Normalize to [-1, 1] range (you may need to adjust this based on your data)
                    max_distance = 0.1  # Adjust this value based on your data's scale
                    normalized_x = np.clip(relative_x / max_distance, -1, 1)
                    normalized_y = np.clip(relative_y / max_distance, -1, 1)
                    
                    # Add confidence (you may want to adjust this based on your needs)
                    confidence = 1.0

                    image_pairs.append((base_image, image_path, normalized_x, normalized_y, confidence))

            all_data.extend(image_pairs)

        return all_data
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        try:
            base_img_path, img_path, x, y, confidence = self.image_pairs[idx]
            base_img = Image.open(base_img_path).convert('RGB')
            img = Image.open(img_path).convert('RGB')

            if self.transform:
                base_img = self.transform(base_img)
                img = self.transform(img)

            return base_img, img, torch.tensor([x, y, confidence], dtype=torch.float)
        except Exception as e:
            print(f"Error loading image pair at index {idx}: {str(e)}")
            # Return a dummy sample
            dummy_img = torch.zeros((3, 2048, 1024))
            dummy_label = torch.tensor([0, 0, 0], dtype=torch.float)
            return dummy_img, dummy_img, dummy_label

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, learning_rate, device):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for img1, img2, labels in train_loader:
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(img1, img2)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * img1.size(0)

        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for img1, img2, labels in val_loader:
                img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
                outputs = model(img1, img2)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * img1.size(0)

        val_loss /= len(val_loader.dataset)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    return model

# Set up data transformations
transform = transforms.Compose([
    transforms.Resize((2048, 1024)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
# Create datasets and data loaders
folder_paths = [r"D:\Unity\AITX_PanLoc\Assets\data\21.34938_-12.25565_folder"]
train_dataset = ImagePairDataset(folder_paths=folder_paths, transform=transform)

# Split the dataset into train and validation sets
train_size = int(0.8 * len(train_dataset.image_pairs))
val_size = len(train_dataset.image_pairs) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

# Initialize the model and move it to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImagePositionPredictor().to(device)

In [None]:

# Initialize the model and move it to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImagePositionPredictor().to(device)

# Train the model
num_epochs = 50
learning_rate = 0.001
trained_model = train_model(model, train_loader, val_loader, num_epochs, learning_rate, device)

# Save the trained model
torch.save(trained_model.state_dict(), 'image_position_predictor.pth')

In [None]:
for img1, img2, labels in train_loader:
    pass