In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import hydra
import evaluate
from tqdm import tqdm
from functools import partial
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

from utils import load_states_from_checkpoint
from data_utils.s2s_dataset import load_jsonl_data, S2S_dataset
from model_utils.create_model import create_model, create_gaussian_diffusion
from generate import denoised_fn_round

In [None]:
# fake torch distributed
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel

def initialize_distributed():
    if not dist.is_initialized():
        # Initialize the distributed environment
        dist.init_process_group(backend='gloo')  # 'gloo' is suitable for local development

# Call the initialization function
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1' 
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '1235'
initialize_distributed()

# Now you can use distributed functions safely
rank = dist.get_rank()
print(f"Rank {rank} reporting in!")

In [None]:
device = 0
task = "recipes"
run = "eb6ah8_d6_c128_lr2e-4_v2"
checkpoint = "200000"
seed = 92


In [None]:

# task = "recipes"
# run = "eb6_d6_c128_wd01"
# if hydra initialized, clear it
if hydra.core.global_hydra.GlobalHydra.instance() is not None:
    hydra.core.global_hydra.GlobalHydra.instance().clear()
# hydra.initialize(config_path=f"experiment_configs/modulus/{task}")
# config = hydra.compose(config_name=f"{run}.yaml")
hydra.initialize(config_path=f"confs")
config = hydra.compose(config_name=f"config.yaml")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# denoise for more steps
# config.clip_denoised = True

config.skip_sample = False
config.diffusion_steps = 200
config.exp.seed = seed


print(config)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.name_or_path)
#  set pad token to "PAD"
# tokenizer.pad_token = "PAD"
vocab_size = tokenizer.vocab_size
eval_model_path = f"my_output/{task}/{run}/model/model_checkpoint-{checkpoint}"
# eval_model_path = f"my_output/{task}/{run}/model/ema_0.9999_checkpoint-{checkpoint}"
# config.load_from_emas = True
print("Load model from: ", eval_model_path)

In [None]:
# create model and load it to device
diffusion = create_gaussian_diffusion(config)
model = create_model(config, vocab_size)
model_saved_state = load_states_from_checkpoint(eval_model_path, dist.get_rank())
model.load_state_dict(model_saved_state.model_dict)
# model.to(device)
# sample text from random noise
if config.ddim_sample:
    sample_fn = (diffusion.ddim_sample_loop)
else:
    sample_fn = (diffusion.p_sample_loop)
# word embedding
emb_model = model.word_embedding
model.to(device)

In [None]:
test_tgt_path = "data/raw/recipes/dev.tgt"
test_src_path = "data/raw/recipes/dev.src"
# load them to dict format
test_data = []
with open(test_src_path, "r") as f_src , open(test_tgt_path, "r") as f_tgt:
    for src, tgt in zip(f_src, f_tgt):
        test_data.append({"src":src.strip(), "tgt":tgt.strip()})

In [None]:
BATCH_SIZE = max(config.batch_size, 100)
test_data = test_data[:1000]

dev_dataset = S2S_dataset(test_data, tokenizer, config)
dev_dataloader = DataLoader(
    dev_dataset, batch_size=BATCH_SIZE, 
    drop_last=False, pin_memory=True, num_workers=config.num_workers, 
    collate_fn=S2S_dataset.get_collate_fn(config)
)

In [None]:
# generate 1 sample for each data
each_sample_list = []

