In [1]:
import os
import time
import shutil
import pickle

import torch
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from tensorboard_logger import configure, log_value

from model import RecurrentAttention
from utils import AverageMeter

from torchvision import transforms, utils, models
from PIL import Image

import torch

import utils

from trainer import Trainer
from config import get_config

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import pandas as pd

import json

In [2]:
def load_inputs(impath):
    to_tens = transforms.ToTensor()
    return to_tens(Image.open(impath).convert('RGB')).unsqueeze(0)

In [3]:
m = open("../../pooling/data/migration_data.json",)
mig_data = json.load(m)
m.close()
mig_data = pd.DataFrame.from_dict(mig_data, orient = 'index').reset_index()
mig_data.columns = ['muni_id', 'num_migrants']
q = 2
mig_data['class'] = pd.qcut(mig_data['num_migrants'], q = q, labels = [i for i in range(q)])
mig_data

Unnamed: 0,muni_id,num_migrants,class
0,484001001,42055.0,1
1,484001002,4017.0,1
2,484001003,11992.0,1
3,484001004,762.0,1
4,484001005,7551.0,1
...,...,...,...
2326,484032049,2487.0,1
2327,484032050,2024.0,1
2328,484032051,3084.0,1
2329,484032052,2919.0,1


In [4]:
mig_data['class'].value_counts()

0    1166
1    1165
Name: class, dtype: int64

In [33]:
def get_png_names(directory):
    images = []
    for i in os.listdir(directory):
        try:
            if os.path.isdir(os.path.join(directory, i)):
                new_path = os.path.join(directory, i, "pngs")
                image = os.listdir(new_path)[0]
                images.append(os.path.join(directory, i, "pngs", image))
        except:
            pass
    return images


            

image_names = get_png_names("../../attn/data/MEX/")

y_class, y_mig = [], []

for i in image_names:
        dta = mig_data[mig_data["muni_id"] == i.split("/")[5]]
        if len(dta) != 0:
            y_class.append(dta['class'].values[0])
            y_mig.append(dta['num_migrants'].values[0])

In [34]:
import matplotlib.pyplot as plt
import torchvision

In [35]:
import random

train_num = int(25 * .70)

train_indices = random.sample(range(0, 25), train_num)
val_indices = [i for i in range(0, 25) if i not in train_indices]

In [36]:
import torchvision

batch_size = 1

# train = [(load_inputs(image_paths[i]).squeeze()[:, 0:28, 0:28], ys[i]) for i in range(0, 93)]
# val = [(load_inputs(image_paths[i]).squeeze()[:, 0:28, 0:28], ys[i]) for i in range(93, 133)]

# brighten = torchvision.transforms.functional.adjust_brightness(brightness_factor = 2)


train = [(torchvision.transforms.functional.adjust_brightness(load_inputs(image_names[i]), brightness_factor = 2).squeeze(), y_class[i], y_mig[i]) for i in train_indices]
val = [(torchvision.transforms.functional.adjust_brightness(load_inputs(image_names[i]), brightness_factor = 2).squeeze(), y_class[i], y_mig[i]) for i in val_indices]


train_dl = torch.utils.data.DataLoader(train, batch_size = batch_size, shuffle = True)
val_dl = torch.utils.data.DataLoader(val, batch_size = batch_size, shuffle = True)

In [37]:
print("Num training: ", len(train_dl))
print("Num validation: ", len(val_dl))

Num training:  17
Num validation:  8


In [38]:
import numpy as np
from utils import plot_images

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

In [39]:
config, unparsed = get_config()

In [40]:
trainer = Trainer(config, (train_dl, val_dl))

In [49]:
checkpoint = torch.load("./ckpt/ram_4_50x50_0.75_model_best.pth.tar")
checkpoint = checkpoint["model_state"]

In [53]:
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from utils import denormalize, bounding_box

def denormalize(T, coords):
    return 0.5 * ((coords + 1.0) * T)

In [54]:
def exceeds(from_x, to_x, from_y, to_y, H, W):
    """Check whether the extracted patch will exceed
    the boundaries of the image of size `T`.
    """
    if (from_x < 0) or (from_y < 0) or (to_x > H) or (to_y > W):
        return True
    return False


