This script is used to finetune the image encoder model using the Circle Loss. The model is trained on a triplet dataset where each triplet consists of an anchor, positive, and negative image. The model is trained to minimize the distance between the anchor and positive images while maximizing the distance between the anchor and negative images. The model is trained using the Circle Loss function which is a variant of the triplet loss function. The model is trained for a fixed number of epochs and the embeddings are generated for the central regions. The embeddings are saved to a CSV file for further analysis.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import convnext_large, ConvNeXt_Large_Weights
from torchvision import transforms
from PIL import Image
import pandas as pd
import geopandas as gpd
import numpy as np
from srai.neighbourhoods import H3Neighbourhood
import random
import wandb
from tqdm import tqdm
import os


In [None]:
class BufferedH3TripletDataset(Dataset):
    def __init__(self, regions_buffered_gdf, image_dir):
        self.regions_buffered_gdf = regions_buffered_gdf
        self.image_dir = image_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.neighborhood = H3Neighbourhood(regions_buffered_gdf)

    def __len__(self):
        return len(self.regions_buffered_gdf)

    def __getitem__(self, idx):
        anchor_id = self.regions_buffered_gdf.index[idx]
    
        positive_ring = random.randint(1, 8)  # Select a random ring distance for positive neighbors
        negative_ring = random.randint(9, 16)  # Select a random ring distance for negative neighbors
    
        positive_neighbors = self.neighborhood.get_neighbours_at_distance(anchor_id, positive_ring)
        negative_neighbors = self.neighborhood.get_neighbours_at_distance(anchor_id, negative_ring)
    
        # Select a random positive neighbor if available, else use anchor_id
        positive_id = random.choice(list(positive_neighbors)) if positive_neighbors else anchor_id
        # Select a random negative neighbor if available, else use anchor_id
        negative_id = random.choice(list(negative_neighbors)) if negative_neighbors else anchor_id
    
        anchor_image = self.load_image(anchor_id)
        positive_image = self.load_image(positive_id)
        negative_image = self.load_image(negative_id)
    
        return (anchor_image, positive_image, negative_image), (anchor_id, positive_id, negative_id)

    def load_image(self, region_id):
        image_path = os.path.join(self.image_dir, f"{region_id}.jpg")
        if os.path.exists(image_path):
            image = Image.open(image_path).convert('RGB')
            return self.transform(image)
        else:
            return torch.zeros(3, 224, 224)  # Return a tensor filled with zeros if the image is not found

class FineTunedConvNeXt(nn.Module):
    def __init__(self):
        super().__init__()
        self.convnext = convnext_large(weights=ConvNeXt_Large_Weights.DEFAULT)

    def forward(self, x):
        features = self.convnext(x)
        return features.view(features.size(0), -1)  # Flatten the features

class CircleLoss(nn.Module):
    def __init__(self, m=0.25, gamma=256):
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp, sn):
        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss

In [None]:
def train_model(model, dataloader, optimizer, criterion, device, epochs, checkpoint_dir, resume_epoch=0):
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_loss = float('inf')

    for epoch in range(resume_epoch, epochs):
        model.train()
        total_loss = 0
        for (anchor_imgs, positive_imgs, negative_imgs), _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            anchor_imgs, positive_imgs, negative_imgs = anchor_imgs.to(device), positive_imgs.to(device), negative_imgs.to(device)

            optimizer.zero_grad()

            anchor_features = model(anchor_imgs)
            positive_features = model(positive_imgs)
            negative_features = model(negative_imgs)

            sp = (anchor_features * positive_features).sum(dim=1)
            sn = (anchor_features * negative_features).sum(dim=1)

            loss = criterion(sp, sn)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            current_lr = optimizer.param_groups[0]['lr']
            wandb.log({"batch_loss": loss.item(), "learning_rate": current_lr})

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
        wandb.log({"epoch": epoch+1, "average_loss": avg_loss})

        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(checkpoint, os.path.join(checkpoint_dir, 'best_model_10.pth'))

    # Save final model
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'final_model_10.pth'))

In [None]:
def generate_embeddings(model, regions_gdf, image_dir, device, batch_size=32):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    embeddings = {}
    dataloader = DataLoader(
        RegionDataset(regions_gdf, image_dir, transform),
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    with torch.no_grad():
        for images, region_ids in tqdm(dataloader, desc="Generating embeddings"):
            images = images.to(device)
            features = model(images)
            for feature, region_id in zip(features, region_ids):
                embeddings[region_id] = feature.cpu().numpy()

    return pd.DataFrame.from_dict(embeddings, orient='index')

class RegionDataset(Dataset):
    def __init__(self, regions_gdf, image_dir, transform):
        self.regions_gdf = regions_gdf
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.regions_gdf)

    def __getitem__(self, idx):
        region_id = self.regions_gdf.index[idx]
        image_path = os.path.join(self.image_dir, f"{region_id}.jpg")
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), region_id

In [None]:
# Use this function in your main script
if __name__ == "__main__":
    wandb.init(project="Urban_Representation_Learning", config={
        "learning_rate": 1e-5,
        "epochs": 1,
        "batch_size": 16,
        "resolution": 10,
        "weight_decay": 1e-4,
    })

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    regions_gdf = gpd.read_file("selected_regions_10.geojson").set_index("region_id")
    regions_buffered_gdf = gpd.read_file("selected_regions_buffered_10.geojson").set_index("region_id")
    image_dir = r"D:\tu delft\Afstuderen\aerial_images_10"

    dataset = BufferedH3TripletDataset(regions_buffered_gdf, image_dir)
    dataloader = DataLoader(dataset, batch_size=wandb.config.batch_size, shuffle=True)

    model = FineTunedConvNeXt().to(device)
    optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate, weight_decay=wandb.config.weight_decay)
    criterion = CircleLoss()

    checkpoint_dir = r"/Phase 6 Experiments/checkpoints_res10"
    resume_epoch = 0

    # Check if there's a checkpoint to resume from
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')])
    if checkpoints:
        latest_checkpoint = checkpoints[-1]
        checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        resume_epoch = checkpoint['epoch']
        print(f"Resuming training from epoch {resume_epoch}")

    print("Starting training...")
    train_model(model, dataloader, optimizer, criterion, device, wandb.config.epochs, checkpoint_dir, resume_epoch)

    print("Generating embeddings for central regions...")
    embeddings_df = generate_embeddings(model, regions_gdf, image_dir, device)

    output_dir = r"D:\tu delft\Afstuderen\Phase 6 Experiments\embeddings"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"learned_finetune_circle_h3_res_{wandb.config.resolution}.csv")
    embeddings_df.to_csv(output_file)
    print(f"Embeddings saved to {output_file}")

    wandb.finish()
    print("All done!")