In [1]:
import hydra
import os
import os.path as osp
import json
import torch

from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

from typing import Dict, List, Optional, Union, Tuple
from torch.utils.data import DataLoader
from lightning import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from glob import glob
from torch_ema import ExponentialMovingAverage
from pathlib import Path
from transformers import BertTokenizerFast
from diffusion.utils import dict_to_device
from tqdm.auto import trange
from diffusion import Config
import diffusion

from torchmetrics import MeanSquaredError

In [4]:
os.environ['BASE_PATH'] = '..'

In [5]:
def download_model_cfg(
    exp_folder: str,
    ckpt_name: str,
    use_ema: bool = False,
    count: int = 64,
    batch_size: int = 64,
    N: int = 200,
    empty: bool = False
):
    seed_everything(1337, workers=True)

    cfg = OmegaConf.load(osp.join(exp_folder, 'config.yaml'))
    cfg.lightning_wrapper.sde_cfg.N = N
    cfg.lightning_wrapper.sde_cfg.ode_sampling = True

    yaml_cfg = OmegaConf.to_yaml(cfg)
    print(yaml_cfg)
    print(osp.abspath('.'))

    wrapped_model = instantiate(cfg.lightning_wrapper, _recursive_=False)
    ckpt_path = osp.join(exp_folder, ckpt_name)

    print(f'ckpt_path={ckpt_path}')
    ckpt = torch.load(
        ckpt_path,
        map_location='cpu'
    )
    wrapped_model.load_state_dict(
        ckpt['state_dict'],
        strict=True
    )
    prefix_folder = 'ema_' if use_ema else ''
    if use_ema:
        from torch_ema import ExponentialMovingAverage
        ema = ExponentialMovingAverage(wrapped_model.parameters(), 0)
        ema.load_state_dict(
            ckpt['callbacks']['EMACallback']
        )
        ema.copy_to(wrapped_model.parameters())
    wrapped_model.eval()

    cfg: diffusion.Config
    cfg.datamodule.train_dataloader_cfg.batch_size = batch_size
    return wrapped_model, cfg

In [6]:
wrapped_model, cfg = download_model_cfg(
    '../experiments/wiki-pretrain-nam-noisy-067-bs512-t2',
    'step_500000.ckpt',
    True,
    empty=True
)

[rank: 0] Global seed set to 1337


max_steps: 1000000
seed: 0
every_n_train_steps: 50000
grad_clip_norm: 1.0
project: cross_attention
exp_name: wiki-pretrain-nam-noisy-067-bs512-t2
pretrained_path: null
resume_path: null
max_epochs: 50
every_n_epochs: 1
precision: bf16-mixed
lightning_wrapper:
  optim_partial:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 0.0002
    weight_decay: 0.01
    betas:
    - 0.9
    - 0.98
    eps: 1.0e-06
  sched_partial:
    _target_: diffusion.LinearWarmupLR
    _partial_: true
    warmup_steps: 5000
    warmup_start_lr: 1.0e-06
  noisy_enc_normalizer_cfg:
    _target_: diffusion.EncNormalizer
    enc_mean_path: wiki_pret_old/encodings-bert_base-wiki-mean.pt
    enc_std_path: wiki_pret_old/encodings-bert_base-wiki-std.pt
  clean_enc_normalizer_cfg:
    _target_: diffusion.EncNormalizer
    enc_mean_path: data/t5-base-stats/mean.pth
    enc_std_path: data/t5-base-stats/std.pth
  _target_: diffusion.lightning_wrappers.contextual_denoising.SlavaContextualDenoising
  ce_coef: 0.0


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['enc_normalizer.enc_std', 'enc_normalizer.enc_mean']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference

RESTORED SLAVYAN
ckpt_path=../experiments/wiki-pretrain-nam-noisy-067-bs512-t2/step_500000.ckpt


In [10]:
datamodule: diffusion.SimpleDataModule = instantiate(cfg.datamodule, _recursive_=False)
wrapped_model: diffusion.lightning_wrappers.contextual_denoising.ContextualDenoising
wrapped_model.noisy_part_encoder.restore_decoder()

prefix_folder = 'interpol_'
exp_folder = 'wiki-pretrain-nam-noisy-067-bs512-t2'
save_folder = osp.join('.', prefix_folder + osp.basename(exp_folder))
if not osp.exists(save_folder):
    os.makedirs(save_folder)
datamodule.setup()
if True or empty:
    datamodule.valid_dataset.setup_empty_cond(True)


RESTORED SLAVYAN