def fix(from_x, to_x, from_y, to_y, H, W, size):

    """
    Check whether the extracted patch will exceed
    the boundaries of the image of size `T`.
    If it will exceed, make a list of the offending reasons and fix them
    """

    offenders = []

    if (from_x < 0):
        offenders.append("negative x")
    if from_y < 0:
        offenders.append("negative y")
    if from_x > H:
        offenders.append("from_x exceeds h")            
    if to_x > H:
        offenders.append("to_x exceeds h")
    if from_y > W:
        offenders.append("from_y exceeds w")
    if to_y > W:
        offenders.append("to_y exceeds w")            


    if ("from_y exceeds w" in offenders) or ("to_y exceeds w" in offenders):
        from_y, to_y = W - size, W

    if ("from_x exceeds h" in offenders) or ("to_x exceeds h" in offenders):
        from_x, to_x = H - size, H     

    elif ("negative x" in offenders):
        from_x, to_x = 0, 0 + size

    elif ("negative y" in offenders):
        from_y, to_y = 0, 0 + size            

    return from_x, to_x, from_y, to_y

In [58]:
locations_dict = {}

for i in image_names[0:5]:
    
    print(i)
    
    muni_id = i.split("/")[5]
    image = load_inputs(i)
    locations = trainer.extract_locations(image, checkpoint)
    
    start = [denormalize(image.shape[2], l) for l in locations]
    start = torch.cat([start[l].unsqueeze(0) for l in range(4)])
        
    B, C, H, W = image.shape
    
    size = int(min(H, W) / 5)
    
    end = start + size
    
#     print(start, end)
    
    coords_dict = {}
    
    for c in range(0, len(start)):
        
        from_coords = start[c]
        to_coords = end[c]
        
        from_x = from_coords[0][0].item()
        from_y = from_coords[0][1].item()
        
        to_x = to_coords[0][0].item()
        to_y = to_coords[0][1] .item()   
        
        if exceeds(from_x = from_x, to_x = to_x, from_y = from_y, to_y = to_y, H = H, W = W):
        
            from_x, to_x, from_y, to_y = fix(from_x = from_x, to_x = to_x, from_y = from_y, to_y = to_y, H = H, W = W, size = size)
        
            print("yes")
        
        coords_dict['glimpse' + str(c)] = {'from_x': from_x, 'to_x': to_x, 'from_y': from_y, 'to_y': to_y}
    
    locations_dict[i] = coords_dict

../../attn/data/MEX/484001008/pngs/484001008_2010_all_box484001008_MAY.png
yes
yes
yes
../../attn/data/MEX/484004004/pngs/484004004_2010_all_box484004004_MAY.png
yes
yes
yes
../../attn/data/MEX/484004001/pngs/484004001_2010_all_box484004001_MAY.png
yes
yes
yes
../../attn/data/MEX/484005010/pngs/484005010_2010_all_box484005010_MAY.png
yes
yes
yes
../../attn/data/MEX/484001005/pngs/484001005_2010_all_box484001005_MAY.png
yes
yes
yes
yes


In [59]:
locations_dict

{'../../attn/data/MEX/484001008/pngs/484001008_2010_all_box484001008_MAY.png': {'glimpse0': {'from_x': 767,
   'to_x': 881,
   'from_y': 0.0,
   'to_y': 114.0},
  'glimpse1': {'from_x': 678.873291015625,
   'to_x': 792.873291015625,
   'from_y': 81.3102035522461,
   'to_y': 195.31021118164062},
  'glimpse2': {'from_x': 767,
   'to_x': 881,
   'from_y': 109.43223571777344,
   'to_y': 223.43223571777344},
  'glimpse3': {'from_x': 767, 'to_x': 881, 'from_y': 0.0, 'to_y': 114.0}},
 '../../attn/data/MEX/484004004/pngs/484004004_2010_all_box484004004_MAY.png': {'glimpse0': {'from_x': 1056,
   'to_x': 1320,
   'from_y': 0.0,
   'to_y': 264.0},
  'glimpse1': {'from_x': 1056,
   'to_x': 1320,
   'from_y': 346.80194091796875,
   'to_y': 610.8019409179688},
  'glimpse2': {'from_x': 1056, 'to_x': 1320, 'from_y': 0.0, 'to_y': 264.0},
  'glimpse3': {'from_x': 859.859619140625,
   'to_x': 1123.859619140625,
   'from_y': 129.07077026367188,
   'to_y': 393.0707702636719}},
 '../../attn/data/MEX/4840040

In [44]:
# torch.rand(2, 256).flatten(start_dim = 0).unsqueeze(0).shape

In [45]:
with open("sample4.json", "w") as outfile: 
    json.dump(locations_dict, outfile)

In [18]:
# test = torch.rand(2, 256)
# patch_dict = {}

# for patch in range(0, test.shape[0]):
#     cur_vals = list(test[patch].numpy())
#     patch_dict[patch] = [str(i) for i in cur_vals]
    
# with open("sample.json", "w") as outfile: 
#     json.dump(patch_dict, outfile)