In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
import torchvision.models as models
from tqdm import tqdm
import random
import numpy as np
import gc
import os

In [None]:
import gdown
import zipfile

# === Google Drive File IDs ===
EMBEDDINGS_FILE_ID = "1k_ua1tHAWKLJQt_89oFUy-yZVum1onPr"
KEYWORD_FILE_ID = "1VdVOWPoSk40Yucacg2WQNuuKNLqrrlD3"
TRIPLETS_FILE_ID = "1aZoyVQ0rsn8yQPM36tfTa_5sjfLKi8Zb"
PATCH_ZIP_ID = "1JA3aoGyhyKbCIn4uikyKOAoN0X_T_D5d"

# === Config ===
EMBED_DIM = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 10

# === Download parquet file from Google Drive ===
os.makedirs("data", exist_ok=True)
parquet_path = "data/terrain_embeddings.parquet"
if not os.path.exists(parquet_path):
    gdown.download(f"https://drive.google.com/uc?id={EMBEDDINGS_FILE_ID}", parquet_path, quiet=False)

# === Download keyword vectors from Google Drive ===
os.makedirs("data", exist_ok=True)
keyword_path = "data/keyword_vecs.pt"
if not os.path.exists(keyword_path):
    gdown.download(f"https://drive.google.com/uc?id={KEYWORD_FILE_ID}", keyword_path, quiet=False)

# === Download contrastive learning triplets from Google Drive ===
os.makedirs("data", exist_ok=True)
triplets_path = "data/triplets.pt"
if not os.path.exists(triplets_path):
    gdown.download(f"https://drive.google.com/uc?id={TRIPLETS_FILE_ID}", triplets_path, quiet=False)

# === Download and unzip terrain patches ===
PATCH_DIR = "data/terrain_patches"
patch_zip_path = "data/terrain_patches.zip"
if not os.path.exists(PATCH_DIR):
    gdown.download(f"https://drive.google.com/uc?id={PATCH_ZIP_ID}", patch_zip_path, quiet=False)
    with zipfile.ZipFile(patch_zip_path, 'r') as zip_ref:
        zip_ref.extractall("data")


df = pd.read_parquet(parquet_path)
triplets = torch.load(triplets_path)
keyword_vecs = torch.load(keyword_path)

Downloaded file size: 11635930
First 4 bytes (should be 'PAR1'): b'PAR1'


In [5]:
# === Filter None values ===
def safe_collate(batch):
    batch = [b for b in batch if b is not None]
    return default_collate(batch) if batch else None

# === Text Encoder ===
class TextEncoder(nn.Module):
    def __init__(self, keyword_vecs):
        super().__init__()
        self.keyword_vecs = keyword_vecs.to(DEVICE)
        self.proj = nn.Linear(self.keyword_vecs.shape[1], EMBED_DIM)

    def forward(self, indices):
        vecs = self.keyword_vecs[indices]
        return self.proj(vecs)

# === Geo Encoder ===
class GeoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, EMBED_DIM)
        )

    def forward(self, coords):
        return self.mlp(coords.to(DEVICE))

