In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/Users/gorjanradevski/PycharmProjects/scene_generation/src")

from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import math
import os
import json
from torch import nn
import torch
from tqdm.notebook import tqdm

from scene_layouts.datasets import DiscreteInferenceDataset, X_MASK, Y_MASK, O_MASK, BUCKET_SIZE
from scene_layouts.modeling import SpatialDiscreteBert, ClipartsPredictionModel
from scene_layouts.generation_strategies import generation_strategy_factory
from transformers import BertConfig, BertTokenizer
from torch.utils.data import Subset

In [None]:
visual2index = json.load(open("../data/visuals_dicts/visual2index.json"))
index2visual = {v: k for k, v in visual2index.items()}
dataset = DiscreteInferenceDataset("../data/test_dataset.json", visual2index)
pngs_path = "../data/AbstractScenes_v1.1/Pngs"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_name = "bert-base-uncased"
config = BertConfig.from_pretrained(bert_name)
config.vocab_size = len(visual2index) + 1
model = nn.DataParallel(SpatialDiscreteBert(config, bert_name)).to(device)
model.load_state_dict(torch.load("../models/discrete_15p.pt", map_location=device))
model.train(False)

In [None]:
dataset_key = 225

## Display original scene

In [None]:
input_ids_sentence, input_ids_visuals, _, _, _, x_labels, y_labels, o_labels = dataset[dataset_key]
# PRINT THE SENTENCE
sentence = []
print("==============================")
for word in tokenizer.convert_ids_to_tokens(input_ids_sentence):
    sentence.append(word)
    if word == ".":
        print(" ".join(sentence))
        sentence = []
