# VIGOR Development Notebook

Author: Ted de Vries Lentsch

First version on March 24, 2022

Last update on April 7, 2022

In [None]:
# code settings
DO_INTERACTION        = False                                                           # widgets and plots
DO_LOCAL              = False                                                           # run code on local machine

# VIGOR dataset (be careful!)
DO_RESIZE             = False                                                           # load, resize, and save VIGOR dataset
DO_NEW_YORK_ONLY      = True                                                            # use only the city New York
DO_DATA_AUGMENTATION  = False                                                           # use data augmentation for training
DO_FULL_DATASET       = False                                                           # use all New York image pairs
DO_POS_ONLY           = False                                                           # use only the positive satellite images

# train settings
CUSTOM_DATASET_LENGTH = 10                                                              # custom train dataset length
BATCH_SIZE            = 8                                                               # batch size
ADD_RELU              = False                                                           # relu activation for first layer
ADD_SIGMOID           = False                                                           # sigmoid activation for second layer

# train and test
DO_TRAIN              = False                                                           # train the model
DO_TEST               = False                                                           # test the model

In [None]:
if DO_INTERACTION:
    %load_ext autoreload
    %autoreload 2

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"                                                # GPU index
os.environ["MKL_NUM_THREADS"]      = "6"                                                # num of threads
os.environ["NUMEXPR_NUM_THREADS"]  = "6"                                                # num of threads
os.environ["OMP_NUM_THREADS"]      = "6"                                                # num of threads

## Imports

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
import shutil
import time

# PyTorch
import torch as th
import torch.nn as nn
import torch.utils.data as th_data
import torchvision as th_vision
from torchvision.transforms import functional as F

# widgets
if DO_INTERACTION:
    import ipywidgets

## Determine Device

In [None]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
"The device is: {}".format(device)

## 1. Visualize Dataset

### 1.1. Explore dataset

In [None]:
# Folder names
if DO_NEW_YORK_ONLY:
    vigor_cities = ["NewYork"]
else:
    vigor_cities           = ["Chicago","NewYork","SanFrancisco","Seattle"]
vigor_files            = ["pano_label_balanced.txt", "same_area_balanced_train.txt", "same_area_balanced_test.txt"]
streetfolder_name      = "panorama"
satellitefolder_name   = "satellite"
combinationfolder_name = "splits"


# Directories
if DO_LOCAL:
    notebookfolder_dir = os.getcwd()
    root_dir           = os.path.dirname(os.path.dirname(notebookfolder_dir))
else:
    assert False, "Change this directory!"
    root_dir = os.path.join("CHANGE")
vigor_dir   = os.path.join("datasets", "VIGOR")
dataset_dir = os.path.join(root_dir, vigor_dir)


# Count function
def count_num_img(dataset_dir, city_name, streetfolder_name, satellitefolder_name):
    street_dir = os.path.join(dataset_dir, city_name, streetfolder_name)
    satellite_dir = os.path.join(dataset_dir, city_name, satellitefolder_name)
    
    num_street_img = len(list(sorted(os.listdir(street_dir))))
    num_satellite_img = len(list(sorted(os.listdir(satellite_dir))))
    
    print(f"\n{city_name} has {num_street_img} street and {num_satellite_img} satellite images!")


# Widget
if DO_INTERACTION:
    ipywidgets.interact(lambda city: count_num_img(dataset_dir=dataset_dir,
                                                  city_name=city,
                                                  streetfolder_name=streetfolder_name,
                                                  satellitefolder_name=satellitefolder_name),
                        city=vigor_cities)

### 1.2. Show Image Combinations

