In [2]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from itertools import product, cycle
from functools import reduce
from operator import mul

from tqdm.notebook import tqdm

import requests
import shutil

import image_recovery.linalg as irl
import image_recovery.imglib as iri

import glob
import os

import cv2

SEED = 17

## Downloading images

In [3]:
CSV_PATH = "google_train.csv"
SAVE_PATH = "google_landmarks"
SAVED_URLS_FILE = "google_landmarks.txt"

max_resolution = (640, 480)
imgs_to_download = 100

with open(CSV_PATH, "r") as fh:
    urls = [x.split(",")[1].strip("\'\"") for x in fh.readlines()[1:]]
    
# ================
# CAREFUL! THIS CODE ERASES ALL THE FILES IN DIRECTORY!

for ximg in glob.glob(os.path.join(SAVE_PATH, "*")):
    os.remove(ximg)
    
# ================

np.random.seed(SEED)
    
with open(SAVED_URLS_FILE, "w") as fh:
    ix = 0
    idxs = iter(np.random.permutation(len(urls)))
    
    while ix < imgs_to_download:
        idx = next(idxs)
        
        try:
            req = requests.get(urls[idx], stream=True)
            
            if req.status_code == 200:
                print(f"{ix + 1}. Successfully downloaded {urls[idx]}")

                if urls[idx].endswith("jpg"):
                    filepath = os.path.join(SAVE_PATH, f"img{idx}.jpg")
                else:
                    filepath = os.path.join(SAVE_PATH, f"img{idx}.png")

                with open(filepath, "wb") as fimg:
                    req.raw.decode_content = True
                    shutil.copyfileobj(req.raw, fimg)

                # ================

                img = cv2.imread(filepath)

                if np.prod(img.shape[:2]) > np.prod(max_resolution):
                    cft = np.sqrt(np.prod(max_resolution)/np.prod(img.shape[:2]))
                    new_shape = (int(img.shape[0]*cft), int(img.shape[1]*cft))

                    print(f"Image has resolution {img.shape[0]} x {img.shape[1]}. Shrinked to {new_shape[0]} x {new_shape[1]}")

                    new_img = cv2.resize(img, (0, 0), fx=cft, fy=cft)
                    cv2.imwrite(filepath, new_img)

                fh.write(", ".join((filepath, urls[idx])) + "\n")
                ix += 1
            else:
                print(f"Failed to download {urls[idx]}")
            
        except:
            print(f"Failed to get URL {urls[idx]}")

1. Successfully downloaded https://lh3.googleusercontent.com/-z9i6VbFuOws/TCwxajH0dLI/AAAAAAAABN4/cGdCZCixLD0/s1600/
Image has resolution 1200 x 1600. Shrinked to 480 x 640
2. Successfully downloaded https://lh6.googleusercontent.com/-gpM3ulJoJZc/Rs9drifk4ZI/AAAAAAAAA8Q/fb_NtsWAR90/rj/
3. Successfully downloaded https://lh4.googleusercontent.com/-9YxLxk5LXXM/T6VOsJzKYiI/AAAAAAAACL0/04QSGNkyNUg/s1600/
Image has resolution 1060 x 1600. Shrinked to 451 x 680
4. Successfully downloaded http://lh5.ggpht.com/-oVsFDevOTaE/TKTlaN-cXPI/AAAAAAAAAeg/JlY7BXnTcdY/s1600/
Image has resolution 1200 x 1600. Shrinked to 480 x 640
5. Successfully downloaded https://lh6.googleusercontent.com/-0lm0Elikcn0/TVGnvHSc2hI/AAAAAAAABFU/ta7fNfAcRg0/s1600/
Image has resolution 1064 x 1600. Shrinked to 451 x 679
6. Successfully downloaded http://lh4.ggpht.com/-fd9sMTBcrCY/SWklzQexDlI/AAAAAAAAB38/wJ3J0RHG1Vg/s1600/
Image has resolution 1074 x 1600. Shrinked to 454 x 676
7. Successfully downloaded http://lh3.ggpht.com