print("==============================")
# PREPARE FOR VISUALIZATION
print(x_labels)
print(y_labels)
print(BUCKET_SIZE)
input_ids_visuals = input_ids_visuals.numpy()
visual_names = [index2visual[index] for index in input_ids_visuals if index in index2visual]
background = Image.open(os.path.join(pngs_path, "background.png"))
for visual_name, x_index, y_index, o_index in zip(visual_names, x_labels, y_labels, o_labels):
    image = Image.open(os.path.join(pngs_path, visual_name))
    print(visual_name, x_index.item(), y_index.item())
    if o_index == 1:
        print(f"Rotating 1 {visual_name}")
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    x_index = int(x_index.item() - image.size[0]//2)
    y_index = int(y_index.item() - image.size[1]//2)
    # Pasting the image
    background.paste(image, (x_index, y_index), image)
plt.figure(figsize=(13,8))
plt.axis("off")
plt.imshow(np.array(background))

# Generate full scene

In [None]:
input_ids_sentence, input_ids_visuals, x_indexes, y_indexes, o_indexes, x_labels, y_labels, _ = dataset[dataset_key]
print("==============================")
sentence = []
for word in tokenizer.convert_ids_to_tokens(input_ids_sentence):
    sentence.append(word)
    if word == ".":
        print(" ".join(sentence))
        sentence = []
print("==============================")

# input_ids_sentence = torch.tensor([tokenizer.cls_token_id, tokenizer.sep_token_id])

# ---------- PREPARE INPUTS FOR MODEL AND OBTAIN OUTPUTS ----------
text_pos = torch.arange(input_ids_sentence.size()[0])
input_ids_sentence = input_ids_sentence.unsqueeze(0)
input_ids_visuals = input_ids_visuals.unsqueeze(0)
print(f"The input ids are {input_ids_visuals}")
text_pos = text_pos.unsqueeze(0)
x_indexes = torch.tensor(x_indexes).unsqueeze(0)
y_indexes = torch.tensor(y_indexes).unsqueeze(0)
o_indexes = torch.tensor(o_indexes).unsqueeze(0)
t_types = torch.cat([torch.zeros_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)
attn_mask = torch.cat([torch.ones_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)

with torch.no_grad():
    x_out, y_out, o_out = generation_strategy_factory("highest_confidence", "discrete",
                                                      input_ids_sentence, input_ids_visuals,
                                                      text_pos, t_types, attn_mask, model, device)
    
# ---------- PLOTTING ----------
input_ids_visuals = input_ids_visuals.squeeze(0).cpu().numpy()
visual_names = [index2visual[index] for index in input_ids_visuals if index in index2visual]
background = Image.open(os.path.join(pngs_path, "background.png"))
print(visual_names)
for visual_name, x_index, y_index, o_index in zip(visual_names, x_out.squeeze(0).numpy(),
                                                  y_out.squeeze(0).numpy(), o_out.squeeze(0).numpy()):
    image = Image.open(os.path.join(pngs_path, visual_name))
    x_index = x_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[0]//2
    y_index = y_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[1]//2
    if o_index == 1:
        print(f"Rotating 1 {visual_name}")
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    # Pasting the image
    background.paste(image, (x_index, y_index), image)
plt.figure(figsize=(13,8))
plt.axis('off')
plt.imshow(np.array(background))

# Generate heat map

In [None]:
input_ids_sentence, input_ids_visuals, x_indexes, y_indexes, o_indexes, _, _, _ = dataset[dataset_key]
print("==============================")
sentence = []
for word in tokenizer.convert_ids_to_tokens(input_ids_sentence):
    sentence.append(word)
    if word == ".":
        print(" ".join(sentence))
        sentence = []
print("==============================")

masked_visual_position_index = 0
# --------- DISPLAY THE MASKED CLIPART -----------  
masked_name = index2visual[input_ids_visuals[masked_visual_position_index].item()]
image = Image.open(os.path.join(pngs_path, masked_name))
plt.imshow(image)

plt.figure(figsize=(13,8))

# input_ids_sentence = torch.tensor([tokenizer.cls_token_id, tokenizer.sep_token_id])


# ---------- PREPARE INPUTS FOR MODEL AND OBTAIN OUTPUTS ----------
x_indexes[masked_visual_position_index] = X_MASK
y_indexes[masked_visual_position_index] = Y_MASK
o_indexes[masked_visual_position_index] = O_MASK

text_pos = torch.arange(input_ids_sentence.size()[0])
input_ids_sentence = input_ids_sentence.unsqueeze(0)
input_ids_visuals = input_ids_visuals.unsqueeze(0)
print(f"The input ids are {input_ids_visuals}")
text_pos = text_pos.unsqueeze(0)
x_indexes = torch.tensor(x_indexes).unsqueeze(0)
y_indexes = torch.tensor(y_indexes).unsqueeze(0)
o_indexes = torch.tensor(o_indexes).unsqueeze(0)
t_types = torch.cat([torch.zeros_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)

with torch.no_grad():
    x_scores, y_scores, o_scores = model(input_ids_sentence.to(device), input_ids_visuals.to(device),
                                         text_pos.to(device), x_indexes.to(device), y_indexes.to(device),
                                         o_indexes.to(device), t_types.to(device))
    x_heats = torch.repeat_interleave(x_scores[0,masked_visual_position_index][:25], BUCKET_SIZE)
    y_heats = torch.repeat_interleave(y_scores[0,masked_visual_position_index][:20], BUCKET_SIZE)

# ---------- PLOTTING ----------
input_ids_visuals = input_ids_visuals.squeeze(0).cpu().numpy()
visual_names = [index2visual[index] for index in input_ids_visuals if index in index2visual]
background = Image.open(os.path.join(pngs_path, "background.png"))
print(visual_names)
for visual_name, x_index, y_index, o_index in zip(visual_names, x_indexes.squeeze(0).numpy(),
                                                  y_indexes.squeeze(0).numpy(), o_indexes.squeeze(0).numpy()):
    if visual_name == index2visual[input_ids_visuals[masked_visual_position_index]]:
        continue
    image = Image.open(os.path.join(pngs_path, visual_name))
    if o_index == 1:
        print(f"Rotating 1 {visual_name}")
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    # Pasting the image
    background.paste(image, (x_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[0]//2,
                             y_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[1]//2), image)
plt.imshow(np.array(background))

x_heats = x_heats.clone().unsqueeze(0).expand(background.size[1], background.size[0])
y_heats = y_heats.clone().unsqueeze(1).expand(background.size[1], background.size[0])
heats = torch.exp(x_heats) * torch.exp(y_heats) * 100
overlay = np.zeros((background.size[1], background.size[0], 4))
overlay[:,:,0] = 255
overlay[:,:,3] = heats
plt.axis('off')
plt.imshow(overlay)

## Generate single object

In [None]:
input_ids_sentence, input_ids_visuals, x_indexes, y_indexes, o_indexes, _, _, _ = dataset[dataset_key]
print("==============================")
sentence = []
for word in tokenizer.convert_ids_to_tokens(input_ids_sentence):
    sentence.append(word)
    if word == ".":
        print(" ".join(sentence))
        sentence = []
print("==============================")

masked_visual_position_index = 5
# --------- DISPLAY THE MASKED CLIPART -----------  
masked_name = index2visual[input_ids_visuals[masked_visual_position_index].item()]
image = Image.open(os.path.join(pngs_path, masked_name))
plt.imshow(image)

plt.figure(figsize=(13,8))

#input_ids_sentence = torch.tensor([tokenizer.cls_token_id, tokenizer.sep_token_id])


# ---------- PREPARE INPUTS FOR MODEL AND OBTAIN OUTPUTS ----------
x_indexes[masked_visual_position_index] = X_MASK
y_indexes[masked_visual_position_index] = Y_MASK
o_indexes[masked_visual_position_index] = O_MASK

num_sen_tokens = input_ids_sentence.size()[0]
text_pos = torch.arange(input_ids_sentence.size()[0])
input_ids_sentence = input_ids_sentence.unsqueeze(0)
input_ids_visuals = input_ids_visuals.unsqueeze(0)
print(f"The input ids are {input_ids_visuals}")
text_pos = text_pos.unsqueeze(0)
x_indexes = torch.tensor(x_indexes).unsqueeze(0)
y_indexes = torch.tensor(y_indexes).unsqueeze(0)
o_indexes = torch.tensor(o_indexes).unsqueeze(0)
t_types = torch.cat([torch.zeros_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)

with torch.no_grad():
    x_scores, y_scores, o_scores = model(input_ids_sentence.to(device), input_ids_visuals.to(device),
                                         text_pos.to(device), x_indexes.to(device), y_indexes.to(device),
                                         o_indexes.to(device), t_types.to(device))
    x_indexes[:, masked_visual_position_index] = torch.argmax(x_scores, dim=-1)[:, masked_visual_position_index]
    y_indexes[:, masked_visual_position_index] = torch.argmax(y_scores, dim=-1)[:, masked_visual_position_index]
    o_indexes[:, masked_visual_position_index] = torch.argmax(o_scores, dim=-1)[:, masked_visual_position_index]

    
print(x_indexes)
# ---------- PLOTTING ----------
input_ids_visuals = input_ids_visuals.squeeze(0).cpu().numpy()
visual_names = [index2visual[index] for index in input_ids_visuals if index in index2visual]
background = Image.open(os.path.join(pngs_path, "background.png"))
print(visual_names)
for visual_name, x_index, y_index, o_index in zip(visual_names, x_indexes.squeeze(0).numpy(),
                                                  y_indexes.squeeze(0).numpy(), o_indexes.squeeze(0).numpy()):
    image = Image.open(os.path.join(pngs_path, visual_name))
    if o_index == 1:
        print(f"Rotating 1 {visual_name}")
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    # Pasting the image
    background.paste(image, (x_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[0]//2,
                             y_index * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[1]//2), image)
plt.imshow(np.array(background))

## Predict cliparts on the Tan 2018 test dataset and then generate a scene

In [None]:
config.vocab_size = len(visual2index)
# CLIP-ARTS MODEL
cliparts_model = nn.DataParallel(ClipartsPredictionModel(config, bert_name)).to(device)
cliparts_model.load_state_dict(torch.load("../models/cliparts_pred_tan.pt", map_location=device))
cliparts_model.train(False)
# SR-BERT MODEL
config.vocab_size = len(visual2index) + 1
model = nn.DataParallel(SpatialDiscreteBert(config, bert_name)).to(device)
model.load_state_dict(torch.load("../models/discrete_15p.pt", map_location=device))
model.train(False)

In [None]:
prefixes = {"s": "0", "p": "1", "h": "2", "a": "3", "c": "4", "e": "5", "t": "6"}

In [None]:
input_ids_sentence, input_ids_visuals, _, _, _, x_labels, y_labels, _ = dataset[5]
print("==============================")
sentence = []
for word in tokenizer.convert_ids_to_tokens(input_ids_sentence):
    sentence.append(word)
    if word == ".":
        print(" ".join(sentence))
        sentence = []
print("==============================")

print(f"Ground truths: {input_ids_visuals}")

# ------------------------------ GET CLIPARTS PREDICTIONS ---------------------------
threshold = 0.35
attn_mask = torch.ones_like(input_ids_sentence)
probs = torch.sigmoid(cliparts_model(input_ids_sentence.unsqueeze(0), attn_mask.unsqueeze(0)))
one_hot_pred = torch.zeros_like(probs)
# Regular objects
one_hot_pred[:, :23][torch.where(probs[:, :23] > threshold)] = 1
one_hot_pred[:, 93:][torch.where(probs[:, 93:] > threshold)] = 1
# Mike and Jenny
max_hb0 = torch.argmax(probs[:, 23:58], axis=-1)
one_hot_pred[:, max_hb0 + 23] = 1
max_hb1 = torch.argmax(probs[:, 58:93], axis=-1)
one_hot_pred[:, max_hb1 + 58] = 1

# i+1 because in the visual2index the indices start from 1
input_ids_visuals = torch.tensor([i+1 for i in range(one_hot_pred.size()[1]) if one_hot_pred[0, i] == 1])

print(f"Predicted: {input_ids_visuals}")

# ---------- PREPARE INPUTS FOR MODEL AND OBTAIN OUTPUTS ----------
num_sen_tokens = input_ids_sentence.size()[0]
text_pos = torch.arange(input_ids_sentence.size()[0])
input_ids_sentence = input_ids_sentence.unsqueeze(0)
input_ids_visuals = input_ids_visuals.unsqueeze(0)
text_pos = text_pos.unsqueeze(0)
t_types = torch.cat([torch.zeros_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)
attn_mask = torch.cat([torch.ones_like(input_ids_sentence), torch.ones_like(input_ids_visuals)], dim=1)

with torch.no_grad():
    x_out, y_out, o_out = generation_strategy_factory("highest_confidence_beam", "discrete",
                                                      input_ids_sentence, input_ids_visuals,
                                                      text_pos, t_types, attn_mask, model, device)
    
# ---------- PLOTTING ----------
# Prepare visual names
input_ids_visuals = input_ids_visuals.squeeze(0).cpu().numpy()
visual_names = [index2visual[index] for index in input_ids_visuals if index in index2visual]
sort_indices = np.argsort([prefixes[visual_name[0]]+f"_{visual_name}" for visual_name in visual_names])
# Plot the stuff
background = Image.open(os.path.join(pngs_path, "background.png"))
for index in sort_indices:
    image = Image.open(os.path.join(pngs_path, visual_names[index]))
    x_index = x_out[0, index] * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[0]//2
    y_index = y_out[0, index] * BUCKET_SIZE + BUCKET_SIZE // 2 - image.size[1]//2
    if o_out[0, index] == 1:
        print(f"Rotating 1 {visual_name}")
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    # Pasting the image
    background.paste(image, (x_index, y_index), image)
plt.figure(figsize=(13,8))
plt.axis('off')
plt.imshow(np.array(background))