Skip to content

Commit

Permalink
Add results to readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jpainam committed Jul 17, 2018
1 parent f1d1c9e commit 0e7169c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 77 deletions.
14 changes: 8 additions & 6 deletions README.md
Expand Up @@ -6,19 +6,21 @@
- Download [CUHK03 Dataset](http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html)
- Download [VIPeR Dataset](https://vision.soe.ucsc.edu/node/178)

###Test
###Testing

```bash
python test_cuhk03.py --which_epoch 59 --name cuhk03_dense --use_dense
python test_cuhk03.py --model_path ./cuhk03/model.pth --use_dense
python eval_cuhk03.py
```

###Currents results after

| Dataset | Rank 1 | Rank 5 | Rank 10 | Rank 20 | mAP |
| --- | --- | --- | --- | --- | --- |
| CUHK03-Dense Baseline | 0.679181 | 0.909378 | 0.953513 | 0.977772 | 0.781033 |
| CUHK03-Dense SLS_ReID | 0.843203 | 0.971337 | 0.989171 | 0.996290 | 0.899174 |
| CUHK03-ResNet Baseline |0.750155 | 0.950839 | 0.979171 | 0.991119 | 0.838676 |
| CUHK03-ResNet SLS_ReID | 0.909938 | 0.982435 | 0.992477 | 0.997420 | 0.941790 |
| `CUHK03-Dense Baseline` | 0.679181 | 0.909378 | 0.953513 | 0.977772 | 0.781033 |
| `CUHK03-Dense SLS_ReID` | 0.843203 | 0.971337 | 0.989171 | 0.996290 | 0.899174 |
| `CUHK03-ResNet Baseline` |0.750155 | 0.950839 | 0.979171 | 0.991119 | 0.838676 |
| `CUHK03-ResNet SLS_ReID` | 0.909938 | 0.982435 | 0.992477 | 0.997420 | 0.941790 |



60 changes: 3 additions & 57 deletions eval_cuhk03.py
@@ -1,62 +1,8 @@
import numpy as np
from collections import defaultdict
import scipy.io
import torch
import torch.nn as nn

def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]

# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)

# compute cmc curve
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue

cmc = orig_cmc.cumsum()
cmc[cmc > 1] = 1

all_cmc.append(cmc[:max_rank])
num_valid_q += 1.

# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = orig_cmc.sum()
tmp_cmc = orig_cmc.cumsum()
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)

assert num_valid_q > 0, "Error: all query identities do not appear in gallery"

all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
import numpy as np
import scipy.io

return all_cmc, mAP

def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
"""Evaluation with cuhk03 metric
Expand Down Expand Up @@ -129,7 +75,7 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):


if __name__ == '__main__':
result = scipy.io.loadmat('./plsro_result.mat')
result = scipy.io.loadmat('./result.mat')
distmat = result['distmat']
q_pids = np.squeeze(result['q_pids'])
g_pids = np.squeeze(result['g_pids'])
Expand Down
21 changes: 7 additions & 14 deletions test_cuhk03.py
Expand Up @@ -15,13 +15,11 @@

import argparse
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--data_dir',default='cuhk03',type=str, help='./test_data')
parser.add_argument('--name', default='resnet', type=str, help='save model path')
parser.add_argument('--model_path', default='resnet', type=str, help='save model path')
parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
parser.add_argument('--n_classe', default=1367, help='n classes' )
parser.add_argument('--use_dense', action='store_true', help='use densenet')
parser.add_argument('--n_classe', default=1367, help='n classes')
parser.add_argument('--dataset', default='/home/paul/datasets', type=str, help='Path to the dataset')

opt = parser.parse_args()
n_classe = opt.n_classe
Expand Down Expand Up @@ -82,12 +80,12 @@ def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
'query_feature': qf.numpy(), 'gallery_feature': gf.numpy()}
print(qf.numpy())
print(gf.numpy())
scipy.io.savemat('./plsro_result.mat', result)
scipy.io.savemat('./result.mat', result)



def load_network(network):
save_path = os.path.join('./'+opt.data_dir+'/'+opt.name+'/net_%s.pth' %opt.which_epoch)
save_path = os.path.join(opt.model_path)
network.load_state_dict(torch.load(save_path))
return network

Expand All @@ -97,11 +95,6 @@ def load_network(network):
if __name__ == '__main__':

use_gpu = torch.cuda.is_available()
height = 224
width = 224
if opt.use_dense:
height = 144
width = 288
data_transforms = transforms.Compose([
transforms.Resize((288, 144), interpolation=3),
transforms.ToTensor(),
Expand All @@ -120,7 +113,7 @@ def load_network(network):
model = model.cuda()

dataset = data_manager.init_img_dataset(
root='/home/paul/datasets', name=opt.data_dir, split_id=0, cuhk03_classic_split=True)
root=opt.dataset, name='cuhk03', split_id=0, cuhk03_classic_split=True)

queryloader = DataLoader(
ImageDataset(dataset.query, transform=data_transforms),
Expand Down

0 comments on commit 0e7169c

Please sign in to comment.