In [None]:
# Plot function
def plot_img(dataset_dir, city_name, streetfolder_name, satellitefolder_name, combinationfolder_name, file, idx):
    """ Plot a street image and accompanying satellite images in one figure. """
    
    def get_img_names_and_deltas(file_dir, idx):
        data_list = []
        with open(file_dir, 'r') as file:
            cnt = 0
            for line in file.readlines():
                data = line.split(' ')
                if cnt==idx:
                    data_list.append(data[0])
                    for idx in range(4):
                        data_list.append((data[3*idx+1], float(data[3*idx+2]), float(data[3*idx+3])))
                    break
                else:
                    cnt +=1
        return data_list
    
    # directories
    street_dir = os.path.join(dataset_dir, city_name, streetfolder_name)
    satellite_dir = os.path.join(dataset_dir, city_name, satellitefolder_name)
    file_dir = os.path.join(dataset_dir, combinationfolder_name, city_name, file)
    
    # data
    data_list = get_img_names_and_deltas(file_dir, idx)
    street_img = cv2.imread(os.path.join(street_dir, data_list[0]))[:,:,::-1]
    satellite_img1 = cv2.imread(os.path.join(satellite_dir, data_list[1][0]))[:,:,::-1]
    satellite_img2 = cv2.imread(os.path.join(satellite_dir, data_list[2][0]))[:,:,::-1]
    satellite_img3 = cv2.imread(os.path.join(satellite_dir, data_list[3][0]))[:,:,::-1]
    satellite_img4 = cv2.imread(os.path.join(satellite_dir, data_list[4][0]))[:,:,::-1]
    delta1, delta2, delta3, delta4 = data_list[1][1:], data_list[2][1:], data_list[3][1:], data_list[4][1:]
    W, H, A = street_img.shape[1], street_img.shape[0], satellite_img1.shape[0]

    # create plot
    fig = plt.figure(figsize=[15, 11])
    grid = plt.GridSpec(2, 3, wspace=0.2, hspace=0.2)
    ax1 = plt.subplot(grid[0, :2])
    ax2 = plt.subplot(grid[0, 2])
    ax3 = plt.subplot(grid[1, 0])
    ax4 = plt.subplot(grid[1, 1])
    ax5 = plt.subplot(grid[1, 2])
    
    # plot street view
    ax1.imshow(street_img, extent=(0, W, H, 0), zorder=-10)
    ax1.set_title('Street View', pad=10, fontsize=24)
    ax1.set_xlim(0, W)
    ax1.set_ylim(H, 0)
    
    # plot satellite view
    axs = [ax2, ax3, ax4, ax5]
    titles = ["Satellite View (Positive)", "Semi-Positive 1", "Semi-Positive 2", "Semi-Positive 3"]
    imgs = [satellite_img1, satellite_img2, satellite_img3, satellite_img4]
    for ax, title, img in zip(axs, titles, imgs):
        ax.imshow(img, extent=(0, A, A, 0), zorder=-10)
        ax.set_title(title, pad=16, fontsize=24)
        ax.set_xlim(0, A)
        ax.set_ylim(A, 0)
    
    # plot rays
    colors = ['springgreen', 'deepskyblue', 'orange', 'magenta'] # North, East, South, West
    
    ax1.vlines(x=0.00*W, ymin=0, ymax=H, color=colors[2], linewidth=3, zorder=10) # South
    ax1.vlines(x=0.25*W, ymin=0, ymax=H, color=colors[3], linewidth=3, zorder=10) # West
    ax1.vlines(x=0.50*W, ymin=0, ymax=H, color=colors[0], linewidth=3, zorder=10) # North
    ax1.vlines(x=0.75*W, ymin=0, ymax=H, color=colors[1], linewidth=3, zorder=10) # East
    ax1.vlines(x=1.00*W, ymin=0, ymax=H, color=colors[2], linewidth=3, zorder=10) # South

    deltas = [delta1, delta2, delta3, delta4]
    for ax, delta in zip(axs, deltas):
        xc, yc = A/2-A/640*delta[1], A/2+A/640*delta[0] # see GitHub of VIGOR for this formula
        ax.scatter(xc, yc, s=150, color="yellow", zorder=20) # Center
        ax.vlines(x=xc, ymin=0, ymax=yc, color=colors[0], linewidth=3, zorder=10) # North
        ax.hlines(y=yc, xmin=xc, xmax=A, color=colors[1], linewidth=3, zorder=10) # East
        ax.vlines(x=xc, ymin=yc, ymax=A, color=colors[2], linewidth=3, zorder=10) # South
        ax.hlines(y=yc, xmin=0, xmax=xc, color=colors[3], linewidth=3, zorder=10) # West

    plt.show()
    
    # print image names
    print("Street view:                      {}".format(data_list[0]))
    print("Satellite view (positive):        {}".format(data_list[1][0]))
    print("Satellite view (semi-positive 1): {}".format(data_list[2][0]))
    print("Satellite view (semi-positive 2): {}".format(data_list[3][0]))
    print("Satellite view (semi-positive 3): {}".format(data_list[4][0]))


# Widget
if DO_INTERACTION:
    ipywidgets.interact(lambda city, file, idx: plot_img(dataset_dir=dataset_dir,
                                                         city_name=city,
                                                         streetfolder_name=streetfolder_name,
                                                         satellitefolder_name=satellitefolder_name,
                                                         combinationfolder_name=combinationfolder_name,
                                                         file=file,
                                                         idx=idx),
                        city=vigor_cities,
                        file=vigor_files,
                        idx=range(100))

### 1.3. Resize Images

In [None]:
if DO_RESIZE:
    street_height_resized    = 320
    street_width_resized     = 640
    satellite_height_resized = 512
    satellite_width_resized  = 512

    if DO_NEW_YORK_ONLY:
        vigor_cities = ["NewYork"]
        vigor_files  = ["same_area_balanced_train.txt","same_area_balanced_test.txt"]
    else:
        vigor_cities = ["Chicago","NewYork","SanFrancisco","Seattle"]
        vigor_files  = ["pano_label_balanced.txt","same_area_balanced_train.txt","same_area_balanced_test.txt"]
        
    for city_name in vigor_cities:
        street_name_list, satellite_name_list = [], []

        # folders with original images
        streetfolder_dir    = os.path.join(dataset_dir, city_name, streetfolder_name)
        satellitefolder_dir = os.path.join(dataset_dir, city_name, satellitefolder_name)

        # folders with resized images
        streetfolder_resized_dir    = os.path.join(dataset_dir, city_name, streetfolder_name)+"_resized"
        satellitefolder_resized_dir = os.path.join(dataset_dir, city_name, satellitefolder_name)+"_resized"

        # make folders
        if os.path.isdir(streetfolder_resized_dir):
            shutil.rmtree(streetfolder_resized_dir)
            time.sleep(0.1)
        if os.path.isdir(satellitefolder_resized_dir):
            shutil.rmtree(satellitefolder_resized_dir)
            time.sleep(0.1)    
        os.makedirs(streetfolder_resized_dir)
        os.makedirs(satellitefolder_resized_dir)

        # read images, resized, and save to new directory
        for file_name in vigor_files:
            file_dir = os.path.join(dataset_dir, combinationfolder_name, city_name, file_name)
            with open(file_dir, 'r') as file:
                for line in file.readlines():
                    data = line.split(' ')
                    street_img_name = data[0]
                    satellite_img_names = [data[1], data[4], data[7], data[10]]

                    # street image
                    if street_img_name not in street_name_list:
                        street_name_list.append(street_img_name)

                        # image directories
                        street_img_dir = os.path.join(streetfolder_dir, street_img_name)
                        street_resized_img_dir = os.path.join(streetfolder_resized_dir, street_img_name)

                        # read, resized and save image
                        street_img = cv2.imread(street_img_dir)
                        street_resized_img = cv2.resize(src=street_img,
                                                        dsize=(street_width_resized, street_height_resized),
                                                        interpolation=cv2.INTER_AREA)
                        cv2.imwrite(filename=street_resized_img_dir, img=street_resized_img)

                    # satellite image
                    for satellite_img_name in satellite_img_names:
                        if satellite_img_name not in satellite_name_list:
                            satellite_name_list.append(satellite_img_name)

                            # image directories
                            satellite_img_dir = os.path.join(satellitefolder_dir, satellite_img_name)
                            satellite_resized_img_dir = os.path.join(satellitefolder_resized_dir, satellite_img_name)

                            # read, resized and save image
                            satellite_img = cv2.imread(satellite_img_dir)
                            satellite_resized_img = cv2.resize(src=satellite_img,
                                                               dsize=(satellite_width_resized, satellite_height_resized),
                                                               interpolation=cv2.INTER_AREA)
                            cv2.imwrite(filename=satellite_resized_img_dir, img=satellite_resized_img)

