In [1]:
import sys

In [2]:
import re
import torch
import itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
#For some setups its neccessary to allow tensorflow to allocate gpu memory
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
physical_devices

In [None]:
from .tf1_model_code.mislnet_model import prefeat_CompareNet_v1, MISLNet128 as MISLNet

In [None]:
f_weights_restore = '/home/tai/1-workdir/3-owen-forensic-graph/models/cam_128/-30' #path to model CNN weights

In [None]:
# turn off eager execution
tf.compat.v1.disable_eager_execution()
# reset tf
tf.compat.v1.reset_default_graph()
# PLACE HOLDERS
x = tf.compat.v1.placeholder(tf.float32, shape=[None, 128, 128, 3], name="input_data")
f1 = tf.compat.v1.placeholder(tf.float32, shape=[None, 200], name="feature1")
f2 = tf.compat.v1.placeholder(tf.float32, shape=[None, 200], name="feature2")
MISL_phase = tf.compat.v1.placeholder(tf.bool, name="phase")

mislnet_feats = MISLNet(x, MISL_phase, nprefilt=6)
mislnet_compare = prefeat_CompareNet_v1(f1, f2)

mislnet_restore = tf.compat.v1.train.Saver()

tf1_var_val = {}

with tf.compat.v1.Session() as sess:
    mislnet_restore.restore(sess, f_weights_restore)  # load pretrained network
    vars = [var for var in tf.compat.v1.global_variables() if ('MISLNet' in var.name or 'CompareNet' in var.name)]
    print(vars) #some infos about variables...
    vars_vals = sess.run(vars)
    for var, val in zip(vars, vars_vals):
        # print("var: {}, value: {}".format(var.name, val))
        tf1_var_val[var.name] = {
            "shape": var.shape.as_list(),
            "value": val,
        }

    sess.close()

In [None]:
from collections import OrderedDict

In [None]:
import torch
import torch.nn.functional as F

In [None]:
class MISLNet(torch.nn.Module):
    def __init__(self, num_pre_filters=6):
        super().__init__()
        self.weights_cstr = torch.nn.Parameter(torch.randn(num_pre_filters, 3, 5, 5))

        self.conv1 = torch.nn.Conv2d(num_pre_filters, 96, kernel_size=7, stride=2, padding="valid")
        self.bn1 = torch.nn.BatchNorm2d(96, momentum=0.99, eps=0.0001)
        self.tanh1 = torch.nn.Tanh()
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)
        
        self.conv2 = torch.nn.Conv2d(96, 64, kernel_size=5, stride=1, padding="same")
        self.bn2 = torch.nn.BatchNorm2d(64, momentum=0.99, eps=0.0001)
        self.tanh2 = torch.nn.Tanh()
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)

        self.conv3 = torch.nn.Conv2d(64, 64, kernel_size=5, stride=1, padding="same")
        self.bn3 = torch.nn.BatchNorm2d(64, momentum=0.99, eps=0.0001)
        self.tanh3 = torch.nn.Tanh()
        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)

        self.conv4 = torch.nn.Conv2d(64, 128, kernel_size=1, stride=1, padding="same")
        self.bn4 = torch.nn.BatchNorm2d(128, momentum=0.99, eps=0.0001)
        self.tanh4 = torch.nn.Tanh()
        self.maxpool4 = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2)

        self.fc1 = torch.nn.Linear(2 * 2 * 128, 200)
        self.tanh_fc1 = torch.nn.Tanh()
        self.fc2 = torch.nn.Linear(200, 200)
        self.tanh_fc2 = torch.nn.Tanh()

    def forward(self, x):
        constr_conv = F.conv2d(x, self.weights_cstr, padding="valid")
        constr_conv = F.pad(constr_conv, (2, 3, 2, 3))
        
        conv1_out = self.maxpool1(self.tanh1(self.bn1(self.conv1(constr_conv))))
        conv2_out = self.maxpool2(self.tanh2(self.bn2(self.conv2(conv1_out))))
        conv3_out = self.maxpool3(self.tanh3(self.bn3(self.conv3(conv2_out))))
        conv4_out = self.maxpool4(self.tanh4(self.bn4(self.conv4(conv3_out))))

        # tf reshape has differerent order.
        conv4_out = conv4_out.permute(0, 2, 3, 1)
        conv4_out = conv4_out.flatten(1, -1)

        dense1_out = self.tanh_fc1(self.fc1(conv4_out))
        dense2_out = self.tanh_fc2(self.fc2(dense1_out))

        return dense2_out

