diff --git a/HandCraftedModules.py b/HandCraftedModules.py new file mode 100644 index 0000000..10f3d6d --- /dev/null +++ b/HandCraftedModules.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import math +import numpy as np +from Utils import GaussianBlur, CircularGaussKernel +from LAF import abc2A,rectifyAffineTransformationUpIsUp, sc_y_x2LAFs +from Utils import generate_2dgrid, generate_2dgrid, generate_3dgrid +from Utils import zero_response_at_border + + +class ScalePyramid(nn.Module): + def __init__(self, nLevels = 3, init_sigma = 1.6, border = 5): + super(ScalePyramid,self).__init__() + self.nLevels = nLevels; + self.init_sigma = init_sigma + self.sigmaStep = 2 ** (1. / float(self.nLevels)) + #print 'step',self.sigmaStep + self.b = border + self.minSize = 2 * self.b + 2 + 1; + return + def forward(self,x): + pixelDistance = 1.0; + curSigma = 0.5 + if self.init_sigma > curSigma: + sigma = np.sqrt(self.init_sigma**2 - curSigma**2) + curSigma = self.init_sigma + curr = GaussianBlur(sigma = sigma)(x) + else: + curr = x + sigmas = [[curSigma]] + pixel_dists = [[1.0]] + pyr = [[curr]] + j = 0 + while True: + curr = pyr[-1][0] + for i in range(1, self.nLevels + 2): + sigma = curSigma * np.sqrt(self.sigmaStep*self.sigmaStep - 1.0 ) + #print 'blur sigma', sigma + curr = GaussianBlur(sigma = sigma)(curr) + curSigma *= self.sigmaStep + pyr[j].append(curr) + sigmas[j].append(curSigma) + pixel_dists[j].append(pixelDistance) + if i == self.nLevels: + nextOctaveFirstLevel = F.avg_pool2d(curr, kernel_size = 1, stride = 2, padding = 0) + pixelDistance = pixelDistance * 2.0 + curSigma = self.init_sigma + if (nextOctaveFirstLevel[0,0,:,:].size(0) <= self.minSize) or (nextOctaveFirstLevel[0,0,:,:].size(1) <= self.minSize): + break + pyr.append([nextOctaveFirstLevel]) + sigmas.append([curSigma]) + pixel_dists.append([pixelDistance]) + j+=1 + return pyr, sigmas, pixel_dists + +class HessianResp(nn.Module): + def __init__(self): + super(HessianResp, self).__init__() + + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[0.5, 0, -0.5]]]], dtype=np.float32)) + + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[0.5], [0], [-0.5]]]], dtype=np.float32)) + + self.gxx = nn.Conv2d(1, 1, kernel_size=(1,3),bias = False) + self.gxx.weight.data = torch.from_numpy(np.array([[[[1.0, -2.0, 1.0]]]], dtype=np.float32)) + + self.gyy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gyy.weight.data = torch.from_numpy(np.array([[[[1.0], [-2.0], [1.0]]]], dtype=np.float32)) + return + def forward(self, x, scale): + gxx = self.gxx(F.pad(x, (1,1,0, 0), 'replicate')) + gyy = self.gyy(F.pad(x, (0,0, 1,1), 'replicate')) + gxy = self.gy(F.pad(self.gx(F.pad(x, (1,1,0, 0), 'replicate')), (0,0, 1,1), 'replicate')) + return torch.abs(gxx * gyy - gxy * gxy) * (scale**4) + + +class AffineShapeEstimator(nn.Module): + def __init__(self, threshold = 0.001, patch_size = 19): + super(AffineShapeEstimator, self).__init__() + self.threshold = threshold; + self.PS = patch_size + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[-1, 0, 1]]]], dtype=np.float32)) + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32)) + self.gk = torch.from_numpy(CircularGaussKernel(kernlen = self.PS, sigma = (self.PS / 2) /3.0).astype(np.float32)) + self.gk = Variable(self.gk, requires_grad=False) + return + def invSqrt(self,a,b,c): + eps = 1e-12 + mask = (b != 0).float() + r1 = mask * (c - a) / (2. * b + eps) + t1 = torch.sign(r1) / (torch.abs(r1) + torch.sqrt(1. + r1*r1)); + r = 1.0 / torch.sqrt( 1. + t1*t1) + t = t1*r; + r = r * mask + 1.0 * (1.0 - mask); + t = t * mask; + + x = 1. / torch.sqrt( r*r*a - 2.0*r*t*b + t*t*c) + z = 1. / torch.sqrt( t*t*a + 2.0*r*t*b + r*r*c) + + d = torch.sqrt( x * z) + + x = x / d + z = z / d + + l1 = torch.max(x,z) + l2 = torch.min(x,z) + + new_a = r*r*x + t*t*z + new_b = -r*t*x + t*r*z + new_c = t*t*x + r*r *z + + return new_a, new_b, new_c, l1, l2 + def forward(self,x): + if x.is_cuda: + self.gk = self.gk.cuda() + else: + self.gk = self.gk.cpu() + gx = self.gx(F.pad(x, (1, 1, 0, 0), 'replicate')) + gy = self.gy(F.pad(x, (0, 0, 1, 1), 'replicate')) + a1 = (gx * gx * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + b1 = (gx * gy * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + c1 = (gy * gy * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + a, b, c, l1, l2 = self.invSqrt(a1,b1,c1) + rat1 = l1/l2 + mask = (torch.abs(rat1) <= 6.).float().view(-1); + return rectifyAffineTransformationUpIsUp(abc2A(a,b,c)), mask +class OrientationDetector(nn.Module): + def __init__(self, + mrSize = 3.0, patch_size = None): + super(OrientationDetector, self).__init__() + if patch_size is None: + patch_size = 32; + self.PS = patch_size; + self.bin_weight_kernel_size, self.bin_weight_stride = self.get_bin_weight_kernel_size_and_stride(self.PS, 1) + self.mrSize = mrSize; + self.num_ang_bins = 36 + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[0.5, 0, -0.5]]]], dtype=np.float32)) + + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[0.5], [0], [-0.5]]]], dtype=np.float32)) + + self.angular_smooth = nn.Conv1d(1, 1, kernel_size=3, padding = 1, bias = False) + self.angular_smooth.weight.data = torch.from_numpy(np.array([[[0.33, 0.34, 0.33]]], dtype=np.float32)) + + self.gk = 10. * torch.from_numpy(CircularGaussKernel(kernlen=self.PS).astype(np.float32)) + self.gk = Variable(self.gk, requires_grad=False) + return + def get_bin_weight_kernel_size_and_stride(self, patch_size, num_spatial_bins): + bin_weight_stride = int(round(2.0 * np.floor(patch_size / 2) / float(num_spatial_bins + 1))) + bin_weight_kernel_size = int(2 * bin_weight_stride - 1); + return bin_weight_kernel_size, bin_weight_stride + def get_rotation_matrix(self, angle_in_radians): + angle_in_radians = angle_in_radians.view(-1, 1, 1); + sin_a = torch.sin(angle_in_radians) + cos_a = torch.cos(angle_in_radians) + A1_x = torch.cat([cos_a, sin_a], dim = 2) + A2_x = torch.cat([-sin_a, cos_a], dim = 2) + transform = torch.cat([A1_x,A2_x], dim = 1) + return transform + + def forward(self, x, return_rot_matrix = False): + gx = self.gx(F.pad(x, (1,1,0, 0), 'replicate')) + gy = self.gy(F.pad(x, (0,0, 1,1), 'replicate')) + mag = torch.sqrt(gx * gx + gy * gy + 1e-10) + if x.is_cuda: + self.gk = self.gk.cuda() + mag = mag * self.gk.unsqueeze(0).unsqueeze(0).expand_as(mag) + ori = torch.atan2(gy,gx) + o_big = float(self.num_ang_bins) *(ori + 1.0 * math.pi )/ (2.0 * math.pi) + bo0_big = torch.floor(o_big) + wo1_big = o_big - bo0_big + bo0_big = bo0_big % self.num_ang_bins + bo1_big = (bo0_big + 1) % self.num_ang_bins + wo0_big = (1.0 - wo1_big) * mag + wo1_big = wo1_big * mag + ang_bins = [] + for i in range(0, self.num_ang_bins): + ang_bins.append(F.adaptive_avg_pool2d((bo0_big == i).float() * wo0_big, (1,1))) + ang_bins = torch.cat(ang_bins,1).view(-1,1,self.num_ang_bins) + ang_bins = self.angular_smooth(ang_bins) + values, indices = ang_bins.view(-1,self.num_ang_bins).max(1) + angle = -((2. * float(np.pi) * indices.float() / float(self.num_ang_bins)) - float(math.pi)) + if return_rot_matrix: + return self.get_rotation_matrix(angle) + return angle + +class NMS2d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS2d, self).__init__() + self.MP = nn.MaxPool2d(kernel_size, stride=1, return_indices=False, padding = kernel_size/2) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + +class NMS3d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS3d, self).__init__() + self.MP = nn.MaxPool3d(kernel_size, stride=1, return_indices=False, padding = (0, kernel_size/2, kernel_size/2)) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + +class NMS3dAndComposeA(nn.Module): + def __init__(self, w = 0, h = 0, kernel_size = 3, threshold = 0, scales = None, border = 3, mrSize = 1.0): + super(NMS3dAndComposeA, self).__init__() + self.eps = 1e-7 + self.ks = 3 + self.th = threshold + self.cube_idxs = [] + self.border = border + self.mrSize = mrSize + self.beta = 1.0 + self.grid_ones = Variable(torch.ones(3,3,3,3), requires_grad=False) + self.NMS3d = NMS3d(kernel_size, threshold) + if (w > 0) and (h > 0): + self.spatial_grid = generate_2dgrid(h, w, False).view(1, h, w,2).permute(3,1, 2, 0) + self.spatial_grid = Variable(self.spatial_grid) + else: + self.spatial_grid = None + return + def forward(self, low, cur, high, num_features = 0, octaveMap = None, scales = None): + assert low.size() == cur.size() == high.size() + #Filter responce map + self.is_cuda = low.is_cuda; + resp3d = torch.cat([low,cur,high], dim = 1) + + mrSize_border = int(self.mrSize); + if octaveMap is not None: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) * (1. - octaveMap.float()) + else: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) + + num_of_nonzero_responces = (nmsed_resp > 0).sum().data[0] + if (num_of_nonzero_responces == 0): + return None,None,None + if octaveMap is not None: + octaveMap = (octaveMap.float() + nmsed_resp.float()).byte() + + nmsed_resp = nmsed_resp.view(-1) + if (num_features > 0) and (num_features < num_of_nonzero_responces): + nmsed_resp, idxs = torch.topk(nmsed_resp, k = num_features); + else: + idxs = nmsed_resp.data.nonzero().squeeze() + nmsed_resp = nmsed_resp[idxs] + #Get point coordinates grid + + if type(scales) is not list: + self.grid = generate_3dgrid(3,self.ks,self.ks) + else: + self.grid = generate_3dgrid(scales,self.ks,self.ks) + self.grid = Variable(self.grid.t().contiguous().view(3,3,3,3), requires_grad=False) + if self.spatial_grid is None: + self.spatial_grid = generate_2dgrid(low.size(2), low.size(3), False).view(1, low.size(2), low.size(3),2).permute(3,1, 2, 0) + self.spatial_grid = Variable(self.spatial_grid) + if self.is_cuda: + self.spatial_grid = self.spatial_grid.cuda() + self.grid_ones = self.grid_ones.cuda() + self.grid = self.grid.cuda() + + #residual_to_patch_center + sc_y_x = F.conv2d(resp3d, self.grid, + padding = 1) / (F.conv2d(resp3d, self.grid_ones, padding = 1) + 1e-8) + + ##maxima coords + sc_y_x[0,1:,:,:] = sc_y_x[0,1:,:,:] + self.spatial_grid[:,:,:,0] + sc_y_x = sc_y_x.view(3,-1).t() + sc_y_x = sc_y_x[idxs,:] + + min_size = float(min((cur.size(2)), cur.size(3))) + sc_y_x[:,0] = sc_y_x[:,0] / min_size + sc_y_x[:,1] = sc_y_x[:,1] / float(cur.size(2)) + sc_y_x[:,2] = sc_y_x[:,2] / float(cur.size(3)) + return nmsed_resp, sc_y_x2LAFs(sc_y_x), octaveMap diff --git a/HardNet.py b/HardNet.py new file mode 100644 index 0000000..9c86049 --- /dev/null +++ b/HardNet.py @@ -0,0 +1,113 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import time +import os +import math +import numpy as np + +class L2Norm(nn.Module): + def __init__(self): + super(L2Norm,self).__init__() + self.eps = 1e-8 + def forward(self, x): + norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) + x= x / norm.unsqueeze(-1).expand_as(x) + return x + +class L1Norm(nn.Module): + def __init__(self): + super(L1Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sum(torch.abs(x), dim = 1) + self.eps + x= x / norm.expand_as(x) + return x + +class HardNetNarELU(nn.Module): + """TFeat model definition + """ + + def __init__(self,sm): + super(HardNetNarELU, self).__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ELU(), + nn.Conv2d(16, 16, kernel_size=3, padding=1), + nn.ELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), + nn.ELU(), + nn.Conv2d(32, 32, kernel_size=3, padding=1), + nn.ELU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2,padding=1), + nn.ELU(), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ELU() + ) + self.classifier = nn.Sequential( + nn.Dropout(0.1), + nn.Conv2d(64, 128, kernel_size=8), + nn.BatchNorm2d(128, affine=False)) + self.SIFT = sm + return + + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + #print(sp) + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + + def forward(self, input): + x_features = self.features(input)#self.input_norm(input)) + #x = self.classifier[1](x_features) + x = nn.AdaptiveAvgPool2d(1)(x_features).view(x_features.size(0), -1) + return x + #return L2Norm()(x) + + +class HardNet(nn.Module): + """HardNet model definition + """ + def __init__(self): + super(HardNet, self).__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, stride=2,padding=1, bias = False), + nn.BatchNorm2d(128, affine=False), + nn.ReLU(), + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(128, affine=False), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv2d(128, 128, kernel_size=8, bias = False), + nn.BatchNorm2d(128, affine=False), + ) + #self.features.apply(weights_init) + + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + + def forward(self, input): + x_features = self.features(self.input_norm(input)) + x = x_features.view(x_features.size(0), -1) + return L2Norm()(x) diff --git a/LAF.py b/LAF.py new file mode 100644 index 0000000..4aa1633 --- /dev/null +++ b/LAF.py @@ -0,0 +1,293 @@ +import numpy as np +import matplotlib.pyplot as plt +from copy import deepcopy +from scipy.spatial.distance import cdist +from numpy.linalg import inv +from scipy.linalg import schur, sqrtm +import torch +from torch.autograd import Variable + +##########numpy +def invSqrt(a,b,c): + eps = 1e-12 + mask = (b != 0) + r1 = mask * (c - a) / (2. * b + eps) + t1 = np.sign(r1) / (np.abs(r1) + np.sqrt(1. + r1*r1)); + r = 1.0 / np.sqrt( 1. + t1*t1) + t = t1*r; + + r = r * mask + 1.0 * (1.0 - mask); + t = t * mask; + + x = 1. / np.sqrt( r*r*a - 2*r*t*b + t*t*c) + z = 1. / np.sqrt( t*t*a + 2*r*t*b + r*r*c) + + d = np.sqrt( x * z) + + x = x / d + z = z / d + + new_a = r*r*x + t*t*z + new_b = -r*t*x + t*r*z + new_c = t*t*x + r*r *z + + return new_a, new_b, new_c + +def Ell2LAF(ell): + A23 = np.zeros((2,3)) + A23[0,2] = ell[0] + A23[1,2] = ell[1] + a = ell[2] + b = ell[3] + c = ell[4] + sc = np.sqrt(np.sqrt(a*c - b*b)) + ia,ib,ic = invSqrt(a,b,c) #because sqrtm returns ::-1, ::-1 matrix, don`t know why + A = np.array([[ia, ib], [ib, ic]]) / sc + sc = np.sqrt(A[0,0] * A[1,1] - A[1,0] * A[0,1]) + A23[0:2,0:2] = rectifyAffineTransformationUpIsUp(A / sc) * sc + return A23 + +def rectifyAffineTransformationUpIsUp_np(A): + det = np.sqrt(np.abs(A[0,0]*A[1,1] - A[1,0]*A[0,1] + 1e-10)) + b2a2 = np.sqrt(A[0,1] * A[0,1] + A[0,0] * A[0,0]) + A_new = np.zeros((2,2)) + A_new[0,0] = b2a2 / det + A_new[0,1] = 0 + A_new[1,0] = (A[1,1]*A[0,1]+A[1,0]*A[0,0])/(b2a2*det) + A_new[1,1] = det / b2a2 + return A_new + +def ells2LAFs(ells): + LAFs = np.zeros((len(ells), 2,3)) + for i in range(len(ells)): + LAFs[i,:,:] = Ell2LAF(ells[i,:]) + return LAFs + +def LAF2pts(LAF, n_pts = 50): + a = np.linspace(0, 2*np.pi, n_pts); + x = [0] + x.extend(list(np.sin(a))) + x = np.array(x).reshape(1,-1) + y = [0] + y.extend(list(np.cos(a))) + y = np.array(y).reshape(1,-1) + HLAF = np.concatenate([LAF, np.array([0,0,1]).reshape(1,3)]) + H_pts =np.concatenate([x,y,np.ones(x.shape)]) + H_pts_out = np.transpose(np.matmul(HLAF, H_pts)) + H_pts_out[:,0] = H_pts_out[:,0] / H_pts_out[:, 2] + H_pts_out[:,1] = H_pts_out[:,1] / H_pts_out[:, 2] + return H_pts_out[:,0:2] + + +def convertLAFs_to_A23format(LAFs): + sh = LAFs.shape + if (len(sh) == 3) and (sh[1] == 2) and (sh[2] == 3): # n x 2 x 3 classical [A, (x;y)] matrix + work_LAFs = deepcopy(LAFs) + elif (len(sh) == 2) and (sh[1] == 7): #flat format, x y scale a11 a12 a21 a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] * LAFs[:,3] + work_LAFs[:,0,1] = LAFs[:,2] * LAFs[:,4] + work_LAFs[:,1,0] = LAFs[:,2] * LAFs[:,5] + work_LAFs[:,1,1] = LAFs[:,2] * LAFs[:,6] + elif (len(sh) == 2) and (sh[1] == 6): #flat format, x y s*a11 s*a12 s*a21 s*a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] + work_LAFs[:,0,1] = LAFs[:,3] + work_LAFs[:,1,0] = LAFs[:,4] + work_LAFs[:,1,1] = LAFs[:,5] + else: + print 'Unknown LAF format' + return None + return work_LAFs + +def LAFs2ell(in_LAFs): + LAFs = convertLAFs_to_A23format(in_LAFs) + ellipses = np.zeros((len(LAFs),5)) + for i in range(len(LAFs)): + LAF = deepcopy(LAFs[i,:,:]) + scale = np.sqrt(LAF[0,0]*LAF[1,1] - LAF[0,1]*LAF[1, 0] + 1e-10) + u, W, v = np.linalg.svd(LAF[0:2,0:2] / scale, full_matrices=True) + W[0] = 1. / (W[0]*W[0]*scale*scale) + W[1] = 1. / (W[1]*W[1]*scale*scale) + A = np.matmul(np.matmul(u, np.diag(W)), u.transpose()) + ellipses[i,0] = LAF[0,2] + ellipses[i,1] = LAF[1,2] + ellipses[i,2] = A[0,0] + ellipses[i,3] = A[0,1] + ellipses[i,4] = A[1,1] + return ellipses + +def visualize_LAFs(img, LAFs): + work_LAFs = convertLAFs_to_A23format(LAFs) + plt.figure() + plt.imshow(255 - img) + for i in range(len(work_LAFs)): + ell = LAF2pts(work_LAFs[i,:,:]) + plt.plot( ell[:,0], ell[:,1], 'r') + plt.show() + return + +####pytorch + +def get_normalized_affine_shape(tilt, angle_in_radians): + assert tilt.size(0) == angle_in_radians.size(0) + num = tilt.size(0) + tilt_A = Variable(torch.eye(2).view(1,2,2).repeat(num,1,1)) + if tilt.is_cuda: + tilt_A = tilt_A.cuda() + tilt_A[:,0,0] = tilt; + rotmat = get_rotation_matrix(angle_in_radians) + out_A = rectifyAffineTransformationUpIsUp(torch.bmm(rotmat, torch.bmm(tilt_A, rotmat))) + #re_scale = (1.0/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + #re_scale = (0.5 + 0.5/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + return out_A# * re_scale.view(-1,1,1).expand(num,2,2) + +def get_rotation_matrix(angle_in_radians): + angle_in_radians = angle_in_radians.view(-1, 1, 1); + sin_a = torch.sin(angle_in_radians) + cos_a = torch.cos(angle_in_radians) + A1_x = torch.cat([cos_a, sin_a], dim = 2) + A2_x = torch.cat([-sin_a, cos_a], dim = 2) + transform = torch.cat([A1_x,A2_x], dim = 1) + return transform + +def rectifyAffineTransformationUpIsUp(A): + det = torch.sqrt(torch.abs(A[:,0,0]*A[:,1,1] - A[:,1,0]*A[:,0,1] + 1e-10)) + b2a2 = torch.sqrt(A[:,0,1] * A[:,0,1] + A[:,0,0] * A[:,0,0]) + A1_ell = torch.cat([(b2a2 / det).contiguous().view(-1,1,1), 0 * det.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([((A[:,1,1]*A[:,0,1]+A[:,1,0]*A[:,0,0])/(b2a2*det)).contiguous().view(-1,1,1), + (det / b2a2).contiguous().view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def abc2A(a,b,c, normalize = False): + A1_ell = torch.cat([a.view(-1,1,1), b.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([b.view(-1,1,1), c.view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def angles2A(angles): + cos_a = torch.cos(angles).view(-1, 1, 1) + sin_a = torch.sin(angles).view(-1, 1, 1) + A1_ang = torch.cat([cos_a, sin_a], dim = 2) + A2_ang = torch.cat([-sin_a, cos_a], dim = 2) + return torch.cat([A1_ang, A2_ang], dim = 1) + +def generate_patch_grid_from_normalized_LAFs(LAFs, w, h, PS): + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3) * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + grid = torch.nn.functional.affine_grid(LAFs * Variable(coef.expand(num_lafs,2,3)), torch.Size((num_lafs,1,PS,PS))) + grid[:,:,:,0] = 2.0 * grid[:,:,:,0] / float(w) - 1.0 + grid[:,:,:,1] = 2.0 * grid[:,:,:,1] / float(h) - 1.0 + return grid + +def extract_patches(img, LAFs, PS = 32): + w = img.size(3) + h = img.size(2) + ch = img.size(1) + grid = generate_patch_grid_from_normalized_LAFs(LAFs, float(w),float(h), PS) + return torch.nn.functional.grid_sample(img.expand(grid.size(0), ch, h, w), grid) + +def get_pyramid_inverted_index_for_LAFs(LAFs, PS, sigmas): + return + +def extract_patches_from_pyramid_with_inv_index(scale_pyramid, pyr_inv_idxs, LAFs, PS = 19): + patches = torch.zeros(LAFs.size(0),scale_pyramid[0][0].size(1), PS, PS) + if LAFs.is_cuda: + patches = patches.cuda() + patches = Variable(patches) + if pyr_inv_idxs is not None: + for i in range(len(scale_pyramid)): + for j in range(len(scale_pyramid[i])): + cur_lvl_idxs = pyr_inv_idxs[i][j] + if cur_lvl_idxs is None: + continue + cur_lvl_idxs = cur_lvl_idxs.view(-1) + #print i,j,cur_lvl_idxs.shape + patches[cur_lvl_idxs,:,:,:] = extract_patches(scale_pyramid[i][j], LAFs[cur_lvl_idxs, :,:], PS ) + return patches + +def get_inverted_pyr_index(scale_pyr, pyr_idxs, level_idxs): + pyr_inv_idxs = [] + ### Precompute octave inverted indexes + for i in range(len(scale_pyr)): + pyr_inv_idxs.append([]) + cur_idxs = pyr_idxs == i #torch.nonzero((pyr_idxs == i).data) + for j in range(0, len(scale_pyr[i])): + cur_lvl_idxs = torch.nonzero(((level_idxs == j) * cur_idxs).data) + if len(cur_lvl_idxs.size()) == 0: + pyr_inv_idxs[i].append(None) + else: + pyr_inv_idxs[i].append(cur_lvl_idxs.squeeze()) + return pyr_inv_idxs + + +def denormalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def normalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() / min_size + coef[0,0,2] = 1.0 / w + coef[0,1,2] = 1.0 / h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def sc_y_x2LAFs(sc_y_x): + base_LAF = torch.eye(2).float().unsqueeze(0).expand(sc_y_x.size(0),2,2) + if sc_y_x.is_cuda: + base_LAF = base_LAF.cuda() + base_A = Variable(base_LAF, requires_grad=False) + A = sc_y_x[:,:1].unsqueeze(1).expand_as(base_A) * base_A + LAFs = torch.cat([A, + torch.cat([sc_y_x[:,2:].unsqueeze(-1), + sc_y_x[:,1:2].unsqueeze(-1)], dim=1)], dim = 2) + + return LAFs +def get_LAFs_scales(LAFs): + return torch.sqrt(torch.abs(LAFs[:,0,0] *LAFs[:,1,1] - LAFs[:,0,1] * LAFs[:,1,0]) + 1e-12) + +def get_pyramid_and_level_index_for_LAFs(dLAFs, sigmas, pix_dists, PS): + scales = get_LAFs_scales(dLAFs); + needed_sigmas = scales / PS; + sigmas_full_list = [] + level_idxs_full = [] + oct_idxs_full = [] + for oct_idx in range(len(sigmas)): + sigmas_full_list = sigmas_full_list + list(np.array(sigmas[oct_idx])*np.array(pix_dists[oct_idx])) + oct_idxs_full = oct_idxs_full + [oct_idx]*len(sigmas[oct_idx]) + level_idxs_full = level_idxs_full + range(0,len(sigmas[oct_idx])) + oct_idxs_full = torch.LongTensor(oct_idxs_full) + level_idxs_full = torch.LongTensor(level_idxs_full) + + closest_imgs = cdist(np.array(sigmas_full_list).reshape(-1,1), needed_sigmas.data.cpu().numpy().reshape(-1,1)).argmin(axis = 0) + closest_imgs = torch.from_numpy(closest_imgs) + if dLAFs.is_cuda: + closest_imgs = closest_imgs.cuda() + oct_idxs_full = oct_idxs_full.cuda() + level_idxs_full = level_idxs_full.cuda() + return Variable(oct_idxs_full[closest_imgs]), Variable(level_idxs_full[closest_imgs]) diff --git a/Losses.py b/Losses.py new file mode 100644 index 0000000..f5f99af --- /dev/null +++ b/Losses.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import sys + +def distance_matrix_vector(anchor, positive): + """Given batch of anchor descriptors and positive descriptors calculate distance matrix""" + + d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1) + d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1) + + eps = 1e-6 + return torch.sqrt((d1_sq.repeat(1, anchor.size(0)) + torch.t(d2_sq.repeat(1, positive.size(0))) + - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps) + +def distance_vectors_pairwise(anchor, positive, negative): + """Given batch of anchor descriptors and positive descriptors calculate distance matrix""" + + a_sq = torch.sum(anchor * anchor, dim=1) + p_sq = torch.sum(positive * positive, dim=1) + n_sq = torch.sum(negative * negative, dim=1) + + eps = 1e-8 + d_a_p = torch.sqrt(a_sq + p_sq - 2*torch.sum(anchor * positive, dim = 1) + eps) + d_a_n = torch.sqrt(a_sq + n_sq - 2*torch.sum(anchor * negative, dim = 1) + eps) + d_p_n = torch.sqrt(p_sq + n_sq - 2*torch.sum(positive * negative, dim = 1) + eps) + return d_a_p, d_a_n, d_p_n + +def loss_random_sampling(anchor, positive, negative, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"): + """Loss with random sampling (no hard in batch). + """ + + assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." + assert anchor.size() == negative.size(), "Input sizes between positive and negative must be equal." + assert anchor.dim() == 2, "Inputd must be a 2D matrix." + eps = 1e-8 + (pos, d_a_n, d_p_n) = distance_vectors_pairwise(anchor, positive, negative) + if anchor_swap: + min_neg = torch.min(d_a_n, d_p_n) + else: + min_neg = d_a_n + + if loss_type == "triplet_margin": + loss = torch.clamp(margin + pos - min_neg, min=0.0) + elif loss_type == 'softmax': + exp_pos = torch.exp(2.0 - pos); + exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps; + loss = - torch.log( exp_pos / exp_den ) + elif loss_type == 'contrastive': + loss = torch.clamp(margin - min_neg, min=0.0) + pos; + else: + print ('Unknown loss type. Try triplet_margin, softmax or contrastive') + sys.exit(1) + loss = torch.mean(loss) + return loss + +def loss_L2Net(anchor, positive, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"): + """L2Net losses: using whole batch as negatives, not only hardest. + """ + + assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." + assert anchor.dim() == 2, "Inputd must be a 2D matrix." + eps = 1e-8 + dist_matrix = distance_matrix_vector(anchor, positive) + eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda() + + # steps to filter out same patches that occur in distance matrix as negatives + pos1 = torch.diag(dist_matrix) + dist_without_min_on_diag = dist_matrix+eye*10 + mask = (dist_without_min_on_diag.ge(0.008)-1)*-1 + mask = mask.type_as(dist_without_min_on_diag)*10 + dist_without_min_on_diag = dist_without_min_on_diag+mask + + if loss_type == 'softmax': + exp_pos = torch.exp(2.0 - pos1); + exp_den = torch.sum(torch.exp(2.0 - dist_matrix),1) + eps; + loss = -torch.log( exp_pos / exp_den ) + if anchor_swap: + exp_den1 = torch.sum(torch.exp(2.0 - dist_matrix),0) + eps; + loss += -torch.log( exp_pos / exp_den1 ) + else: + print ('Only softmax loss works with L2Net sampling') + sys.exit(1) + loss = torch.mean(loss) + return loss + +def loss_HardNet(anchor, positive, anchor_swap = False, anchor_ave = False,\ + margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"): + """HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance. + """ + + assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." + assert anchor.dim() == 2, "Inputd must be a 2D matrix." + eps = 1e-8 + dist_matrix = distance_matrix_vector(anchor, positive) +eps + eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda() + + # steps to filter out same patches that occur in distance matrix as negatives + pos1 = torch.diag(dist_matrix) + dist_without_min_on_diag = dist_matrix+eye*10 + mask = (dist_without_min_on_diag.ge(0.008).float()-1)*-1 + mask = mask.type_as(dist_without_min_on_diag)*10 + dist_without_min_on_diag = dist_without_min_on_diag+mask + if batch_reduce == 'min': + min_neg = torch.min(dist_without_min_on_diag,1)[0] + if anchor_swap: + min_neg2 = torch.min(dist_without_min_on_diag,0)[0] + min_neg = torch.min(min_neg,min_neg2) + if False: + dist_matrix_a = distance_matrix_vector(anchor, anchor)+ eps + dist_matrix_p = distance_matrix_vector(positive,positive)+eps + dist_without_min_on_diag_a = dist_matrix_a+eye*10 + dist_without_min_on_diag_p = dist_matrix_p+eye*10 + min_neg_a = torch.min(dist_without_min_on_diag_a,1)[0] + min_neg_p = torch.t(torch.min(dist_without_min_on_diag_p,0)[0]) + min_neg_3 = torch.min(min_neg_p,min_neg_a) + min_neg = torch.min(min_neg,min_neg_3) + print (min_neg_a) + print (min_neg_p) + print (min_neg_3) + print (min_neg) + min_neg = min_neg + pos = pos1 + elif batch_reduce == 'average': + pos = pos1.repeat(anchor.size(0)).view(-1,1).squeeze(0) + min_neg = dist_without_min_on_diag.view(-1,1) + if anchor_swap: + min_neg2 = torch.t(dist_without_min_on_diag).contiguous().view(-1,1) + min_neg = torch.min(min_neg,min_neg2) + min_neg = min_neg.squeeze(0) + elif batch_reduce == 'random': + idxs = torch.autograd.Variable(torch.randperm(anchor.size()[0]).long()).cuda() + min_neg = dist_without_min_on_diag.gather(1,idxs.view(-1,1)) + if anchor_swap: + min_neg2 = torch.t(dist_without_min_on_diag).gather(1,idxs.view(-1,1)) + min_neg = torch.min(min_neg,min_neg2) + min_neg = torch.t(min_neg).squeeze(0) + pos = pos1 + else: + print ('Unknown batch reduce mode. Try min, average or random') + sys.exit(1) + if loss_type == "triplet_margin": + loss = torch.clamp(margin + pos - min_neg, min=0.0) + elif loss_type == 'softmax': + exp_pos = torch.exp(2.0 - pos); + exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps; + loss = - torch.log( exp_pos / exp_den ) + elif loss_type == 'contrastive': + loss = torch.clamp(margin - min_neg, min=0.0) + pos; + else: + print ('Unknown loss type. Try triplet_margin, softmax or contrastive') + sys.exit(1) + loss = torch.mean(loss) + return loss + + +def global_orthogonal_regularization(anchor, negative): + + neg_dis = torch.sum(torch.mul(anchor,negative),1) + dim = anchor.size(1) + gor = torch.pow(torch.mean(neg_dis),2) + torch.clamp(torch.mean(torch.pow(neg_dis,2))-1.0/dim, min=0.0) + + return gor + diff --git a/README.md b/README.md new file mode 100644 index 0000000..b14f1c3 --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +# AffNet model implementation +CNN-based affine shape estimator. + +AffNet model implementation in PyTorch for TechReport "Learning Discriminative Affine Regions via Discriminability" + + +## Datasets and Training + +To download datasets and start learning affnet: + +```bash +git clone https://github.com/ducha-aiki/affnet +./run_me.sh +``` + +## Pre-trained models + +Pre-trained models can be found in folder pretrained: AffNet.pth + +## Usage example + +We provide two examples, how to estimate affine shape with AffNet. +First, on patch-column file, in [HPatches](https://github.com/hpatches/hpatches-benchmark) format, i.e. grayscale image with w = patchSize and h = nPatches * patchSize + +``` +cd examples/just_shape +python detect_affine_shape.py imgs/face.png out.txt +``` + +Out file format is upright affine frame a11 0 a21 a22 + + +Second, AffNet inside pytorch implementation of Hessian-Affine +2000 is number of regions to detect. + +``` +cd examples/hesaffnet +python hesaffnet.py img/cat.png ells-affnet.txt 2000 +python hesaffBaum.py img/cat.png ells-Baumberg.txt 2000 +``` + +output ells-affnet.txt is [Oxford affine](http://www.robots.ox.ac.uk/~vgg/research/affine/) format +1.0 +128 +x y a b c + + +## Citation + +Please cite us if you use this code: + +``` +@article{AffNet2017, + author = {Dmytro Mishkin, Filip Radenovic, Jiri Matas}, + title = "{Learning Discriminative Affine Regions via Discriminability}", + year = 2017, + month = nov} +``` diff --git a/Utils.py b/Utils.py new file mode 100644 index 0000000..c56b5e8 --- /dev/null +++ b/Utils.py @@ -0,0 +1,182 @@ +import torch +import torch.nn.init +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import cv2 +import numpy as np + +# resize image to size 32x32 +cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), + interpolation=cv2.INTER_LINEAR) +# reshape image +np_reshape32 = lambda x: np.reshape(x, (32, 32, 1)) +np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) + +def zeros_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.zeros(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.zeros(y) + +def ones_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.ones(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.ones(y) + + +def batched_forward(model, data, batch_size, **kwargs): + n_patches = len(data) + if n_patches > batch_size: + bs = batch_size + n_batches = n_patches / bs + 1 + for batch_idx in range(n_batches): + st = batch_idx * bs + if batch_idx == n_batches - 1: + if (batch_idx + 1) * bs > n_patches: + end = n_patches + else: + end = (batch_idx + 1) * bs + else: + end = (batch_idx + 1) * bs + if st >= end: + continue + if batch_idx == 0: + first_batch_out = model(data[st:end], kwargs) + out_size = torch.Size([n_patches] + list(first_batch_out.size()[1:])) + #out_size[0] = n_patches + out = torch.zeros(out_size); + if data.is_cuda: + out = out.cuda() + out = Variable(out) + out[st:end] = first_batch_out + else: + out[st:end,:,:] = model(data[st:end], kwargs) + return out + else: + return model(data, kwargs) + +class L2Norm(nn.Module): + def __init__(self): + super(L2Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) + x= x / norm.unsqueeze(-1).expand_as(x) + return x + +class L1Norm(nn.Module): + def __init__(self): + super(L1Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sum(torch.abs(x), dim = 1) + self.eps + x= x / norm.expand_as(x) + return x + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + +def CircularGaussKernel(kernlen=None, circ_zeros = False, sigma = None, norm = True): + assert ((kernlen is not None) or sigma is not None) + if kernlen is None: + kernlen = int(2.0 * 3.0 * sigma + 1.0) + if (kernlen % 2 == 0): + kernlen = kernlen + 1; + halfSize = kernlen / 2; + halfSize = kernlen / 2; + r2 = float(halfSize*halfSize) + if sigma is None: + sigma2 = 0.9 * r2; + sigma = np.sqrt(sigma2) + else: + sigma2 = 2.0 * sigma * sigma + x = np.linspace(-halfSize,halfSize,kernlen) + xv, yv = np.meshgrid(x, x, sparse=False, indexing='xy') + distsq = (xv)**2 + (yv)**2 + kernel = np.exp(-( distsq/ (sigma2))) + if circ_zeros: + kernel *= (distsq <= r2).astype(np.float32) + if norm: + kernel /= np.sum(kernel) + return kernel + +def generate_2dgrid(h,w, centered = True): + if centered: + x = torch.linspace(-w/2+1, w/2, w) + y = torch.linspace(-h/2+1, h/2, h) + else: + x = torch.linspace(0, w-1, w) + y = torch.linspace(0, h-1, h) + grid2d = torch.stack([y.repeat(w,1).t().contiguous().view(-1), x.repeat(h)],1) + return grid2d + +def generate_3dgrid(d, h, w, centered = True): + if type(d) is not list: + if centered: + z = torch.linspace(-d/2+1, d/2, d) + else: + z = torch.linspace(0, d-1, d) + dl = d + else: + z = torch.FloatTensor(d) + dl = len(d) + grid2d = generate_2dgrid(h,w, centered = centered) + grid3d = torch.cat([z.repeat(w*h,1).t().contiguous().view(-1,1), grid2d.repeat(dl,1)],dim = 1) + return grid3d + +def zero_response_at_border(x, b): + if (b < x.size(3)) and (b < x.size(2)): + x[:, :, 0:b, :] = 0 + x[:, :, x.size(2) - b: , :] = 0 + x[:, :, :, 0:b] = 0 + x[:, :, :, x.size(3) - b: ] = 0 + else: + return x * 0 + return x + +class GaussianBlur(nn.Module): + def __init__(self, sigma=1.6): + super(GaussianBlur, self).__init__() + weight = self.calculate_weights(sigma) + self.register_buffer('buf', weight) + return + def calculate_weights(self, sigma): + kernel = CircularGaussKernel(sigma = sigma, circ_zeros = False) + h,w = kernel.shape + halfSize = float(h) / 2.; + self.pad = int(np.floor(halfSize)) + return torch.from_numpy(kernel.astype(np.float32)).view(1,1,h,w); + def forward(self, x): + w = Variable(self.buf) + if x.is_cuda: + w = w.cuda() + return F.conv2d(F.pad(x, (self.pad,self.pad,self.pad,self.pad), 'replicate'), w, padding = 0) + +def batch_eig2x2(A): + trace = A[:,0,0] + A[:,1,1] + delta1 = (trace*trace - 4 * ( A[:,0,0]* A[:,1,1] - A[:,1,0]* A[:,0,1])) + mask = delta1 > 0 + delta = torch.sqrt(torch.abs(delta1)) + l1 = mask.float() * (trace + delta) / 2.0 + 1000. * (1.0 - mask.float()) + l2 = mask.float() * (trace - delta) / 2.0 + 0.0001 * (1.0 - mask.float()) + return l1,l2 + +def line_prepender(filename, line): + with open(filename, 'r+') as f: + content = f.read() + f.seek(0, 0) + f.write(line.rstrip('\r\n') + '\n' + content) + return diff --git a/architectures.py b/architectures.py new file mode 100644 index 0000000..ade3aca --- /dev/null +++ b/architectures.py @@ -0,0 +1,190 @@ +from __future__ import division, print_function +import os +import errno +import numpy as np +import sys +from copy import deepcopy +import math +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.autograd import Variable +from Utils import L2Norm, generate_2dgrid +from Utils import str2bool +from LAF import denormalizeLAFs, LAFs2ell, abc2A, extract_patches,normalizeLAFs, get_rotation_matrix +from LAF import get_LAFs_scales, get_normalized_affine_shape +from LAF import rectifyAffineTransformationUpIsUp + +class OriNetFast(nn.Module): + def __init__(self, PS = 16): + super(OriNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 2, kernel_size=int(PS/4), stride=1,padding=1, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/4) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.9) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_rot_matrix = True): + xy = self.features(self.input_norm(input)).view(-1,2) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(angle) + return angle + +class GHH(nn.Module): + def __init__(self, n_in, n_out, s = 4, m = 4): + super(GHH, self).__init__() + self.n_out = n_out + self.s = s + self.m = m + self.conv = nn.Linear(n_in, n_out * s * m) + d = torch.arange(0, s) + self.deltas = -1.0 * (d % 2 != 0).float() + 1.0 * (d % 2 == 0).float() + self.deltas = Variable(self.deltas) + return + def forward(self,x): + x_feats = self.conv(x.view(x.size(0),-1)).view(x.size(0), self.n_out, self.s, self.m); + max_feats = x_feats.max(dim = 3)[0]; + if x.is_cuda: + self.deltas = self.deltas.cuda() + else: + self.deltas = self.deltas.cpu() + out = (max_feats * self.deltas.view(1,1,-1).expand_as(max_feats)).sum(dim = 2) + return out + +class YiNet(nn.Module): + def __init__(self, PS = 28): + super(YiNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 10, kernel_size=5, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding = 1), + nn.Conv2d(10, 20, kernel_size=5, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=4, stride=2, padding = 2), + nn.Conv2d(20, 50, kernel_size=3, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.AdaptiveMaxPool2d(1), + GHH(50, 100), + GHH(100, 2) + ) + self.input_mean = 0.427117081207483 + self.input_std = 0.21888339179665006; + self.PS = PS + return + def import_weights(self, dir_name): + self.features[0].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_W.npy'))).float() + self.features[0].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_b.npy'))).float().view(-1) + self.features[3].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_W.npy'))).float() + self.features[3].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_b.npy'))).float().view(-1) + self.features[6].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_W.npy'))).float() + self.features[6].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_b.npy'))).float().view(-1) + self.features[9].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_W.npy'))).float().view(50, 1600).contiguous().t().contiguous()#.view(1600, 50, 1, 1).contiguous() + self.features[9].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_b.npy'))).float().view(1600) + self.features[10].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_W.npy'))).float().view(100, 32).contiguous().t().contiguous()#.view(32, 100, 1, 1).contiguous() + self.features[10].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_b.npy'))).float().view(32) + self.input_mean = float(np.load(os.path.join(dir_name, 'input_mean.npy'))) + self.input_std = float(np.load(os.path.join(dir_name, 'input_std.npy'))) + return + def input_norm1(self,x): + return (x - self.input_mean) / self.input_std + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def forward(self, input, return_rot_matrix = False): + xy = self.features(self.input_norm(input)) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(-angle) + return angle + +class AffNetFast(nn.Module): + def __init__(self, PS = 32): + super(AffNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 3, kernel_size=8, stride=1, padding=0, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/2) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.8) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_A_matrix = False): + xy = self.features(self.input_norm(input)).view(-1,3) + a1 = torch.cat([1.0 + xy[:,0].contiguous().view(-1,1,1), 0 * xy[:,0].contiguous().view(-1,1,1)], dim = 2).contiguous() + a2 = torch.cat([xy[:,1].contiguous().view(-1,1,1), 1.0 + xy[:,2].contiguous().view(-1,1,1)], dim = 2).contiguous() + return rectifyAffineTransformationUpIsUp(torch.cat([a1,a2], dim = 1).contiguous()) + diff --git a/augmentation.py b/augmentation.py new file mode 100644 index 0000000..eeb09bf --- /dev/null +++ b/augmentation.py @@ -0,0 +1,58 @@ +import numpy as np +from PIL import Image + +import sys +from copy import deepcopy +import argparse +import math +import torch.utils.data as data +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision.transforms as transforms +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from LAF import get_rotation_matrix,get_normalized_affine_shape + +def get_random_rotation_LAFs(patches, angle_mag = math.pi): + rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1)); + phi = (Variable(2.0 * torch.rand(patches.size(0)) - 1.0) ).view(-1,1,1) + if patches.is_cuda: + rot_LAFs = rot_LAFs.cuda() + phi = phi.cuda() + rotmat = get_rotation_matrix(angle_mag * phi) + inv_rotmat = get_rotation_matrix(-angle_mag * phi) + rot_LAFs[:,0:2,0:2] = torch.bmm(rotmat, rot_LAFs[:,0:2,0:2]); + return rot_LAFs, inv_rotmat + +def get_random_shifts_LAFs(patches, w_mag, h_mag = 3): + shift_w = (torch.IntTensor(patches.size(0)).random_(2*w_mag) - w_mag / 2).float() / 2.0 + shift_h = (torch.IntTensor(patches.size(0)).random_(2*w_mag) - w_mag / 2).float() / 2.0 + if patches.is_cuda: + shift_h = shift_h.cuda() + shift_w = shift_w.cuda() + shift_h = Variable(shift_h) + shift_w = Variable(shift_w) + return shift_w, shift_h + +def get_random_norm_affine_LAFs(patches, max_tilt = 1.0): + assert max_tilt > 0 + aff_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1)); + tilt = Variable( 1/max_tilt + (max_tilt - 1./max_tilt)* torch.rand(patches.size(0), 1, 1)); + phi = math.pi * (Variable(2.0 * torch.rand(patches.size(0)) - 1.0) ).view(-1,1,1) + if patches.is_cuda: + tilt = tilt.cuda() + phi = phi.cuda() + aff_LAFs = aff_LAFs.cuda() + TA = get_normalized_affine_shape(tilt, phi) + inv_TA = Variable(torch.zeros(patches.size(0),2,2)); + if patches.is_cuda: + inv_TA = inv_TA.cuda() + for i in range(len(inv_TA)): + inv_TA[i,:,:] = TA[i,:,:].inverse(); + aff_LAFs[:,0:2,0:2] = torch.bmm(TA, aff_LAFs[:,0:2,0:2]); + return aff_LAFs, inv_TA; + diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..fdaf812 --- /dev/null +++ b/dataset.py @@ -0,0 +1,417 @@ +# Training settings +import os +import errno +import numpy as np +from PIL import Image +import torchvision.datasets as dset + +import sys +from copy import deepcopy +import argparse +import math +import torch.utils.data as data +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision.transforms as transforms +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from tqdm import tqdm +import random +import cv2 +import copy +from Utils import str2bool + +def find_files(_data_dir, _image_ext): + """Return a list with the file names of the images containing the patches + """ + files = [] + # find those files with the specified extension + for file_dir in os.listdir(_data_dir): + if file_dir.endswith(_image_ext): + files.append(os.path.join(_data_dir, file_dir)) + return sorted(files) # sort files in ascend order to keep relations +def np2torch(npr): + if len(npr.shape) == 4: + return torch.from_numpy(np.rollaxis(npr, 3, 1)) + elif len(npr.shape) == 3: + torch.from_numpy(np.rollaxis(npr, 2, 0)) + else: + return torch.from_numpy(npr) +def read_patch_file(fname, patch_w = 65, patch_h = 65, start_patch_idx = 0): + img = Image.open(fname).convert('RGB') + width, height = img.size + #print (img.size, patch_w, patch_h) + assert ((height % patch_h == 0) and (width % patch_w == 0)) + patch_idxs = [] + patches = [] + current_patch_idx = start_patch_idx + for y in range(0, height, patch_h): + patch_idxs.append([]) + curr_patches = [] + for x in range(0, width, patch_w): + patch = np.array(img.crop((x, y, x + patch_w, y + patch_h))).mean(axis = 2, keepdims = True) + #print(patch.astype(np.float32).std(), patch.mean()) + if (patch.mean() != 0) and (patch.astype(np.float32).std() > 1e-2): + curr_patches.append(patch.astype(np.uint8)) + patch_idxs[-1].append(current_patch_idx) + current_patch_idx+=1 + if len(curr_patches) > 1: + patches = patches + curr_patches + else: + for i in range(len(curr_patches)): + current_patch_idx -=1 + patch_idxs = patch_idxs[:-1] + return np2torch(np.array(patches)), patch_idxs, patch_idxs[-1][-1] + +def read_image_dir(dir_name, ext, patch_w, patch_h, good_fnames): + fnames = find_files(dir_name, ext) + patches = [] + idxs = [] + current_max_idx = 0 + for f in fnames: + if f.split('/')[-1].replace('.png', '') not in good_fnames: + continue + try: + torch_patches, p_idxs_list, max_idx = read_patch_file(f, patch_w, patch_h, current_max_idx) + except: + continue + current_max_idx = max_idx + 1 + #if patches is None: + # patches = torch_patches + # idxs = p_idxs_list + #else: + patches.append(torch_patches) + idxs = idxs + p_idxs_list + print (f, len(idxs)) + print( 'torch.cat') + patches = torch.cat(patches, dim = 0) + print ('done') + return patches, idxs + + +class HPatchesDM(data.Dataset): + image_ext = 'png' + def __init__(self, root, name, train=True, transform=None, + download=True, pw = 65, ph = 65, + n_pairs = 1000, batch_size = 128, split_name = 'b'): + self.root = os.path.expanduser(root) + self.name = name + self.n_pairs = n_pairs + self.split_name = split_name + self.batch_size = batch_size + self.train = train + self.data_dir = os.path.join(self.root, name) + if self.train: + self.data_file = os.path.join(self.root, '{}.pt'.format(self.name + '_train' )) + else: + self.data_file = os.path.join(self.root, '{}.pt'.format(self.name + '_test' )) + self.transform = transform + self.patch_h = ph + self.patch_w = pw + self.batch_size = batch_size + if download: + self.download() + + if not self._check_datafile_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + # load the serialized data + self.patches, self.idxs = torch.load(self.data_file) + print('Generating {} triplets'.format(self.n_pairs)) + self.pairs = self.generate_pairs(self.idxs, self.n_pairs) + return + def generate_pairs(self, labels, n_pairs): + pairs = [] + n_classes = len(labels) + # add only unique indices in batch + already_idxs = set() + for x in tqdm(range(n_pairs)): + if len(already_idxs) >= self.batch_size: + already_idxs = set() + c1 = np.random.randint(0, n_classes) + while c1 in already_idxs: + c1 = np.random.randint(0, n_classes) + while len(labels[c1]) < 3: + c1 = np.random.randint(0, n_classes) + already_idxs.add(c1) + if len(labels[c1]) == 2: # hack to speed up process + n1, n2 = 0, 1 + else: + n1 = np.random.randint(0, len(labels[c1])) + while (self.patches[labels[c1][n1],:,:,:].float().std() < 1e-2): + n1 = np.random.randint(0, len(labels[c1])) + n2 = np.random.randint(0, len(labels[c1])) + while (self.patches[labels[c1][n2],:,:,:].float().std() < 1e-2): + n2 = np.random.randint(0, len(labels[c1])) + pairs.append([labels[c1][n1], labels[c1][n2]]) + return torch.LongTensor(np.array(pairs)) + def __getitem__(self, index): + def transform_pair(i1,i2): + if self.transform is not None: + return self.transform(i1.cpu().numpy()), self.transform(i2.cpu().numpy()) + else: + return i1,i2 + t = self.pairs[index] + a, p = self.patches[t[0],:,:,:], self.patches[t[1],:,:,:] + a1,p1 = transform_pair(a,p) + return (a1,p1) + + def __len__(self): + return len(self.pairs) + + def _check_datafile_exists(self): + return os.path.exists(self.data_file) + + def _check_downloaded(self): + return os.path.exists(self.data_dir) + + def download(self): + if self._check_datafile_exists(): + print('# Found cached data {}'.format(self.data_file)) + return + # process and save as torch files + print('# Caching data {}'.format(self.data_file)) + import json + from pprint import pprint + #print self.urls['splits'] + with open(os.path.join(self.root, 'splits.json')) as splits_file: + data = json.load(splits_file) + if self.train: + self.img_fnames = data[self.split_name]['train'] + else: + self.img_fnames = data[self.split_name]['test'] + dataset = read_image_dir(self.data_dir, self.image_ext, self.patch_w, self.patch_h, self.img_fnames) + print('saving...') + with open(self.data_file, 'wb') as f: + torch.save(dataset, f) + return +class TotalDatasetsLoader(data.Dataset): + + def __init__(self, datasets_path, train = True, transform = None, batch_size = None, n_triplets = 5000000, fliprot = False, *arg, **kw): + super(TotalDatasetsLoader, self).__init__() + + datasets_path = [os.path.join(datasets_path, dataset) for dataset in os.listdir(datasets_path)] + + datasets = [torch.load(dataset) for dataset in datasets_path] + + data, labels = datasets[0][0], datasets[0][1] + + for i in range(1,len(datasets)): + data = torch.cat([data,datasets[i][0]]) + labels = torch.cat([labels, datasets[i][1]+torch.max(labels)+1]) + + del datasets + + self.data, self.labels = data, labels + self.transform = transform + self.train = train + self.n_triplets = n_triplets + self.batch_size = batch_size + self.fliprot = fliprot + if self.train: + print('Generating {} triplets'.format(self.n_triplets)) + self.triplets = self.generate_triplets(self.labels, self.n_triplets, self.batch_size) + + + def generate_triplets(self, labels, num_triplets, batch_size): + def create_indices(_labels): + inds = dict() + for idx, ind in enumerate(_labels): + if ind not in inds: + inds[ind] = [] + inds[ind].append(idx) + return inds + + triplets = [] + indices = create_indices(labels) + unique_labels = np.unique(labels.numpy()) + n_classes = unique_labels.shape[0] + # add only unique indices in batch + already_idxs = set() + + for x in tqdm(range(num_triplets)): + if len(already_idxs) >= batch_size: + already_idxs = set() + c1 = np.random.randint(0, n_classes) + while c1 in already_idxs: + c1 = np.random.randint(0, n_classes) + already_idxs.add(c1) + c2 = np.random.randint(0, n_classes) + while c1 == c2: + c2 = np.random.randint(0, n_classes) + if len(indices[c1]) == 2: # hack to speed up process + n1, n2 = 0, 1 + else: + n1 = np.random.randint(0, len(indices[c1])) + n2 = np.random.randint(0, len(indices[c1])) + while n1 == n2: + n2 = np.random.randint(0, len(indices[c1])) + n3 = np.random.randint(0, len(indices[c2])) + triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) + return torch.LongTensor(np.array(triplets)) + + def __getitem__(self, index): + def transform_img(img): + if self.transform is not None: + img = (img.numpy())/255.0 + img = self.transform(img) + return img + + t = self.triplets[index] + a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] + + img_a = transform_img(a) + img_p = transform_img(p) + + # transform images if required + if self.fliprot: + do_flip = random.random() > 0.5 + do_rot = random.random() > 0.5 + + if do_rot: + img_a = img_a.permute(0,2,1) + img_p = img_p.permute(0,2,1) + + if do_flip: + img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) + img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) + return img_a, img_p + + def __len__(self): + if self.train: + return self.triplets.size(0) + +class TripletPhotoTour(dset.PhotoTour): + """From the PhotoTour Dataset it generates triplet samples + note: a triplet is composed by a pair of matching images and one of + different class. + """ + urls = { + 'notredame_harris': [ + 'http://matthewalunbrown.com/patchdata/notredame_harris.zip', + 'notredame_harris.zip', + '69f8c90f78e171349abdf0307afefe4d' + ], + 'yosemite_harris': [ + 'http://matthewalunbrown.com/patchdata/yosemite_harris.zip', + 'yosemite_harris.zip', + 'a73253d1c6fbd3ba2613c45065c00d46' + ], + 'liberty_harris': [ + 'http://matthewalunbrown.com/patchdata/liberty_harris.zip', + 'liberty_harris.zip', + 'c731fcfb3abb4091110d0ae8c7ba182c' + ], + 'notredame': [ + 'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip', + 'notredame.zip', + '509eda8535847b8c0a90bbb210c83484' + ], + 'yosemite': [ + 'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip', + 'yosemite.zip', + '533b2e8eb7ede31be40abc317b2fd4f0' + ], + 'liberty': [ + 'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip', + 'liberty.zip', + 'fdd9152f138ea5ef2091746689176414' + ], + } + mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} + std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} + lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, 'liberty_harris': 379587, 'yosemite_harris': 450912 , 'notredame_harris': 325295} + def __init__(self, train=True, transform=None, batch_size = None, n_triplets = 5000, load_random_triplets = False, *arg, **kw): + super(TripletPhotoTour, self).__init__(*arg, **kw) + self.transform = transform + self.out_triplets = load_random_triplets + self.train = train + self.n_triplets = 1000 + self.batch_size = batch_size + + if self.train: + print('Generating {} triplets'.format(self.n_triplets)) + self.triplets = self.generate_triplets(self.labels, self.n_triplets) + def generate_triplets(self,labels, num_triplets): + def create_indices(_labels): + inds = dict() + for idx, ind in enumerate(_labels): + if ind not in inds: + inds[ind] = [] + inds[ind].append(idx) + return inds + triplets = [] + indices = create_indices(labels) + unique_labels = np.unique(labels.numpy()) + n_classes = unique_labels.shape[0] + # add only unique indices in batch + already_idxs = set() + for x in tqdm(range(num_triplets)): + if len(already_idxs) >= self.batch_size: + already_idxs = set() + c1 = np.random.randint(0, n_classes - 1) + while c1 in already_idxs: + c1 = np.random.randint(0, n_classes - 1) + already_idxs.add(c1) + c2 = np.random.randint(0, n_classes - 1) + while c1 == c2: + c2 = np.random.randint(0, n_classes - 1) + if len(indices[c1]) == 2: # hack to speed up process + n1, n2 = 0, 1 + else: + n1 = np.random.randint(0, len(indices[c1]) - 1) + n2 = np.random.randint(0, len(indices[c1]) - 1) + while n1 == n2: + n2 = np.random.randint(0, len(indices[c1]) - 1) + n3 = np.random.randint(0, len(indices[c2]) - 1) + triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) + return torch.LongTensor(np.array(triplets)) + def __getitem__(self, index): + def transform_img(img): + if self.transform is not None: + img = self.transform(img.numpy()) + return img + + if not self.train: + m = self.matches[index] + img1 = transform_img(self.data[m[0]]) + img2 = transform_img(self.data[m[1]]) + return img1, img2, m[2] + + t = self.triplets[index] + a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] + img_a = transform_img(a) + img_p = transform_img(p) + img_n = None + if self.out_triplets: + img_n = transform_img(n) + # transform images if required + if True:#args.fliprot: + do_flip = random.random() > 0.5 + do_rot = random.random() > 0.5 + if do_rot: + img_a = img_a.permute(0,2,1) + img_p = img_p.permute(0,2,1) + if self.out_triplets: + img_n = img_n.permute(0,2,1) + if do_flip: + img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) + img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) + if self.out_triplets: + img_n = torch.from_numpy(deepcopy(img_n.numpy()[:,:,::-1])) + if self.out_triplets: + return (img_a, img_p, img_n) + else: + return (img_a, img_p) + + def __len__(self): + if self.train: + return self.triplets.size(0) + else: + return self.matches.size(0) + diff --git a/examples/hesaffnet/HandCraftedModules.py b/examples/hesaffnet/HandCraftedModules.py new file mode 100644 index 0000000..cf67b75 --- /dev/null +++ b/examples/hesaffnet/HandCraftedModules.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import math +import numpy as np +from Utils import GaussianBlur, CircularGaussKernel +from LAF import abc2A,rectifyAffineTransformationUpIsUp, sc_y_x2LAFs +from Utils import generate_2dgrid, generate_2dgrid, generate_3dgrid +from Utils import zero_response_at_border + + +class ScalePyramid(nn.Module): + def __init__(self, nLevels = 3, init_sigma = 1.6, border = 5): + super(ScalePyramid,self).__init__() + self.nLevels = nLevels; + self.init_sigma = init_sigma + self.sigmaStep = 2 ** (1. / float(self.nLevels)) + #print 'step',self.sigmaStep + self.b = border + self.minSize = 2 * self.b + 2 + 1; + return + def forward(self,x): + pixelDistance = 1.0; + curSigma = 0.5 + if self.init_sigma > curSigma: + sigma = np.sqrt(self.init_sigma**2 - curSigma**2) + curSigma = self.init_sigma + curr = GaussianBlur(sigma = sigma)(x) + else: + curr = x + sigmas = [[curSigma]] + pixel_dists = [[1.0]] + pyr = [[curr]] + j = 0 + while True: + curr = pyr[-1][0] + for i in range(1, self.nLevels + 2): + sigma = curSigma * np.sqrt(self.sigmaStep*self.sigmaStep - 1.0 ) + #print 'blur sigma', sigma + curr = GaussianBlur(sigma = sigma)(curr) + curSigma *= self.sigmaStep + pyr[j].append(curr) + sigmas[j].append(curSigma) + pixel_dists[j].append(pixelDistance) + if i == self.nLevels: + nextOctaveFirstLevel = F.avg_pool2d(curr, kernel_size = 1, stride = 2, padding = 0) + pixelDistance = pixelDistance * 2.0 + curSigma = self.init_sigma + if (nextOctaveFirstLevel[0,0,:,:].size(0) <= self.minSize) or (nextOctaveFirstLevel[0,0,:,:].size(1) <= self.minSize): + break + pyr.append([nextOctaveFirstLevel]) + sigmas.append([curSigma]) + pixel_dists.append([pixelDistance]) + j+=1 + return pyr, sigmas, pixel_dists + +class HessianResp(nn.Module): + def __init__(self): + super(HessianResp, self).__init__() + + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[0.5, 0, -0.5]]]], dtype=np.float32)) + + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[0.5], [0], [-0.5]]]], dtype=np.float32)) + + self.gxx = nn.Conv2d(1, 1, kernel_size=(1,3),bias = False) + self.gxx.weight.data = torch.from_numpy(np.array([[[[1.0, -2.0, 1.0]]]], dtype=np.float32)) + + self.gyy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gyy.weight.data = torch.from_numpy(np.array([[[[1.0], [-2.0], [1.0]]]], dtype=np.float32)) + return + def forward(self, x, scale): + gxx = self.gxx(F.pad(x, (1,1,0, 0), 'replicate')) + gyy = self.gyy(F.pad(x, (0,0, 1,1), 'replicate')) + gxy = self.gy(F.pad(self.gx(F.pad(x, (1,1,0, 0), 'replicate')), (0,0, 1,1), 'replicate')) + return torch.abs(gxx * gyy - gxy * gxy) * (scale**4) + + +class AffineShapeEstimator(nn.Module): + def __init__(self, threshold = 0.001, patch_size = 19): + super(AffineShapeEstimator, self).__init__() + self.threshold = threshold; + self.PS = patch_size + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[-1, 0, 1]]]], dtype=np.float32)) + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32)) + self.gk = torch.from_numpy(CircularGaussKernel(kernlen = self.PS, sigma = (self.PS / 2) /3.0).astype(np.float32)) + self.gk = Variable(self.gk, requires_grad=False) + return + def invSqrt(self,a,b,c): + eps = 1e-12 + mask = (b != 0).float() + r1 = mask * (c - a) / (2. * b + eps) + t1 = torch.sign(r1) / (torch.abs(r1) + torch.sqrt(1. + r1*r1)); + r = 1.0 / torch.sqrt( 1. + t1*t1) + t = t1*r; + r = r * mask + 1.0 * (1.0 - mask); + t = t * mask; + + x = 1. / torch.sqrt( r*r*a - 2.0*r*t*b + t*t*c) + z = 1. / torch.sqrt( t*t*a + 2.0*r*t*b + r*r*c) + + d = torch.sqrt( x * z) + + x = x / d + z = z / d + + l1 = torch.max(x,z) + l2 = torch.min(x,z) + + new_a = r*r*x + t*t*z + new_b = -r*t*x + t*r*z + new_c = t*t*x + r*r *z + + return new_a, new_b, new_c, l1, l2 + def forward(self,x): + if x.is_cuda: + self.gk = self.gk.cuda() + else: + self.gk = self.gk.cpu() + gx = self.gx(F.pad(x, (1, 1, 0, 0), 'replicate')) + gy = self.gy(F.pad(x, (0, 0, 1, 1), 'replicate')) + a1 = (gx * gx * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + b1 = (gx * gy * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + c1 = (gy * gy * self.gk.unsqueeze(0).unsqueeze(0).expand_as(gx)).view(x.size(0),-1).mean(dim=1) + a, b, c, l1, l2 = self.invSqrt(a1,b1,c1) + rat1 = l1/l2 + mask = (torch.abs(rat1) <= 6.).float().view(-1); + return rectifyAffineTransformationUpIsUp(abc2A(a,b,c))#, mask +class OrientationDetector(nn.Module): + def __init__(self, + mrSize = 3.0, patch_size = None): + super(OrientationDetector, self).__init__() + if patch_size is None: + patch_size = 32; + self.PS = patch_size; + self.bin_weight_kernel_size, self.bin_weight_stride = self.get_bin_weight_kernel_size_and_stride(self.PS, 1) + self.mrSize = mrSize; + self.num_ang_bins = 36 + self.gx = nn.Conv2d(1, 1, kernel_size=(1,3), bias = False) + self.gx.weight.data = torch.from_numpy(np.array([[[[0.5, 0, -0.5]]]], dtype=np.float32)) + + self.gy = nn.Conv2d(1, 1, kernel_size=(3,1), bias = False) + self.gy.weight.data = torch.from_numpy(np.array([[[[0.5], [0], [-0.5]]]], dtype=np.float32)) + + self.angular_smooth = nn.Conv1d(1, 1, kernel_size=3, padding = 1, bias = False) + self.angular_smooth.weight.data = torch.from_numpy(np.array([[[0.33, 0.34, 0.33]]], dtype=np.float32)) + + self.gk = 10. * torch.from_numpy(CircularGaussKernel(kernlen=self.PS).astype(np.float32)) + self.gk = Variable(self.gk, requires_grad=False) + return + def get_bin_weight_kernel_size_and_stride(self, patch_size, num_spatial_bins): + bin_weight_stride = int(round(2.0 * np.floor(patch_size / 2) / float(num_spatial_bins + 1))) + bin_weight_kernel_size = int(2 * bin_weight_stride - 1); + return bin_weight_kernel_size, bin_weight_stride + def get_rotation_matrix(self, angle_in_radians): + angle_in_radians = angle_in_radians.view(-1, 1, 1); + sin_a = torch.sin(angle_in_radians) + cos_a = torch.cos(angle_in_radians) + A1_x = torch.cat([cos_a, sin_a], dim = 2) + A2_x = torch.cat([-sin_a, cos_a], dim = 2) + transform = torch.cat([A1_x,A2_x], dim = 1) + return transform + + def forward(self, x, return_rot_matrix = False): + gx = self.gx(F.pad(x, (1,1,0, 0), 'replicate')) + gy = self.gy(F.pad(x, (0,0, 1,1), 'replicate')) + mag = torch.sqrt(gx * gx + gy * gy + 1e-10) + if x.is_cuda: + self.gk = self.gk.cuda() + mag = mag * self.gk.unsqueeze(0).unsqueeze(0).expand_as(mag) + ori = torch.atan2(gy,gx) + o_big = float(self.num_ang_bins) *(ori + 1.0 * math.pi )/ (2.0 * math.pi) + bo0_big = torch.floor(o_big) + wo1_big = o_big - bo0_big + bo0_big = bo0_big % self.num_ang_bins + bo1_big = (bo0_big + 1) % self.num_ang_bins + wo0_big = (1.0 - wo1_big) * mag + wo1_big = wo1_big * mag + ang_bins = [] + for i in range(0, self.num_ang_bins): + ang_bins.append(F.adaptive_avg_pool2d((bo0_big == i).float() * wo0_big, (1,1))) + ang_bins = torch.cat(ang_bins,1).view(-1,1,self.num_ang_bins) + ang_bins = self.angular_smooth(ang_bins) + values, indices = ang_bins.view(-1,self.num_ang_bins).max(1) + angle = -((2. * float(np.pi) * indices.float() / float(self.num_ang_bins)) - float(math.pi)) + if return_rot_matrix: + return self.get_rotation_matrix(angle) + return angle + +class NMS2d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS2d, self).__init__() + self.MP = nn.MaxPool2d(kernel_size, stride=1, return_indices=False, padding = kernel_size/2) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + +class NMS3d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS3d, self).__init__() + self.MP = nn.MaxPool3d(kernel_size, stride=1, return_indices=False, padding = (0, kernel_size/2, kernel_size/2)) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + +class NMS3dAndComposeA(nn.Module): + def __init__(self, w = 0, h = 0, kernel_size = 3, threshold = 0, scales = None, border = 3, mrSize = 1.0): + super(NMS3dAndComposeA, self).__init__() + self.eps = 1e-7 + self.ks = 3 + self.th = threshold + self.cube_idxs = [] + self.border = border + self.mrSize = mrSize + self.beta = 1.0 + self.grid_ones = Variable(torch.ones(3,3,3,3), requires_grad=False) + self.NMS3d = NMS3d(kernel_size, threshold) + if (w > 0) and (h > 0): + self.spatial_grid = generate_2dgrid(h, w, False).view(1, h, w,2).permute(3,1, 2, 0) + self.spatial_grid = Variable(self.spatial_grid) + else: + self.spatial_grid = None + return + def forward(self, low, cur, high, num_features = 0, octaveMap = None, scales = None): + assert low.size() == cur.size() == high.size() + #Filter responce map + self.is_cuda = low.is_cuda; + resp3d = torch.cat([low,cur,high], dim = 1) + + mrSize_border = int(self.mrSize); + if octaveMap is not None: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) * (1. - octaveMap.float()) + else: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) + + num_of_nonzero_responces = (nmsed_resp > 0).sum().data[0] + if (num_of_nonzero_responces == 0): + return None,None,None + if octaveMap is not None: + octaveMap = (octaveMap.float() + nmsed_resp.float()).byte() + + nmsed_resp = nmsed_resp.view(-1) + if (num_features > 0) and (num_features < num_of_nonzero_responces): + nmsed_resp, idxs = torch.topk(nmsed_resp, k = num_features); + else: + idxs = nmsed_resp.data.nonzero().squeeze() + nmsed_resp = nmsed_resp[idxs] + #Get point coordinates grid + + if type(scales) is not list: + self.grid = generate_3dgrid(3,self.ks,self.ks) + else: + self.grid = generate_3dgrid(scales,self.ks,self.ks) + self.grid = Variable(self.grid.t().contiguous().view(3,3,3,3), requires_grad=False) + if self.spatial_grid is None: + self.spatial_grid = generate_2dgrid(low.size(2), low.size(3), False).view(1, low.size(2), low.size(3),2).permute(3,1, 2, 0) + self.spatial_grid = Variable(self.spatial_grid) + if self.is_cuda: + self.spatial_grid = self.spatial_grid.cuda() + self.grid_ones = self.grid_ones.cuda() + self.grid = self.grid.cuda() + + #residual_to_patch_center + sc_y_x = F.conv2d(resp3d, self.grid, + padding = 1) / (F.conv2d(resp3d, self.grid_ones, padding = 1) + 1e-8) + + ##maxima coords + sc_y_x[0,1:,:,:] = sc_y_x[0,1:,:,:] + self.spatial_grid[:,:,:,0] + sc_y_x = sc_y_x.view(3,-1).t() + sc_y_x = sc_y_x[idxs,:] + + min_size = float(min((cur.size(2)), cur.size(3))) + sc_y_x[:,0] = sc_y_x[:,0] / min_size + sc_y_x[:,1] = sc_y_x[:,1] / float(cur.size(2)) + sc_y_x[:,2] = sc_y_x[:,2] / float(cur.size(3)) + return nmsed_resp, sc_y_x2LAFs(sc_y_x), octaveMap diff --git a/examples/hesaffnet/LAF.py b/examples/hesaffnet/LAF.py new file mode 100644 index 0000000..4aa1633 --- /dev/null +++ b/examples/hesaffnet/LAF.py @@ -0,0 +1,293 @@ +import numpy as np +import matplotlib.pyplot as plt +from copy import deepcopy +from scipy.spatial.distance import cdist +from numpy.linalg import inv +from scipy.linalg import schur, sqrtm +import torch +from torch.autograd import Variable + +##########numpy +def invSqrt(a,b,c): + eps = 1e-12 + mask = (b != 0) + r1 = mask * (c - a) / (2. * b + eps) + t1 = np.sign(r1) / (np.abs(r1) + np.sqrt(1. + r1*r1)); + r = 1.0 / np.sqrt( 1. + t1*t1) + t = t1*r; + + r = r * mask + 1.0 * (1.0 - mask); + t = t * mask; + + x = 1. / np.sqrt( r*r*a - 2*r*t*b + t*t*c) + z = 1. / np.sqrt( t*t*a + 2*r*t*b + r*r*c) + + d = np.sqrt( x * z) + + x = x / d + z = z / d + + new_a = r*r*x + t*t*z + new_b = -r*t*x + t*r*z + new_c = t*t*x + r*r *z + + return new_a, new_b, new_c + +def Ell2LAF(ell): + A23 = np.zeros((2,3)) + A23[0,2] = ell[0] + A23[1,2] = ell[1] + a = ell[2] + b = ell[3] + c = ell[4] + sc = np.sqrt(np.sqrt(a*c - b*b)) + ia,ib,ic = invSqrt(a,b,c) #because sqrtm returns ::-1, ::-1 matrix, don`t know why + A = np.array([[ia, ib], [ib, ic]]) / sc + sc = np.sqrt(A[0,0] * A[1,1] - A[1,0] * A[0,1]) + A23[0:2,0:2] = rectifyAffineTransformationUpIsUp(A / sc) * sc + return A23 + +def rectifyAffineTransformationUpIsUp_np(A): + det = np.sqrt(np.abs(A[0,0]*A[1,1] - A[1,0]*A[0,1] + 1e-10)) + b2a2 = np.sqrt(A[0,1] * A[0,1] + A[0,0] * A[0,0]) + A_new = np.zeros((2,2)) + A_new[0,0] = b2a2 / det + A_new[0,1] = 0 + A_new[1,0] = (A[1,1]*A[0,1]+A[1,0]*A[0,0])/(b2a2*det) + A_new[1,1] = det / b2a2 + return A_new + +def ells2LAFs(ells): + LAFs = np.zeros((len(ells), 2,3)) + for i in range(len(ells)): + LAFs[i,:,:] = Ell2LAF(ells[i,:]) + return LAFs + +def LAF2pts(LAF, n_pts = 50): + a = np.linspace(0, 2*np.pi, n_pts); + x = [0] + x.extend(list(np.sin(a))) + x = np.array(x).reshape(1,-1) + y = [0] + y.extend(list(np.cos(a))) + y = np.array(y).reshape(1,-1) + HLAF = np.concatenate([LAF, np.array([0,0,1]).reshape(1,3)]) + H_pts =np.concatenate([x,y,np.ones(x.shape)]) + H_pts_out = np.transpose(np.matmul(HLAF, H_pts)) + H_pts_out[:,0] = H_pts_out[:,0] / H_pts_out[:, 2] + H_pts_out[:,1] = H_pts_out[:,1] / H_pts_out[:, 2] + return H_pts_out[:,0:2] + + +def convertLAFs_to_A23format(LAFs): + sh = LAFs.shape + if (len(sh) == 3) and (sh[1] == 2) and (sh[2] == 3): # n x 2 x 3 classical [A, (x;y)] matrix + work_LAFs = deepcopy(LAFs) + elif (len(sh) == 2) and (sh[1] == 7): #flat format, x y scale a11 a12 a21 a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] * LAFs[:,3] + work_LAFs[:,0,1] = LAFs[:,2] * LAFs[:,4] + work_LAFs[:,1,0] = LAFs[:,2] * LAFs[:,5] + work_LAFs[:,1,1] = LAFs[:,2] * LAFs[:,6] + elif (len(sh) == 2) and (sh[1] == 6): #flat format, x y s*a11 s*a12 s*a21 s*a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] + work_LAFs[:,0,1] = LAFs[:,3] + work_LAFs[:,1,0] = LAFs[:,4] + work_LAFs[:,1,1] = LAFs[:,5] + else: + print 'Unknown LAF format' + return None + return work_LAFs + +def LAFs2ell(in_LAFs): + LAFs = convertLAFs_to_A23format(in_LAFs) + ellipses = np.zeros((len(LAFs),5)) + for i in range(len(LAFs)): + LAF = deepcopy(LAFs[i,:,:]) + scale = np.sqrt(LAF[0,0]*LAF[1,1] - LAF[0,1]*LAF[1, 0] + 1e-10) + u, W, v = np.linalg.svd(LAF[0:2,0:2] / scale, full_matrices=True) + W[0] = 1. / (W[0]*W[0]*scale*scale) + W[1] = 1. / (W[1]*W[1]*scale*scale) + A = np.matmul(np.matmul(u, np.diag(W)), u.transpose()) + ellipses[i,0] = LAF[0,2] + ellipses[i,1] = LAF[1,2] + ellipses[i,2] = A[0,0] + ellipses[i,3] = A[0,1] + ellipses[i,4] = A[1,1] + return ellipses + +def visualize_LAFs(img, LAFs): + work_LAFs = convertLAFs_to_A23format(LAFs) + plt.figure() + plt.imshow(255 - img) + for i in range(len(work_LAFs)): + ell = LAF2pts(work_LAFs[i,:,:]) + plt.plot( ell[:,0], ell[:,1], 'r') + plt.show() + return + +####pytorch + +def get_normalized_affine_shape(tilt, angle_in_radians): + assert tilt.size(0) == angle_in_radians.size(0) + num = tilt.size(0) + tilt_A = Variable(torch.eye(2).view(1,2,2).repeat(num,1,1)) + if tilt.is_cuda: + tilt_A = tilt_A.cuda() + tilt_A[:,0,0] = tilt; + rotmat = get_rotation_matrix(angle_in_radians) + out_A = rectifyAffineTransformationUpIsUp(torch.bmm(rotmat, torch.bmm(tilt_A, rotmat))) + #re_scale = (1.0/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + #re_scale = (0.5 + 0.5/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + return out_A# * re_scale.view(-1,1,1).expand(num,2,2) + +def get_rotation_matrix(angle_in_radians): + angle_in_radians = angle_in_radians.view(-1, 1, 1); + sin_a = torch.sin(angle_in_radians) + cos_a = torch.cos(angle_in_radians) + A1_x = torch.cat([cos_a, sin_a], dim = 2) + A2_x = torch.cat([-sin_a, cos_a], dim = 2) + transform = torch.cat([A1_x,A2_x], dim = 1) + return transform + +def rectifyAffineTransformationUpIsUp(A): + det = torch.sqrt(torch.abs(A[:,0,0]*A[:,1,1] - A[:,1,0]*A[:,0,1] + 1e-10)) + b2a2 = torch.sqrt(A[:,0,1] * A[:,0,1] + A[:,0,0] * A[:,0,0]) + A1_ell = torch.cat([(b2a2 / det).contiguous().view(-1,1,1), 0 * det.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([((A[:,1,1]*A[:,0,1]+A[:,1,0]*A[:,0,0])/(b2a2*det)).contiguous().view(-1,1,1), + (det / b2a2).contiguous().view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def abc2A(a,b,c, normalize = False): + A1_ell = torch.cat([a.view(-1,1,1), b.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([b.view(-1,1,1), c.view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def angles2A(angles): + cos_a = torch.cos(angles).view(-1, 1, 1) + sin_a = torch.sin(angles).view(-1, 1, 1) + A1_ang = torch.cat([cos_a, sin_a], dim = 2) + A2_ang = torch.cat([-sin_a, cos_a], dim = 2) + return torch.cat([A1_ang, A2_ang], dim = 1) + +def generate_patch_grid_from_normalized_LAFs(LAFs, w, h, PS): + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3) * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + grid = torch.nn.functional.affine_grid(LAFs * Variable(coef.expand(num_lafs,2,3)), torch.Size((num_lafs,1,PS,PS))) + grid[:,:,:,0] = 2.0 * grid[:,:,:,0] / float(w) - 1.0 + grid[:,:,:,1] = 2.0 * grid[:,:,:,1] / float(h) - 1.0 + return grid + +def extract_patches(img, LAFs, PS = 32): + w = img.size(3) + h = img.size(2) + ch = img.size(1) + grid = generate_patch_grid_from_normalized_LAFs(LAFs, float(w),float(h), PS) + return torch.nn.functional.grid_sample(img.expand(grid.size(0), ch, h, w), grid) + +def get_pyramid_inverted_index_for_LAFs(LAFs, PS, sigmas): + return + +def extract_patches_from_pyramid_with_inv_index(scale_pyramid, pyr_inv_idxs, LAFs, PS = 19): + patches = torch.zeros(LAFs.size(0),scale_pyramid[0][0].size(1), PS, PS) + if LAFs.is_cuda: + patches = patches.cuda() + patches = Variable(patches) + if pyr_inv_idxs is not None: + for i in range(len(scale_pyramid)): + for j in range(len(scale_pyramid[i])): + cur_lvl_idxs = pyr_inv_idxs[i][j] + if cur_lvl_idxs is None: + continue + cur_lvl_idxs = cur_lvl_idxs.view(-1) + #print i,j,cur_lvl_idxs.shape + patches[cur_lvl_idxs,:,:,:] = extract_patches(scale_pyramid[i][j], LAFs[cur_lvl_idxs, :,:], PS ) + return patches + +def get_inverted_pyr_index(scale_pyr, pyr_idxs, level_idxs): + pyr_inv_idxs = [] + ### Precompute octave inverted indexes + for i in range(len(scale_pyr)): + pyr_inv_idxs.append([]) + cur_idxs = pyr_idxs == i #torch.nonzero((pyr_idxs == i).data) + for j in range(0, len(scale_pyr[i])): + cur_lvl_idxs = torch.nonzero(((level_idxs == j) * cur_idxs).data) + if len(cur_lvl_idxs.size()) == 0: + pyr_inv_idxs[i].append(None) + else: + pyr_inv_idxs[i].append(cur_lvl_idxs.squeeze()) + return pyr_inv_idxs + + +def denormalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def normalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() / min_size + coef[0,0,2] = 1.0 / w + coef[0,1,2] = 1.0 / h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def sc_y_x2LAFs(sc_y_x): + base_LAF = torch.eye(2).float().unsqueeze(0).expand(sc_y_x.size(0),2,2) + if sc_y_x.is_cuda: + base_LAF = base_LAF.cuda() + base_A = Variable(base_LAF, requires_grad=False) + A = sc_y_x[:,:1].unsqueeze(1).expand_as(base_A) * base_A + LAFs = torch.cat([A, + torch.cat([sc_y_x[:,2:].unsqueeze(-1), + sc_y_x[:,1:2].unsqueeze(-1)], dim=1)], dim = 2) + + return LAFs +def get_LAFs_scales(LAFs): + return torch.sqrt(torch.abs(LAFs[:,0,0] *LAFs[:,1,1] - LAFs[:,0,1] * LAFs[:,1,0]) + 1e-12) + +def get_pyramid_and_level_index_for_LAFs(dLAFs, sigmas, pix_dists, PS): + scales = get_LAFs_scales(dLAFs); + needed_sigmas = scales / PS; + sigmas_full_list = [] + level_idxs_full = [] + oct_idxs_full = [] + for oct_idx in range(len(sigmas)): + sigmas_full_list = sigmas_full_list + list(np.array(sigmas[oct_idx])*np.array(pix_dists[oct_idx])) + oct_idxs_full = oct_idxs_full + [oct_idx]*len(sigmas[oct_idx]) + level_idxs_full = level_idxs_full + range(0,len(sigmas[oct_idx])) + oct_idxs_full = torch.LongTensor(oct_idxs_full) + level_idxs_full = torch.LongTensor(level_idxs_full) + + closest_imgs = cdist(np.array(sigmas_full_list).reshape(-1,1), needed_sigmas.data.cpu().numpy().reshape(-1,1)).argmin(axis = 0) + closest_imgs = torch.from_numpy(closest_imgs) + if dLAFs.is_cuda: + closest_imgs = closest_imgs.cuda() + oct_idxs_full = oct_idxs_full.cuda() + level_idxs_full = level_idxs_full.cuda() + return Variable(oct_idxs_full[closest_imgs]), Variable(level_idxs_full[closest_imgs]) diff --git a/examples/hesaffnet/NMS.py b/examples/hesaffnet/NMS.py new file mode 100644 index 0000000..4fa7a2f --- /dev/null +++ b/examples/hesaffnet/NMS.py @@ -0,0 +1,103 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Variable +from Utils import CircularGaussKernel, generate_2dgrid, generate_2dgrid, generate_3dgrid, zero_response_at_border +from LAF import sc_y_x2LAFs + +class NMS2d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS2d, self).__init__() + self.MP = nn.MaxPool2d(kernel_size, stride=1, return_indices=False, padding = kernel_size/2) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + +class NMS3d(nn.Module): + def __init__(self, kernel_size = 3, threshold = 0): + super(NMS3d, self).__init__() + self.MP = nn.MaxPool3d(kernel_size, stride=1, return_indices=False, padding = (0, kernel_size/2, kernel_size/2)) + self.eps = 1e-5 + self.th = threshold + return + def forward(self, x): + #local_maxima = self.MP(x) + if self.th > self.eps: + return x * (x > self.th).float() * ((x + self.eps - self.MP(x)) > 0).float() + else: + return ((x - self.MP(x) + self.eps) > 0).float() * x + + +class NMS3dAndComposeA(nn.Module): + def __init__(self,kernel_size = 3, threshold = 0, scales = None, border = 3, mrSize = 1.0): + super(NMS3dAndComposeA, self).__init__() + self.eps = 1e-7 + self.ks = 3 + if type(scales) is not list: + self.grid = generate_3dgrid(3,self.ks,self.ks) + else: + self.grid = generate_3dgrid(scales,self.ks,self.ks) + self.grid = Variable(self.grid.t().contiguous().view(3,3,3,3), requires_grad=False) + self.th = threshold + self.cube_idxs = [] + self.border = border + self.mrSize = mrSize + self.beta = 1.0 + self.grid_ones = Variable(torch.ones(3,3,3,3), requires_grad=False) + self.NMS3d = NMS3d(kernel_size, threshold) + return + def forward(self, low, cur, high, octaveMap = None, num_features = 0): + assert low.size() == cur.size() == high.size() + + #Filter responce map + self.is_cuda = low.is_cuda; + resp3d = torch.cat([low,cur,high], dim = 1) + + mrSize_border = int(self.mrSize); + if octaveMap is not None: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) * (1. - octaveMap.float()) + else: + nmsed_resp = zero_response_at_border(self.NMS3d(resp3d.unsqueeze(1)).squeeze(1)[:,1:2,:,:], mrSize_border) + + num_of_nonzero_responces = (nmsed_resp > 0).sum().data[0] + if (num_of_nonzero_responces == 0): + return None,None,None + if octaveMap is not None: + octaveMap = (octaveMap.float() + nmsed_resp.float()).byte() + + nmsed_resp = nmsed_resp.view(-1) + if (num_features > 0) and (num_features < num_of_nonzero_responces): + nmsed_resp, idxs = torch.topk(nmsed_resp, k = num_features); + else: + idxs = nmsed_resp.data.nonzero().squeeze() + nmsed_resp = nmsed_resp[idxs] + + #Get point coordinates + + spatial_grid = Variable(generate_2dgrid(low.size(2), low.size(3), False)).view(1,low.size(2), low.size(3),2) + spatial_grid = spatial_grid.permute(3,1, 2, 0) + if self.is_cuda: + spatial_grid = spatial_grid.cuda() + self.grid = self.grid.cuda() + self.grid_ones = self.grid_ones.cuda() + #residual_to_patch_center + sc_y_x = F.conv2d(resp3d, self.grid, + padding = 1) / (F.conv2d(resp3d, self.grid_ones, padding = 1) + 1e-8) + + ##maxima coords + sc_y_x[0,1:,:,:] = sc_y_x[0,1:,:,:] + spatial_grid[:,:,:,0] + sc_y_x = sc_y_x.view(3,-1).t() + sc_y_x = sc_y_x[idxs,:] + + min_size = float(min((cur.size(2)), cur.size(3))) + sc_y_x[:,0] = sc_y_x[:,0] / min_size + sc_y_x[:,1] = sc_y_x[:,1] / float(cur.size(2)) + sc_y_x[:,2] = sc_y_x[:,2] / float(cur.size(3)) + + return nmsed_resp, sc_y_x2LAFs(sc_y_x), octaveMap \ No newline at end of file diff --git a/examples/hesaffnet/SparseImgRepresenter.py b/examples/hesaffnet/SparseImgRepresenter.py new file mode 100644 index 0000000..a16315f --- /dev/null +++ b/examples/hesaffnet/SparseImgRepresenter.py @@ -0,0 +1,200 @@ +import torch +import torch.nn as nn +import numpy as np +import math +import torch.nn.functional as F +from torch.autograd import Variable +from copy import deepcopy +from Utils import GaussianBlur, batch_eig2x2, line_prepender, batched_forward +from LAF import LAFs2ell,abc2A, angles2A, generate_patch_grid_from_normalized_LAFs, extract_patches, get_inverted_pyr_index, denormalizeLAFs, extract_patches_from_pyramid_with_inv_index, rectifyAffineTransformationUpIsUp +from LAF import get_pyramid_and_level_index_for_LAFs, normalizeLAFs +from HandCraftedModules import HessianResp, AffineShapeEstimator, OrientationDetector, ScalePyramid, NMS3dAndComposeA +import time + +class ScaleSpaceAffinePatchExtractor(nn.Module): + def __init__(self, + border = 16, + num_features = 500, + patch_size = 32, + mrSize = 3.0, + nlevels = 3, + num_Baum_iters = 0, + init_sigma = 1.6, + RespNet = None, OriNet = None, AffNet = None): + super(ScaleSpaceAffinePatchExtractor, self).__init__() + self.mrSize = mrSize + self.PS = patch_size + self.b = border; + self.num = num_features + self.nlevels = nlevels + self.num_Baum_iters = num_Baum_iters + self.init_sigma = init_sigma + if RespNet is not None: + self.RespNet = RespNet + else: + self.RespNet = HessianResp() + if OriNet is not None: + self.OriNet = OriNet + else: + self.OriNet= OrientationDetector(patch_size = 19); + if AffNet is not None: + self.AffNet = AffNet + else: + self.AffNet = AffineShapeEstimator(patch_size = 19) + self.ScalePyrGen = ScalePyramid(nLevels = self.nlevels, init_sigma = self.init_sigma, border = self.b) + return + + def multiScaleDetector(self,x, num_features = 0): + t = time.time() + self.scale_pyr, self.sigmas, self.pix_dists = self.ScalePyrGen(x) + ### Detect keypoints in scale space + aff_matrices = [] + top_responces = [] + pyr_idxs = [] + level_idxs = [] + det_t = 0 + nmst = 0 + for oct_idx in range(len(self.sigmas)): + #print oct_idx + octave = self.scale_pyr[oct_idx] + sigmas_oct = self.sigmas[oct_idx] + pix_dists_oct = self.pix_dists[oct_idx] + low = None + cur = None + high = None + octaveMap = (self.scale_pyr[oct_idx][0] * 0).byte() + nms_f = NMS3dAndComposeA(w = octave[0].size(3), + h = octave[0].size(2), + border = self.b, mrSize = self.mrSize) + for level_idx in range(1, len(octave)-1): + if cur is None: + low = self.RespNet(octave[level_idx - 1], (sigmas_oct[level_idx - 1 ])) + else: + low = cur + if high is None: + cur = self.RespNet(octave[level_idx ], (sigmas_oct[level_idx ])) + else: + cur = high + high = self.RespNet(octave[level_idx + 1], (sigmas_oct[level_idx + 1 ])) + top_resp, aff_matrix, octaveMap_current = nms_f(low, cur, high, + num_features = num_features, + octaveMap = octaveMap, + scales = sigmas_oct[level_idx - 1:level_idx + 2]) + if top_resp is None: + continue + octaveMap = octaveMap_current + aff_matrices.append(aff_matrix), top_responces.append(top_resp) + pyr_id = Variable(oct_idx * torch.ones(aff_matrix.size(0))) + lev_id = Variable((level_idx - 1) * torch.ones(aff_matrix.size(0))) #prevBlur + if x.is_cuda: + pyr_id = pyr_id.cuda() + lev_id = lev_id.cuda() + pyr_idxs.append(pyr_id) + level_idxs.append(lev_id) + all_responses = torch.cat(top_responces, dim = 0) + aff_m_scales = torch.cat(aff_matrices,dim = 0) + pyr_idxs_scales = torch.cat(pyr_idxs,dim = 0) + level_idxs_scale = torch.cat(level_idxs, dim = 0) + if (num_features > 0) and (num_features < all_responses.size(0)): + all_responses, idxs = torch.topk(all_responses, k = num_features); + LAFs = torch.index_select(aff_m_scales, 0, idxs) + final_pyr_idxs = pyr_idxs_scales[idxs] + final_level_idxs = level_idxs_scale[idxs] + else: + return all_responses, aff_m_scales, pyr_idxs_scales , level_idxs_scale + return all_responses, LAFs, final_pyr_idxs, final_level_idxs, + + def getAffineShape(self, final_resp, LAFs, final_pyr_idxs, final_level_idxs, num_features = 0): + pe_time = 0 + affnet_time = 0 + pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, final_pyr_idxs, final_level_idxs) + t = time.time() + patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.AffNet.PS) + pe_time+=time.time() - t + t = time.time() + base_A = torch.eye(2).unsqueeze(0).expand(final_pyr_idxs.size(0),2,2) + if final_resp.is_cuda: + base_A = base_A.cuda() + base_A = Variable(base_A) + is_good = None + n_patches = patches_small.size(0) + for i in range(self.num_Baum_iters): + t = time.time() + A = batched_forward(self.AffNet, patches_small, 512) + is_good_current = 1 + affnet_time += time.time() - t + if is_good is None: + is_good = is_good_current + else: + is_good = is_good * is_good_current + base_A = torch.bmm(A, base_A); + new_LAFs = torch.cat([torch.bmm(base_A,LAFs[:,:,0:2]), LAFs[:,:,2:] ], dim =2) + #print torch.sqrt(new_LAFs[0,0,0]*new_LAFs[0,1,1] - new_LAFs[0,1,0] *new_LAFs[0,0,1]) * scale_pyr[0][0].size(2) + if i != self.num_Baum_iters - 1: + pe_time+=time.time() - t + t = time.time() + patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, new_LAFs, PS = self.AffNet.PS) + pe_time+= time.time() - t + l1,l2 = batch_eig2x2(A) + ratio1 = torch.abs(l1 / (l2 + 1e-8)) + converged_mask = (ratio1 <= 1.2) * (ratio1 >= (0.8)) + l1,l2 = batch_eig2x2(base_A) + ratio = torch.abs(l1 / (l2 + 1e-8)) + idxs_mask = ((ratio < 6.0) * (ratio > (1./6.)))# * converged_mask.float()) > 0 + num_survived = idxs_mask.float().sum() + if (num_features > 0) and (num_survived.data[0] > num_features): + final_resp = final_resp * idxs_mask.float() #zero bad points + final_resp, idxs = torch.topk(final_resp, k = num_features); + else: + idxs = torch.nonzero(idxs_mask.data).view(-1).long() + final_resp = final_resp[idxs] + final_pyr_idxs = final_pyr_idxs[idxs] + final_level_idxs = final_level_idxs[idxs] + base_A = torch.index_select(base_A, 0, idxs) + LAFs = torch.index_select(LAFs, 0, idxs) + new_LAFs = torch.cat([torch.bmm(rectifyAffineTransformationUpIsUp(base_A), LAFs[:,:,0:2]), + LAFs[:,:,2:]], dim =2) + print 'affnet_time',affnet_time + print 'pe_time', pe_time + return final_resp, new_LAFs, final_pyr_idxs, final_level_idxs + + def getOrientation(self, LAFs, final_pyr_idxs, final_level_idxs): + pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, final_pyr_idxs, final_level_idxs) + patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS) + max_iters = 1 + ### Detect orientation + for i in range(max_iters): + angles = self.OriNet(patches_small) + #print np.degrees(ori.data.cpu().numpy().ravel()[1]) + LAFs = torch.cat([torch.bmm(angles2A(angles), LAFs[:,:,:2]), LAFs[:,:,2:]], dim = 2) + if i != max_iters: + patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS) + return LAFs + def extract_patches_from_pyr(self, dLAFs, PS = 41): + pyr_idxs, level_idxs = get_pyramid_and_level_index_for_LAFs(dLAFs, self.sigmas, self.pix_dists, PS) + pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, pyr_idxs, level_idxs) + patches = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, + pyr_inv_idxs, + normalizeLAFs(dLAFs, self.scale_pyr[0][0].size(3), self.scale_pyr[0][0].size(2)), + PS = PS) + return patches + def forward(self,x): + ### Detection + t = time.time() + num_features_prefilter = self.num + if self.num_Baum_iters > 0: + num_features_prefilter = int(1.5 * self.num); + responses, LAFs, final_pyr_idxs, final_level_idxs = self.multiScaleDetector(x,num_features_prefilter) + print time.time() - t, 'detection multiscale' + t = time.time() + LAFs[:,0:2,0:2] = self.mrSize * LAFs[:,:,0:2] + if self.num_Baum_iters > 0: + responses, LAFs, final_pyr_idxs, final_level_idxs = self.getAffineShape(responses, LAFs, final_pyr_idxs, final_level_idxs, self.num) + print time.time() - t, 'affine shape iters' + t = time.time() + #LAFs = self.getOrientation(scale_pyr, LAFs, final_pyr_idxs, final_level_idxs) + #pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs, final_level_idxs) + #patches = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.PS) + #patches = extract_patches(x, LAFs, PS = self.PS) + #print time.time() - t, len(LAFs), ' patches extraction' + return denormalizeLAFs(LAFs, x.size(3), x.size(2)), responses diff --git a/examples/hesaffnet/Utils.py b/examples/hesaffnet/Utils.py new file mode 100644 index 0000000..af37a7f --- /dev/null +++ b/examples/hesaffnet/Utils.py @@ -0,0 +1,187 @@ +import torch +import torch.nn.init +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import cv2 +import numpy as np + +# resize image to size 32x32 +cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), + interpolation=cv2.INTER_LINEAR) +# reshape image +np_reshape = lambda x: np.reshape(x, (32, 32, 1)) + +def zeros_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.zeros(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.zeros(y) + +def ones_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.ones(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.ones(y) + + +def batched_forward(model, data, batch_size, **kwargs): + n_patches = len(data) + if n_patches > batch_size: + bs = batch_size + n_batches = n_patches / bs + 1 + for batch_idx in range(n_batches): + st = batch_idx * bs + if batch_idx == n_batches - 1: + if (batch_idx + 1) * bs > n_patches: + end = n_patches + else: + end = (batch_idx + 1) * bs + else: + end = (batch_idx + 1) * bs + if st >= end: + continue + if batch_idx == 0: + try: + first_batch_out = model(data[st:end])#, kwargs) + except: + first_batch_out = model(data[st:end])# kwargs) + out_size = torch.Size([n_patches] + list(first_batch_out.size()[1:])) + #out_size[0] = n_patches + out = torch.zeros(out_size); + if data.is_cuda: + out = out.cuda() + out = Variable(out) + out[st:end] = first_batch_out + else: + try: + out[st:end,:,:] = model(data[st:end])#, kwargs) + except: + out[st:end,:,:] = model(data[st:end])#, kwargs) + return out + else: + return model(data)#, kwargs) + +class L2Norm(nn.Module): + def __init__(self): + super(L2Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) + x= x / norm.unsqueeze(-1).expand_as(x) + return x + +class L1Norm(nn.Module): + def __init__(self): + super(L1Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sum(torch.abs(x), dim = 1) + self.eps + x= x / norm.expand_as(x) + return x + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + +def CircularGaussKernel(kernlen=None, circ_zeros = False, sigma = None, norm = True): + assert ((kernlen is not None) or sigma is not None) + if kernlen is None: + kernlen = int(2.0 * 3.0 * sigma + 1.0) + if (kernlen % 2 == 0): + kernlen = kernlen + 1; + halfSize = kernlen / 2; + halfSize = kernlen / 2; + r2 = float(halfSize*halfSize) + if sigma is None: + sigma2 = 0.9 * r2; + sigma = np.sqrt(sigma2) + else: + sigma2 = 2.0 * sigma * sigma + x = np.linspace(-halfSize,halfSize,kernlen) + xv, yv = np.meshgrid(x, x, sparse=False, indexing='xy') + distsq = (xv)**2 + (yv)**2 + kernel = np.exp(-( distsq/ (sigma2))) + if circ_zeros: + kernel *= (distsq <= r2).astype(np.float32) + if norm: + kernel /= np.sum(kernel) + return kernel + +def generate_2dgrid(h,w, centered = True): + if centered: + x = torch.linspace(-w/2+1, w/2, w) + y = torch.linspace(-h/2+1, h/2, h) + else: + x = torch.linspace(0, w-1, w) + y = torch.linspace(0, h-1, h) + grid2d = torch.stack([y.repeat(w,1).t().contiguous().view(-1), x.repeat(h)],1) + return grid2d + +def generate_3dgrid(d, h, w, centered = True): + if type(d) is not list: + if centered: + z = torch.linspace(-d/2+1, d/2, d) + else: + z = torch.linspace(0, d-1, d) + dl = d + else: + z = torch.FloatTensor(d) + dl = len(d) + grid2d = generate_2dgrid(h,w, centered = centered) + grid3d = torch.cat([z.repeat(w*h,1).t().contiguous().view(-1,1), grid2d.repeat(dl,1)],dim = 1) + return grid3d + +def zero_response_at_border(x, b): + if (b < x.size(3)) and (b < x.size(2)): + x[:, :, 0:b, :] = 0 + x[:, :, x.size(2) - b: , :] = 0 + x[:, :, :, 0:b] = 0 + x[:, :, :, x.size(3) - b: ] = 0 + else: + return x * 0 + return x + +class GaussianBlur(nn.Module): + def __init__(self, sigma=1.6): + super(GaussianBlur, self).__init__() + weight = self.calculate_weights(sigma) + self.register_buffer('buf', weight) + return + def calculate_weights(self, sigma): + kernel = CircularGaussKernel(sigma = sigma, circ_zeros = False) + h,w = kernel.shape + halfSize = float(h) / 2.; + self.pad = int(np.floor(halfSize)) + return torch.from_numpy(kernel.astype(np.float32)).view(1,1,h,w); + def forward(self, x): + w = Variable(self.buf) + if x.is_cuda: + w = w.cuda() + return F.conv2d(F.pad(x, (self.pad,self.pad,self.pad,self.pad), 'replicate'), w, padding = 0) + +def batch_eig2x2(A): + trace = A[:,0,0] + A[:,1,1] + delta1 = (trace*trace - 4 * ( A[:,0,0]* A[:,1,1] - A[:,1,0]* A[:,0,1])) + mask = delta1 > 0 + delta = torch.sqrt(torch.abs(delta1)) + l1 = mask.float() * (trace + delta) / 2.0 + 1000. * (1.0 - mask.float()) + l2 = mask.float() * (trace - delta) / 2.0 + 0.0001 * (1.0 - mask.float()) + return l1,l2 + +def line_prepender(filename, line): + with open(filename, 'r+') as f: + content = f.read() + f.seek(0, 0) + f.write(line.rstrip('\r\n') + '\n' + content) + return diff --git a/examples/hesaffnet/architectures.py b/examples/hesaffnet/architectures.py new file mode 100644 index 0000000..ade3aca --- /dev/null +++ b/examples/hesaffnet/architectures.py @@ -0,0 +1,190 @@ +from __future__ import division, print_function +import os +import errno +import numpy as np +import sys +from copy import deepcopy +import math +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.autograd import Variable +from Utils import L2Norm, generate_2dgrid +from Utils import str2bool +from LAF import denormalizeLAFs, LAFs2ell, abc2A, extract_patches,normalizeLAFs, get_rotation_matrix +from LAF import get_LAFs_scales, get_normalized_affine_shape +from LAF import rectifyAffineTransformationUpIsUp + +class OriNetFast(nn.Module): + def __init__(self, PS = 16): + super(OriNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 2, kernel_size=int(PS/4), stride=1,padding=1, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/4) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.9) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_rot_matrix = True): + xy = self.features(self.input_norm(input)).view(-1,2) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(angle) + return angle + +class GHH(nn.Module): + def __init__(self, n_in, n_out, s = 4, m = 4): + super(GHH, self).__init__() + self.n_out = n_out + self.s = s + self.m = m + self.conv = nn.Linear(n_in, n_out * s * m) + d = torch.arange(0, s) + self.deltas = -1.0 * (d % 2 != 0).float() + 1.0 * (d % 2 == 0).float() + self.deltas = Variable(self.deltas) + return + def forward(self,x): + x_feats = self.conv(x.view(x.size(0),-1)).view(x.size(0), self.n_out, self.s, self.m); + max_feats = x_feats.max(dim = 3)[0]; + if x.is_cuda: + self.deltas = self.deltas.cuda() + else: + self.deltas = self.deltas.cpu() + out = (max_feats * self.deltas.view(1,1,-1).expand_as(max_feats)).sum(dim = 2) + return out + +class YiNet(nn.Module): + def __init__(self, PS = 28): + super(YiNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 10, kernel_size=5, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding = 1), + nn.Conv2d(10, 20, kernel_size=5, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=4, stride=2, padding = 2), + nn.Conv2d(20, 50, kernel_size=3, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.AdaptiveMaxPool2d(1), + GHH(50, 100), + GHH(100, 2) + ) + self.input_mean = 0.427117081207483 + self.input_std = 0.21888339179665006; + self.PS = PS + return + def import_weights(self, dir_name): + self.features[0].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_W.npy'))).float() + self.features[0].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_b.npy'))).float().view(-1) + self.features[3].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_W.npy'))).float() + self.features[3].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_b.npy'))).float().view(-1) + self.features[6].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_W.npy'))).float() + self.features[6].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_b.npy'))).float().view(-1) + self.features[9].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_W.npy'))).float().view(50, 1600).contiguous().t().contiguous()#.view(1600, 50, 1, 1).contiguous() + self.features[9].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_b.npy'))).float().view(1600) + self.features[10].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_W.npy'))).float().view(100, 32).contiguous().t().contiguous()#.view(32, 100, 1, 1).contiguous() + self.features[10].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_b.npy'))).float().view(32) + self.input_mean = float(np.load(os.path.join(dir_name, 'input_mean.npy'))) + self.input_std = float(np.load(os.path.join(dir_name, 'input_std.npy'))) + return + def input_norm1(self,x): + return (x - self.input_mean) / self.input_std + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def forward(self, input, return_rot_matrix = False): + xy = self.features(self.input_norm(input)) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(-angle) + return angle + +class AffNetFast(nn.Module): + def __init__(self, PS = 32): + super(AffNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 3, kernel_size=8, stride=1, padding=0, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/2) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.8) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_A_matrix = False): + xy = self.features(self.input_norm(input)).view(-1,3) + a1 = torch.cat([1.0 + xy[:,0].contiguous().view(-1,1,1), 0 * xy[:,0].contiguous().view(-1,1,1)], dim = 2).contiguous() + a2 = torch.cat([xy[:,1].contiguous().view(-1,1,1), 1.0 + xy[:,2].contiguous().view(-1,1,1)], dim = 2).contiguous() + return rectifyAffineTransformationUpIsUp(torch.cat([a1,a2], dim = 1).contiguous()) + diff --git a/examples/hesaffnet/hesaffBaum.py b/examples/hesaffnet/hesaffBaum.py new file mode 100644 index 0000000..7126398 --- /dev/null +++ b/examples/hesaffnet/hesaffBaum.py @@ -0,0 +1,50 @@ +#!/usr/bin/python2 -utt +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import numpy as np +import sys +import os +import time + +from PIL import Image +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import torch.optim as optim +from tqdm import tqdm +import math +import torch.nn.functional as F + +from copy import deepcopy + +from SparseImgRepresenter import ScaleSpaceAffinePatchExtractor +from LAF import denormalizeLAFs, LAFs2ell, abc2A +from Utils import line_prepender +from architectures import AffNetFast +from HandCraftedModules import AffineShapeEstimator +USE_CUDA = False +try: + input_img_fname = sys.argv[1] + output_fname = sys.argv[2] + nfeats = int(sys.argv[3]) +except: + print "Wrong input format. Try python hesaffBaum.py imgs/cat.png cat.txt 2000" + sys.exit(1) + +img = Image.open(input_img_fname).convert('RGB') +img = np.mean(np.array(img), axis = 2) + +var_image = torch.autograd.Variable(torch.from_numpy(img.astype(np.float32)), volatile = True) +var_image_reshape = var_image.view(1, 1, var_image.size(0),var_image.size(1)) + +HA = ScaleSpaceAffinePatchExtractor( mrSize = 5.192, num_features = nfeats, border = 5, num_Baum_iters = 16, AffNet = AffineShapeEstimator(patch_size=19)) +if USE_CUDA: + HA = HA.cuda() + var_image_reshape = var_image_reshape.cuda() + +LAFs, resp = HA(var_image_reshape) +ells = LAFs2ell(LAFs.data.cpu().numpy()) + +np.savetxt(output_fname, ells, delimiter=' ', fmt='%10.10f') +line_prepender(output_fname, str(len(ells))) +line_prepender(output_fname, '1.0') diff --git a/examples/hesaffnet/hesaffnet.py b/examples/hesaffnet/hesaffnet.py new file mode 100644 index 0000000..08131fe --- /dev/null +++ b/examples/hesaffnet/hesaffnet.py @@ -0,0 +1,59 @@ +#!/usr/bin/python2 -utt +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import numpy as np +import sys +import os +import time + +from PIL import Image +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import torch.optim as optim +from tqdm import tqdm +import math +import torch.nn.functional as F + +from copy import deepcopy + +from SparseImgRepresenter import ScaleSpaceAffinePatchExtractor +from LAF import denormalizeLAFs, LAFs2ell, abc2A +from Utils import line_prepender +from architectures import AffNetFast +USE_CUDA = False + +try: + input_img_fname = sys.argv[1] + output_fname = sys.argv[2] + nfeats = int(sys.argv[3]) +except: + print "Wrong input format. Try python hesaffnet.py imgs/cat.png cat.txt 2000" + sys.exit(1) + +img = Image.open(input_img_fname).convert('RGB') +img = np.mean(np.array(img), axis = 2) + +var_image = torch.autograd.Variable(torch.from_numpy(img.astype(np.float32)), volatile = True) +var_image_reshape = var_image.view(1, 1, var_image.size(0),var_image.size(1)) + + +AffNetPix = AffNetFast(PS = 32) +weightd_fname = '../../pretrained/AffNet.pth' + +checkpoint = torch.load(weightd_fname) +AffNetPix.load_state_dict(checkpoint['state_dict']) + +AffNetPix.eval() + +HA = ScaleSpaceAffinePatchExtractor( mrSize = 5.192, num_features = nfeats, border = 5, num_Baum_iters = 1, AffNet = AffNetPix) +if USE_CUDA: + HA = HA.cuda() + var_image_reshape = var_image_reshape.cuda() + +LAFs, resp = HA(var_image_reshape) +ells = LAFs2ell(LAFs.data.cpu().numpy()) + +np.savetxt(output_fname, ells, delimiter=' ', fmt='%10.10f') +line_prepender(output_fname, str(len(ells))) +line_prepender(output_fname, '1.0') diff --git a/examples/hesaffnet/img/cat.png b/examples/hesaffnet/img/cat.png new file mode 100755 index 0000000..e15ec60 Binary files /dev/null and b/examples/hesaffnet/img/cat.png differ diff --git a/examples/just_shape/LAF.py b/examples/just_shape/LAF.py new file mode 100644 index 0000000..4aa1633 --- /dev/null +++ b/examples/just_shape/LAF.py @@ -0,0 +1,293 @@ +import numpy as np +import matplotlib.pyplot as plt +from copy import deepcopy +from scipy.spatial.distance import cdist +from numpy.linalg import inv +from scipy.linalg import schur, sqrtm +import torch +from torch.autograd import Variable + +##########numpy +def invSqrt(a,b,c): + eps = 1e-12 + mask = (b != 0) + r1 = mask * (c - a) / (2. * b + eps) + t1 = np.sign(r1) / (np.abs(r1) + np.sqrt(1. + r1*r1)); + r = 1.0 / np.sqrt( 1. + t1*t1) + t = t1*r; + + r = r * mask + 1.0 * (1.0 - mask); + t = t * mask; + + x = 1. / np.sqrt( r*r*a - 2*r*t*b + t*t*c) + z = 1. / np.sqrt( t*t*a + 2*r*t*b + r*r*c) + + d = np.sqrt( x * z) + + x = x / d + z = z / d + + new_a = r*r*x + t*t*z + new_b = -r*t*x + t*r*z + new_c = t*t*x + r*r *z + + return new_a, new_b, new_c + +def Ell2LAF(ell): + A23 = np.zeros((2,3)) + A23[0,2] = ell[0] + A23[1,2] = ell[1] + a = ell[2] + b = ell[3] + c = ell[4] + sc = np.sqrt(np.sqrt(a*c - b*b)) + ia,ib,ic = invSqrt(a,b,c) #because sqrtm returns ::-1, ::-1 matrix, don`t know why + A = np.array([[ia, ib], [ib, ic]]) / sc + sc = np.sqrt(A[0,0] * A[1,1] - A[1,0] * A[0,1]) + A23[0:2,0:2] = rectifyAffineTransformationUpIsUp(A / sc) * sc + return A23 + +def rectifyAffineTransformationUpIsUp_np(A): + det = np.sqrt(np.abs(A[0,0]*A[1,1] - A[1,0]*A[0,1] + 1e-10)) + b2a2 = np.sqrt(A[0,1] * A[0,1] + A[0,0] * A[0,0]) + A_new = np.zeros((2,2)) + A_new[0,0] = b2a2 / det + A_new[0,1] = 0 + A_new[1,0] = (A[1,1]*A[0,1]+A[1,0]*A[0,0])/(b2a2*det) + A_new[1,1] = det / b2a2 + return A_new + +def ells2LAFs(ells): + LAFs = np.zeros((len(ells), 2,3)) + for i in range(len(ells)): + LAFs[i,:,:] = Ell2LAF(ells[i,:]) + return LAFs + +def LAF2pts(LAF, n_pts = 50): + a = np.linspace(0, 2*np.pi, n_pts); + x = [0] + x.extend(list(np.sin(a))) + x = np.array(x).reshape(1,-1) + y = [0] + y.extend(list(np.cos(a))) + y = np.array(y).reshape(1,-1) + HLAF = np.concatenate([LAF, np.array([0,0,1]).reshape(1,3)]) + H_pts =np.concatenate([x,y,np.ones(x.shape)]) + H_pts_out = np.transpose(np.matmul(HLAF, H_pts)) + H_pts_out[:,0] = H_pts_out[:,0] / H_pts_out[:, 2] + H_pts_out[:,1] = H_pts_out[:,1] / H_pts_out[:, 2] + return H_pts_out[:,0:2] + + +def convertLAFs_to_A23format(LAFs): + sh = LAFs.shape + if (len(sh) == 3) and (sh[1] == 2) and (sh[2] == 3): # n x 2 x 3 classical [A, (x;y)] matrix + work_LAFs = deepcopy(LAFs) + elif (len(sh) == 2) and (sh[1] == 7): #flat format, x y scale a11 a12 a21 a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] * LAFs[:,3] + work_LAFs[:,0,1] = LAFs[:,2] * LAFs[:,4] + work_LAFs[:,1,0] = LAFs[:,2] * LAFs[:,5] + work_LAFs[:,1,1] = LAFs[:,2] * LAFs[:,6] + elif (len(sh) == 2) and (sh[1] == 6): #flat format, x y s*a11 s*a12 s*a21 s*a22 + work_LAFs = np.zeros((sh[0], 2,3)) + work_LAFs[:,0,2] = LAFs[:,0] + work_LAFs[:,1,2] = LAFs[:,1] + work_LAFs[:,0,0] = LAFs[:,2] + work_LAFs[:,0,1] = LAFs[:,3] + work_LAFs[:,1,0] = LAFs[:,4] + work_LAFs[:,1,1] = LAFs[:,5] + else: + print 'Unknown LAF format' + return None + return work_LAFs + +def LAFs2ell(in_LAFs): + LAFs = convertLAFs_to_A23format(in_LAFs) + ellipses = np.zeros((len(LAFs),5)) + for i in range(len(LAFs)): + LAF = deepcopy(LAFs[i,:,:]) + scale = np.sqrt(LAF[0,0]*LAF[1,1] - LAF[0,1]*LAF[1, 0] + 1e-10) + u, W, v = np.linalg.svd(LAF[0:2,0:2] / scale, full_matrices=True) + W[0] = 1. / (W[0]*W[0]*scale*scale) + W[1] = 1. / (W[1]*W[1]*scale*scale) + A = np.matmul(np.matmul(u, np.diag(W)), u.transpose()) + ellipses[i,0] = LAF[0,2] + ellipses[i,1] = LAF[1,2] + ellipses[i,2] = A[0,0] + ellipses[i,3] = A[0,1] + ellipses[i,4] = A[1,1] + return ellipses + +def visualize_LAFs(img, LAFs): + work_LAFs = convertLAFs_to_A23format(LAFs) + plt.figure() + plt.imshow(255 - img) + for i in range(len(work_LAFs)): + ell = LAF2pts(work_LAFs[i,:,:]) + plt.plot( ell[:,0], ell[:,1], 'r') + plt.show() + return + +####pytorch + +def get_normalized_affine_shape(tilt, angle_in_radians): + assert tilt.size(0) == angle_in_radians.size(0) + num = tilt.size(0) + tilt_A = Variable(torch.eye(2).view(1,2,2).repeat(num,1,1)) + if tilt.is_cuda: + tilt_A = tilt_A.cuda() + tilt_A[:,0,0] = tilt; + rotmat = get_rotation_matrix(angle_in_radians) + out_A = rectifyAffineTransformationUpIsUp(torch.bmm(rotmat, torch.bmm(tilt_A, rotmat))) + #re_scale = (1.0/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + #re_scale = (0.5 + 0.5/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small + return out_A# * re_scale.view(-1,1,1).expand(num,2,2) + +def get_rotation_matrix(angle_in_radians): + angle_in_radians = angle_in_radians.view(-1, 1, 1); + sin_a = torch.sin(angle_in_radians) + cos_a = torch.cos(angle_in_radians) + A1_x = torch.cat([cos_a, sin_a], dim = 2) + A2_x = torch.cat([-sin_a, cos_a], dim = 2) + transform = torch.cat([A1_x,A2_x], dim = 1) + return transform + +def rectifyAffineTransformationUpIsUp(A): + det = torch.sqrt(torch.abs(A[:,0,0]*A[:,1,1] - A[:,1,0]*A[:,0,1] + 1e-10)) + b2a2 = torch.sqrt(A[:,0,1] * A[:,0,1] + A[:,0,0] * A[:,0,0]) + A1_ell = torch.cat([(b2a2 / det).contiguous().view(-1,1,1), 0 * det.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([((A[:,1,1]*A[:,0,1]+A[:,1,0]*A[:,0,0])/(b2a2*det)).contiguous().view(-1,1,1), + (det / b2a2).contiguous().view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def abc2A(a,b,c, normalize = False): + A1_ell = torch.cat([a.view(-1,1,1), b.view(-1,1,1)], dim = 2) + A2_ell = torch.cat([b.view(-1,1,1), c.view(-1,1,1)], dim = 2) + return torch.cat([A1_ell, A2_ell], dim = 1) + + + +def angles2A(angles): + cos_a = torch.cos(angles).view(-1, 1, 1) + sin_a = torch.sin(angles).view(-1, 1, 1) + A1_ang = torch.cat([cos_a, sin_a], dim = 2) + A2_ang = torch.cat([-sin_a, cos_a], dim = 2) + return torch.cat([A1_ang, A2_ang], dim = 1) + +def generate_patch_grid_from_normalized_LAFs(LAFs, w, h, PS): + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3) * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + grid = torch.nn.functional.affine_grid(LAFs * Variable(coef.expand(num_lafs,2,3)), torch.Size((num_lafs,1,PS,PS))) + grid[:,:,:,0] = 2.0 * grid[:,:,:,0] / float(w) - 1.0 + grid[:,:,:,1] = 2.0 * grid[:,:,:,1] / float(h) - 1.0 + return grid + +def extract_patches(img, LAFs, PS = 32): + w = img.size(3) + h = img.size(2) + ch = img.size(1) + grid = generate_patch_grid_from_normalized_LAFs(LAFs, float(w),float(h), PS) + return torch.nn.functional.grid_sample(img.expand(grid.size(0), ch, h, w), grid) + +def get_pyramid_inverted_index_for_LAFs(LAFs, PS, sigmas): + return + +def extract_patches_from_pyramid_with_inv_index(scale_pyramid, pyr_inv_idxs, LAFs, PS = 19): + patches = torch.zeros(LAFs.size(0),scale_pyramid[0][0].size(1), PS, PS) + if LAFs.is_cuda: + patches = patches.cuda() + patches = Variable(patches) + if pyr_inv_idxs is not None: + for i in range(len(scale_pyramid)): + for j in range(len(scale_pyramid[i])): + cur_lvl_idxs = pyr_inv_idxs[i][j] + if cur_lvl_idxs is None: + continue + cur_lvl_idxs = cur_lvl_idxs.view(-1) + #print i,j,cur_lvl_idxs.shape + patches[cur_lvl_idxs,:,:,:] = extract_patches(scale_pyramid[i][j], LAFs[cur_lvl_idxs, :,:], PS ) + return patches + +def get_inverted_pyr_index(scale_pyr, pyr_idxs, level_idxs): + pyr_inv_idxs = [] + ### Precompute octave inverted indexes + for i in range(len(scale_pyr)): + pyr_inv_idxs.append([]) + cur_idxs = pyr_idxs == i #torch.nonzero((pyr_idxs == i).data) + for j in range(0, len(scale_pyr[i])): + cur_lvl_idxs = torch.nonzero(((level_idxs == j) * cur_idxs).data) + if len(cur_lvl_idxs.size()) == 0: + pyr_inv_idxs[i].append(None) + else: + pyr_inv_idxs[i].append(cur_lvl_idxs.squeeze()) + return pyr_inv_idxs + + +def denormalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() * min_size + coef[0,0,2] = w + coef[0,1,2] = h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def normalizeLAFs(LAFs, w, h): + w = float(w) + h = float(h) + num_lafs = LAFs.size(0) + min_size = min(h,w) + coef = torch.ones(1,2,3).float() / min_size + coef[0,0,2] = 1.0 / w + coef[0,1,2] = 1.0 / h + if LAFs.is_cuda: + coef = coef.cuda() + return Variable(coef.expand(num_lafs,2,3)) * LAFs + +def sc_y_x2LAFs(sc_y_x): + base_LAF = torch.eye(2).float().unsqueeze(0).expand(sc_y_x.size(0),2,2) + if sc_y_x.is_cuda: + base_LAF = base_LAF.cuda() + base_A = Variable(base_LAF, requires_grad=False) + A = sc_y_x[:,:1].unsqueeze(1).expand_as(base_A) * base_A + LAFs = torch.cat([A, + torch.cat([sc_y_x[:,2:].unsqueeze(-1), + sc_y_x[:,1:2].unsqueeze(-1)], dim=1)], dim = 2) + + return LAFs +def get_LAFs_scales(LAFs): + return torch.sqrt(torch.abs(LAFs[:,0,0] *LAFs[:,1,1] - LAFs[:,0,1] * LAFs[:,1,0]) + 1e-12) + +def get_pyramid_and_level_index_for_LAFs(dLAFs, sigmas, pix_dists, PS): + scales = get_LAFs_scales(dLAFs); + needed_sigmas = scales / PS; + sigmas_full_list = [] + level_idxs_full = [] + oct_idxs_full = [] + for oct_idx in range(len(sigmas)): + sigmas_full_list = sigmas_full_list + list(np.array(sigmas[oct_idx])*np.array(pix_dists[oct_idx])) + oct_idxs_full = oct_idxs_full + [oct_idx]*len(sigmas[oct_idx]) + level_idxs_full = level_idxs_full + range(0,len(sigmas[oct_idx])) + oct_idxs_full = torch.LongTensor(oct_idxs_full) + level_idxs_full = torch.LongTensor(level_idxs_full) + + closest_imgs = cdist(np.array(sigmas_full_list).reshape(-1,1), needed_sigmas.data.cpu().numpy().reshape(-1,1)).argmin(axis = 0) + closest_imgs = torch.from_numpy(closest_imgs) + if dLAFs.is_cuda: + closest_imgs = closest_imgs.cuda() + oct_idxs_full = oct_idxs_full.cuda() + level_idxs_full = level_idxs_full.cuda() + return Variable(oct_idxs_full[closest_imgs]), Variable(level_idxs_full[closest_imgs]) diff --git a/examples/just_shape/Utils.py b/examples/just_shape/Utils.py new file mode 100644 index 0000000..c56b5e8 --- /dev/null +++ b/examples/just_shape/Utils.py @@ -0,0 +1,182 @@ +import torch +import torch.nn.init +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import cv2 +import numpy as np + +# resize image to size 32x32 +cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), + interpolation=cv2.INTER_LINEAR) +# reshape image +np_reshape32 = lambda x: np.reshape(x, (32, 32, 1)) +np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) + +def zeros_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.zeros(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.zeros(y) + +def ones_like(x): + assert x.__class__.__name__.find('Variable') != -1 or x.__class__.__name__.find('Tensor') != -1, "Object is neither a Tensor nor a Variable" + y = torch.ones(x.size()) + if x.is_cuda: + y = y.cuda() + if x.__class__.__name__ == 'Variable': + return torch.autograd.Variable(y, requires_grad=x.requires_grad) + elif x.__class__.__name__.find('Tensor') != -1: + return torch.ones(y) + + +def batched_forward(model, data, batch_size, **kwargs): + n_patches = len(data) + if n_patches > batch_size: + bs = batch_size + n_batches = n_patches / bs + 1 + for batch_idx in range(n_batches): + st = batch_idx * bs + if batch_idx == n_batches - 1: + if (batch_idx + 1) * bs > n_patches: + end = n_patches + else: + end = (batch_idx + 1) * bs + else: + end = (batch_idx + 1) * bs + if st >= end: + continue + if batch_idx == 0: + first_batch_out = model(data[st:end], kwargs) + out_size = torch.Size([n_patches] + list(first_batch_out.size()[1:])) + #out_size[0] = n_patches + out = torch.zeros(out_size); + if data.is_cuda: + out = out.cuda() + out = Variable(out) + out[st:end] = first_batch_out + else: + out[st:end,:,:] = model(data[st:end], kwargs) + return out + else: + return model(data, kwargs) + +class L2Norm(nn.Module): + def __init__(self): + super(L2Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) + x= x / norm.unsqueeze(-1).expand_as(x) + return x + +class L1Norm(nn.Module): + def __init__(self): + super(L1Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sum(torch.abs(x), dim = 1) + self.eps + x= x / norm.expand_as(x) + return x + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + +def CircularGaussKernel(kernlen=None, circ_zeros = False, sigma = None, norm = True): + assert ((kernlen is not None) or sigma is not None) + if kernlen is None: + kernlen = int(2.0 * 3.0 * sigma + 1.0) + if (kernlen % 2 == 0): + kernlen = kernlen + 1; + halfSize = kernlen / 2; + halfSize = kernlen / 2; + r2 = float(halfSize*halfSize) + if sigma is None: + sigma2 = 0.9 * r2; + sigma = np.sqrt(sigma2) + else: + sigma2 = 2.0 * sigma * sigma + x = np.linspace(-halfSize,halfSize,kernlen) + xv, yv = np.meshgrid(x, x, sparse=False, indexing='xy') + distsq = (xv)**2 + (yv)**2 + kernel = np.exp(-( distsq/ (sigma2))) + if circ_zeros: + kernel *= (distsq <= r2).astype(np.float32) + if norm: + kernel /= np.sum(kernel) + return kernel + +def generate_2dgrid(h,w, centered = True): + if centered: + x = torch.linspace(-w/2+1, w/2, w) + y = torch.linspace(-h/2+1, h/2, h) + else: + x = torch.linspace(0, w-1, w) + y = torch.linspace(0, h-1, h) + grid2d = torch.stack([y.repeat(w,1).t().contiguous().view(-1), x.repeat(h)],1) + return grid2d + +def generate_3dgrid(d, h, w, centered = True): + if type(d) is not list: + if centered: + z = torch.linspace(-d/2+1, d/2, d) + else: + z = torch.linspace(0, d-1, d) + dl = d + else: + z = torch.FloatTensor(d) + dl = len(d) + grid2d = generate_2dgrid(h,w, centered = centered) + grid3d = torch.cat([z.repeat(w*h,1).t().contiguous().view(-1,1), grid2d.repeat(dl,1)],dim = 1) + return grid3d + +def zero_response_at_border(x, b): + if (b < x.size(3)) and (b < x.size(2)): + x[:, :, 0:b, :] = 0 + x[:, :, x.size(2) - b: , :] = 0 + x[:, :, :, 0:b] = 0 + x[:, :, :, x.size(3) - b: ] = 0 + else: + return x * 0 + return x + +class GaussianBlur(nn.Module): + def __init__(self, sigma=1.6): + super(GaussianBlur, self).__init__() + weight = self.calculate_weights(sigma) + self.register_buffer('buf', weight) + return + def calculate_weights(self, sigma): + kernel = CircularGaussKernel(sigma = sigma, circ_zeros = False) + h,w = kernel.shape + halfSize = float(h) / 2.; + self.pad = int(np.floor(halfSize)) + return torch.from_numpy(kernel.astype(np.float32)).view(1,1,h,w); + def forward(self, x): + w = Variable(self.buf) + if x.is_cuda: + w = w.cuda() + return F.conv2d(F.pad(x, (self.pad,self.pad,self.pad,self.pad), 'replicate'), w, padding = 0) + +def batch_eig2x2(A): + trace = A[:,0,0] + A[:,1,1] + delta1 = (trace*trace - 4 * ( A[:,0,0]* A[:,1,1] - A[:,1,0]* A[:,0,1])) + mask = delta1 > 0 + delta = torch.sqrt(torch.abs(delta1)) + l1 = mask.float() * (trace + delta) / 2.0 + 1000. * (1.0 - mask.float()) + l2 = mask.float() * (trace - delta) / 2.0 + 0.0001 * (1.0 - mask.float()) + return l1,l2 + +def line_prepender(filename, line): + with open(filename, 'r+') as f: + content = f.read() + f.seek(0, 0) + f.write(line.rstrip('\r\n') + '\n' + content) + return diff --git a/examples/just_shape/architectures.py b/examples/just_shape/architectures.py new file mode 100644 index 0000000..ade3aca --- /dev/null +++ b/examples/just_shape/architectures.py @@ -0,0 +1,190 @@ +from __future__ import division, print_function +import os +import errno +import numpy as np +import sys +from copy import deepcopy +import math +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.autograd import Variable +from Utils import L2Norm, generate_2dgrid +from Utils import str2bool +from LAF import denormalizeLAFs, LAFs2ell, abc2A, extract_patches,normalizeLAFs, get_rotation_matrix +from LAF import get_LAFs_scales, get_normalized_affine_shape +from LAF import rectifyAffineTransformationUpIsUp + +class OriNetFast(nn.Module): + def __init__(self, PS = 16): + super(OriNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 2, kernel_size=int(PS/4), stride=1,padding=1, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/4) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.9) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_rot_matrix = True): + xy = self.features(self.input_norm(input)).view(-1,2) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(angle) + return angle + +class GHH(nn.Module): + def __init__(self, n_in, n_out, s = 4, m = 4): + super(GHH, self).__init__() + self.n_out = n_out + self.s = s + self.m = m + self.conv = nn.Linear(n_in, n_out * s * m) + d = torch.arange(0, s) + self.deltas = -1.0 * (d % 2 != 0).float() + 1.0 * (d % 2 == 0).float() + self.deltas = Variable(self.deltas) + return + def forward(self,x): + x_feats = self.conv(x.view(x.size(0),-1)).view(x.size(0), self.n_out, self.s, self.m); + max_feats = x_feats.max(dim = 3)[0]; + if x.is_cuda: + self.deltas = self.deltas.cuda() + else: + self.deltas = self.deltas.cpu() + out = (max_feats * self.deltas.view(1,1,-1).expand_as(max_feats)).sum(dim = 2) + return out + +class YiNet(nn.Module): + def __init__(self, PS = 28): + super(YiNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 10, kernel_size=5, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=3, stride=2, padding = 1), + nn.Conv2d(10, 20, kernel_size=5, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.MaxPool2d(kernel_size=4, stride=2, padding = 2), + nn.Conv2d(20, 50, kernel_size=3, stride=1, padding=0, bias = True), + nn.ReLU(), + nn.AdaptiveMaxPool2d(1), + GHH(50, 100), + GHH(100, 2) + ) + self.input_mean = 0.427117081207483 + self.input_std = 0.21888339179665006; + self.PS = PS + return + def import_weights(self, dir_name): + self.features[0].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_W.npy'))).float() + self.features[0].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer0_b.npy'))).float().view(-1) + self.features[3].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_W.npy'))).float() + self.features[3].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer1_b.npy'))).float().view(-1) + self.features[6].weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_W.npy'))).float() + self.features[6].bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer2_b.npy'))).float().view(-1) + self.features[9].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_W.npy'))).float().view(50, 1600).contiguous().t().contiguous()#.view(1600, 50, 1, 1).contiguous() + self.features[9].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer3_b.npy'))).float().view(1600) + self.features[10].conv.weight.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_W.npy'))).float().view(100, 32).contiguous().t().contiguous()#.view(32, 100, 1, 1).contiguous() + self.features[10].conv.bias.data = torch.from_numpy(np.load(os.path.join(dir_name, 'layer4_b.npy'))).float().view(32) + self.input_mean = float(np.load(os.path.join(dir_name, 'input_mean.npy'))) + self.input_std = float(np.load(os.path.join(dir_name, 'input_std.npy'))) + return + def input_norm1(self,x): + return (x - self.input_mean) / self.input_std + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def forward(self, input, return_rot_matrix = False): + xy = self.features(self.input_norm(input)) + angle = torch.atan2(xy[:,0] + 1e-8, xy[:,1]+1e-8); + if return_rot_matrix: + return get_rotation_matrix(-angle) + return angle + +class AffNetFast(nn.Module): + def __init__(self, PS = 32): + super(AffNetFast, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(16, affine=False), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(32, affine=False), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias = False), + nn.BatchNorm2d(64, affine=False), + nn.ReLU(), + nn.Dropout(0.25), + nn.Conv2d(64, 3, kernel_size=8, stride=1, padding=0, bias = True), + nn.Tanh(), + nn.AdaptiveAvgPool2d(1) + ) + self.PS = PS + self.features.apply(self.weights_init) + self.halfPS = int(PS/2) + return + def input_norm(self,x): + flat = x.view(x.size(0), -1) + mp = torch.mean(flat, dim=1) + sp = torch.std(flat, dim=1) + 1e-7 + return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) + def weights_init(self,m): + if isinstance(m, nn.Conv2d): + nn.init.orthogonal(m.weight.data, gain=0.8) + try: + nn.init.constant(m.bias.data, 0.01) + except: + pass + return + def forward(self, input, return_A_matrix = False): + xy = self.features(self.input_norm(input)).view(-1,3) + a1 = torch.cat([1.0 + xy[:,0].contiguous().view(-1,1,1), 0 * xy[:,0].contiguous().view(-1,1,1)], dim = 2).contiguous() + a2 = torch.cat([xy[:,1].contiguous().view(-1,1,1), 1.0 + xy[:,2].contiguous().view(-1,1,1)], dim = 2).contiguous() + return rectifyAffineTransformationUpIsUp(torch.cat([a1,a2], dim = 1).contiguous()) + diff --git a/examples/just_shape/detect_affine_shape.py b/examples/just_shape/detect_affine_shape.py new file mode 100644 index 0000000..655b7d1 --- /dev/null +++ b/examples/just_shape/detect_affine_shape.py @@ -0,0 +1,70 @@ +#!/usr/bin/python2 -utt +# -*- coding: utf-8 -*- +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import time +import os +import cv2 +import math +import numpy as np +from architectures import AffNetFast +PS = 32 +USE_CUDA = False + + +model = AffNetFast(PS = PS) +weightd_fname = '../../pretrained/AffNet.pth' + +checkpoint = torch.load(weightd_fname) +model.load_state_dict(checkpoint['state_dict']) + +model.eval() +if USE_CUDA: + model.cuda() + +try: + input_img_fname = sys.argv[1] + output_fname = sys.argv[2] +except: + print "Wrong input format. Try ./detect_affine_shape.py imgs/ref.png out.txt" + sys.exit(1) + +image = cv2.imread(input_img_fname,0) +h,w = image.shape + +n_patches = h/w + +descriptors_for_net = np.zeros((n_patches, 4)) + +patches = np.ndarray((n_patches, 1, PS, PS), dtype=np.float32) +for i in range(n_patches): + patch = image[i*(w): (i+1)*(w), 0:w] + patches[i,0,:,:] = cv2.resize(patch,(PS,PS)) / 255. +bs = 128; +outs = [] +n_batches = n_patches / bs + 1 +t = time.time() +for batch_idx in range(n_batches): + if batch_idx == n_batches - 1: + if (batch_idx + 1) * bs > n_patches: + end = n_patches + else: + end = (batch_idx + 1) * bs + else: + end = (batch_idx + 1) * bs + if batch_idx * bs >= end: + continue + data_a = patches[batch_idx * bs: end, :, :, :].astype(np.float32) + data_a = torch.from_numpy(data_a) + if USE_CUDA: + data_a = data_a.cuda() + data_a = Variable(data_a, volatile=True) + # compute output + out_a = model(data_a) + descriptors_for_net[batch_idx * bs: end,:] = out_a.data.cpu().numpy().reshape(-1, 4) +et = time.time() - t +np.savetxt(output_fname, descriptors_for_net, delimiter=' ', fmt='%10.5f') diff --git a/examples/just_shape/img/face.png b/examples/just_shape/img/face.png new file mode 100644 index 0000000..8da71ca Binary files /dev/null and b/examples/just_shape/img/face.png differ diff --git a/gen_ds.py b/gen_ds.py new file mode 100644 index 0000000..c31c6a8 --- /dev/null +++ b/gen_ds.py @@ -0,0 +1,86 @@ + +import os +import errno +import numpy as np +from PIL import Image +import torchvision.datasets as dset + +import sys +from copy import deepcopy +import argparse +import math +import torch.utils.data as data +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision.transforms as transforms +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from tqdm import tqdm +import random +import cv2 +import copy +from Utils import str2bool + +from dataset import TripletPhotoTour +root='dataset' +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='notredame', + download=True, + transform=None), + batch_size=128, + shuffle=False) + +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='yosemite', + download=True, + transform=None), + batch_size=128, + shuffle=False) + +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='liberty', + download=True, + transform=None), + batch_size=128, + shuffle=False) +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='notredame_harris', + download=True, + transform=None), + batch_size=128, + shuffle=False) +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='yosemite_harris', + download=True, + transform=None), + batch_size=128, + shuffle=False) +train_loader = torch.utils.data.DataLoader( + TripletPhotoTour(train=True, + batch_size=128, + root=root, + name='liberty_harris', + download=True, + transform=None), + batch_size=128, + shuffle=False) + diff --git a/pretrained/AffNet.pth b/pretrained/AffNet.pth new file mode 100644 index 0000000..bc03ebd Binary files /dev/null and b/pretrained/AffNet.pth differ diff --git a/pretrained/AffNetFast.caffemodel b/pretrained/AffNetFast.caffemodel new file mode 100644 index 0000000..a6d909e Binary files /dev/null and b/pretrained/AffNetFast.caffemodel differ diff --git a/pretrained/AffNetFast.prototxt b/pretrained/AffNetFast.prototxt new file mode 100644 index 0000000..1c38abb --- /dev/null +++ b/pretrained/AffNetFast.prototxt @@ -0,0 +1,365 @@ +name: "HardNet" +layer { + name: "data" + type: "Input" + top: "data" + input_param { shape: { dim: 256 dim: 1 dim: 32 dim: 32 } } +} +layer { + name: "data_norm" + type: "MVN" + bottom: "data" + top: "data_norm" +} + + +layer { + name: "conv1" + type: "Convolution" + bottom: "data_norm" + top: "conv1" + convolution_param { + num_output: 16 + kernel_size: 3 + stride: 1 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv1_BN" + type: "BatchNorm" + bottom: "conv1" + top: "conv1_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu1" + type: "ReLU" + bottom: "conv1_BN" + top: "relu1" +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "relu1" + top: "conv2" + convolution_param { + num_output: 16 + kernel_size: 3 + stride: 1 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv2_BN" + type: "BatchNorm" + bottom: "conv2" + top: "conv2_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu2" + type: "ReLU" + bottom: "conv2_BN" + top: "relu2" +} + + +layer { + name: "conv3" + type: "Convolution" + bottom: "relu2" + top: "conv3" + convolution_param { + num_output: 32 + kernel_size: 3 + stride: 2 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv3_BN" + type: "BatchNorm" + bottom: "conv3" + top: "conv3_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu3" + type: "ReLU" + bottom: "conv3_BN" + top: "relu3" + +} + + + +layer { + name: "conv4" + type: "Convolution" + bottom: "relu3" + top: "conv4" + convolution_param { + num_output: 32 + kernel_size: 3 + stride: 1 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv4_BN" + type: "BatchNorm" + bottom: "conv4" + top: "conv4_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu4" + type: "ReLU" + bottom: "conv4_BN" + top: "relu4" +} + + +layer { + name: "conv5" + type: "Convolution" + bottom: "relu4" + top: "conv5" + convolution_param { + num_output: 64 + kernel_size: 3 + stride: 2 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv5_BN" + type: "BatchNorm" + bottom: "conv5" + top: "conv5_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu5" + type: "ReLU" + bottom: "conv5_BN" + top: "relu5" +} + +layer { + name: "conv6" + type: "Convolution" + bottom: "relu5" + top: "conv6" + convolution_param { + num_output: 64 + kernel_size: 3 + stride: 1 + pad: 1 + bias_term: false + } +} + + layer { + name: "conv6_BN" + type: "BatchNorm" + bottom: "conv6" + top: "conv6_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { eps: 1e-5 + use_global_stats: true + moving_average_fraction: 0.9 + } +} + +layer { + name: "relu6" + type: "ReLU" + bottom: "conv6_BN" + top: "relu6" +} + + +layer { + name: "conv7" + type: "Convolution" + bottom: "relu6" + top: "conv7" + convolution_param { + num_output: 3 + kernel_size: 8 + stride: 1 + #padding: 1 + } +} + +layer { + name: "relu7" + type: "TanH" + bottom: "conv7" + top: "relu7" +} + +layer{ + name: "finalpool" + type: "Pooling" + bottom: "relu7" + top: "finalpool" + pooling_param { + pool: AVE + global_pooling: true + } +} + +layer { + name: "flatten" + type: "Flatten" + bottom: "finalpool" + top: "flatten" +} + +layer { + name: "Slice" + type: "Slice" + bottom: "flatten" + top: "a11" + top: "a21" + top: "a22" + slice_param { + axis: 1 + slice_point: 1 + slice_point: 2 + } + +} +layer { + name: "PlusOne1" + type: "Power" + bottom: "a11" + top: "a11" + power_param { + shift: 1 + } + } + +layer { + name: "PlusOne2" + type: "Power" + bottom: "a22" + top: "a22" + power_param { + shift: 1 + } + } + + layer { + name: "final" + top: "final" + bottom: "a11" + bottom: "a21" + bottom: "a22" + type: "Concat" + } diff --git a/pytorch_sift.py b/pytorch_sift.py new file mode 100644 index 0000000..a9e39c2 --- /dev/null +++ b/pytorch_sift.py @@ -0,0 +1,94 @@ +import torch +import math +import torch.nn.init +import torch.nn as nn +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +import numpy as np +import torch.nn.functional as F + +class L2Norm(nn.Module): + def __init__(self): + super(L2Norm,self).__init__() + self.eps = 1e-10 + def forward(self, x): + norm = torch.sqrt(torch.abs(torch.sum(x * x, dim = 1)) + self.eps) + x= x / norm.unsqueeze(1).expand_as(x) + return x + +def getPoolingKernel(kernel_size = 25): + step = 1. / float(np.floor( kernel_size / 2.)); + x_coef = np.arange(step/2., 1. ,step) + xc2 = np.hstack([x_coef,[1], x_coef[::-1]]) + kernel = np.outer(xc2.T,xc2) + kernel = np.maximum(0,kernel) + return kernel +def get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins): + bin_weight_stride = int(round(2.0 * math.floor(patch_size / 2) / float(num_spatial_bins + 1))) + bin_weight_kernel_size = int(2 * bin_weight_stride - 1); + return bin_weight_kernel_size, bin_weight_stride +class SIFTNet(nn.Module): + def CircularGaussKernel(self,kernlen=21): + halfSize = kernlen / 2; + r2 = float(halfSize*halfSize); + sigma2 = 0.9 * r2; + disq = 0; + kernel = np.zeros((kernlen,kernlen)) + for y in range(kernlen): + for x in range(kernlen): + disq = (y - halfSize)*(y - halfSize) + (x - halfSize)*(x - halfSize); + if disq < r2: + kernel[y,x] = math.exp(-disq / sigma2) + else: + kernel[y,x] = 0. + return kernel + def __init__(self, patch_size = 65, num_ang_bins = 8, num_spatial_bins = 4, clipval = 0.2): + super(SIFTNet, self).__init__() + gk = torch.from_numpy(self.CircularGaussKernel(kernlen=patch_size).astype(np.float32)) + self.bin_weight_kernel_size, self.bin_weight_stride = get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins) + self.gk = Variable(gk) + self.num_ang_bins = num_ang_bins + self.num_spatial_bins = num_spatial_bins + self.clipval = clipval + self.gx = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(1,3), bias = False)) + for l in self.gx: + if isinstance(l, nn.Conv2d): + l.weight.data = torch.from_numpy(np.array([[[[-1, 0, 1]]]], dtype=np.float32)) + self.gy = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(3,1), bias = False)) + for l in self.gy: + if isinstance(l, nn.Conv2d): + l.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32)) + self.pk = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(self.bin_weight_kernel_size, self.bin_weight_kernel_size), + stride = (self.bin_weight_stride, self.bin_weight_stride), + bias = False)) + for l in self.pk: + if isinstance(l, nn.Conv2d): + nw = getPoolingKernel(kernel_size = self.bin_weight_kernel_size) + new_weights = np.array(nw.reshape((1, 1, self.bin_weight_kernel_size, self.bin_weight_kernel_size))) + l.weight.data = torch.from_numpy(new_weights.astype(np.float32)) + def forward(self, x): + gx = self.gx(F.pad(x, (1,1,0, 0), 'replicate')) + gy = self.gy(F.pad(x, (0,0, 1,1), 'replicate')) + mag = torch.sqrt(gx **2 + gy **2 + 1e-10) + ori = torch.atan2(gy,gx + 1e-8) + if x.is_cuda: + self.gk = self.gk.cuda() + else: + self.gk = self.gk.cpu() + mag = mag * self.gk.expand_as(mag) + o_big = (ori +2.0 * math.pi )/ (2.0 * math.pi) * float(self.num_ang_bins) + bo0_big = torch.floor(o_big) + wo1_big = o_big - bo0_big + bo0_big = bo0_big % self.num_ang_bins + bo1_big = (bo0_big + 1) % self.num_ang_bins + wo0_big = (1.0 - wo1_big) * mag + wo1_big = wo1_big * mag + ang_bins = [] + for i in range(0, self.num_ang_bins): + ang_bins.append(self.pk((bo0_big == i).float() * wo0_big + (bo1_big == i).float() * wo1_big)) + ang_bins = torch.cat(ang_bins,1) + ang_bins = ang_bins.view(ang_bins.size(0), -1) + ang_bins = L2Norm()(ang_bins) + ang_bins = torch.clamp(ang_bins, 0.,float(self.clipval)) + ang_bins = L2Norm()(ang_bins) + return ang_bins diff --git a/run_me.sh b/run_me.sh new file mode 100755 index 0000000..94f26ab --- /dev/null +++ b/run_me.sh @@ -0,0 +1,9 @@ +#!/bin/bash +mkdir dataset +mkdir dataset/HP_HessianPatches +wget http://cmp.felk.cvut.cz/~mishkdmy/datasets/HPatches_HessianPatches/_test.pt +mv _test.pt dataset/HP_HessianPatches/_test.pt +mkdir dataset/6Brown +mv dataset/*.pt dataset/6Brown +python -utt gen_ds.py +python -utt train_AffNet.py --gpu-id=0 --dataroot=dataset/6Brown --lr=0.005 --n-pairs=10000000 --batch-size=1024 --descriptor=SIFT --epochs=20 --expname=AffNetFast_lr005_10M_20ep_aswap 2>&1 | tee affnet.log & diff --git a/train_AffNet.py b/train_AffNet.py new file mode 100644 index 0000000..9276fad --- /dev/null +++ b/train_AffNet.py @@ -0,0 +1,328 @@ +#from __future__ import division, print_function +import os +import errno +import numpy as np +from PIL import Image + +import sys +from copy import deepcopy +import argparse +import math +import torch.utils.data as data +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.datasets as dset + +import torchvision.transforms as transforms +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from tqdm import tqdm +import random +import cv2 +import copy +from Utils import L2Norm, cv2_scale +#from Utils import np_reshape64 as np_reshape +np_reshape = lambda x: np.reshape(x, (64, 64, 1)) +from Utils import str2bool +from dataset import HPatchesDM,TripletPhotoTour, TotalDatasetsLoader +cv2_scale40 = lambda x: cv2.resize(x, dsize=(40, 40), + interpolation=cv2.INTER_LINEAR) +from augmentation import get_random_norm_affine_LAFs,get_random_rotation_LAFs, get_random_shifts_LAFs +from LAF import denormalizeLAFs, LAFs2ell, abc2A, extract_patches,normalizeLAFs +from pytorch_sift import SIFTNet +from HardNet import HardNet +from Losses import loss_HardNet +PS = 32 +# Training settings +parser = argparse.ArgumentParser(description='PyTorch AffNet') + +parser.add_argument('--dataroot', type=str, + default='datasets/', + help='path to dataset') +parser.add_argument('--log-dir', default='./logs', + help='folder to output model checkpoints') +parser.add_argument('--num-workers', default= 8, + help='Number of workers to be created') +parser.add_argument('--pin-memory',type=bool, default= True, + help='') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--epochs', type=int, default=10, metavar='E', + help='number of epochs to train (default: 10)') +parser.add_argument('--batch-size', type=int, default=128, metavar='BS', + help='input batch size for training (default: 128)') +parser.add_argument('--test-batch-size', type=int, default=1024, metavar='BST', + help='input batch size for testing (default: 1000)') +parser.add_argument('--n-pairs', type=int, default=500000, metavar='N', + help='how many pairs will generate from the dataset') +parser.add_argument('--n-test-pairs', type=int, default=50000, metavar='N', + help='how many pairs will generate from the test dataset') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +# Device options +parser.add_argument('--no-cuda', action='store_true', default=False, + help='enables CUDA training') +parser.add_argument('--gpu-id', default='0', type=str, + help='id(s) for CUDA_VISIBLE_DEVICES') +parser.add_argument('--expname', default='', type=str, + help='experiment name') +parser.add_argument('--seed', type=int, default=0, metavar='S', + help='random seed (default: 0)') +parser.add_argument('--log-interval', type=int, default=10, metavar='LI', + help='how many batches to wait before logging training status') +parser.add_argument('--descriptor', type=str, + default='pixels', + help='what is minimized. Variants: pixels, SIFT, HardNet') +parser.add_argument('--merge', type=str, + default='sum', + help='Combination of geom loss and descriptor loss: mul, sum') +parser.add_argument('--geom-loss-coef', type=float, + default=1.0, + help='coef of geom loss (linear if sum, power if mul) (defualt 1.0') +parser.add_argument('--descr-loss-coef', type=float, + default=0.0, + help='coef of descr loss (linear if sum, power if mul (default 0)') + + +args = parser.parse_args() + + +# set the device to use by setting CUDA_VISIBLE_DEVICES env variable in +# order to prevent any memory allocation on unused GPUs +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id + +args.cuda = not args.no_cuda and torch.cuda.is_available() +if args.cuda: + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + +# create loggin directory +if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + +# set random seeds +torch.manual_seed(args.seed) +np.random.seed(args.seed) + +if args.descriptor == 'SIFT': + descriptor = SIFTNet(patch_size=PS) + if not args.no_cuda: + descriptor = descriptor.cuda() +elif args.descriptor == 'HardNet': + descriptor = HardNet() + if not args.no_cuda: + descriptor = descriptor.cuda() + model_weights = 'HardNet++.pth' + hncheckpoint = torch.load(model_weights) + descriptor.load_state_dict(hncheckpoint['state_dict']) + descriptor.train() +else: + descriptor = lambda x: x.view(x.size(0),-1) + +suffix= args.expname + '_6Brown_' + args.merge + '_' + args.descriptor + '_' + str(args.lr) + '_' + str(args.n_pairs) +##########################################3 +def create_loaders(): + + kwargs = {'num_workers': args.num_workers, 'pin_memory': args.pin_memory} if args.cuda else {} + transform = transforms.Compose([ + transforms.Lambda(np_reshape), + transforms.ToTensor() + ]) + + train_loader = torch.utils.data.DataLoader( + TotalDatasetsLoader(datasets_path = args.dataroot, train=True, + n_triplets = args.n_pairs, + fliprot=True, + batch_size=args.batch_size, + download=True, + transform=transform), + batch_size=args.batch_size, + shuffle=False, **kwargs) + + test_loader = torch.utils.data.DataLoader( + HPatchesDM('dataset/HP_HessianPatches/','', train=False, + n_pairs = args.n_test_pairs, + batch_size=args.test_batch_size, + download=True, + transform=transforms.Compose([])), + batch_size=args.test_batch_size, + shuffle=False, **kwargs) + return train_loader, test_loader + +def extract_and_crop_patches_by_predicted_transform(patches, trans, crop_size = 32): + assert patches.size(0) == trans.size(0) + st = int((patches.size(2) - crop_size) / 2) + fin = st + crop_size + rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1)); + if patches.is_cuda: + rot_LAFs = rot_LAFs.cuda() + trans = trans.cuda() + rot_LAFs1 = torch.cat([torch.bmm(trans, rot_LAFs[:,0:2,0:2]), rot_LAFs[:,0:2,2:]], dim = 2); + return extract_patches(patches, rot_LAFs1, PS = patches.size(2))[:,:, st:fin, st:fin].contiguous() + + +def train(train_loader, model, optimizer, epoch): + # switch to train mode + model.train() + pbar = tqdm(enumerate(train_loader)) + for batch_idx, data in pbar: + data_a, data_p = data + if args.cuda: + data_a, data_p = data_a.float().cuda(), data_p.float().cuda() + data_a, data_p = Variable(data_a), Variable(data_p) + st = int((data_p.size(2) - model.PS)/2) + fin = st + model.PS + # + # + max_tilt = 3.0 + if epoch > 1: + max_tilt = 4.0 + if epoch > 3: + max_tilt = 4.5 + if epoch > 5: + max_tilt = 4.8 + rot_LAFs_a, inv_rotmat_a = get_random_rotation_LAFs(data_a, math.pi) + aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, max_tilt); + aff_LAFs_a[:,0:2,0:2] = torch.bmm(rot_LAFs_a[:,0:2,0:2],aff_LAFs_a[:,0:2,0:2]) + data_a_aff = extract_patches(data_a, aff_LAFs_a, PS = data_a.size(2)) + data_a_aff_crop = data_a_aff[:,:, st:fin, st:fin].contiguous() + aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, max_tilt); + aff_LAFs_p[:,0:2,0:2] = torch.bmm(rot_LAFs_a[:,0:2,0:2],aff_LAFs_p[:,0:2,0:2]) + data_p_aff = extract_patches(data_p, aff_LAFs_p, PS = data_p.size(2)) + data_p_aff_crop = data_p_aff[:,:, st:fin, st:fin].contiguous() + out_a_aff, out_p_aff = model(data_a_aff_crop,True), model(data_p_aff_crop,True) + out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff) + out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff) + ######Apply rot and get sifts + out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_aff, out_a_aff, crop_size = model.PS) + out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p_aff, out_p_aff, crop_size = model.PS) + desc_a = descriptor(out_patches_a_crop) + desc_p = descriptor(out_patches_p_crop) + descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6) + descr_loss = loss_HardNet(desc_a,desc_p, anchor_swap = True); + geom_dist = torch.sqrt(((out_a_aff_back - out_p_aff_back)**2 ).view(-1,4).mean(dim=1) + 1e-8) + if args.merge == 'sum': + loss = descr_loss + elif args.merge == 'mul': + loss = descr_loss + else: + print ('Unknown merge option') + sys.exit(0) + optimizer.zero_grad() + loss.backward() + optimizer.step() + adjust_learning_rate(optimizer) + if batch_idx % 2 == 0: + pbar.set_description( + 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {},{:.4f}'.format( + epoch, batch_idx * len(data_a), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.data[0], geom_dist.mean().data[0], descr_dist.mean().data[0])) + torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, + '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch)) + +def test(test_loader, model, epoch): + # switch to evaluate mode + model.eval() + geom_distances, desc_distances = [], [] + pbar = tqdm(enumerate(test_loader)) + for batch_idx, data in pbar: + data_a, data_p = data + if args.cuda: + data_a, data_p = data_a.float().cuda(), data_p.float().cuda() + data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True) + st = int((data_p.size(2) - model.PS)/2) + fin = st + model.PS + aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, 3.0); + shift_w_a, shift_h_a = get_random_shifts_LAFs(data_a, 3, 3) + aff_LAFs_a[:,0,2] = aff_LAFs_a[:,0,2] + shift_w_a / float(data_a.size(3)) + aff_LAFs_a[:,1,2] = aff_LAFs_a[:,1,2] + shift_h_a / float(data_a.size(2)) + data_a_aff = extract_patches(data_a, aff_LAFs_a, PS = data_a.size(2)) + data_a_aff_crop = data_a_aff[:,:, st:fin, st:fin].contiguous() + aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, 3.0); + shift_w_p, shift_h_p = get_random_shifts_LAFs(data_p, 3, 3) + aff_LAFs_p[:,0,2] = aff_LAFs_p[:,0,2] + shift_w_p / float(data_a.size(3)) + aff_LAFs_p[:,1,2] = aff_LAFs_p[:,1,2] + shift_h_p / float(data_a.size(2)) + data_p_aff = extract_patches(data_p, aff_LAFs_p, PS = data_p.size(2)) + data_p_aff_crop = data_p_aff[:,:, st:fin, st:fin].contiguous() + out_a_aff, out_p_aff = model(data_a_aff_crop,True), model(data_p_aff_crop,True) + out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff) + out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff) + ######Apply rot and get sifts + out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_aff, out_a_aff, crop_size = model.PS) + out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p_aff, out_p_aff, crop_size = model.PS) + desc_a = descriptor(out_patches_a_crop) + desc_p = descriptor(out_patches_p_crop) + descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6) / float(desc_a.size(1)) + geom_dist = torch.sqrt(((out_a_aff_back - out_p_aff_back)**2 ).view(-1,4).mean(dim=1) + 1e-8) + geom_distances.append(geom_dist.mean().data.cpu().numpy().reshape(-1,1)) + desc_distances.append(descr_dist.mean().data.cpu().numpy().reshape(-1,1)) + if batch_idx % args.log_interval == 0: + pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( + epoch, batch_idx * len(data_a), len(test_loader.dataset), + 100. * batch_idx / len(test_loader))) + + geom_distances = np.vstack(geom_distances).reshape(-1,1) + desc_distances = np.vstack(desc_distances).reshape(-1,1) + print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(geom_distances.mean())) + print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(desc_distances.mean())) + return + +def adjust_learning_rate(optimizer): + """Updates the learning rate given the learning rate decay. + The routine has been implemented according to the original Lua SGD optimizer + """ + for group in optimizer.param_groups: + if 'step' not in group: + group['step'] = 0. + else: + group['step'] += 1. + group['lr'] = args.lr * ( + 1.0 - float(group['step']) * float(args.batch_size) / (args.n_pairs * float(args.epochs))) + return + +def create_optimizer(model, new_lr): + optimizer = optim.SGD(model.parameters(), lr=new_lr, + momentum=0.9, dampening=0.9, + weight_decay=args.wd) + return optimizer + +def main(train_loader, test_loader, model): + # print the experiment configuration + print('\nparsed options:\n{}\n'.format(vars(args))) + if args.cuda: + model.cuda() + optimizer1 = create_optimizer(model, args.lr) + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print('=> loading checkpoint {}'.format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint['state_dict']) + else: + print('=> no checkpoint found at {}'.format(args.resume)) + start = args.start_epoch + end = start + args.epochs + for epoch in range(start, end): + # iterate over test loaders and test results + train(train_loader, model, optimizer1, epoch) + #test(test_loader, model, epoch) + return 0 +if __name__ == '__main__': + LOG_DIR = args.log_dir + LOG_DIR = os.path.join(args.log_dir,suffix) + if not os.path.isdir(LOG_DIR): + os.makedirs(LOG_DIR) + from architectures import AffNetFast + model = AffNetFast(PS=PS) + train_loader, test_loader = create_loaders() + main(train_loader, test_loader, model) diff --git a/train_ONet.py b/train_ONet.py new file mode 100644 index 0000000..0ed9976 --- /dev/null +++ b/train_ONet.py @@ -0,0 +1,311 @@ +#from __future__ import division, print_function +import os +import errno +import numpy as np +from PIL import Image + +import sys +from copy import deepcopy +import argparse +import math +import torch.utils.data as data +import torch +import torch.nn.init +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision.transforms as transforms +from torch.autograd import Variable +import torch.backends.cudnn as cudnn +from tqdm import tqdm +import random +import cv2 +import copy +from Utils import L2Norm, cv2_scale, generate_2dgrid +from Utils import str2bool +np_reshape = lambda x: np.reshape(x, (64, 64, 1)) + +from dataset import HPatchesDM,TotalDatasetsLoader +cv2_scale16 = lambda x: cv2.resize(x, dsize=(16, 16), + interpolation=cv2.INTER_LINEAR) +from augmentation import get_random_rotation_LAFs, get_random_shifts_LAFs +from LAF import denormalizeLAFs, LAFs2ell, abc2A, extract_patches,normalizeLAFs +from pytorch_sift import SIFTNet +from HardNet import HardNet,HardNetNarELU +from Losses import loss_HardNet +# Training settings +parser = argparse.ArgumentParser(description='PyTorch OriNet') + +parser.add_argument('--dataroot', type=str, + default='datasets/', + help='path to dataset') +parser.add_argument('--log-dir', default='./logs', + help='folder to output model checkpoints') +parser.add_argument('--num-workers', default= 8, + help='Number of workers to be created') +parser.add_argument('--pin-memory',type=bool, default= True, + help='') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--epochs', type=int, default=10, metavar='E', + help='number of epochs to train (default: 10)') +parser.add_argument('--batch-size', type=int, default=128, metavar='BS', + help='input batch size for training (default: 128)') +parser.add_argument('--test-batch-size', type=int, default=1024, metavar='BST', + help='input batch size for testing (default: 1000)') +parser.add_argument('--n-pairs', type=int, default=500000, metavar='N', + help='how many pairs will generate from the dataset') +parser.add_argument('--n-test-pairs', type=int, default=500000, metavar='N', + help='how many pairs will generate from the test dataset') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +# Device options +parser.add_argument('--no-cuda', action='store_true', default=False, + help='enables CUDA training') +parser.add_argument('--gpu-id', default='0', type=str, + help='id(s) for CUDA_VISIBLE_DEVICES') +parser.add_argument('--seed', type=int, default=0, metavar='S', + help='random seed (default: 0)') +parser.add_argument('--log-interval', type=int, default=10, metavar='LI', + help='how many batches to wait before logging training status') +parser.add_argument('--descriptor', type=str, + default='pixels', + help='what is minimized. Variants: pixels, SIFT, HardNet') +parser.add_argument('--merge', type=str, + default='sum', + help='Combination of geom loss and descriptor loss: mul, sum') +parser.add_argument('--geom-loss-coef', type=float, + default=1.0, + help='coef of geom loss (linear if sum, power if mul) (defualt 1.0') +parser.add_argument('--descr-loss-coef', type=float, + default=0.0, + help='coef of descr loss (linear if sum, power if mul (default 0)') + + +args = parser.parse_args() + + +# set the device to use by setting CUDA_VISIBLE_DEVICES env variable in +# order to prevent any memory allocation on unused GPUs +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id + +args.cuda = not args.no_cuda and torch.cuda.is_available() +if args.cuda: + cudnn.benchmark = True + torch.cuda.manual_seed_all(args.seed) + +# create loggin directory +if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + +# set random seeds +torch.manual_seed(args.seed) +np.random.seed(args.seed) + +if args.descriptor == 'SIFT': + descriptor = SIFTNet(patch_size=32) + if not args.no_cuda: + descriptor = descriptor.cuda() +elif args.descriptor == 'HardNet': + descriptor = HardNet() + #descriptor = HardNetNarELU(SIFTNet(patch_size=32)) + if not args.no_cuda: + descriptor = descriptor.cuda() + model_weights = 'HardNet++.pth' + #model_weights = 'HardNetELU_Narr.pth' + hncheckpoint = torch.load(model_weights) + descriptor.load_state_dict(hncheckpoint['state_dict']) + descriptor.train() +else: + descriptor = lambda x: x.view(x.size(0),-1) + +suffix='ONet_' + args.merge + '_' + args.descriptor + '_' + str(args.lr) + '_' + str(args.n_pairs) +##########################################3 +def create_loaders(): + + kwargs = {'num_workers': args.num_workers, 'pin_memory': args.pin_memory} if args.cuda else {} + transform = transforms.Compose([ + transforms.Lambda(np_reshape), + transforms.ToTensor() + ]) + + train_loader = torch.utils.data.DataLoader( + TotalDatasetsLoader(datasets_path = args.dataroot, train=True, + n_triplets = args.n_pairs, + fliprot=True, + batch_size=args.batch_size, + download=True, + transform=transform), + batch_size=args.batch_size, + shuffle=False, **kwargs) + + test_loader = torch.utils.data.DataLoader( + HPatchesDM('datasets/HPatches_HessianPatches','', train=False, + n_pairs = args.n_test_pairs, + batch_size=args.test_batch_size, + download=True, + transform=transforms.Compose([])), + batch_size=args.test_batch_size, + shuffle=False, **kwargs) + return train_loader, test_loader + +def extract_and_crop_patches_by_predicted_transform(patches, trans, crop_size = 32): + assert patches.size(0) == trans.size(0) + st = int((patches.size(2) - crop_size) / 2) + fin = st + crop_size + rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1)); + if patches.is_cuda: + rot_LAFs = rot_LAFs.cuda() + trans = trans.cuda() + rot_LAFs1 = torch.cat([torch.bmm(trans, rot_LAFs[:,0:2,0:2]), rot_LAFs[:,0:2,2:]], dim = 2); + return extract_patches(patches, rot_LAFs1, PS = patches.size(2))[:,:, st:fin, st:fin].contiguous() + +def train(train_loader, model, optimizer, epoch): + # switch to train mode + model.train() + pbar = tqdm(enumerate(train_loader)) + for batch_idx, data in pbar: + data_a, data_p = data + if args.cuda: + data_a, data_p = data_a.float().cuda(), data_p.float().cuda() + data_a, data_p = Variable(data_a), Variable(data_p) + rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi) + scale = Variable( 0.9 + 0.3* torch.rand(data_a.size(0), 1, 1)); + if args.cuda: + scale = scale.cuda() + rot_LAFs[:,0:2,0:2] = rot_LAFs[:,0:2,0:2] * scale.expand(data_a.size(0),2,2) + shift_w, shift_h = get_random_shifts_LAFs(data_a, 2, 2) + rot_LAFs[:,0,2] = rot_LAFs[:,0,2] + shift_w / float(data_a.size(3)) + rot_LAFs[:,1,2] = rot_LAFs[:,1,2] + shift_h / float(data_a.size(2)) + data_a_rot = extract_patches(data_a, rot_LAFs, PS = data_a.size(2)) + st = int((data_p.size(2) - model.PS)/2) + fin = st + model.PS + + data_p_crop = data_p[:,:, st:fin, st:fin].contiguous() + data_a_rot_crop = data_a_rot[:,:, st:fin, st:fin].contiguous() + out_a_rot, out_p, out_a = model(data_a_rot_crop,True), model(data_p_crop,True), model(data_a[:,:, st:fin, st:fin].contiguous(), True) + out_p_rotatad = torch.bmm(inv_rotmat, out_p) + + ######Apply rot and get sifts + out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS) + out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS) + + desc_a = descriptor(out_patches_a_crop) + desc_p = descriptor(out_patches_p_crop) + loss_hn = loss_HardNet(desc_a,desc_p) + descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6) #/ float(desc_a.size(1)) + + geom_dist = torch.sqrt(((out_a_rot - out_p_rotatad)**2 ).view(-1,4).max(dim=1)[0] + 1e-8) + loss = loss_hn + optimizer.zero_grad() + loss.backward() + optimizer.step() + adjust_learning_rate(optimizer) + if batch_idx % 10 == 0: + pbar.set_description( + 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, {}'.format( + epoch, batch_idx * len(data_a), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.data[0], geom_dist.mean().data[0])) + torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, + '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch)) + +def test(test_loader, model, epoch): + # switch to evaluate mode + model.eval() + + geom_distances, desc_distances = [], [] + + pbar = tqdm(enumerate(test_loader)) + for batch_idx, (data_a, data_p) in pbar: + + if args.cuda: + data_a, data_p = data_a.float().cuda(), data_p.float().cuda() + data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True) + rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi) + data_a_rot = extract_patches(data_a, rot_LAFs, PS = data_a.size(2)) + st = int((data_p.size(2) - model.PS)/2) + fin = st + model.PS + data_p = data_p[:,:, st:fin, st:fin].contiguous() + data_a_rot = data_a_rot[:,:, st:fin, st:fin].contiguous() + out_a_rot, out_p = model(data_a_rot, True), model(data_p, True) + out_p_rotatad = torch.bmm(inv_rotmat, out_p) + geom_dist = torch.sqrt((out_a_rot - out_p_rotatad)**2 + 1e-12).mean() + out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS) + out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS) + desc_a = descriptor(out_patches_a_crop) + desc_p = descriptor(out_patches_p_crop) + descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6)#/ float(desc_a.size(1)) + descr_dist = descr_dist.mean() + geom_distances.append(geom_dist.data.cpu().numpy().reshape(-1,1)) + desc_distances.append(descr_dist.data.cpu().numpy().reshape(-1,1)) + if batch_idx % args.log_interval == 0: + pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( + epoch, batch_idx * len(data_a), len(test_loader.dataset), + 100. * batch_idx / len(test_loader))) + + geom_distances = np.vstack(geom_distances).reshape(-1,1) + desc_distances = np.vstack(desc_distances).reshape(-1,1) + + print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(geom_distances.mean())) + print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(desc_distances.mean())) + return + +def adjust_learning_rate(optimizer): + """Updates the learning rate given the learning rate decay. + The routine has been implemented according to the original Lua SGD optimizer + """ + for group in optimizer.param_groups: + if 'step' not in group: + group['step'] = 0. + else: + group['step'] += 1. + group['lr'] = args.lr * ( + 1.0 - float(group['step']) * float(args.batch_size) / (args.n_pairs * float(args.epochs))) + return + +def create_optimizer(model, new_lr): + optimizer = optim.SGD(model.parameters(), lr=new_lr, + momentum=0.9, dampening=0.9, + weight_decay=args.wd) + return optimizer + + +def main(train_loader, test_loader, model): + # print the experiment configuration + print('\nparsed options:\n{}\n'.format(vars(args))) + if args.cuda: + model.cuda() + optimizer1 = create_optimizer(model, args.lr) + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print('=> loading checkpoint {}'.format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint['state_dict']) + else: + print('=> no checkpoint found at {}'.format(args.resume)) + start = args.start_epoch + end = start + args.epochs + for epoch in range(start, end): + # iterate over test loaders and test results + train(train_loader, model, optimizer1, epoch) + test(test_loader, model, epoch) + return 0 + +if __name__ == '__main__': + LOG_DIR = args.log_dir + LOG_DIR = os.path.join(args.log_dir,suffix) + if not os.path.isdir(LOG_DIR): + os.makedirs(LOG_DIR) + from architectures import OriNetFast + model = OriNetFast(PS=32) + train_loader, test_loader = create_loaders() + main(train_loader, test_loader, model)