Image has resolution 1600 x 1200. Shrinked to 640 x 480
51. Successfully downloaded http://lh6.ggpht.com/-tljd-0M37jo/ReDMb47WdKI/AAAAAAAAAQk/FEQ3SSFdAKs/s1600/
Image has resolution 768 x 1024. Shrinked to 480 x 640
52. Successfully downloaded https://lh3.googleusercontent.com/-g8KoiZ7c6ys/SCiX9E1uy6I/AAAAAAAAAkQ/UBaZ1YIH7m8/rj/
53. Successfully downloaded https://lh3.googleusercontent.com/-RRJ0Lnsv4ac/S9BCqu_QsPI/AAAAAAAAAcE/2wf3Z_YM9u8/s1600/
Image has resolution 1200 x 1600. Shrinked to 480 x 640
54. Successfully downloaded https://lh3.googleusercontent.com/-auqcf9zaJds/UNjfyNJmvSI/AAAAAAABXcw/QYJzPqjh8S4/s1600/
Image has resolution 1359 x 840. Shrinked to 704 x 435
55. Successfully downloaded https://lh6.googleusercontent.com/-5fClTSRJLPs/TA5bDdQG7nI/AAAAAAAAJeo/BSR6ikxH8xk/rj/
Failed to download http://mw2.google.com/mw-panoramio/photos/small/73584599.jpg
56. Successfully downloaded http://lh3.ggpht.com/-5QT-fQwgKLk/TE3yEPOxs1I/AAAAAAAAAG0/jqTlFb4FTdg/s1600/
Image has resolution 1

Image has resolution 1600 x 1200. Shrinked to 640 x 480
100. Successfully downloaded https://lh4.googleusercontent.com/-xgg5QipwMLw/SqeBLEH8wbI/AAAAAAAACZ0/1Lc069FYLA4/s1600/


## Experimenting

In [4]:
configs = {
    "q": [0.5],
    "reg_coef": np.logspace(-1.0, 1.0, 5),
    "rot": [5.0, 10.0, 20.0],
    "hard_rank_reduction": [True, False],
    "rank_mult": [0.8]
}

In [5]:
def rprod(xs):
    return reduce(mul, xs)

In [6]:
RESULTS_FILE = "seed17_n100.txt"

paths = glob.glob(os.path.join(SAVE_PATH, "*.png")) + glob.glob(os.path.join(SAVE_PATH, "*.jpg"))

if os.path.isfile(RESULTS_FILE):
    os.remove(RESULTS_FILE)

for (ix, ximg) in tqdm(enumerate(paths), total=len(paths)):
    img = iri.img2qm(ximg)
    tainted = (None, None)
    results = [None for _ in range(rprod(map(len, configs.values())))]
    
    for (jx, xcfg) in tqdm(enumerate(product(*list(zip(cycle((k, )), v) for (k, v) in configs.items()))), 
                           total=rprod(map(len, configs.values())), leave=False):
        cfg = dict(xcfg)
        xq = cfg.pop("q")
        
        if xq != tainted[0]:
            tainted = (xq, iri.add_random_missing_pixels(img, q=xq, mode="uniform", random_state=SEED))
            
        # ================
        
        rimg, U, _ = irl.lrqmc(mtr=tainted[1][0], mask=tainted[1][1], init_rank=100,
                               max_iter=100, rel_tol=1e-3, random_state=SEED, progress=False, **cfg)
        results[jx] = {
            **cfg,
            **{
                "img": ximg,
                "rank": U.shape[1],
                "norm": np.sqrt(np.power(rimg - img, 2).sum())
            }
        }
        
    # ================
        
    pd.DataFrame(results).to_csv(RESULTS_FILE, mode="a", header=(ix == 0))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

KeyboardInterrupt: 