In [None]:
from pathlib import Path
import sys
import os


current_file_path = Path(os.path.abspath('')).resolve()
# print(current_file_path.parent)
sys.path.insert(0, str(current_file_path.parent))

In [20]:
from tqdm import tqdm
from PIL import Image, ImageFile
from src.utils.misc import read_config
from src.slot_attention import UOD

import matplotlib.pyplot as plt
from diffusers.utils import make_image_grid
import numpy as np
import os
from torchvision import transforms
import torchvision.transforms.functional as TF
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import pickle
import torchvision.transforms as T
import json

from transformers import T5EncoderModel, T5Tokenizer


In [41]:
from train_scripts.train_editor import SlotDataset
from train_scripts.train_editor import TrainingConfig, SlotEditor, unnormalize, image_from_tensor

In [None]:
predict_residual = True
val_pickle_file:str = "/home/cse/btech/cs1210561/scratch/SA/output/clevr_run_res112_try_cont/val_data_latest_wo_112_res_cont.pickle"
val_image_root:str = '/home/scai/phd/aiz228170/scratch/Datasets/CIM-NLI/combined/valid' 
val_json_file:str = '/home/cse/btech/cs1210561/scratch/combined/valid/CLEVR_questions.json'
use_saved_t5 = False
device = "cuda"


checkpoint_to_vis = "/home/cse/btech/cs1210561/scratch/SA/editor_runs/cont_512_4_8_4_300_4096_64_64_1_4e-4_1000_on_train_l2_clip_1_cosine_lr_200ep/checkpoints/epoch_72_step_121105.pth"

In [None]:
def save_image_with_caption(image_array, caption, save_path):
    """Saves an image with a caption overlaid at the bottom."""
    image = Image.fromarray(image_array)
    
    img_width, img_height = image.size
    caption_height = 40  
    new_image = Image.new("RGB", (img_width, img_height + caption_height), (255, 255, 255))
    new_image.paste(image, (0, 0))

    draw = ImageDraw.Draw(new_image)
    font_size = 20
    try:
        font = ImageFont.truetype("arial.ttf", font_size)
    except IOError:
        font = ImageFont.load_default()

    bbox = draw.textbbox((0, 0), caption, font=font)
    text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
    
    text_x = (img_width - text_width) // 2
    text_y = img_height + (caption_height - text_height) // 2

    draw.text((text_x, text_y), caption, fill="black", font=font)
    
    new_image.save(save_path)


In [54]:

@torch.inference_mode()
def log_validation(model, vis_model, val_batch, global_step, device, samples_2_show=4, save_dir = "saved_images_on_val_all",text_encoder=None, tokenizer=None):
    # print("Logging Visualization")
    torch.cuda.empty_cache()
    model = model.eval()
    vis_model = vis_model.eval()

    input_slots = val_batch['input_slots'][:samples_2_show].to(device)
    target_slots = val_batch['output_slots'][:samples_2_show].to(device)

    if use_saved_t5 and (text_encoder is None and tokenizer is None):
        y = val_batch['text_emb'][:samples_2_show].to(device).squeeze(1)
        y_mask = val_batch['emb_mask'][:samples_2_show].to(device).squeeze(1)
    else:
        max_length = config.model_max_length
        txt_tokens = tokenizer(val_batch['edit_prompt'][:samples_2_show], max_length=max_length, padding='max_length', truncation=True,
        return_tensors='pt').to(device)
        
        y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:,None]
        y_mask = y_mask = txt_tokens.attention_mask[:, None, None]

    src_images = unnormalize(val_batch['image_name'][:samples_2_show])
    tgt_images = unnormalize(val_batch['out_image_name'][:samples_2_show])

    img_stack = [np.hstack((np.asarray(image_from_tensor(og)),np.asarray(image_from_tensor(tgt)))) for og,tgt in zip(src_images, tgt_images)]


    output_slots = model(input_slots, y, y_mask)
    if predict_residual:
        output_slots  = input_slots + output_slots
    # print(output_slots.shape)

    torch.cuda.empty_cache()

    dec_output = vis_model.slot_to_img(output_slots)
    dec_img_gt_slots =  vis_model.slot_to_img(target_slots)

    gen_tensor = unnormalize(dec_output['generated'])
    gt_gen_tensor = unnormalize(dec_img_gt_slots['generated'])
    
    bs = output_slots.shape[0]
    num_slots = output_slots.shape[1]
    
    os.makedirs(save_dir, exist_ok=True)

    import wandb
    # formatted_images = []
    for idx in range(bs):
        istack = img_stack[idx]
        gt_reconstructed_image = image_from_tensor(gt_gen_tensor[idx])
        image_reconstructed = image_from_tensor(gen_tensor[idx])
        slot_masks = np.hstack([Image.fromarray((dec_output['masks'][idx][k].squeeze().cpu().detach().numpy()*255).astype(np.uint8)).convert('RGB') for k in range(num_slots)])
        final_img = np.hstack((istack, gt_reconstructed_image, image_reconstructed, slot_masks))
        # image = wandb.Image(final_img, caption=val_batch['edit_prompt'][idx])
        # formatted_images.append(image)
        image_path = os.path.join(save_dir, f"validation_{global_step}_{idx}.png")
        save_image_with_caption(final_img, val_batch['edit_prompt'][idx], image_path)
    

    
    # del vis_model
    # return image_logs
    return ((gt_gen_tensor-gen_tensor)**2).mean()