# === Terrain Encoder ===
class TerrainEncoder(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        base = models.resnet18(pretrained=True)

        weight = base.conv1.weight
        new_weight = weight.sum(dim=1, keepdim=True) / 3.0
        base.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        base.conv1.weight.data = new_weight

        self.encoder = nn.Sequential(*(list(base.children())[:-1]))
        self.proj = nn.Linear(512, dim)

    def forward(self, x):
        x = self.encoder(x).squeeze(-1).squeeze(-1)
        x = self.proj(x)
        return F.normalize(x, dim=-1)

# === Fused Encoder ===
class FusedEncoder(nn.Module):
    def __init__(self, keyword_vecs):
        super().__init__()
        self.text_encoder = TextEncoder(keyword_vecs)
        self.geo_encoder = GeoEncoder()
        self.terrain_encoder = TerrainEncoder()
        self.fuse = nn.Sequential(
            nn.Linear(3 * EMBED_DIM, 256),
            nn.ReLU(),
            nn.Linear(256, EMBED_DIM)
        )

    def forward(self, texts, coords, terrain):
        text_emb = self.text_encoder(texts)
        geo_emb = self.geo_encoder(coords)
        terrain_emb = self.terrain_encoder(terrain)
        return self.fuse(torch.cat([text_emb, geo_emb, terrain_emb], dim=-1))

### Preloaded dataset for contrastive learning

In [6]:
class TripletDataset(Dataset):
    def __init__(self, df, triplets):
        self.df = df.reset_index(drop=True)
        self.triplets = triplets

    def load_patch(self, patch_id):
        path = os.path.join("data/terrain_patches", f"{patch_id}.npy")
        patch = np.load(path, mmap_mode='r')
        return torch.tensor(patch).unsqueeze(0).float() / 1000.0

    def __getitem__(self, idx):
        a_idx, p_idx, n_idx = self.triplets[idx]

        def get_item(i):
            row = self.df.iloc[i]
            patch = self.load_patch(row["terrain_patch_id"])
            coord = torch.tensor([row["lat"], row["lon"]], dtype=torch.float32)
            return patch, i, coord

        return get_item(a_idx) + get_item(p_idx) + get_item(n_idx)

    def __len__(self):
        return len(self.triplets)

In [7]:
# === InfoNCE Loss ===
def info_nce(anchor, positive, negative, temperature=0.07):
    anchor = F.normalize(anchor, dim=-1)
    positive = F.normalize(positive, dim=-1)
    negative = F.normalize(negative, dim=-1)
    pos_sim = torch.exp(torch.sum(anchor * positive, dim=-1) / temperature)
    neg_sim = torch.exp(torch.sum(anchor * negative, dim=-1) / temperature)
    return -torch.log(pos_sim / (pos_sim + neg_sim)).mean()

In [None]:
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model = FusedEncoder(keyword_vecs).to(DEVICE)
dataset = TripletDataset(df, triplets)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        if batch is None:
            continue
        (anchor_patch, anchor_text_idx, anchor_coord,
         pos_patch, pos_text_idx, pos_coord,
         neg_patch, neg_text_idx, neg_coord) = batch

        anchor_patch, pos_patch, neg_patch = anchor_patch.to(DEVICE), pos_patch.to(DEVICE), neg_patch.to(DEVICE)
        anchor_coord, pos_coord, neg_coord = anchor_coord.to(DEVICE), pos_coord.to(DEVICE), neg_coord.to(DEVICE)

        a = model(anchor_text_idx.long().to(DEVICE), anchor_coord, anchor_patch)
        p = model(pos_text_idx.long().to(DEVICE), pos_coord, pos_patch)
        n = model(neg_text_idx.long().to(DEVICE), neg_coord, neg_patch)

        loss = info_nce(a, p, n)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        del anchor_patch, anchor_text_idx, anchor_coord
        del pos_patch, pos_text_idx, pos_coord
        del neg_patch, neg_text_idx, neg_coord
        del a, p, n, loss
        gc.collect()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"fused_encoder_epoch_{epoch+1}.pt")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved checkpoint to {checkpoint_path}")

torch.save(model.state_dict(), "trained_fused_encoder.pt")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]
Epoch 1/10: 100%|██████████| 524/524 [27:24<00:00,  3.14s/it]


Epoch 1: Loss = 0.1442
Saved checkpoint to checkpoints/fused_encoder_epoch_1.pt


Epoch 2/10: 100%|██████████| 524/524 [27:38<00:00,  3.16s/it]


Epoch 2: Loss = 0.0493
Saved checkpoint to checkpoints/fused_encoder_epoch_2.pt


Epoch 3/10: 100%|██████████| 524/524 [27:31<00:00,  3.15s/it]


Epoch 3: Loss = 0.0439
Saved checkpoint to checkpoints/fused_encoder_epoch_3.pt


Epoch 4/10: 100%|██████████| 524/524 [27:32<00:00,  3.15s/it]


Epoch 4: Loss = 0.0448
Saved checkpoint to checkpoints/fused_encoder_epoch_4.pt


Epoch 5/10: 100%|██████████| 524/524 [27:29<00:00,  3.15s/it]


Epoch 5: Loss = 0.0413
Saved checkpoint to checkpoints/fused_encoder_epoch_5.pt


Epoch 6/10:  58%|█████▊    | 305/524 [16:01<11:59,  3.28s/it]