Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ignacio Rocco committed Dec 20, 2017
0 parents commit e09d0eb
Show file tree
Hide file tree
Showing 35 changed files with 22,831 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .gitignore
@@ -0,0 +1,13 @@
.ipynb_checkpoints
__pycache__
training_data/seasons/
training_data/day-night/
training_data/theta.csv
datasets/
results/
tests/
notebooks/
images/
baselines/
**/*.bak
*.png
67 changes: 67 additions & 0 deletions README.md
@@ -0,0 +1,67 @@
# End-to-end weakly-supervised semantic alignment

![](http://www.di.ens.fr/willow/research/weakalign/images/teaser.jpg)


## About

This is the implementation of the paper "End-to-end weakly-supervised semantic alignment" by I. Rocco, R. Arandjelović and J. Sivic.

For more information check out the project [[website](http://www.di.ens.fr/willow/research/weakalign/)] and the paper on [[arXiv](https://arxiv.org/abs/1712.06861)].


## Getting started

### Dependencies

The code is implemented using Python 3 and PyTorch 0.2. All dependencies are included in the standard Anaconda distribution.

### Training

The code includes scripts for pre-training the models with strong supervision (`train_strong.py`) as proposed in [our previous work](http://www.di.ens.fr/willow/research/cnngeometric/), as well as to fine-tune the model using weak supervision (`train_weak.py`) as proposed in this work.

Training scripts can be found in the `scripts/` folder.

### Evaluation

Evaluation is implemented in the `eval.py` file. It can evaluate a single affine or TPS model (with the `--model-aff` and `--model-tps` parameters respectively), or a combined affine+TPS model (with the `--model`) parameter.

The evaluation dataset is passed with the `--eval-dataset` parameter.

### Trained models

Trained models for the baseline method using only strong supervision and the proposed method using additional weak supervision are provided below. You can store them in the `trained_models/` folder.

With the provided code below you should obtain the results from Table 2 of the paper.


**CNNGeometric with VGG-16 baseline:** [[affine model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_vgg16_affine.pth.tar)],[[TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_vgg16_tps.pth.tar)]

```
python eval.py --feature-extraction-cnn vgg --model-aff trained_models/cnngeo_vgg16_affine.pth.tar --model-tps trained_models/cnngeo_vgg16_tps.pth.tar --eval-dataset pf-pascal
```

**CNNGeometric with ResNet-101 baseline:** [[affine model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_resnet101_affine.pth.tar)],[[TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/cnngeo_resnet101_tps.pth.tar)]

```
python eval.py --feature-extraction-cnn resnet101 --model-aff trained_models/cnngeo_resnet101_affine.pth.tar --model-tps trained_models/cnngeo_resnet101_tps.pth.tar --eval-dataset pf-pascal
```

**Proposed method:** [[combined aff+TPS model](http://www.di.ens.fr/willow/research/weakalign/trained_models/weakalign_resnet101_affine_tps.pth.tar)]

```
python eval.py --feature-extraction-cnn resnet101 --model trained_models/weakalign_resnet101_affine_tps.pth.tar --eval-dataset pf-pascal
```

## BibTeX

If you use this code in your project, please cite our paper:
````
@article{Rocco18,
author = "Rocco, I. and Arandjelovi\'c, R. and Sivic, J.",
title = "End-to-end weakly-supervised semantic alignment",
journal={arXiv preprint arXiv:1712.06861},
}
````


88 changes: 88 additions & 0 deletions data/caltech_dataset.py
@@ -0,0 +1,88 @@
from __future__ import print_function, division
import os
import torch
from torch.autograd import Variable
from skimage import io
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from geotnf.transformation import GeometricTnf

class CaltechDataset(Dataset):

"""
Caltech-101 image pair dataset
Args:
csv_file (string): Path to the csv file with image names and annotation files.
dataset_path (string): Directory with the images.
output_size (2-tuple): Desired output size
transform (callable): Transformation for post-processing the training pair (eg. image normalization)
"""

def __init__(self, csv_file, dataset_path,output_size=(240,240),transform=None):

self.category_names=['Faces','Faces_easy','Leopards','Motorbikes','accordion','airplanes','anchor','ant','barrel','bass','beaver','binocular','bonsai','brain','brontosaurus','buddha','butterfly','camera','cannon','car_side','ceiling_fan','cellphone','chair','chandelier','cougar_body','cougar_face','crab','crayfish','crocodile','crocodile_head','cup','dalmatian','dollar_bill','dolphin','dragonfly','electric_guitar','elephant','emu','euphonium','ewer','ferry','flamingo','flamingo_head','garfield','gerenuk','gramophone','grand_piano','hawksbill','headphone','hedgehog','helicopter','ibis','inline_skate','joshua_tree','kangaroo','ketch','lamp','laptop','llama','lobster','lotus','mandolin','mayfly','menorah','metronome','minaret','nautilus','octopus','okapi','pagoda','panda','pigeon','pizza','platypus','pyramid','revolver','rhino','rooster','saxophone','schooner','scissors','scorpion','sea_horse','snoopy','soccer_ball','stapler','starfish','stegosaurus','stop_sign','strawberry','sunflower','tick','trilobite','umbrella','watch','water_lilly','wheelchair','wild_cat','windsor_chair','wrench','yin_yang']
self.out_h, self.out_w = output_size
self.pairs = pd.read_csv(csv_file)
self.img_A_names = self.pairs.iloc[:,0]
self.img_B_names = self.pairs.iloc[:,1]
self.category = self.pairs.iloc[:,2].as_matrix().astype('float')
self.annot_A_str = self.pairs.iloc[:, 3:5]
self.annot_B_str = self.pairs.iloc[:, 5:]
self.dataset_path = dataset_path
self.transform = transform
# no cuda as dataset is called from CPU threads in dataloader and produces confilct
self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False)

def __len__(self):
return len(self.pairs)

def __getitem__(self, idx):
# get pre-processed images
image_A,im_size_A = self.get_image(self.img_A_names,idx)
image_B,im_size_B = self.get_image(self.img_B_names,idx)

# get pre-processed point coords
annot_A = self.get_points(self.annot_A_str, idx)
annot_B = self.get_points(self.annot_B_str, idx)

sample = {'source_image': image_A, 'target_image': image_B, 'source_im_size': im_size_A, 'target_im_size': im_size_B, 'source_polygon': annot_A, 'target_polygon': annot_B}

if self.transform:
sample = self.transform(sample)

return sample

def get_image(self,img_name_list,idx):
img_name = os.path.join(self.dataset_path, img_name_list[idx])
image = io.imread(img_name)

# if grayscale convert to 3-channel image
if image.ndim==2:
image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3)

# get image size
im_size = np.asarray(image.shape)

# convert to torch Variable
image = np.expand_dims(image.transpose((2,0,1)),0)
image = torch.Tensor(image.astype(np.float32))
image_var = Variable(image,requires_grad=False)

# Resize image using bilinear sampling with identity affine tnf
image = self.affineTnf(image_var).data.squeeze(0)

im_size = torch.Tensor(im_size.astype(np.float32))

return (image, im_size)

def get_points(self,point_coords_list,idx):
point_coords_x = point_coords_list[point_coords_list.columns[0]][idx]
point_coords_y = point_coords_list[point_coords_list.columns[1]][idx]

return (point_coords_x,point_coords_y)

125 changes: 125 additions & 0 deletions data/download_datasets.py
@@ -0,0 +1,125 @@
from os.path import exists, join, basename, dirname, splitext
from os import makedirs, remove, rename
from six.moves import urllib
import tarfile
import zipfile
import requests
import sys
import click

def download_and_uncompress(url, dest=None, chunk_size=1024, replace="ask",
label="Downloading {dest_basename} ({size:.2f}MB)"):
dest = dest or "./"+url.split("/")[-1]
dest_dir = dirname(dest)
if not exists(dest_dir):
makedirs(dest_dir)
if exists(dest):
if (replace is False
or replace == "ask"
and not click.confirm("Replace {}?".format(dest))):
return
# download file
with open(dest, "wb") as f:
response = requests.get(url, stream=True)
total_length = response.headers.get('content-length')

if total_length is None: # no content length header
f.write(response.content)
else:
dl = 0
total_length = int(total_length)
for data in response.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) )
sys.stdout.write("{:.1%}".format(dl / total_length))
sys.stdout.flush()
sys.stdout.write("\n")
# uncompress
if dest.endswith("zip"):
file = zipfile.ZipFile(dest, 'r')
elif dest.endswith("tar"):
file = tarfile.open(dest, 'r')
elif dest.endswith("tar.gz"):
file = tarfile.open(dest, 'r:gz')
else:
return dest

print("Extracting data...")
file.extractall(dest_dir)
file.close()

return dest

def download_PF_willow(dest="datasets/proposal-flow-willow"):
print("Fetching PF Willow dataset ")
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip"
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

print('Downloading image pair list \n') ;
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf.csv"
file_path = join(dest,basename(url))
download_and_uncompress(url,file_path)

def download_PF_pascal(dest="datasets/proposal-flow-pascal"):
print("Fetching PF Pascal dataset ")
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip"
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

print('Downloading image pair list \n') ;
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf_pascal.csv"
file_path = join(dest,basename(url))
download_and_uncompress(url,file_path)
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/val_pairs_pf_pascal.csv"
file_path = join(dest,basename(url))
download_and_uncompress(url,file_path)

def download_pascal(dest="datasets/pascal-voc11"):
print("Fetching Pascal VOC2011 dataset")
url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar"
file_path = join(dest, basename(url))
download_and_uncompress(url, file_path)

def download_pascal_parts(dest="datasets/pascal-parts"):
print("Fetching Pascal Parts dataset")
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/pascal_data.tar"
file_path = join(dest, basename(url))
download_and_uncompress(url, file_path)

def download_caltech(dest="datasets/caltech-101"):
print("Fetching Caltech-101 dataset")
url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz"
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

print("Fetching Caltech-101 annotations")
url="http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar"
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

print('Renaming some annotation directories\n') ;
os.rename(join(dest,'Annotations','Airplanes_Side_2'),join(dest,'Annotations','airplanes'))
os.rename(join(dest,'Annotations','Faces_2'),join(dest,'Annotations','Faces'))
os.rename(join(dest,'Annotations','Faces_3'),join(dest,'Annotations','Faces_easy'))
os.rename(join(dest,'Annotations','Motorbikes_16'),join(dest,'Annotations','Motorbikes'))
print('Done renaming\n') ;

print('Downloading image pair list \n') ;
url='http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_caltech.csv'
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

def download_TSS(dest="datasets/tss"):
print("Fetching TSS dataset ")
url = "http://www.hci.iis.u-tokyo.ac.jp/datasets/data/JointCorrCoseg/TSS_CVPR2016.zip"
file_path = join(dest, basename(url))
download_and_uncompress(url,file_path)

print('Downloading image pair list \n') ;
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_tss.csv"
file_path = join(dest,basename(url))
download_and_uncompress(url,file_path)

0 comments on commit e09d0eb

Please sign in to comment.