In [30]:
config = TrainingConfig()
model = SlotEditor(
            hidden_size=config.hidden_size,
            depth=config.depth,
            num_heads=config.num_heads,
            mlp_ratio=config.mlp_ratio,
            model_max_length=config.model_max_length,
            caption_channels=config.caption_channels,
            slot_dim=config.slot_dim
        )

sa_config = read_config(config.config_file)
sa_model = UOD(sa_config)

ckpt = torch.load(config.checkpoint_path)
sa_model.load_state_dict(ckpt['state_dict'])
for param in sa_model.parameters():
    param.requires_grad = False


tokenizer = text_encoder = None
if not use_saved_t5:
    tokenizer = T5Tokenizer.from_pretrained(config.pipeline_load_from, subfolder="tokenizer")
    text_encoder = T5EncoderModel.from_pretrained(
        config.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(device)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:

ckpt_file = checkpoint_to_vis
checkpoint = torch.load(ckpt_file, map_location="cpu")

state_dict = checkpoint.get('state_dict', checkpoint)

missing, unexpect = model.load_state_dict(state_dict, strict=False)


model = model.to(device)
sa_model = sa_model.to(device)

In [None]:
val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = None)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)

mse = 0.0
steps = 0
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    mse+=log_validation(model, sa_model, batch, step, device=device, samples_2_show=2,save_dir = "saved_images_on_val_all", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

print("Avg MSE:",mse/steps)

Step:  0 , MSE : 0.0
Step:  10 , MSE : tensor(0.0062, device='cuda:0')
Step:  20 , MSE : tensor(0.0157, device='cuda:0')
Step:  30 , MSE : tensor(0.0242, device='cuda:0')
Step:  40 , MSE : tensor(0.0321, device='cuda:0')
Step:  50 , MSE : tensor(0.0398, device='cuda:0')
Step:  60 , MSE : tensor(0.0458, device='cuda:0')
Step:  70 , MSE : tensor(0.0531, device='cuda:0')
Step:  80 , MSE : tensor(0.0608, device='cuda:0')
Step:  90 , MSE : tensor(0.0774, device='cuda:0')
Step:  100 , MSE : tensor(0.0957, device='cuda:0')
Step:  110 , MSE : tensor(0.1147, device='cuda:0')
Step:  120 , MSE : tensor(0.1373, device='cuda:0')
Step:  130 , MSE : tensor(0.1513, device='cuda:0')
Step:  140 , MSE : tensor(0.1678, device='cuda:0')
Step:  150 , MSE : tensor(0.1877, device='cuda:0')
Step:  160 , MSE : tensor(0.1948, device='cuda:0')
Step:  170 , MSE : tensor(0.2012, device='cuda:0')
Step:  180 , MSE : tensor(0.2093, device='cuda:0')
Step:  190 , MSE : tensor(0.2166, device='cuda:0')
Step:  200 , MSE : 

In [None]:
val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = 0)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)
mse = 0.0
steps = 0
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    
    #Wrong name for save dir(it should be train)
    mse+=log_validation(model, sa_model, batch, steps, device=device, samples_2_show=2,save_dir = "saved_images_on_val_01hop", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = 1)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    mse+=log_validation(model, sa_model, batch, steps, device=device, samples_2_show=2,save_dir = "saved_images_on_val_01hops", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

print("Avg MSE for 0-1 hop data:",mse/steps)

Step:  0 , MSE : 0.0
Step:  10 , MSE : tensor(0.0092, device='cuda:0')
Step:  20 , MSE : tensor(0.0174, device='cuda:0')
Step:  30 , MSE : tensor(0.0250, device='cuda:0')
Step:  0 , MSE : tensor(0.0305, device='cuda:0')
Step:  10 , MSE : tensor(0.0403, device='cuda:0')
Step:  20 , MSE : tensor(0.0513, device='cuda:0')
Step:  30 , MSE : tensor(0.0674, device='cuda:0')
Step:  40 , MSE : tensor(0.0820, device='cuda:0')
Step:  50 , MSE : tensor(0.0963, device='cuda:0')
Step:  60 , MSE : tensor(0.1073, device='cuda:0')
Avg MSE for 0-1 hop data: tensor(0.0011, device='cuda:0')


In [60]:
checkpoint_to_vis = "/home/cse/btech/cs1210561/scratch/SA/editor_runs/cont_512_4_8_4_300_4096_64_64_1_4e-4_1000_on_val_l2_clip_1_cosine_lr_1000ep/checkpoints/epoch_1000_step_452001.pth"


ckpt_file = checkpoint_to_vis
checkpoint = torch.load(ckpt_file, map_location="cpu")

state_dict = checkpoint.get('state_dict', checkpoint)

missing, unexpect = model.load_state_dict(state_dict, strict=False)


model = model.to(device)
sa_model = sa_model.to(device)

In [61]:
val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = None)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)