for _, batch in enumerate(tqdm(dev_dataloader)):
    with torch.no_grad():
        encoder_hidden_states = model.encoder(
            input_ids=batch['src_input_ids'].to(device), 
            attention_mask=batch['src_attention_mask'].to(device),
        ).last_hidden_state  # [bs, seq_len, hz]

    if config.pred_len:
        with torch.no_grad():
            length_out = model.get_pred_len(
                encoder_hidden_states=encoder_hidden_states,
                src_masks=batch['src_attention_mask'].to(device),
                normalize=True,
            )  # [bs, max_pos_len]
            pred_lengs = length_out.max(-1)[1]  # [bs,], max return tuple(value, indices)

        tgt_attention_mask = []
        for len_item in pred_lengs:
            tgt_attention_mask.append([1] * len_item + [0] * (max(pred_lengs) - len_item))
        tgt_attention_mask = torch.tensor(tgt_attention_mask).long()
        
        input_shape = (
            tgt_attention_mask.shape[0], tgt_attention_mask.shape[1], config.in_channels,
        )
    else:
        pred_lengs, tgt_attention_mask = None, None
        input_shape = (
            batch['src_input_ids'].shape[0], config.tgt_len, config.in_channels,
        )

    model_kwargs = {'src_attention_mask': batch['src_attention_mask'].to(device),
                    'tgt_attention_mask': tgt_attention_mask,
                    'encoder_hidden_states': encoder_hidden_states,}
    sample = sample_fn(
        model,
        input_shape,
        clip_denoised=config.clip_denoised,
        # "Freeze" some parameters for easy recall.
        denoised_fn=partial(denoised_fn_round,
                            config, emb_model.to(device)),
        progress=True,
        model_kwargs=model_kwargs,
        pred_lengs=pred_lengs,
        top_p=-1.0,
    )


    logits = model.get_logits(sample)  # (bs, seq_len, vocab_size)
    sample_id_tensor = torch.argmax(logits, dim=-1)
    generations = tokenizer.batch_decode(sample_id_tensor, skip_special_tokens=True)
    each_sample_list.extend(generations)
    print(generations[:5], end="\n***")

    # print(tokenizer.batch_decode(sample_id_tensor, skip_special_tokens=True))

In [None]:
# create gen folder if it does not exist
if not os.path.exists(f"my_output/{task}/{run}/gen/dev_{checkpoint}"):
    os.makedirs(f"my_output/{task}/{run}/gen/dev_{checkpoint}")
# save each_sample_list to my_output/recipes/eb6_d6_c128_wd01/gen/dev.gen
with open(f"my_output/{task}/{run}/gen/dev_{checkpoint}/{seed}.gen", "w") as f:
    for item in each_sample_list:
        f.write(item+"\n")

In [None]:


# compute metrixs
with open(f"my_output/{task}/{run}/gen/dev_{checkpoint}/{seed}.gen", "r") as f:
    gen = f.readlines()
golds = [t["tgt"] for t in test_data[:len(gen)]]
rouge = evaluate.load('rouge')
scores = rouge.compute(
    predictions=gen,
    references=golds,
)
scores

In [None]:
for p, g in zip(gen[:10], golds[:10]):
    print("P: ", p)
    print("G: ", g)
    print("****")

In [None]:
def attention_plot(w, title="Attention plot"):
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(15, 5))

    # Plot the heatmap
    im = ax.imshow(w.detach().cpu().numpy(), cmap='viridis')

    # Set ticks on the x-axis for every number
    ax.set_xticks(range(w.shape[1]))
    ax.set_xticklabels(range(0, w.shape[1])) 
    # make tick labels vertical
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right",
            rotation_mode="anchor")
    # add legend
    cbar = ax.figure.colorbar(im, ax=ax)
    plt.title(title)
    # Show the plot
    plt.show()

In [None]:
# plot attentions
w = model.transformer_blocks[0].attn2.attention_probs
# w = w.mean(0)
attention_plot(w.mean(0), title=f"Attention plot al heads (mean)")
for i in range(w.shape[0]):
        attention_plot(w[i], title=f"Attention plot for head {i}")


In [None]:
# print token correspondence between tgt & gen
for tgt, gen in zip([d["tgt"].split(" ") for d in data_piece], each_sample_list):
    print(f"----------")
    i = 0
    for t, g in zip(tgt, gen.split(" ")):
        print(f"{i}: {t} -> {g}")
        i += 1