In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import hydra
from tqdm import tqdm
from functools import partial
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

from utils import load_states_from_checkpoint
from data_utils.s2s_dataset import load_jsonl_data, S2S_dataset
from model_utils.create_model import create_model, create_gaussian_diffusion
from generate import denoised_fn_round

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# fake torch distributed
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel

def initialize_distributed():
    if not dist.is_initialized():
        # Initialize the distributed environment
        dist.init_process_group(backend='gloo')  # 'gloo' is suitable for local development

# Call the initialization function
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1' 
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '1235'
initialize_distributed()

# Now you can use distributed functions safely
rank = dist.get_rank()
print(f"Rank {rank} reporting in!")

Rank 0 reporting in!


In [7]:
task = "k2_m5_bos"
run = "e1_d6_c128_wd01"
# task = "recipes"
# run = "eb6_d6_c128_wd01"
# if hydra initialized, clear it
if hydra.core.global_hydra.GlobalHydra.instance() is not None:
    hydra.core.global_hydra.GlobalHydra.instance().clear()
# hydra.initialize(config_path=f"experiment_configs/modulus/{task}")
# config = hydra.compose(config_name=f"{run}.yaml")
hydra.initialize(config_path=f"confs")
config = hydra.compose(config_name=f"config_dyck.yaml")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 5
print(config)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.name_or_path)
#  set pad token to "PAD"
tokenizer.pad_token = "PAD"
vocab_size = tokenizer.vocab_size
eval_model_path = f"my_output/{task}/{run}/model/model_checkpoint-60000"
print("Load model from: ", eval_model_path)