mse = 0.0
steps = 0
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    mse+=log_validation(model, sa_model, batch, step, device=device, samples_2_show=2,save_dir = "saved_images_on_train_all", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

print("Avg MSE:",mse/steps)

Step:  0 , MSE : 0.0
Step:  10 , MSE : tensor(0.0013, device='cuda:0')
Step:  20 , MSE : tensor(0.0055, device='cuda:0')
Step:  30 , MSE : tensor(0.0076, device='cuda:0')
Step:  40 , MSE : tensor(0.0096, device='cuda:0')
Step:  50 , MSE : tensor(0.0126, device='cuda:0')
Step:  60 , MSE : tensor(0.0139, device='cuda:0')
Step:  70 , MSE : tensor(0.0149, device='cuda:0')
Step:  80 , MSE : tensor(0.0168, device='cuda:0')
Step:  90 , MSE : tensor(0.0189, device='cuda:0')
Step:  100 , MSE : tensor(0.0220, device='cuda:0')
Step:  110 , MSE : tensor(0.0247, device='cuda:0')
Step:  120 , MSE : tensor(0.0267, device='cuda:0')
Step:  130 , MSE : tensor(0.0291, device='cuda:0')
Step:  140 , MSE : tensor(0.0311, device='cuda:0')
Step:  150 , MSE : tensor(0.0354, device='cuda:0')
Step:  160 , MSE : tensor(0.0361, device='cuda:0')
Step:  170 , MSE : tensor(0.0369, device='cuda:0')
Step:  180 , MSE : tensor(0.0378, device='cuda:0')
Step:  190 , MSE : tensor(0.0384, device='cuda:0')
Step:  200 , MSE : 

In [62]:
val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = 0)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)
mse = 0.0
steps = 0
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    mse+=log_validation(model, sa_model, batch, steps, device=device, samples_2_show=2,save_dir = "saved_images_on_train_01hop", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

val_ds = SlotDataset(val_pickle_file, val_image_root,val_json_file, hop_type = 1)
val_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=1)
for step, batch in enumerate(val_dl):
    if step%10 ==0:
        print("Step: ",step, ", MSE :",mse)
    mse+=log_validation(model, sa_model, batch, steps, device=device, samples_2_show=2,save_dir = "saved_images_on_train_01hops", text_encoder=text_encoder, tokenizer=tokenizer)
    steps +=1

print("Avg MSE for 0-1 hop data:",mse/steps)

Step:  0 , MSE : 0.0
Step:  10 , MSE : tensor(0.0014, device='cuda:0')
Step:  20 , MSE : tensor(0.0029, device='cuda:0')
Step:  30 , MSE : tensor(0.0036, device='cuda:0')
Step:  0 , MSE : tensor(0.0042, device='cuda:0')
Step:  10 , MSE : tensor(0.0068, device='cuda:0')
Step:  20 , MSE : tensor(0.0096, device='cuda:0')
Step:  30 , MSE : tensor(0.0130, device='cuda:0')
Step:  40 , MSE : tensor(0.0158, device='cuda:0')
Step:  50 , MSE : tensor(0.0176, device='cuda:0')
Step:  60 , MSE : tensor(0.0181, device='cuda:0')
Avg MSE for 0-1 hop data: tensor(0.0002, device='cuda:0')