In [None]:
class CompareNet(torch.nn.Module):
    def __init__(self, input_dim=200, map1_dim=2048, map2_dim=64):
        super().__init__()
        
        self.fc1 = torch.nn.Linear(input_dim, map1_dim)
        self.relu_fc1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(map1_dim*3, map2_dim)
        self.relu_fc2 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(map2_dim, 2)

    def forward(self, x):
        x1, x2 = x
        m1_x1 = self.relu_fc1(self.fc1(x1))
        m1_x2 = self.relu_fc1(self.fc1(x2))

        x1x2_mult = m1_x1 * m1_x2
        x1x2_concat = torch.concat([m1_x1, x1x2_mult, m1_x2], dim=1)

        m2 = self.relu_fc2(self.fc2(x1x2_concat))
        out = self.fc3(m2)

        return out

In [None]:
class FSG(torch.nn.Module):
    def __init__(self, num_pre_filters=6, input_dim=200, map1_dim=2048, map2_dim=64):
        super().__init__()
        self.mislnet = MISLNet(num_pre_filters)
        self.comparenet = CompareNet(input_dim, map1_dim, map2_dim)

    def forward(self, x):
        x1, x2 = x

        x1 = self.mislnet(x1)
        x2 = self.mislnet(x2)

        out = self.comparenet([x1, x2])

        return out

In [None]:
fsg = FSG()

In [None]:
# list(fsg.state_dict().keys())
# list(tf1_var_val.keys())

In [None]:
import copy
fsg_state_dict = copy.deepcopy((fsg.state_dict()))
for key in list(fsg_state_dict.keys()):
    if 'num_batches_tracked' in key:
        del fsg_state_dict[key]
fsg_torch_to_tf1_state_dict_key_mapping = dict(zip(list(fsg_state_dict.keys()), list(tf1_var_val.keys())))
fsg_torch_to_tf1_state_dict_key_mapping

In [None]:
fsg_torch_state_dict = OrderedDict()
for torch_key, tf1_key in fsg_torch_to_tf1_state_dict_key_mapping.items():
    torch_shape = list(fsg.state_dict()[torch_key].shape)
    tf1_shape = list(tf1_var_val[tf1_key]["value"].shape)

    perm_tf1_to_torch = list(range(len(tf1_shape)))
    perm_tf1_to_torch.reverse()

    tf1_val = torch.from_numpy(tf1_var_val[tf1_key]["value"])

    if len(re.findall(r"conv\d+\.weight", torch_key)) > 0:
        tf1_val = tf1_val.permute(3, 2, 0, 1)
    elif len(re.findall(r"fc\d+\.weight", torch_key)) > 0:
        tf1_val = tf1_val.permute(1, 0)
    elif "weights_cstr" in torch_key:
        tf1_val = tf1_val.permute(3, 2, 0, 1)

    fsg_torch_state_dict[torch_key] = tf1_val

In [None]:
fsg.load_state_dict(fsg_torch_state_dict)
fsg = fsg.eval()

In [None]:
img_path = "test_images/img_demo/splicing-01.TIF"

In [None]:
img_plt = plt.imread(img_path)[:,:,:3]
plt.imshow(img_plt)

In [None]:
img = torch.tensor(img_plt).permute(2,0,1)

In [None]:
kernel_size, stride = 128, 128-96

In [None]:
patches = img.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride).permute(1, 2, 0, 3, 4)
patches = patches.contiguous().view(-1, 3, kernel_size, kernel_size)

In [None]:
def batch_fn(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx : min(ndx + n, l)]