### 1.4. Show Resized Image Combinations

In [None]:
# folder names
streetfolder_resized_name = streetfolder_name+"_resized"
satellitefolder_resized_name = satellitefolder_name+"_resized"


# cities and files
if DO_NEW_YORK_ONLY:
    vigor_cities = ["NewYork"]
    vigor_files  = ["same_area_balanced_train.txt","same_area_balanced_test.txt"]
else:
    vigor_cities = ["Chicago","NewYork","SanFrancisco","Seattle"]
    vigor_files  = ["pano_label_balanced.txt","same_area_balanced_train.txt","same_area_balanced_test.txt"]


# widget
if DO_INTERACTION:
    ipywidgets.interact(lambda city, file, idx: plot_img(dataset_dir=dataset_dir,
                                                         city_name=city,
                                                         streetfolder_name=streetfolder_resized_name,
                                                         satellitefolder_name=satellitefolder_resized_name,
                                                         combinationfolder_name=combinationfolder_name,
                                                         file=file,
                                                         idx=idx),
                        city=vigor_cities,
                        file=vigor_files,
                        idx=range(100))

### 2. Dataset

In [None]:
class VIGORDataset(th_data.Dataset):
    """ VIGOR dataset class. """
    
    def __init__(self, root, cities, streetfolder_name, satellitefolder_name, combinationfolder_name, file_name,
                 transforms=None, only_pos=False):
        self.root       = root                                                          # directory to dataset
        self.cities     = cities                                                        # list with cities
        self.transforms = transforms                                                    # list with transforms
        
        # (resized) images dimensions
        self.H            = 320                                                         # height of street image
        self.W            = 640                                                         # width of street image
        self.A            = 512                                                         # height and width of satellite image
        self.A_original   = 640                                                         # height and width of satellite image
        self.scale_factor = self.A/self.A_original                                      # new_size/old_size satellite image
        
        # get combinations
        self.combs      = []                                                            # (street, satellite) combinations
        self.annots     = []                                                            # deltas
        self.get_combinations(streetfolder_name, satellitefolder_name, combinationfolder_name, file_name, only_pos)

    def get_combinations(self, streetfolder_name, satellitefolder_name, combinationfolder_name, file_name, only_pos):
        """
        Args:
            streetfolder_name (str):      name of folder with street images
            satellitefolder_name (str):   name of folder with satellite images
            combinationfolder_name (str): name of folder with combination and splits information
            file_name (str):              name of (text) document with combination information
        """
        
        for city_name in self.cities:
            file_dir = os.path.join(self.root, combinationfolder_name, city_name, file_name)
            with open(file_dir, 'r') as file:
                for line in file.readlines():
                    data = line.split(' ')
                    street_img_subdir = os.path.join(city_name, streetfolder_name, data[0])
                    for idx in range(4):
                        satellite_img_subdir = os.path.join(city_name, satellitefolder_name, data[3*idx+1])
                        delta = (float(data[3*idx+2]), float(data[3*idx+3]))
                        if abs(delta[0])<=self.A//2 and abs(delta[1])<=self.A//2: # check whether GT location is on image
                            self.combs.append((street_img_subdir, satellite_img_subdir))
                            self.annots.append((self.A//2+self.scale_factor*delta[0], self.A//2-self.scale_factor*delta[1]))
                        if only_pos:
                            break

    def __getitem__(self, idx):
        """
        Args:
            idx (int): index of sample
        Output:
            imgs (list):   list with street and satellite image (both tensors)
            target (dict): dict with index of sample and annotation (the GT location) 
        """
        
        street_img_dir    = os.path.join(self.root, self.combs[idx][0])
        satellite_img_dir = os.path.join(self.root, self.combs[idx][1])
        
        street_img    = cv2.imread(street_img_dir)[:,:,::-1]
        satellite_img = cv2.imread(satellite_img_dir)[:,:,::-1]
        imgs          = [street_img, satellite_img]

        target                 = {}                                                     # target
        target['image_id']     = th.tensor([idx])                                       # image id (equal to index)
        target['street_id']    = self.combs[idx][0]                                     # subdir of street image
        target['satellite_id'] = self.combs[idx][1]                                     # subdir of satellite image
        target['location']     = th.tensor([*self.annots[idx]])                         # location (h, w)

        if self.transforms is not None:
            imgs, target = self.transforms(imgs, target)

        return imgs, target

    def __len__(self):       
        return len(self.combs)

### 3. Transformations

### 3.1. Define transformations

In [None]:
class ToTensor(object):
    """ Convert all images in the list to tensors with type float. """
    
    def __call__(self, imgs, target):
        imgs = [F.to_tensor(img.copy()).type(th.float) for img in imgs]
      
        return imgs, target

In [None]:
class NormalizeTensor(object):
    """ Normalize tensors by using PyTorch ImageNet mean and std. """

    def __init__(self):
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

    def __call__(self, imgs, target):
        imgs = [F.normalize(img, self.mean, self.std) for img in imgs]
        
        return imgs, target

In [None]:
class RandomRotateSatellite(object):
    """ Rotate the satellite image 0, 90, 180 or 270 degrees randomly. """

    def __init__(self, W, A):
        self.prob_thresholds = [0.25, 0.50, 0.75, 1.00]
        self.W               = W
        self.A               = A

    def __call__(self, imgs, target, prob=random.random()):       
        # no rotation
        if prob<=self.prob_thresholds[0]:
            return imgs, target
        else:
            street_img, satellite_img = imgs
            location                  = target["location"]

            # 90 degrees rotation
            if prob<=self.prob_thresholds[1]:
                satellite_img = th.rot90(satellite_img, 1, [-1,-2])                     # rotate 90 degrees clockwise
                part1 = street_img[:,:,int(0.00*self.W):int(0.25*self.W)]               # part1
                part2 = street_img[:,:,int(0.25*self.W):int(0.50*self.W)]               # part2
                part3 = street_img[:,:,int(0.50*self.W):int(0.75*self.W)]               # part3
                part4 = street_img[:,:,int(0.75*self.W):int(1.00*self.W)]               # part4
                street_img    = th.cat((part4, part1, part2, part3), dim=2)             # combine all parts
                location      = th.tensor([location[1], self.A-location[0]])            # change location
            # 180 degrees rotation
            elif prob<=self.prob_thresholds[2]:
                satellite_img = th.rot90(satellite_img, 2, [-1,-2])                     # rotate 180 degrees clockwise
                part1 = street_img[:,:,int(0.00*self.W):int(0.25*self.W)]               # part1
                part2 = street_img[:,:,int(0.25*self.W):int(0.50*self.W)]               # part2
                part3 = street_img[:,:,int(0.50*self.W):int(0.75*self.W)]               # part3
                part4 = street_img[:,:,int(0.75*self.W):int(1.00*self.W)]               # part4
                street_img    = th.cat((part3, part4, part1, part2), dim=2)             # combine all parts
                location      = th.tensor([self.A-location[0], self.A-location[1]])     # change location
            # 270 degrees rotation
            elif prob<=self.prob_thresholds[3]:
                satellite_img = th.rot90(satellite_img, 3, [-1,-2])                     # rotate 270 degrees clockwise
                part1 = street_img[:,:,int(0.00*self.W):int(0.25*self.W)]               # part1
                part2 = street_img[:,:,int(0.25*self.W):int(0.50*self.W)]               # part2
                part3 = street_img[:,:,int(0.50*self.W):int(0.75*self.W)]               # part3
                part4 = street_img[:,:,int(0.75*self.W):int(1.00*self.W)]               # part4
                street_img    = th.cat((part2, part3, part4, part1), dim=2)             # combine all parts
                location      = th.tensor([self.A-location[1], location[0]])            # change location

            imgs                     = [street_img, satellite_img]
            target["location"]       = location
            target["Rotated"] = 1
            
            return imgs, target

In [None]:
class RandomHorizontalFlipSatellite(object):
    """ Flip the satellite image horizontally with a certain probability. """
    
    def __init__(self, A):
        self.prob_threshold = 0.5
        self.A              = A

    def __call__(self, imgs, target):
        if random.random()<=self.prob_threshold:
            street_img, satellite_img = imgs
            location                  = target["location"]
            
            satellite_img = satellite_img.flip([2])                                     # flip satellite image horizontally
            street_img    = street_img.flip([2])                                        # flip satellite image horizontally
            location      = th.tensor([location[0], self.A-location[1]])                # change width location
            
            imgs                     = [street_img, satellite_img]
            target["location"]       = location
            target["HorizontalFlip"] = 1

        return imgs, target

In [None]:
class RandomVerticalFlipSatellite(object):
    """ Flip the satellite image vertically with a certain probability. """
    
    def __init__(self, W, A):
        self.prob_threshold = 0.5
        self.W              = W
        self.A              = A

    def __call__(self, imgs, target):
        if random.random()<=self.prob_threshold:
            street_img, satellite_img = imgs
            location                  = target["location"]
            
            satellite_img = satellite_img.flip([1])                                     # flip satellite image vertically
            left_part     = street_img[:,:,:self.W//2].flip([2])                        # flip left part horizontally
            right_part    = street_img[:,:,self.W//2:].flip([2])                        # flip right part horizontally
            street_img    = th.cat((left_part, right_part), dim=2)                      # combine left and right part
            location      = th.tensor([self.A-location[0], location[1]])                # change height location
            
            imgs                   = [street_img, satellite_img]
            target["location"]     = location
            target["VerticalFlip"] = 1

        return imgs, target

In [None]:
class Compose(object):
    """ Apply all transformations. """
    
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, imgs, target):
        for transform in self.transforms:
            imgs, target = transform(imgs, target)
            
        return imgs, target

In [None]:
def get_transform(normalize=True, train=False, rotate=False, hflip=False, vflip=False, H=320, W=640, A=512):
    """ Combine all transformations into a single transformation. """
    
    transforms = []
    
    # numpy array to tensor (this scales element values from 0-255 to 0-1)
    transforms.append(ToTensor())
    
    # normalize tensor
    if normalize:
        transforms.append(NormalizeTensor())
    
    # during training, randomly flip the satellite images and change the street images accordingly
    if train:
        if rotate:
            transforms.append(RandomRotateSatellite(W=W, A=A))
        if hflip:
            transforms.append(RandomHorizontalFlipSatellite(A=A))
        if vflip:
            transforms.append(RandomVerticalFlipSatellite(W=W, A=A))

    return Compose(transforms)

### 3.2. Show Transformations

In [None]:
# Plot function
def plot_visual_transforms(dataset, idx, Rotate, HorizontalFlip, VerticalFlip):
    """ Plot a street image and accompanying satellite images in one figure. """

    visual_transforms = [RandomRotateSatellite(W=640, A=512),
                         RandomHorizontalFlipSatellite(A=512),
                         RandomVerticalFlipSatellite(W=640, A=512)]

    # data
    imgs, target = dataset[idx]
    street_img, satellite_img = imgs
    W, H, A = street_img.shape[2], street_img.shape[1], satellite_img.shape[1]
    xc, yc = target["location"][1], target["location"][0] # see GitHub of VIGOR for this formula
    
    # create plot
    fig = plt.figure(figsize=[15, 11])
    grid = plt.GridSpec(2, 3, wspace=0.1, hspace=0.1)
    ax1 = plt.subplot(grid[0, :2])
    ax2 = plt.subplot(grid[0, 2])
    ax3 = plt.subplot(grid[1, :2])
    ax4 = plt.subplot(grid[1, 2])
    
    # plot street view
    ax1.imshow(street_img.permute(1, 2, 0).numpy(), extent=(0, W, H, 0), zorder=-10)
    ax1.set_title('Street View', pad=10, fontsize=24)
    ax1.set_xlim(0, W)
    ax1.set_ylim(H, 0)
    ax1.axis('off')
    
    # plot satellite view
    ax2.imshow(satellite_img.permute(1, 2, 0).numpy(), extent=(0, A, A, 0), zorder=-10)
    ax2.set_title('Satellite View', pad=10, fontsize=24)
    ax2.set_xlim(0, A)
    ax2.set_ylim(A, 0)
    ax2.axis('off')

    # transformation
    if Rotate=="0 degrees":
        imgs, target = visual_transforms[0](imgs, target, 0.25)
    elif Rotate=="90 degrees":
        imgs, target = visual_transforms[0](imgs, target, 0.50)
    elif Rotate=="180 degrees":
        imgs, target = visual_transforms[0](imgs, target, 0.75)
    elif Rotate=="270 degrees":
        imgs, target = visual_transforms[0](imgs, target, 1.00)
    if HorizontalFlip=="Apply":
        visual_transforms[1].prob_threshold = 1
        imgs, target = visual_transforms[1](imgs, target)
    if VerticalFlip=="Apply":
        visual_transforms[2].prob_threshold = 1
        imgs, target = visual_transforms[2](imgs, target)

    # transformed data
    street_img, satellite_img = imgs
    W, H, A = street_img.shape[2], street_img.shape[1], satellite_img.shape[1]
    xc_transformed, yc_transformed = target["location"][1], target["location"][0] # see GitHub of VIGOR for this formula

    # plot transformed street view   
    ax3.imshow(street_img.permute(1, 2, 0).numpy(), extent=(0, W, H, 0), zorder=-10)
    ax3.set_title('Transformed Street View', pad=10, fontsize=24)
    ax3.set_xlim(0, W)
    ax3.set_ylim(H, 0)
    ax3.axis('off')
    
    # plot transformed satellite view
    ax4.imshow(satellite_img.permute(1, 2, 0).numpy(), extent=(0, A, A, 0), zorder=-10)
    ax4.set_title('Transformed Satellite View', pad=10, fontsize=24)
    ax4.set_xlim(0, A)
    ax4.set_ylim(A, 0)
    ax4.axis('off')
    
    # plot rays
    colors = ['springgreen', 'deepskyblue', 'orange', 'magenta'] # North, East, South, West
    
    ax1.vlines(x=0.00*W, ymin=0, ymax=H, color=colors[2], linewidth=3, zorder=10) # South
    ax1.vlines(x=0.25*W, ymin=0, ymax=H, color=colors[3], linewidth=3, zorder=10) # West
    ax1.vlines(x=0.50*W, ymin=0, ymax=H, color=colors[0], linewidth=3, zorder=10) # North
    ax1.vlines(x=0.75*W, ymin=0, ymax=H, color=colors[1], linewidth=3, zorder=10) # East
    ax1.vlines(x=1.00*W, ymin=0, ymax=H, color=colors[2], linewidth=3, zorder=10) # South
    
    ax2.scatter(xc, yc, s=150, color="yellow", zorder=20) # Center
    ax2.vlines(x=xc, ymin=0, ymax=yc, color=colors[0], linewidth=3, zorder=10) # North
    ax2.hlines(y=yc, xmin=xc, xmax=A, color=colors[1], linewidth=3, zorder=10) # East
    ax2.vlines(x=xc, ymin=yc, ymax=A, color=colors[2], linewidth=3, zorder=10) # South
    ax2.hlines(y=yc, xmin=0, xmax=xc, color=colors[3], linewidth=3, zorder=10) # West
    
    ax3.vlines(x=0.00*W, ymin=0, ymax=H, color="white", linewidth=3, zorder=10)
    ax3.vlines(x=0.25*W, ymin=0, ymax=H, color="white", linewidth=3, zorder=10)
    ax3.vlines(x=0.50*W, ymin=0, ymax=H, color="white", linewidth=3, zorder=10)
    ax3.vlines(x=0.75*W, ymin=0, ymax=H, color="white", linewidth=3, zorder=10)
    ax3.vlines(x=1.00*W, ymin=0, ymax=H, color="white", linewidth=3, zorder=10)

    xc, yc = target["location"][1], target["location"][0] # see GitHub of VIGOR for this formula
    ax4.scatter(xc_transformed, yc_transformed, s=150, color="yellow", zorder=20) # Center
    ax4.vlines(x=xc_transformed, ymin=0, ymax=yc_transformed, color="white", linewidth=3, zorder=10)
    ax4.hlines(y=yc_transformed, xmin=xc_transformed, xmax=A, color="white", linewidth=3, zorder=10)
    ax4.vlines(x=xc_transformed, ymin=yc_transformed, ymax=A, color="white", linewidth=3, zorder=10)
    ax4.hlines(y=yc_transformed, xmin=0, xmax=xc_transformed, color="white", linewidth=3, zorder=10)
        
    plt.show()

    
# Folder names
streetfolder_resized_name = streetfolder_name+"_resized"
satellitefolder_resized_name = satellitefolder_name+"_resized"


# Dataset
dataset = VIGORDataset(root=dataset_dir,
                       cities=vigor_cities,
                       streetfolder_name=streetfolder_resized_name,
                       satellitefolder_name=satellitefolder_resized_name,
                       combinationfolder_name=combinationfolder_name,
                       file_name="same_area_balanced_train.txt",
                       transforms=get_transform(normalize=False,
                                                train=False,
                                                rotate=False,
                                                hflip=False,
                                                vflip=False,
                                                H=320,
                                                W=640,
                                                A=512),
                       only_pos=False)


# Widget
if DO_INTERACTION:
    ipywidgets.interact(lambda idx, Rotate, HFlip, VFlip: plot_visual_transforms(dataset=dataset,
                                                                                 idx=idx,
                                                                                 Rotate=Rotate,
                                                                                 HorizontalFlip=HFlip,
                                                                                 VerticalFlip=VFlip),
                        idx=range(100),
                        Rotate=["0 degrees","90 degrees","180 degrees","270 degrees"],
                        HFlip=["Don't apply","Apply"],
                        VFlip=["Don't apply","Apply"])

## 4. Create datasets

In [None]:
streetfolder_resized_name = streetfolder_name+"_resized"
satellitefolder_resized_name = satellitefolder_name+"_resized"
dataset_traindata = VIGORDataset(root=dataset_dir,
                                 cities=["NewYork"],
                                 streetfolder_name=streetfolder_resized_name,
                                 satellitefolder_name=satellitefolder_resized_name,
                                 combinationfolder_name=combinationfolder_name,
                                 file_name="same_area_balanced_train.txt",
                                 transforms=get_transform(normalize=True,
                                                          train=True,
                                                          rotate=DO_DATA_AUGMENTATION,
                                                          hflip=DO_DATA_AUGMENTATION,
                                                          vflip=False,
                                                          H=320,
                                                          W=640,
                                                          A=512),
                                 only_pos=DO_POS_ONLY)
dataset_testdata = VIGORDataset(root=dataset_dir,
                                cities=["NewYork"],
                                streetfolder_name=streetfolder_resized_name,
                                satellitefolder_name=satellitefolder_resized_name,
                                combinationfolder_name=combinationfolder_name,
                                file_name="same_area_balanced_test.txt",
                                transforms=get_transform(normalize=True,
                                                         train=False,
                                                         rotate=False,
                                                         hflip=False,
                                                         vflip=False,
                                                         H=320,
                                                         W=640,
                                                         A=512),
                                only_pos=DO_POS_ONLY)

## 5. Models

### 5.1. VIGOR baseline

In [None]:
class SPE_module(nn.Module):
    """ The Spatial-aware Position Embedding (SPE) Module of SAFA. """
    
    def __init__(self, map_height, map_width, add_relu, add_sigmoid):
        super().__init__()
        
        # define linear layers for creating position embedding map P
        input_size, hidden_size, output_size = map_height*map_width, map_height*map_width//2, map_height*map_width
        self.linear1 = th.nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.linear2 = th.nn.Linear(in_features=hidden_size, out_features=output_size, bias=True)
        
        # SAFA source code forgot the activation functions
        self.relu        = nn.ReLU() if add_relu else None
        self.sigmoid     = nn.Sigmoid() if add_sigmoid else None
        self.add_relu    = add_relu
        self.add_sigmoid = add_sigmoid
        
    def forward(self, fmaps):
        # max-pooling along channels
        maxpooled = self.channel_max_pool(fmaps.clone())
        
        # spatial-aware importance generator
        y1 = self.linear1(maxpooled.flatten(start_dim=-2, end_dim=-1))
        if self.add_relu:
            y1 = self.relu(y1)
        y2 = self.linear2(y1)
        if self.add_sigmoid:
            y2 = self.sigmoid(y2)
        P = y2.unsqueeze(dim=1)
        
        # Frobenius inner product
        embedding = th.mul(fmaps.flatten(start_dim=-2, end_dim=-1), P).sum(dim=-1)
        
        return embedding
    
    def channel_max_pool(self, feature_map):
        maxpooled, indices = th.max(feature_map, axis=-3)
        return maxpooled

    
class SAFA_module(nn.Module):
    """ The Spatial-aware Feature Aggregation (SAFA) Module of SAFA. """
    
    def __init__(self, map_height, map_width, add_relu, add_sigmoid):
        super().__init__()
        
        # define multiple SPE modules (8 in total)
        self.SPE1 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE2 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE3 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE4 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE5 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE6 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE7 = SPE_module(map_height, map_width, add_relu, add_sigmoid)
        self.SPE8 = SPE_module(map_height, map_width, add_relu, add_sigmoid)

    def forward(self, fmaps):
        # calculate the embedding for every SPE module
        em1 = self.SPE1(fmaps.clone())
        em2 = self.SPE2(fmaps.clone())
        em3 = self.SPE3(fmaps.clone())
        em4 = self.SPE4(fmaps.clone())
        em5 = self.SPE5(fmaps.clone())
        em6 = self.SPE6(fmaps.clone())
        em7 = self.SPE7(fmaps.clone())
        em8 = self.SPE8(fmaps.clone())

        # create global descriptor by aggregate the embeddings
        gd = th.cat((em1, em2, em3, em4, em5, em6, em7, em8), dim=-1)
        
        # normalize global descriptor
        gd = th.div(gd, th.linalg.norm(gd, ord=2, dim=1).unsqueeze(dim=1))
        
        return gd


class CVR(nn.Module):
    """ The Cross-View Regression model of VIGOR. """

    def __init__(self, street_img_height, street_img_width, sat_img_height, sat_img_width, stride=16,
                 add_relu=False, add_sigmoid=False, do_metric_localization=False):
        super().__init__()
        
        # backbones
        self.vgg16_street = nn.Sequential(*th_vision.models.vgg16(pretrained=True).features[:-1])
        self.vgg16_sat    = nn.Sequential(*th_vision.models.vgg16(pretrained=True).features[:-1])
        
        # safa modules
        self.safa_street = SAFA_module(map_height=street_img_height//stride,
                                       map_width=street_img_width//stride,
                                       add_relu=add_relu,
                                       add_sigmoid=add_sigmoid)
        self.safa_sat    = SAFA_module(map_height=sat_img_height//stride,
                                       map_width=sat_img_width//stride,
                                       add_relu=add_relu,
                                       add_sigmoid=add_sigmoid)

        # offset prediction
        if do_metric_localization:
            self.linear1 = th.nn.Linear(in_features=2*4096, out_features=512, bias=True)
            self.linear2 = th.nn.Linear(in_features=512, out_features=2, bias=True)
            self.relu    = nn.ReLU()
        
        # setting
        self.do_metric_localization = do_metric_localization
        
    def forward(self, street_imgs, sat_imgs):
        # determine global descriptor of street image
        fmaps_street = self.vgg16_street(street_imgs)
        gds_street = self.safa_street(fmaps_street)
        
        # determine global descriptor of satellite image
        fmaps_sat = self.vgg16_sat(sat_imgs)
        gds_sat = self.safa_sat(fmaps_sat)
               
        if not self.do_metric_localization:
            return gds_street, gds_sat
        
        else:
            # determine offset
            gds_concatenated = th.cat((gds_street, gds_sat), dim=-1)
            delta = self.linear2(self.relu(self.linear1(gds_concatenated)))
            
            return gds_street, gds_sat, delta


def get_CVR_model(add_relu=False, add_sigmoid=False):
    model = CVR(street_img_height=320,
                street_img_width=640,
                sat_img_height=512,
                sat_img_width=512,
                add_relu=add_relu,
                add_sigmoid=add_sigmoid, 
                do_metric_localization=True)
    
    return model

## 6. Training, Validation, and Test Functions

In [None]:
def train_model_for_loss(model, criterion, optimizer, data_loader, device):
    A = 512
    
    # variable to store the loss and pixel difference
    total_loss   = 0
    total_pxdiff = 0

    # set model to train mode
    model.train()

    for image_pairs, targets in data_loader:
        street_imgs = th.stack([street_img.to(device) for street_img, _ in image_pairs])
        sat_imgs = th.stack([sat_img.to(device) for _, sat_img in image_pairs])
        targets = th.stack([target["location"] for target in targets]).to(device)

        # predict and loss
        gds_street, gds_sat, delta_preds = model(street_imgs, sat_imgs)
        loss = criterion(delta_preds, targets/A)

        # update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # add loss and pixel difference
        total_loss   += loss.item()
        total_pxdiff += th.linalg.norm(targets-A*delta_preds, ord=2, dim=1).mean().item()

    # average loss and pixel difference
    avg_loss = total_loss/len(data_loader)
    avg_pxdiff = total_pxdiff/len(data_loader)

    return avg_loss, avg_pxdiff

In [None]:
def validate_model_for_loss(model, criterion, data_loader, device):
    A = 512
    
    # variable to store the loss and pixel difference
    total_loss   = 0
    total_pxdiff = 0

    # set model to evaluation mode
    model.eval()

    # disable gradient calculation
    with th.no_grad():           
        for image_pairs, targets in data_loader:
            street_imgs = th.stack([street_img.to(device) for street_img, _ in image_pairs])
            sat_imgs = th.stack([sat_img.to(device) for _, sat_img in image_pairs])
            targets = th.stack([target["location"] for target in targets]).to(device)

            # predict and loss
            gds_street, gds_sat, delta_preds = model(street_imgs, sat_imgs)
            loss = criterion(delta_preds, targets/A)

            # add loss and pixel difference
            total_loss   += loss.item()
            total_pxdiff += th.linalg.norm(targets-A*delta_preds, ord=2, dim=1).mean().item()

    # average loss and pixel difference
    avg_loss = total_loss/len(data_loader)
    avg_pxdiff = total_pxdiff/len(data_loader)

    return avg_loss, avg_pxdiff

In [None]:
def test_model_for_pixels(model, data_loader, device):
    A = 512

    # variable to store test result
    test_result = th.empty([0,4]).to(device)
    total_loss   = 0
    total_pxdiff = 0

    # set model to evaluation mode
    model.eval()

    # disable gradient calculation
    with th.no_grad():           
        for image_pairs, targets in data_loader:
            street_imgs = th.stack([street_img.to(device) for street_img, _ in image_pairs])
            sat_imgs = th.stack([sat_img.to(device) for _, sat_img in image_pairs])
            targets = th.stack([target["location"] for target in targets]).to(device)

            # predict
            gds_street, gds_sat, delta_preds = model(street_imgs, sat_imgs)

            # append predictions and targets to test result
            test_result = th.cat((test_result, th.cat((A*delta_preds, targets), dim=1)), dim=0)

    # convert to numpy array
    test_result = test_result.cpu().numpy()
    
    return test_result

## 7. Training

In [None]:
if DO_TRAIN:
    # set seed for repetitiveness
    th.manual_seed(1)


    # utility function for data loader
    # convert batch from [(imgs1, target1), (imgs2, target2)] into ((imgs1, imgs2), (target1, target2))
    def collate_fn(batch):
        return tuple(zip(*batch))


    # split the dataset in train, validation and test dataset
    if DO_FULL_DATASET:
        len_train = len(dataset_traindata)
    else:
        len_train = CUSTOM_DATASET_LENGTH
    dataset_train = th_data.Subset(dataset_traindata, list(range(0, int(0.8*len_train), 1)))
    dataset_val   = th_data.Subset(dataset_traindata, list(range(int(0.8*len_train), len_train, 1)))
    dataset_test  = dataset_testdata


    # create data loaders
    data_loader_train = th_data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    data_loader_val   = th_data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    data_loader_test  = th_data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    
    
    # model
    model = get_CVR_model(ADD_RELU, ADD_SIGMOID).to(device)
    for param in model.parameters():
        param.requires_grad = True
    
    
    # criterion and optimizer
    criterion = nn.MSELoss(reduction='mean')
    params    = [p for p in model.parameters() if p.requires_grad]
    optimizer = th.optim.Adam(params, lr=1e-5, betas=(0.9, 0.999))

    
    # save directory
    if DO_POS_ONLY:
        if ADD_RELU and ADD_SIGMOID:
            name = "pos_relusigmoid"
        elif ADD_RELU and not ADD_SIGMOID:
            name = "pos_relu"
        else:
            name = "pos_no"
    else:
        if ADD_RELU and ADD_SIGMOID:
            name = "semipos_relusigmoid"
        elif ADD_RELU and not ADD_SIGMOID:
            name = "semipos_relu"
        else:
            name = "semipos_no"
    saving_time = time.strftime("%Y-%m-%d-%H-%M")
    base_save_dir = "/scratch/zxia/MSc/Ted/checkpoints/VIGOR_DEVEL_2022-04-07"
    if not os.path.isdir(base_save_dir):
        os.mkdir(base_save_dir)
    model_save_dir = os.path.join(base_save_dir, f"model_{name}_{saving_time}.pt")
    loss_save_dir = os.path.join(base_save_dir, f"loss_{name}_{saving_time}.npy")
    pxdiff_save_dir = os.path.join(base_save_dir, f"pxdiff_{name}_{saving_time}.npy")
    print(f"The model will be saved as {model_save_dir}")
    print(f"The losses will be saved as {loss_save_dir}")
    print(f"The pixel differences will be saved as {pxdiff_save_dir}")
    print()

    
    # training variables
    NUM_EPOCHS = 20

    # early stopping implementation
    best_val_loss = float('inf')
    patience      = 5
    patience_cnt  = 0
    
    # store and save losses
    losses, pxdiffs = [], []

    # compute initial loss on validation dataset
    val_loss, val_pxdiff = validate_model_for_loss(model, criterion, data_loader_val, device)
    losses.append((np.nan, val_loss))
    pxdiffs.append((np.nan, val_pxdiff))
    print(f"Epoch 00 - Val loss: {np.round(val_loss, 3)} | Val pixel difference: {np.round(val_pxdiff, 3)}")

    for epoch in range(NUM_EPOCHS):
        # train for one epoch
        train_loss, train_pxdiff = train_model_for_loss(model, criterion, optimizer, data_loader_train, device)

        # evaluate on validation dataset
        val_loss, val_pxdiff = validate_model_for_loss(model, criterion, data_loader_val, device)

        # store and save losses
        losses.append((train_loss, val_loss))
        pxdiffs.append((train_pxdiff, val_pxdiff))
        np.save(loss_save_dir, np.array(losses))
        np.save(pxdiff_save_dir, np.array(pxdiffs))

        # print losses
        epoch_str = f"{epoch+1}".zfill(2)
        print(f"Epoch {epoch_str} - Train and Val loss: {np.round(train_loss, 3)} ; {np.round(val_loss, 3)} | Train and Val pixel difference: {np.round(train_pxdiff, 3)} ; {np.round(val_pxdiff, 3)}")

        # apply early stopping
        if val_loss<best_val_loss:
            patience_cnt = 0
            best_val_loss = val_loss

            # save model
            th.save(model.state_dict(), model_save_dir)

        else:
            patience_cnt += 1

            if patience_cnt==patience:              
                # stop training
                break

    print(f"\nThe files have been saved!")

In [None]:
if DO_TRAIN and DO_TEST:
    # save directory
    test_save_dir = os.path.join(base_save_dir, f"testresult_{name}_{saving_time}.npy")
    print(f"The test result will be saved as {test_save_dir}")
    print()
    
    # test on test dataset
    test_result =  test_model_for_pixels(model, data_loader_test, device)
    
    # save output
    np.save(test_save_dir, test_result)