In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import argparse
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
from models.soft_shift_net.soft_shiftnet_model import ShiftNetModel
from options.train_options import TrainOptions 
from models import networks
import numpy as np
import warnings
warnings.simplefilter("ignore")

In [None]:
import util
from util.util import *
from collections import namedtuple 

In [None]:
optClass = namedtuple('Options', ['fineSize'])

In [None]:
opt = optClass(fineSize=256)
opt.fineSize

In [None]:
%time mask = wrapper_gmask(opt)

In [None]:
plt.imshow(np.squeeze(mask))

In [None]:
import skimage
from skimage.transform import resize

In [None]:
masks = []
for _ in range(1000):
    mask = wrapper_gmask(opt).cpu().numpy()
    masks.append(resize(np.squeeze(mask), (64, 64)))
masks = np.array(masks)

In [None]:
#masks = masks.reshape((1000, -1))

In [None]:
mean = np.mean(masks) 
masks[masks >= mean] = 1
masks[masks < mean] = 0

In [None]:
masks.shape

In [None]:
plt.imshow(masks[0])

In [None]:
masks = masks.astype(np.int)
masks = masks.reshape((1000, -1))

In [None]:
class OptimizerMask:
    
    def __init__(self, masks, stop_criteria=0.85):
        self.masks = masks
        self.indexes = []
        self.stop_criteria = stop_criteria
        
        
    def get_iou(self):
        intersection = np.matmul(masks, masks.T)
        diag = np.diag(intersection)
        outer_add = np.add.outer(diag, diag)
        self.iou = intersection / outer_add 
        self.shape = self.iou.shape
        
    def _is_finished(self):
        masks = self.masks[self.indexes]
        #print(masks.shape)
        masks = np.sum(masks, axis=0)
        #print(masks.shape)
        masks[masks > 0] = 1
        plt.imshow(masks.reshape((64, 64)))
        area_coverage = np.sum(masks)/np.product(masks.shape)
        #print(area_coverage)
        if area_coverage < self.stop_criteria:
            return False
        else:
            return True
        
    def mean(self):
        _mean = np.mean(np.sum(self.masks[self.indexes], axis=-1))/(64*64)
        print(_mean)
        
    def _get_next_indexes(self):
        ious = self.iou[self.indexes]
        _mean_iou = np.mean(ious, axis=0)
        idx = np.argmin(_mean_iou)
        #print(idx)
        self.indexes = np.append(self.indexes, np.argmin(_mean_iou))
        
    def _solve(self):
        self.indexes = list(np.unravel_index(np.argmin(self.iou), self.shape))
        #print(self.indexes)
        while not self._is_finished():
            self._get_next_indexes()
            
    def get_masks(self):
        masks = self.masks[self.indexes]
        full = np.ones_like(masks[0])
        left = full - np.mean(masks, axis=0)
        return np.append(masks, left).reshape((-1, 64, 64))
    
    def solve(self):
        self._solve()
        

In [None]:
opti = OptimizerMask(masks)

In [None]:
opti.get_iou()

In [None]:
opti.solve()

In [None]:
opti.mean()

In [None]:
output = opti.get_masks()

In [None]:
output.shape

In [None]:
masks = np.array([resize(mask, (256, 256)) for mask in output])

In [None]:
masks[masks > 0] = 1

In [None]:
masks[-1] = 1 - np.max(masks[:-1], axis=0)

In [None]:
masks[-1]

In [None]:
plt.imshow(masks[-1])

In [None]:
from skimage.morphology import *

In [None]:
masks[-1] = dilation(masks[-1], diamond(5))

In [None]:
plt.imshow(masks[-1])

# DEFINE THE MODEL

In [None]:
dataroot = '/mnt/hdd2/AIM/DAGM/Class8' # ENTER HERE THE PATH YOU WANT TO USE AS DATAROOT
which_model_netG = 'acc_unet_shift_triple'
add_mask2input = 'True'
model = 'soft_shiftnet'
dataset_mode = 'aligned_resized'
options = '--dataroot {} --which_model_netG {} --add_mask2input {} --model {} --dataset_mode {}'.format(dataroot, which_model_netG, add_mask2input, model, dataset_mode).split(' ')

In [None]:
def get_parser(options=None):
    parser = TrainOptions()
    parser.parse(options=options)
    return parser

In [None]:
parser = get_parser(options=options)
opt = parser.opt

In [None]:
from models import create_model
from collections import OrderedDict

In [None]:
model = create_model(opt)

In [None]:
os.listdir('/mnt/hdd2/AIM/checkpoints/')

In [None]:
path_weights = '/mnt/hdd2/AIM/checkpoints/17_12_2018_Class8_0/latest_net_G.pth'

In [None]:
weights = torch.load(path_weights)

In [None]:
list(weights.keys())[0]

In [None]:
new_state = OrderedDict()
for k in weights.keys():
    new_k = 'module.'+k
    new_state[new_k] = weights[k]

In [None]:
list(model.netG.state_dict().keys())[0]

In [None]:
model.netG.load_state_dict(new_state, strict=True)

# CREATE DATASET

In [None]:
from data.data_loader import CreateDataLoader
from util.util import tensor2im, hist_match

In [None]:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

In [None]:
def mask2tensor(mask):
    return torch.ByteTensor(mask[np.newaxis, ...][np.newaxis, ...]).cuda()


In [None]:
max_error = []
for i, data in enumerate(dataset):
    
    #print('i {}'.format(i))
    img = data['A']
    print(img.min(), img.max())
    img_A = tensor2im(data['A'])
    p = data['A_paths']
    print(p)
    print(img_A.min(), img_A.max())
    plt.imshow(img_A)
    plt.show()
    fake_holder = np.zeros((256, 256, 3))
    fake_sum = np.zeros((256, 256, 1))
    
    for mask in masks:
        model.set_input_with_mask(data, mask2tensor(mask))
        model.forward()

        fake_B = model.fake_B
        fake_B = tensor2im(fake_B)
        mask = mask[..., np.newaxis]
        m = np.tile(mask, (1, 1, 3))
        fake_holder[m == 1] += fake_B[m == 1]
        fake_sum += mask
    
    img_A = img_A.astype(np.float)
    img_A-=img_A.min()
    img_A/=img_A.max()
    print(img_A.min(), img_A.max())
    rec = fake_holder/fake_sum.astype(np.float)
    print(rec.min(), rec.max())
    rec-=rec.min()
    rec/=rec.max()    
    rec = hist_match(rec, img_A)
    print(rec.min(), rec.max())
    img_A = img_A.astype(np.float)
    
    plt.imshow(rec)
    plt.show()
    
    diff = np.abs(img_A - rec).astype(np.float)
    max_error.append([p, np.max(diff)])
    print(np.max(diff))
    #diff/=np.max(diff)
    plt.imshow(diff)
    plt.show()
    #break

In [None]:
plt.hist(np.array(max_error)[:, 1])

In [None]:
rec = fake_holder/fake_sum.astype(np.float)

In [None]:
rec-=rec.min()
rec/=rec.max()

In [None]:
img_A = img_A.astype(np.float)
img_A-=img_A.min()
img_A/=img_A.max()

In [None]:
plt.imshow(img_A)

In [None]:
plt.imshow(rec)

In [None]:
plt.imshow(np.abs(img_A - rec))

In [None]:
plt.imshow(np.squeeze(fake_sum))

In [None]:
fake_sum.max()