In [None]:
patches_features = []
fsg = fsg.cuda()
for batch in tqdm(batch_fn(patches, 128)):
    batch = batch.float().cuda()
    feats = fsg.mislnet(batch).detach().cpu()
    patches_features.append(feats)
patches_features = torch.vstack(patches_features)

In [None]:
patch_cart_prod = torch.cartesian_prod(torch.arange(patches.shape[0]), torch.arange(patches.shape[0]))

In [None]:
patches_sim_score = patches_features[patch_cart_prod]

In [None]:
patches_sim_scores = []
for batch in tqdm(batch_fn(patches_sim_score, 256)):
    batch = batch.permute(1,0,2).float().cuda()
    scores = torch.nn.functional.softmax(fsg.comparenet(batch), dim=1).detach().cpu()
    patches_sim_scores.append(scores)
patches_sim_scores = torch.vstack(patches_sim_scores)

In [None]:
sim_mat = patches_sim_scores[:, 1].reshape(len(patches), len(patches))
sim_mat = 0.5*(sim_mat + sim_mat.T)
sim_mat.fill_diagonal_(1.0)
sim_mat = sim_mat.numpy()

In [None]:
from src.spectral_utils import laplacian, eigap01, spectral_cluster
from src.localization import PatchLocalization, pixel_loc_from_patch_pred

In [None]:
L = laplacian(sim_mat) #laplacian matrix
gap = eigap01(L) #spectral gap
print(f'Spectral Gap = {gap:.2f}')

normL = laplacian(sim_mat, laplacian_type='sym') #normalized laplacian matrix
normgap = eigap01(normL) #normalized spectral gap
print(f'Normalized Spectral Gap = {normgap:.4f}')

In [None]:
patches_features.shape

In [None]:
x_inds = torch.arange(img_plt.shape[1]).unfold(0, kernel_size, stride)[:, 0]
y_inds = torch.arange(img_plt.shape[0]).unfold(0, kernel_size, stride)[:, 0]
xy_inds = [
    (ii, jj)
    for jj in y_inds
    for ii in x_inds
]

In [None]:
prediction = spectral_cluster(normL)

pat_loc = PatchLocalization(
    inds = xy_inds, 
    patch_size = 128,
    prediction = ~prediction)
f = pat_loc.plot_heatmap(image=img_plt, label=0)
#here we flip the label for easier visualization..
#note the label=0 in the line above
#and the ~pat_loc.prediction in the line below
pix_loc = pixel_loc_from_patch_pred(
    prediction=~pat_loc.prediction,
    inds = xy_inds,
    patch_size = 128,
    image_shape = img_plt.shape[:2],
    threshold = 0.45
)

pix_loc.plot(image=img_plt)

In [None]:
# torch.save(fsg.state_dict(), "fsg_image_pytorch_from_tf1.pt")

In [None]:
# fsg_reload = FSG()
# fsg_reload.load_state_dict(torch.load("fsg_image_pytorch_from_tf1.pt"))
# fsg_reload = fsg_reload.eval()

In [None]:
torch.manual_seed(0)
batch_size = 2
x1 = torch.randint(0, 255, (batch_size, 3, 128, 128)).float()
x2 = torch.randint(0, 255, (batch_size, 3, 128, 128)).float()

In [None]:
# fsg([patches[0].float().unsqueeze(0), patches[1].float().unsqueeze(0)])

In [None]:
# fsg_reload([x1, x2])

In [None]:
with tf.compat.v1.Session() as sess:
    mislnet_restore.restore(sess, f_weights_restore)  # load pretrained network
    tf1_x1 = sess.run(mislnet_feats, feed_dict={x: x1.permute(0, 2, 3, 1), MISL_phase: False})
    tf1_x2 = sess.run(mislnet_feats, feed_dict={x: x2.permute(0, 2, 3, 1), MISL_phase: False})
    tf1_out = sess.run(mislnet_compare, feed_dict={f1: tf1_x1, f2: tf1_x2})

    print(tf1_out)