Implementation of the paper "Topology-Aware Representation Alignment for Semi-Supervised Vision-Language Learning".
Vision-language models have shown strong performance, but they often generalize poorly to specialized domains.
While semi-supervised vision-language learning mitigates this limitation by leveraging a small set of labeled image-text pairs together with abundant unlabeled images, existing methods remain fundamentally pairwise and fail to model the global structure of multimodal representation manifolds.
Existing topology-based alignment methods rely on persistence diagram matching, which neither guarantees geometric alignment nor utilizes the image-text pairing information central to vision-language learning.
We propose Topology-Aware Multimodal Representation Alignment (ToMA), a framework that uses persistent homology to identify topologically salient edges and aligns them across modalities through available cross-modal correspondences.
ToMA leverages both
This repository contains the code used to train and evaluate:
- SemiCLIP baselines
- ToMA (topology-aware alignment)
- ToMA-domain (domain-wise topology-aware alignment)
The experiments cover:
- Remote sensing
- in-distribution semi-supervised setting
- distribution-shift semi-supervised setting
- Fashion
- semi-supervised setting on Fashion200k, FashionGen, and Polyvore
.
├── main.py
├── environment.yml
├── README.md
├── data/
├── custom/
├── training/
├── keywords/
│ ├── RS/
│ └── fashion/
├── scripts/
│ ├── train_rs_stage1_2.sh
│ ├── train_rs_shift_stage1_2.sh
│ ├── train_fashion_stage1_2.sh
│ ├── eval_RS_stage2.sh
│ └── eval_fashion_stage2.sh
└── scripts_semiclip/
├── train_rs_stage1.sh
├── train_rs_stage2.sh
├── train_fashion_stage1.sh
└── train_fashion_stage2.sh
The code was prepared and tested in the following environment:
- Ubuntu Linux
- Python 3.9
- PyTorch with CUDA support
- NVIDIA A100 GPU
Create the conda environment with:
conda env create --file environment.yml
conda activate tomaAll scripts assume the dataset root is:
./data/Please place the datasets under ./data/ so that the training and evaluation scripts can find them through:
--data-dir "./data/"A recommended structure is:
data/
├── aerial/
│ ├── RSICD/
│ ├── UCM_captions/
│ ├── Sydney_captions/
│ ├── RESISC45/
│ ├── WHU-RS19/
│ ├── RSSCN7/
│ └── AID/
├── fashion/
│ ├── fashion200k/
│ ├── FashionGen/
│ └── PolyvoreOutfits/
└── ...
Please make sure that the downloaded files, captions, and split files are arranged so that they match the expected names used by the codebase and scripts.
bash scripts_semiclip/train_rs_stage1.sh
bash scripts_semiclip/train_rs_stage2.shbash scripts_semiclip/train_fashion_stage1.sh
bash scripts_semiclip/train_fashion_stage2.shbash scripts/train_rs_stage1_2.sh
bash scripts/train_rs_shift_stage1_2.shbash scripts/train_fashion_stage1_2.shbash scripts/eval_RS_stage2.shThis script evaluates:
-
zero-shot classification on:
RSICD-CLSUCM-CLSWHU-RS19RSSCN7AID
-
image-text retrieval on:
RSICDUCMSydney
bash scripts/eval_fashion_stage2.shThis script evaluates:
-
zero-shot classification on:
Fashion200k-CLSFashion200k-SUBCLSFashionGen-CLSFashionGen-SUBCLSPolyvore-CLS
-
image-text retrieval on:
Fashion200kFashionGenPolyvore
Important: before running scripts/eval_fashion_stage2.sh, please fill the ckpts array in that script with the checkpoint names produced by training.
- Baidupan
- GoogleDrive
- UCM_captions-BaiduPan
- Sydney_captions-BaiduPan
- UCM_captions-MEGA
- RSICD-MEGA
- Sydney_captions-MEGA
- RESISC45
- WHU-RS19
- RSSCN7
- AID
This implementation builds on publicly available codebases from prior semi-supervised vision-language learning projects. We thank the authors of:
for making their code publicly available.