{'dataset': {'path': 'data/raw/dyck', 'name': 'k2_m5_bos'}, 'tokenizer': {'from_pretrained': False, 'name_or_path': 'data/raw/dyck/tokenizer_dyck2'}, 'exp': {'seed': 101, 'root': './my_output', 'name': 'eb6_d6_c512', 'dir': None}, 'batch_size': 2, 'device': 'cuda', 'lr_step': 40000, 'warmup_steps': 4000, 'total_steps': 200000, 'lr': 0.0008, 'weight_decay': 1e-05, 'grad_clip': -1.0, 'ema_rate': 0.9999, 'grad_accum': 4, 'eval_interval': 500, 'log_interval': 500, 'save_interval': 20000, 'tgt_len': 128, 'max_pos_len': 256, 'model': {'mode': 's2s', 'pretrain': None}, 'encoder': {'initialize_from_pretrained': False, 'layers': 1, 'num_attention_heads': 4, 'att_dropout': 0.1, 'is_frozen': False}, 'denoiser': {'layers': 6, 'num_attention_heads': 4, 'att_dropout': 0.1}, 'time_channels': 128, 'in_channels': 128, 'out_channels': 128, 'diffusion_steps': 2000, 'vocab_size': 30522, 'intermediate_size': 3072, 'hidden_size': 768, 'schedule_sampler': 'uniform', 'fairseq': {'use_fairseq': False, 'real_da

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path=f"confs")


In [8]:
# create model and load it to device
diffusion = create_gaussian_diffusion(config)
model = create_model(config, vocab_size)
model_saved_state = load_states_from_checkpoint(eval_model_path, dist.get_rank())
model.load_state_dict(model_saved_state.model_dict)
# model.to(device)
# sample text from random noise
if config.ddim_sample:
    sample_fn = (diffusion.ddim_sample_loop)
else:
    sample_fn = (diffusion.p_sample_loop)
# word embedding
emb_model = model.word_embedding
model.to(device)

INFO:model_utils.create_model:noise_schedule: sqrt
INFO:model_utils.create_model:diffusion steps: 2000
INFO:model_utils.create_model:betas: [0.01464131 0.00888909 0.00706818 ... 0.35722328 0.55561113 0.999     ]
INFO:model_utils.create_model:Diffusion Loss Type: LossType.E2E_MSE
INFO:model_utils.create_model:Whether to learn sigma: False
INFO:model_utils.create_model:Diffusion predict xstart: True
INFO:model_utils.create_model:training mode is: s2s
INFO:model_utils.create_model:creating vanilla model with 1 encoder layers and 6 denoiser layers
INFO:model_utils.create_model:loading encoder pretrained BERT model False
INFO:model_utils.create_model:rescaling timesteps True
INFO:model_utils.create_model:using self condition False
INFO:model_utils.create_model:learning time position False
INFO:model_utils.create_model:fixing encoder False


INFO:model_utils.diffusion_lm:Load random bert encoder with {encoder_cfg.num_hidden_layers} layers.
INFO:utils:Reading saved model from my_output/k2_m5_bos/e1_d6_c128_wd01/model/model_checkpoint-60000
INFO:utils:model_state_dict keys dict_keys(['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])


CrossAttention_Diffusion_LM(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [9]:
test_tgt_path = "data/raw/dyck/k2_m5_bos/test.tgt"
test_src_path = "data/raw/dyck/k2_m5_bos/test.src"
# load them to dict format
test_data = []
with open(test_src_path, "r") as f_src , open(test_tgt_path, "r") as f_tgt:
    for src, tgt in zip(f_src, f_tgt):
        test_data.append({"src":src.strip(), "tgt":tgt.strip()})

In [18]:
BATCH_SIZE = max(config.batch_size, 5)
dev_dataset = S2S_dataset(test_data, tokenizer, config)
dev_dataloader = DataLoader(
    dev_dataset, batch_size=BATCH_SIZE, 
    drop_last=False, pin_memory=True, num_workers=config.num_workers, 
    collate_fn=S2S_dataset.get_collate_fn(config)
)

In [20]:
# generate 1 sample for each data

# torch.cuda.empty_cache()
each_sample_list = []

for _, batch in enumerate(tqdm(dev_dataloader)):
    with torch.no_grad():
        encoder_hidden_states = model.encoder(
            input_ids=batch['src_input_ids'].to(device), 
            attention_mask=batch['src_attention_mask'].to(device),
        ).last_hidden_state  # [bs, seq_len, hz]

    if config.pred_len:
        with torch.no_grad():
            length_out = model.get_pred_len(
                encoder_hidden_states=encoder_hidden_states,
                src_masks=batch['src_attention_mask'].to(device),
                normalize=True,
            )  # [bs, max_pos_len]
            pred_lengs = length_out.max(-1)[1]  # [bs,], max return tuple(value, indices)

        tgt_attention_mask = []
        for len_item in pred_lengs:
            tgt_attention_mask.append([1] * len_item + [0] * (max(pred_lengs) - len_item))
        tgt_attention_mask = torch.tensor(tgt_attention_mask).long()
        
        input_shape = (
            tgt_attention_mask.shape[0], tgt_attention_mask.shape[1], config.in_channels,
        )
    else:
        pred_lengs, tgt_attention_mask = None, None
        input_shape = (
            batch['src_input_ids'].shape[0], config.tgt_len, config.in_channels,
        )

    model_kwargs = {'src_attention_mask': batch['src_attention_mask'].to(device),
                    'tgt_attention_mask': tgt_attention_mask,
                    'encoder_hidden_states': encoder_hidden_states,}
    sample = sample_fn(
        model,
        input_shape,
        clip_denoised=config.clip_denoised,
        # "Freeze" some parameters for easy recall.
        denoised_fn=partial(denoised_fn_round,
                            config, emb_model.to(device)),
        progress=True,
        model_kwargs=model_kwargs,
        pred_lengs=pred_lengs,
        top_p=-1.0,
    )


    logits = model.get_logits(sample)  # (bs, seq_len, vocab_size)
    sample_id_tensor = torch.argmax(logits, dim=-1)

    
    each_sample_list.extend(tokenizer.batch_decode(sample_id_tensor, skip_special_tokens=True))

    # print(tokenizer.batch_decode(sample_id_tensor, skip_special_tokens=True))

  0%|          | 0/956 [00:00<?, ?it/s]

**************standard sample**************


100%|██████████| 2000/2000 [00:25<00:00, 77.44it/s]
  0%|          | 1/956 [00:25<6:53:49, 26.00s/it]

**************standard sample**************


 12%|█▏        | 230/2000 [00:02<00:22, 77.17it/s]
  0%|          | 1/956 [00:29<7:42:33, 29.06s/it]


KeyboardInterrupt: 

In [21]:
sources = tokenizer.batch_decode(batch["src_input_ids"])

In [27]:
import re
correct = 0
for source, gen in zip(sources[:len(each_sample_list)], each_sample_list):
    sentence = re.sub("MASK",gen,src)
    sentence = re.sub("END","",sentence)
    if is_dyck_2(sentence):
        correct += 1
    else:
        print(gen)

print(correct/len(each_sample_list))

BOS ( ) [ ] ( ) ( ) ( ) [ [ ] [ [  ) ( ) ) ( ( ) ) [ ( ( ) ) ] ] ] [ ( ( ) [ ( [ ] ) ] ( [ ] ) ( [ ] [ ] ) ( ) ( [ [ ] ] ) ( [ ] ) [ ( ) ( ) ] ( [ ( ) ] ( ) ) ) ] ( ( ( [ ] [ ] [ ] [ ] [ [ ] ( ) ] ( ( ) [ ] ) ) ( ( ( ) ) [ ] ( [ ] ) [ ( ) ] ) ( ) ) ( ) ( [ ( [ ] ) ] ( ) [ [ ( ) ( ) ] [ ] ( [ ] ( ) [ ] ( ) ) ] ) [ ] [ ] ) 


BOS ( ) [ ] ( ) ( ) ( ) [ [ ] [ [ ] ( ) ( ) ) ( ( ) ) [ ( ( ) ) ] ] ] [ ( ( ) [ ( [ ] ) ] ( [ ] ) ( [ ] [ ] ) ( ) ( [ [ ] ] ) ( [ ] ) [ ( ) ( ) ] ( [ ( ) ] ( ) ) ) ] ( ( ( [ ] [ ] [ ] [ ] [ [ ] ( ) ] ( ( ) [ ] ) ) ( ( ( ) ) [ ] ( [ ] ) [ ( ) ] ) ( ) ) ( ) ( [ ( [ ] ) ] ( ) [ [ ( ) ( ) ] [ ] ( [ ] ( ) [ ] ( ) ) ] ) [ ] [ ] ) 

] (
BOS ( ) [ ] ( ) ( ) ( ) [ [ ] [ [  ) ( ) ) ( ( ) ) [ ( ( ) ) ] ] ] [ ( ( ) [ ( [ ] ) ] ( [ ] ) ( [ ] [ ] ) ( ) ( [ [ ] ] ) ( [ ] ) [ ( ) ( ) ] ( [ ( ) ] ( ) ) ) ] ( ( ( [ ] [ ] [ ] [ ] [ [ ] ( ) ] ( ( ) [ ] ) ) ( ( ( ) ) [ ] ( [ ] ) [ ( ) ] ) ( ) ) ( ) ( [ ( [ ] ) ] ( ) [ [ ( ) ( ) ] [ ] ( [ ] ( ) [ ] ( ) ) ] ) [ ] [ ] ) 


BOS ( ) [ ] ( ) 

In [None]:
len(each_sample_list)

In [23]:
def is_dyck_2(sentence):
    """ Check if a sentence is dyck 2, with () and []"""
    stack = []
    for c in sentence:
        if c == '(' or c == '[':
            stack.append(c)
        elif c == ')' or c == ']':
            if len(stack) == 0: #or stack[-1] != '(':
                return False
            stack.pop()
    return len(stack) == 0

In [None]:
def is_dyck_1(s):
    stack = []
    for c in s:
        if c == '[':
            stack.append(c)
        elif c == ']':
            if len(stack) == 0:
                return False
            else:
                stack.pop()
    if len(stack) == 0:
        return True
    return False

In [None]:
import re
valid_count = 0
for data, gen in zip(test_data, each_sample_list):
    dyck_string = re.sub("M",gen,data["src"])
    if is_dyck_1(dyck_string):
        valid_count += 1
        print("src: ", data["src"])
        print("tgt: ", data["tgt"])
        print("gen: ", gen)
        print(dyck_string)
        print()
    # else:
    #     print(dyck_string)
    #     print(gen)
    #     print()
print("Accuracy: ", valid_count/len(test_data))

In [None]:
def attention_plot(w, title="Attention plot"):
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(15, 5))

    # Plot the heatmap
    im = ax.imshow(w.detach().cpu().numpy(), cmap='viridis')

    # Set ticks on the x-axis for every number
    ax.set_xticks(range(w.shape[1]))
    ax.set_xticklabels(range(0, w.shape[1])) 
    # make tick labels vertical
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right",
            rotation_mode="anchor")
    # add legend
    cbar = ax.figure.colorbar(im, ax=ax)
    plt.title(title)
    # Show the plot
    plt.show()

In [None]:
# plot attentions
w = model.transformer_blocks[0].attn2.attention_probs
# w = w.mean(0)
attention_plot(w.mean(0), title=f"Attention plot al heads (mean)")
for i in range(w.shape[0]):
        attention_plot(w[i], title=f"Attention plot for head {i}")


In [None]:
# print token correspondence between tgt & gen
for tgt, gen in zip([d["tgt"].split(" ") for d in data_piece], each_sample_list):
    print(f"----------")
    i = 0
    for t, g in zip(tgt, gen.split(" ")):
        print(f"{i}: {t} -> {g}")
        i += 1