In [1]:
from __future__ import print_function, division
import os
import argparse
import torch
import torch.nn as nn
from os.path import exists
from torch.utils.data import Dataset, DataLoader
from model.cnn_geometric_model import CNNGeometric, TwoStageCNNGeometric
from data.pf_dataset import PFDataset, PFPascalDataset
from data.download_datasets import download_PF_willow
from image.normalization import NormalizeImageDict, normalize_image
from util.torch_util import BatchTensorToVars, str_to_bool
from geotnf.transformation import GeometricTnf
from geotnf.point_tnf import *
import matplotlib.pyplot as plt
from skimage import io
from collections import OrderedDict
import torch.nn.functional as F

# for compatibility with Python 2
try:
    input = raw_input
except NameError:
    pass

"""

Script to demonstrate evaluation on a trained model

"""

print('WeakAlign demo script')

# Argument parsing
parser = argparse.ArgumentParser(description='WeakAlign PyTorch implementation')
# Paths
parser.add_argument('--model', type=str, default='trained_models/weakalign_resnet101_affine_tps.pth.tar', help='Trained two-stage model filename')
parser.add_argument('--model-aff', type=str, default='', help='Trained affine model filename')
parser.add_argument('--model-tps', type=str, default='', help='Trained TPS model filename')
parser.add_argument('--pf-path', type=str, default='datasets/proposal-flow-pascal', help='Path to PF dataset')
parser.add_argument('--feature-extraction-cnn', type=str, default='resnet101', help='feature extraction CNN model architecture: vgg/resnet101')
parser.add_argument('--tps-reg-factor', type=float, default=0.0, help='regularisation factor for tps tnf')

args = parser.parse_args()

use_cuda = torch.cuda.is_available()

do_aff = not args.model_aff==''
do_tps = not args.model_tps==''

if args.pf_path=='':
    args.args.pf_path='datasets/proposal-flow-pascal/'
    
# Download dataset if needed
if not exists(args.pf_path):
    download_PF_pascal(args.pf_path)

# Create model
print('Creating CNN model...')
model = TwoStageCNNGeometric(use_cuda=use_cuda,
                             return_correlation=False,
                             feature_extraction_cnn=args.feature_extraction_cnn)

# Load trained weights
print('Loading trained model weights...')
if args.model!='':
    checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage)
    checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
        
    for name, param in model.FeatureExtraction.state_dict().items():
        model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])    
    for name, param in model.FeatureRegression.state_dict().items():
        model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])
    for name, param in model.FeatureRegression2.state_dict().items():
        model.FeatureRegression2.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression2.' + name])
        
else:
    checkpoint_aff = torch.load(args.model_aff, map_location=lambda storage, loc: storage)
    checkpoint_aff['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint_aff['state_dict'].items()])
    for name, param in model.FeatureExtraction.state_dict().items():
        model.FeatureExtraction.state_dict()[name].copy_(checkpoint_aff['state_dict']['FeatureExtraction.' + name])    
    for name, param in model.FeatureRegression.state_dict().items():
        model.FeatureRegression.state_dict()[name].copy_(checkpoint_aff['state_dict']['FeatureRegression.' + name])

    checkpoint_tps = torch.load(args.model_tps, map_location=lambda storage, loc: storage)
    checkpoint_tps['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint_tps['state_dict'].items()])
    for name, param in model.FeatureRegression2.state_dict().items():
        model.FeatureRegression2.state_dict()[name].copy_(checkpoint_tps['state_dict']['FeatureRegression.' + name])

WeakAlign demo script


usage: ipykernel_launcher.py [-h] [--model MODEL] [--model-aff MODEL_AFF]
                             [--model-tps MODEL_TPS] [--pf-path PF_PATH]
                             [--feature-extraction-cnn FEATURE_EXTRACTION_CNN]
                             [--tps-reg-factor TPS_REG_FACTOR]
ipykernel_launcher.py: error: unrecognized arguments: -f /run/user/1000/jupyter/kernel-01f53b7f-88d6-48f1-bccd-82ea0cb1df4f.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