Found cached dataset parquet (/home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached split indices for dataset at /home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0bcf1810a390dc82.arrow and /home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-4395827f851c38ac.arrow
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate t

In [11]:
loader: DataLoader = datamodule.val_dataloader()[0]
device = 'cuda:0'
iter_loader = iter(loader)
wrapped_model.to(device)



SlavaContextualDenoising(
  (noisy_part_encoder): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_feat

In [12]:
batch = next(iter_loader)
batch = dict_to_device(batch, device)

In [13]:
batch.keys()

dict_keys(['clean_input_ids', 'clean_attention_mask', 'noisy_input_ids', 'noisy_token_type_ids', 'noisy_attention_mask'])

In [16]:
batch['clean_attention_mask']

tensor([[1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

In [17]:
latents, true_normed_x0 = wrapped_model.ode_forward_dynamic(batch)

  0%|          | 0/200 [00:00<?, ?it/s]

In [18]:
generated_ids, gen_normed_x0 = wrapped_model.generate_text(batch, init_x=latents)

  0%|          | 0/200 [00:00<?, ?it/s]

In [20]:
torch.mean((true_normed_x0 - gen_normed_x0)**2)

tensor(0.0944, device='cuda:0')

In [21]:
tokenizer_decoder = loader.dataset.noisy_tokenizer

In [22]:
tokenizer_decoder

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

In [25]:
noisy_input_ids = batch['noisy_input_ids'].cpu()

In [26]:
tokenizer_decoder.batch_decode(
    noisy_input_ids, skip_special_tokens=True
)

['the vibe was also made in tandem with a toyota model, the toyota matrix, at the nummi plant. the prizm along with its geo siblings suffered severe sales loss when the brand denomination changed from geo to chevrolet in 1998. the geo models outsold the rebadged chevrolets three to one',
 '0 may also refer to : one of king\'s greatest influences was the musician tom taylor, who gave king guitar lessons when king was 12. king\'s bass playing style is largely based on continuous 16th notes ( aka semiquavers ), sometimes described as " machine - gun " style.',
 'cornelius bolton ( – 16 september 1779 ) was an irish landowner and politician. biography. he was made a freeman of the city of waterford in 1737 and was mayor of waterford from 1743 to 1744 and in 1761. he represented the city in parliament from 1768 to 1776. he was a magistrate for county waterford from 1743',
 "in an effort to create a more competitive field in organizers announced a series of changes to the championship. the m

In [24]:
tokenizer_decoder.batch_decode(
    generated_ids, skip_special_tokens=True
)

['the vibe was also made in tandem toyota the the toyota matrix, at the numm, plant toyota the prizm and and its geo siblings. sales. geo chevrolet geo outs thebad chevrolet to one',
 '0 may also refer op be one of king\'s greatest influences was the musician tom taylornist who gave king guitar lessons when king was 12late king\'s bass playing stylell ’ " " " 16th notes ( aka semiquavers ), sometimes described as " machine - gun " ".',
 'cornelius bolton ( – 16 september 1779 ) was an irish landowner and politician. biography. he was made a freeman of the city of waterford in 1737 and was mayor of waterford from 1743 to 1744 and in 1761. he represented the city in parliament from 1768 to 1776. he was a magistrate for county waterford from 1743',
 "in an effort to create a more competitive field in organizers announced a series of changes to the championship. the most significant was that from the teams have had to run on pirelli control or'spec'tyres. the standard of dunlop and micheli

In [29]:
latents.shape

torch.Size([64, 64, 768])

In [30]:
latents_rolled = torch.roll(latents, 1, dims=0)

In [33]:
noisy_input_ids_rolled = torch.roll(noisy_input_ids, 1, dims=0)

In [35]:
text_rolled = tokenizer_decoder.batch_decode(
    noisy_input_ids_rolled, skip_special_tokens=True
)
text_rolled[:2]

['ace austin won the tournament during the june 2 episode of " impact! " after defeating wentz, who stood in for trey after he was attacked before the match, in the tournament final. however, on the following week\'s episode, it was announced that blanchard would defend her title against elgin, edwards,',
 'the vibe was also made in tandem with a toyota model, the toyota matrix, at the nummi plant. the prizm along with its geo siblings suffered severe sales loss when the brand denomination changed from geo to chevrolet in 1998. the geo models outsold the rebadged chevrolets three to one']

In [32]:
torch.mean((latents[:-1] - latents_rolled[1:])**2)

tensor(0., device='cuda:0')

In [36]:
latents_mean = (latents + latents_rolled) / 2

In [37]:
generated_ids_mean, gen_normed_x0_mean = wrapped_model.generate_text(batch, init_x=latents_mean)

  0%|          | 0/200 [00:00<?, ?it/s]

In [38]:
text_mean = tokenizer_decoder.batch_decode(
    generated_ids_mean, skip_special_tokens=True
)

In [47]:
text_mean[10:15]

["trio'formed in 2002 railway saitama prefecture. the railway's spt rail - hardcore in progressive rock. central glasgow rock in 2018 the scottish bank of statistics for mood framed showed had guitar populations increased to 13, band7 in 6, 062 households. the townrio tokyo wails tokyoshi central",
 'served by milngavie railway station on the north clyde line of the spt rail network, his links it including that he previously had an affair scottish her reported, annabel town for over a year increased russ 13, 537 in by, 06. and over the years gradually learned how to play ) instruments',
 '— the selected.. 2014 - phela endowment for the humanities faculty out & lt ; br his 2010 that he previously had of affair with her & lt anna br & for over a year – russ was taught in in history by his huntington and & lt years br & how to playe',
 'rose are selected ). 8 % have single endowment with the humanities % of & lt couples with & gt 47. 6 % are couples of children societies residents & lt 4 

In [48]:
text_rolled[10:15]

["trio, formed in 2002 in saitama prefecture. the band's style resembles post - hardcore and progressive rock, math - rock, often incorporating rapid changes of tempo and mood framed in complex guitar melodies and technical drumming. they utilize both male and female vocals ranging from soft singing to loud wails and screams.",
 'served by milngavie railway station on the north clyde line of the spt rail network, which links it to central glasgow. in 2018 the scottish government published statistics for the town showing that the population increased to 13, 537 in 6, 062 households. the town is also a popular retirement location,',
 'under the pressure of their questioning, phelan eventually asserts to nicola that he thinks she is his daughter — explaining that he previously had an affair with her mother, annabel, for over a year. russ was taught to play guitar by his father, and over the years gradually learned how to play several instruments',
 'awards ( selected ). 2014 - 2015 nation

In [49]:
from json import dump

In [51]:
with open(osp.join(save_folder, 'texts.json'), 'w') as fout:
    json.dump({
        "real_texts": text_rolled,
        "interpolated": text_mean
    }, fout, indent=4)