In [None]:
# Implement the D2-Net for image matching and retrieval

import os
import sys
import glob
import argparse
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from tensorboardX import SummaryWriter
from utils.utils import *
from utils.utils import *
from utils.dataloader import *
from utils.model import *

# Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Set the parameters
parser = argparse.ArgumentParser(description='D2-Net')
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--model_path', type=str, default='models')
parser.add_argument('--log_path', type=str, default='logs')
parser.add_argument('--log_freq', type=int, default=100)
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--max_train_iter', type=int, default=100000)
parser.add_argument('--max_val_iter', type=int, default=1000)
parser.add_argument('--max_test_iter', type=int, default=1000)
parser.add_argument('--save_freq', type=int, default=1000)
parser.add_argument('--vis', action='store_true')
parser.add_argument('--vis_freq', type=int, default=100)

args = parser.parse_args()

# Set the data path
train_dataset = D2NetDataset(args.data_path, 'train')
val_dataset = D2NetDataset(args.data_path, 'val')
test_dataset = D2NetDataset(args.data_path, 'test')

# Set the dataloader
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

# Set the model
model = D2Net().to(device)
print(model)

# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Set the loss function
loss_fn = nn.MSELoss()

# Set the summary writer
writer = SummaryWriter(args.log_path)

# Set the checkpoint
if args.checkpoint is not None:
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_iter = checkpoint['iter']
    print('Load checkpoint %s (epoch: %d, iter: %d)' % (args.checkpoint, start_epoch, start_iter))
else:
    start_epoch = 0
    start_iter = 0

# Train the model

