Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ignacio Rocco
committed
Dec 20, 2017
0 parents
commit e09d0eb
Showing
35 changed files
with
22,831 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
.ipynb_checkpoints | ||
__pycache__ | ||
training_data/seasons/ | ||
training_data/day-night/ | ||
training_data/theta.csv | ||
datasets/ | ||
results/ | ||
tests/ | ||
notebooks/ | ||
images/ | ||
baselines/ | ||
**/*.bak | ||
*.png |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# End-to-end weakly-supervised semantic alignment | ||
|
||
![](http://www.di.ens.fr/willow/research/weakalign/images/teaser.jpg) | ||
|
||
|
||
## About | ||
|
||
This is the implementation of the paper "End-to-end weakly-supervised semantic alignment" by I. Rocco, R. Arandjelović and J. Sivic. | ||
|
||
For more information check out the project [[website](http://www.di.ens.fr/willow/research/weakalign/)] and the paper on [[arXiv](https://arxiv.org/abs/1712.06861)]. | ||
|
||
|
||
## Getting started | ||
|
||
### Dependencies | ||
|
||
The code is implemented using Python 3 and PyTorch 0.2. All dependencies are included in the standard Anaconda distribution. | ||
|
||
### Training | ||
|
||
The code includes scripts for pre-training the models with strong supervision (`train_strong.py`) as proposed in [our previous work](http://www.di.ens.fr/willow/research/cnngeometric/), as well as to fine-tune the model using weak supervision (`train_weak.py`) as proposed in this work. | ||
|
||
Training scripts can be found in the `scripts/` folder. | ||
|
||
### Evaluation | ||
|
||
Evaluation is implemented in the `eval.py` file. It can evaluate a single affine or TPS model (with the `--model-aff` and `--model-tps` parameters respectively), or a combined affine+TPS model (with the `--model`) parameter. | ||
|
||
The evaluation dataset is passed with the `--eval-dataset` parameter. | ||
|
||
### Trained models | ||
|
||
Trained models for the baseline method using only strong supervision and the proposed method using additional weak supervision are provided below. You can store them in the `trained_models/` folder. | ||
|
||
With the provided code below you should obtain the results from Table 2 of the paper. | ||
|
||
|
||
**CNNGeometric with VGG-16 baseline:** [[affine model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_vgg16_affine.pth.tar)],[[TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_vgg16_tps.pth.tar)] | ||
|
||
``` | ||
python eval.py --feature-extraction-cnn vgg --model-aff trained_models/cnngeo_vgg16_affine.pth.tar --model-tps trained_models/cnngeo_vgg16_tps.pth.tar --eval-dataset pf-pascal | ||
``` | ||
|
||
**CNNGeometric with ResNet-101 baseline:** [[affine model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_resnet101_affine.pth.tar)],[[TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_resnet101_tps.pth.tar)] | ||
|
||
``` | ||
python eval.py --feature-extraction-cnn resnet101 --model-aff trained_models/cnngeo_resnet101_affine.pth.tar --model-tps trained_models/cnngeo_resnet101_tps.pth.tar --eval-dataset pf-pascal | ||
``` | ||
|
||
**Proposed method:** [[combined aff+TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/weakalign_resnet101_affine_tps.pth.tar)] | ||
|
||
``` | ||
python eval.py --feature-extraction-cnn resnet101 --model trained_models/weakalign_resnet101_affine_tps.pth.tar --eval-dataset pf-pascal | ||
``` | ||
|
||
## BibTeX | ||
|
||
If you use this code in your project, please cite our paper: | ||
```` | ||
@article{Rocco18, | ||
author = "Rocco, I. and Arandjelovi\'c, R. and Sivic, J.", | ||
title = "End-to-end weakly-supervised semantic alignment", | ||
journal={arXiv preprint arXiv:1712.06861}, | ||
} | ||
```` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import print_function, division | ||
import os | ||
import torch | ||
from torch.autograd import Variable | ||
from skimage import io | ||
import pandas as pd | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
from geotnf.transformation import GeometricTnf | ||
|
||
class CaltechDataset(Dataset): | ||
|
||
""" | ||
Caltech-101 image pair dataset | ||
Args: | ||
csv_file (string): Path to the csv file with image names and annotation files. | ||
dataset_path (string): Directory with the images. | ||
output_size (2-tuple): Desired output size | ||
transform (callable): Transformation for post-processing the training pair (eg. image normalization) | ||
""" | ||
|
||
def __init__(self, csv_file, dataset_path,output_size=(240,240),transform=None): | ||
|
||
self.category_names=['Faces','Faces_easy','Leopards','Motorbikes','accordion','airplanes','anchor','ant','barrel','bass','beaver','binocular','bonsai','brain','brontosaurus','buddha','butterfly','camera','cannon','car_side','ceiling_fan','cellphone','chair','chandelier','cougar_body','cougar_face','crab','crayfish','crocodile','crocodile_head','cup','dalmatian','dollar_bill','dolphin','dragonfly','electric_guitar','elephant','emu','euphonium','ewer','ferry','flamingo','flamingo_head','garfield','gerenuk','gramophone','grand_piano','hawksbill','headphone','hedgehog','helicopter','ibis','inline_skate','joshua_tree','kangaroo','ketch','lamp','laptop','llama','lobster','lotus','mandolin','mayfly','menorah','metronome','minaret','nautilus','octopus','okapi','pagoda','panda','pigeon','pizza','platypus','pyramid','revolver','rhino','rooster','saxophone','schooner','scissors','scorpion','sea_horse','snoopy','soccer_ball','stapler','starfish','stegosaurus','stop_sign','strawberry','sunflower','tick','trilobite','umbrella','watch','water_lilly','wheelchair','wild_cat','windsor_chair','wrench','yin_yang'] | ||
self.out_h, self.out_w = output_size | ||
self.pairs = pd.read_csv(csv_file) | ||
self.img_A_names = self.pairs.iloc[:,0] | ||
self.img_B_names = self.pairs.iloc[:,1] | ||
self.category = self.pairs.iloc[:,2].as_matrix().astype('float') | ||
self.annot_A_str = self.pairs.iloc[:, 3:5] | ||
self.annot_B_str = self.pairs.iloc[:, 5:] | ||
self.dataset_path = dataset_path | ||
self.transform = transform | ||
# no cuda as dataset is called from CPU threads in dataloader and produces confilct | ||
self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) | ||
|
||
def __len__(self): | ||
return len(self.pairs) | ||
|
||
def __getitem__(self, idx): | ||
# get pre-processed images | ||
image_A,im_size_A = self.get_image(self.img_A_names,idx) | ||
image_B,im_size_B = self.get_image(self.img_B_names,idx) | ||
|
||
# get pre-processed point coords | ||
annot_A = self.get_points(self.annot_A_str, idx) | ||
annot_B = self.get_points(self.annot_B_str, idx) | ||
|
||
sample = {'source_image': image_A, 'target_image': image_B, 'source_im_size': im_size_A, 'target_im_size': im_size_B, 'source_polygon': annot_A, 'target_polygon': annot_B} | ||
|
||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
def get_image(self,img_name_list,idx): | ||
img_name = os.path.join(self.dataset_path, img_name_list[idx]) | ||
image = io.imread(img_name) | ||
|
||
# if grayscale convert to 3-channel image | ||
if image.ndim==2: | ||
image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3) | ||
|
||
# get image size | ||
im_size = np.asarray(image.shape) | ||
|
||
# convert to torch Variable | ||
image = np.expand_dims(image.transpose((2,0,1)),0) | ||
image = torch.Tensor(image.astype(np.float32)) | ||
image_var = Variable(image,requires_grad=False) | ||
|
||
# Resize image using bilinear sampling with identity affine tnf | ||
image = self.affineTnf(image_var).data.squeeze(0) | ||
|
||
im_size = torch.Tensor(im_size.astype(np.float32)) | ||
|
||
return (image, im_size) | ||
|
||
def get_points(self,point_coords_list,idx): | ||
point_coords_x = point_coords_list[point_coords_list.columns[0]][idx] | ||
point_coords_y = point_coords_list[point_coords_list.columns[1]][idx] | ||
|
||
return (point_coords_x,point_coords_y) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from os.path import exists, join, basename, dirname, splitext | ||
from os import makedirs, remove, rename | ||
from six.moves import urllib | ||
import tarfile | ||
import zipfile | ||
import requests | ||
import sys | ||
import click | ||
|
||
def download_and_uncompress(url, dest=None, chunk_size=1024, replace="ask", | ||
label="Downloading {dest_basename} ({size:.2f}MB)"): | ||
dest = dest or "./"+url.split("/")[-1] | ||
dest_dir = dirname(dest) | ||
if not exists(dest_dir): | ||
makedirs(dest_dir) | ||
if exists(dest): | ||
if (replace is False | ||
or replace == "ask" | ||
and not click.confirm("Replace {}?".format(dest))): | ||
return | ||
# download file | ||
with open(dest, "wb") as f: | ||
response = requests.get(url, stream=True) | ||
total_length = response.headers.get('content-length') | ||
|
||
if total_length is None: # no content length header | ||
f.write(response.content) | ||
else: | ||
dl = 0 | ||
total_length = int(total_length) | ||
for data in response.iter_content(chunk_size=4096): | ||
dl += len(data) | ||
f.write(data) | ||
done = int(50 * dl / total_length) | ||
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) ) | ||
sys.stdout.write("{:.1%}".format(dl / total_length)) | ||
sys.stdout.flush() | ||
sys.stdout.write("\n") | ||
# uncompress | ||
if dest.endswith("zip"): | ||
file = zipfile.ZipFile(dest, 'r') | ||
elif dest.endswith("tar"): | ||
file = tarfile.open(dest, 'r') | ||
elif dest.endswith("tar.gz"): | ||
file = tarfile.open(dest, 'r:gz') | ||
else: | ||
return dest | ||
|
||
print("Extracting data...") | ||
file.extractall(dest_dir) | ||
file.close() | ||
|
||
return dest | ||
|
||
def download_PF_willow(dest="datasets/proposal-flow-willow"): | ||
print("Fetching PF Willow dataset ") | ||
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
print('Downloading image pair list \n') ; | ||
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf.csv" | ||
file_path = join(dest,basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
def download_PF_pascal(dest="datasets/proposal-flow-pascal"): | ||
print("Fetching PF Pascal dataset ") | ||
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
print('Downloading image pair list \n') ; | ||
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf_pascal.csv" | ||
file_path = join(dest,basename(url)) | ||
download_and_uncompress(url,file_path) | ||
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/val_pairs_pf_pascal.csv" | ||
file_path = join(dest,basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
def download_pascal(dest="datasets/pascal-voc11"): | ||
print("Fetching Pascal VOC2011 dataset") | ||
url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url, file_path) | ||
|
||
def download_pascal_parts(dest="datasets/pascal-parts"): | ||
print("Fetching Pascal Parts dataset") | ||
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/pascal_data.tar" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url, file_path) | ||
|
||
def download_caltech(dest="datasets/caltech-101"): | ||
print("Fetching Caltech-101 dataset") | ||
url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
print("Fetching Caltech-101 annotations") | ||
url="http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
print('Renaming some annotation directories\n') ; | ||
os.rename(join(dest,'Annotations','Airplanes_Side_2'),join(dest,'Annotations','airplanes')) | ||
os.rename(join(dest,'Annotations','Faces_2'),join(dest,'Annotations','Faces')) | ||
os.rename(join(dest,'Annotations','Faces_3'),join(dest,'Annotations','Faces_easy')) | ||
os.rename(join(dest,'Annotations','Motorbikes_16'),join(dest,'Annotations','Motorbikes')) | ||
print('Done renaming\n') ; | ||
|
||
print('Downloading image pair list \n') ; | ||
url='http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_caltech.csv' | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
def download_TSS(dest="datasets/tss"): | ||
print("Fetching TSS dataset ") | ||
url = "http://www.hci.iis.u-tokyo.ac.jp/datasets/data/JointCorrCoseg/TSS_CVPR2016.zip" | ||
file_path = join(dest, basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
||
print('Downloading image pair list \n') ; | ||
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_tss.csv" | ||
file_path = join(dest,basename(url)) | ||
download_and_uncompress(url,file_path) | ||
|
Oops, something went wrong.