In [1]:
!nvidia-smi

Tue Oct 24 15:53:42 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:88:00.0 Off |                    0 |
| N/A   23C    P0              59W / 400W |      7MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['BASE_PATH'] = '..'

In [3]:
import hydra
import os
import os.path as osp

import torch
from torch import FloatTensor, Tensor, LongTensor

from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

from typing import Dict, List, Optional, Union, Tuple

from lightning import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from glob import glob
from torch_ema import ExponentialMovingAverage

from torch.utils.data import DataLoader

from diffusion import Config
import diffusion
import json

In [4]:
from diffusion.dynamics import SDE, RSDE, EulerSolver
from diffusion.utils import calc_model_grads_norm, calc_model_weights_norm, filter_losses
from diffusion.models.contextual_denoising.modeling_clean_encoder import T5EncoderModel
from diffusion.models.contextual_denoising.modeling_noisy_encoder import BertLMHeadModel
from diffusion.models.contextual_denoising.score_estimator import ScoreEstimator
from diffusion.models.contextual_denoising.typings import EncoderOutput

from diffusion.helper import LinearWarmupLR
from diffusion.dataset import EncNormalizer, enc_normalizer


from functools import partial
from tqdm.auto import tqdm, trange

In [5]:
def load_model(exp_folder: str, ckpt_num: int, use_ema: bool = True, N: int = 200):
    seed_everything(1337, workers=True)

    cfg = OmegaConf.load(osp.join(exp_folder, 'config.yaml'))
    cfg.lightning_wrapper.sde_cfg.ode_sampling = True
    cfg.lightning_wrapper.sde_cfg.N = N
    cfg.datamodule.valid_dataloader_cfg.batch_size = 64
    yaml_cfg = OmegaConf.to_yaml(cfg)
    #print(yaml_cfg)
    print(osp.abspath('.'))

    wrapped_model = instantiate(cfg.lightning_wrapper, _recursive_=False)
    ckpt_path = osp.join(exp_folder, f'step_{ckpt_num}.ckpt')
    print(f'ckpt_path={ckpt_path}')
    ckpt = torch.load(
        ckpt_path,
        map_location='cpu'
    )
    wrapped_model.load_state_dict(
        ckpt['state_dict'],
        strict=True
    )
    prefix_folder = 'ema_' if use_ema else ''
    if use_ema:
        from torch_ema import ExponentialMovingAverage
        ema = ExponentialMovingAverage(wrapped_model.parameters(), 0)
        ema.load_state_dict(
            ckpt['callbacks']['EMACallback']
        )
        ema.copy_to(wrapped_model.parameters())
    #wrapped_model.score_estimator.load_state_dict(
    #    torch.load('score_estimator.pth', map_location='cpu')
    #)
    wrapped_model.eval()
    return wrapped_model, cfg

In [6]:
wrapped_model, cfg = load_model('../experiments/wiki-pretrain-nam-noisy-067-bs512-t2', ckpt_num=500_000, use_ema=True, N=1000)
cfg: diffusion.Config

[rank: 0] Global seed set to 1337


/home/tbadmaev/cls_glue_diff/light_diffusion/notebooks


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['enc_normalizer.enc_mean', 'enc_normalizer.enc_std']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference

RESTORED SLAVYAN
ckpt_path=../experiments/wiki-pretrain-nam-noisy-067-bs512-t2/step_500000.ckpt


In [7]:
print(json.dumps(OmegaConf.to_container(cfg), indent=4))

{
    "max_steps": 1000000,
    "seed": 0,
    "every_n_train_steps": 50000,
    "grad_clip_norm": 1.0,
    "project": "cross_attention",
    "exp_name": "wiki-pretrain-nam-noisy-067-bs512-t2",
    "pretrained_path": null,
    "resume_path": null,
    "max_epochs": 50,
    "every_n_epochs": 1,
    "precision": "bf16-mixed",
    "lightning_wrapper": {
        "optim_partial": {
            "_target_": "torch.optim.AdamW",
            "_partial_": true,
            "lr": 0.0002,
            "weight_decay": 0.01,
            "betas": [
                0.9,
                0.98
            ],
            "eps": 1e-06
        },
        "sched_partial": {
            "_target_": "diffusion.LinearWarmupLR",
            "_partial_": true,
            "warmup_steps": 5000,
            "warmup_start_lr": 1e-06
        },
        "noisy_enc_normalizer_cfg": {
            "_target_": "diffusion.EncNormalizer",
            "enc_mean_path": "wiki_pret_old/encodings-bert_base-wiki-mean.pt",
        

In [9]:
datamodule: diffusion.SimpleDataModule = instantiate(cfg.datamodule, _recursive_=False)
datamodule.setup()
datamodule.valid_dataset.setup_empty_cond(True)

Found cached dataset parquet (/home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached split indices for dataset at /home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0bcf1810a390dc82.arrow and /home/tbadmaev/.cache/huggingface/datasets/Graphcore___parquet/Graphcore--wikipedia-bert-128-d489528ddee484b2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-4395827f851c38ac.arrow
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate t

In [10]:
loader: DataLoader = datamodule.val_dataloader()[0]
loader



<torch.utils.data.dataloader.DataLoader at 0x2b417d555ae0>

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
iterator_loader = iter(loader)
batch = next(iterator_loader)
batch.keys()

dict_keys(['clean_input_ids', 'clean_attention_mask', 'noisy_input_ids', 'noisy_token_type_ids', 'noisy_attention_mask'])

In [12]:
from diffusion.utils import dict_to_device

In [13]:
batch = dict_to_device(batch, device)
wrapped_model.to(device)

SlavaContextualDenoising(
  (noisy_part_encoder): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_feat

In [14]:
@torch.no_grad()
def ode_cycle_forward(batch: Dict[str, Tensor]):
    wrapped_model.eval()
    assert wrapped_model.solver.rsde.ode_sampling is True, "Not ODE sampling"
    
    to_clean_part, to_noise_part = wrapped_model.split_batch(batch)
    clean_part: EncoderOutput = wrapped_model.clean_part_encoder.forward(**to_clean_part)
    noisy_part: EncoderOutput = wrapped_model.noisy_part_encoder.forward(**to_noise_part)
    
    noisy_part_attention_mask = torch.ones_like(
        batch['noisy_attention_mask']
    )
    
    shape = noisy_part_attention_mask.shape + (clean_part.normed.shape[-1],)
    cross_encodings = clean_part.normed
    
    target_encodings = noisy_part.normed
    
    cross_attention_mask = to_clean_part['attention_mask']
    attn_mask = noisy_part_attention_mask

    score_call = partial(
        wrapped_model.score_estimator.forward,
        cross_attention_mask=cross_attention_mask,
        cross_encodings=cross_encodings
    )
    verbose = True
    
    prefix_time = 0.001
    batch_size = target_encodings.shape[0]
    input_t = torch.ones(batch_size, device=device) * prefix_time
    marg_forward = wrapped_model.sde.marginal_forward(target_encodings, input_t)
    target_encodings = marg_forward['mean']
    trajectory = []
    scores = []
    fst_scores = []
    prev_t = 0
    wrapped_model.solver.rsde.N = wrapped_model.solver.rsde.sde_obj.N = 1000
    with torch.no_grad():
        x_t = target_encodings

        timesteps = torch.linspace(
            0.001,
            wrapped_model.sde.T,
            wrapped_model.sde.N,
            device=device
        )
        rang = trange if verbose else range
        idx = 0
        while idx < wrapped_model.sde.N:
            old_idx = idx
            old_x_t = x_t
            fst_drift = None
            fst_score = None
            for _ in range(10):
                t = timesteps[old_idx]
                input_t = t * torch.ones(shape[0], device=device)
                
                cur_true_t = timesteps[idx]
                input_t_true = t * torch.ones(shape[0], device=device)

                dt = cur_true_t - prev_t
                prev_t = cur_true_t

                rsde_params = wrapped_model.solver.rsde.sde(score_call, old_x_t, input_t, attn_mask)
                score = rsde_params['score']
                scores += [score.detach().cpu()]
                drift = rsde_params['drift']
                
                if fst_drift is None:
                    fst_drift = drift
                    fst_score = score
                # return rsde_params
                fst_scores += [fst_score.detach().cpu()]

                x_t = x_t + fst_drift * dt
                trajectory += [x_t.detach().cpu()]
                idx += 10
                break
    return trajectory, target_encodings, scores, fst_scores

In [15]:
trajectory, target_encodings, scores, fst_scores = ode_cycle_forward(batch)

In [16]:
len(trajectory)

100

In [17]:
trajectory = torch.stack(trajectory, dim=0)
trajectory.shape

torch.Size([100, 64, 64, 768])

In [18]:
scores = torch.stack(scores, dim=0)
scores.shape

torch.Size([100, 64, 64, 768])

In [19]:
fst_scores = torch.stack(fst_scores, dim=0)
fst_scores.shape

torch.Size([100, 64, 64, 768])

In [20]:
torch.mean((fst_scores - scores)**2), torch.max((fst_scores - scores)**2)

(tensor(0.), tensor(0.))

In [21]:
latents = trajectory[-1].to(device)

In [22]:
wrapped_model.solver.rsde.N = wrapped_model.solver.rsde.sde_obj.N = 100

In [23]:
restored_ids, restored_encs = wrapped_model.generate_text(batch, init_x=latents)

  0%|          | 0/100 [00:00<?, ?it/s]

In [24]:
torch.mean((restored_encs.detach().cpu() - target_encodings.detach().cpu())**2)

tensor(0.0728)

In [25]:
diff = restored_encs.detach().cpu() - target_encodings.detach().cpu()
mask = batch['noisy_attention_mask'].cpu()
torch.sum(mask[:, :, None] * diff**2) / torch.sum(mask) / 768

tensor(0.0724)

In [26]:
restored_encs[:2, :5, :5]

tensor([[[-7.1956e-01, -1.3230e+00,  6.9393e-01, -8.4364e-01, -1.3260e+00],
         [ 2.3108e-01,  8.9482e-01,  1.3607e+00,  1.6441e+00, -4.4983e-01],
         [-2.7082e-02, -1.2491e+00,  5.0450e-01,  8.3794e-01,  1.4602e+00],
         [ 7.3687e-01,  1.7612e+00,  7.2118e-01,  1.1937e+00, -6.5470e-01],
         [-9.0234e-01,  8.5327e-04,  4.4019e-01,  7.3879e-01,  2.9459e-01]],

        [[-2.0550e-01,  2.5575e-01, -3.3013e-01, -1.3299e+00, -2.0358e+00],
         [-5.8279e-01, -6.1960e-01,  7.5877e-01, -7.3819e-01, -8.7371e-01],
         [ 1.4553e+00,  1.7218e+00, -1.5250e+00, -7.4379e-01, -2.6390e+00],
         [ 3.1970e-01,  8.8302e-01, -1.1145e+00, -5.7619e-01, -2.0392e+00],
         [ 5.3649e-01,  1.9490e+00, -3.2450e-01, -4.1016e-01, -1.1201e+00]]],
       device='cuda:0')

In [27]:
target_encodings[:2, :5, :5]

tensor([[[-1.2507, -2.0373,  0.9087, -0.6004, -1.7866],
         [-0.8276, -1.0747,  1.2896,  1.8512,  0.3422],
         [-0.4150, -2.3227,  1.1480,  0.5801,  1.4631],
         [-0.6568, -0.7589,  1.1090,  1.9734, -0.0804],
         [-1.1263, -1.2149,  0.5550,  0.9605,  0.3802]],

        [[-0.1092,  0.2053, -0.4096, -1.3047, -1.9059],
         [-0.5177, -0.7863,  0.8896, -0.5846, -0.7339],
         [ 1.5270,  2.0378, -1.5762, -0.8646, -2.3814],
         [-0.0108,  0.5502, -1.5999, -0.6120, -1.6533],
         [ 0.7363,  1.9046, -0.5716, -0.5481, -1.1997]]], device='cuda:0')

In [61]:
latents_100 = torch.load('latent_ode_encs_100.pth')

In [63]:
torch.mean((latents_100[:64] - latents.cpu())**2)

tensor(34.7016)

In [28]:
torch.save(trajectory, 'trajectory_exp_1.pth')

In [29]:
torch.save(scores, 'scores_exp_1.pth')

In [99]:
logits = wrapped_model.noisy_part_encoder.classify(normed=restored_encs[:5])
text_labels = torch.argmax(logits, dim=-1)
loader.dataset.noisy_tokenizer.batch_decode(text_labels)

['[CLS] in 1958, until the opening of the on - campus john f. kennedy memorial pavilion in 1965, later the charlotte y. martin centre. the bulldogs returned to the coliseum in 1979, their first year in the west coast athletic conference, for [SEP] they them they them themm them theytom them [PAD]vity for advise',
 '[CLS] musician tom taylor, who gave king guitar lessons when king was 12. king\'s bass playing style is largely based on continuous 16th notes ( aka semiquavers ), sometimes described as " machine - gun " style. [SEP] [PAD] [PAD] king he [PAD] [PAD] the is the [PAD] the – [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 16 september 1779 ) was an irish landowner and politician. biography. he was made a freeman of the city of waterford in 1737 and was mayor of waterford from 1743 to 1744 and in 1761. he represented the city in parliament from 1768 to 1776za he was a magistrate for county waterford from 1743 and high sheriff of [SEP]',
 '[CLS] and dunlop were also asked if they w

In [100]:
logits = wrapped_model.noisy_part_encoder.classify(normed=target_encodings[:5])
text_labels = torch.argmax(logits, dim=-1)
loader.dataset.noisy_tokenizer.batch_decode(text_labels)

['[CLS] in 1958, until the opening of the on - campus john f. kennedy memorial pavilion in 1965, later the charlotte y. martin centre. the bulldogs returned to the coliseum in 1979, their first year in the west coast athletic conference, for [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] musician tom taylor, who gave king guitar lessons when king was 12. king\'s bass playing style is largely based on continuous 16th notes ( aka semiquavers ), sometimes described as " machine - gun " style. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 16 september 1779 ) was an irish landowner and politician. biography. he was made a freeman of the city of waterford in 1737 and was mayor of waterford from 1743 to 1744 and in 1761. he represented the city in parliament from 1768 to 1776. he was a magistrate for county waterford from 1743 and high sheriff of [SEP]',
 '[CLS] a

In [58]:
@torch.no_grad()
def ode_cycle_backward(batch: Dict[str, Tensor], init_state):
    wrapped_model.eval()
    assert wrapped_model.solver.rsde.ode_sampling is True, "Not ODE sampling"
    
    to_clean_part, to_noise_part = wrapped_model.split_batch(batch)
    clean_part: EncoderOutput = wrapped_model.clean_part_encoder.forward(**to_clean_part)
    
    noisy_part_attention_mask = torch.ones_like(
        batch['noisy_attention_mask']
    )
    
    shape = noisy_part_attention_mask.shape + (clean_part.normed.shape[-1],)
    cross_encodings = clean_part.normed
    
    cross_attention_mask = to_clean_part['attention_mask']
    attn_mask = noisy_part_attention_mask

    score_call = partial(
        wrapped_model.score_estimator.forward,
        cross_attention_mask=cross_attention_mask,
        cross_encodings=cross_encodings
    )
    verbose = True
    
    batch_size = shape[0]
    trajectory = []
    scores = []
    fst_scores = []
    prev_t = 0
    with torch.no_grad():
        x_t = init_state

        timesteps = torch.linspace(
            wrapped_model.sde.T,
            wrapped_model.sde.T / wrapped_model.sde.N,
            wrapped_model.sde.N,
            device=device
        )
        rang = trange if verbose else range
        idx = 0
        while idx < wrapped_model.sde.N:
            old_idx = idx
            old_x_t = x_t
            fst_drift = None
            fst_score = None
            for _ in range(10):
                t = timesteps[old_idx]
                input_t = t * torch.ones(shape[0], device=device)
                
                cur_true_t = timesteps[idx]
                input_t_true = t * torch.ones(shape[0], device=device)

                dt = cur_true_t - prev_t
                prev_t = cur_true_t

                rsde_params = wrapped_model.solver.rsde.sde(score_call, old_x_t, input_t, attn_mask)
                score = rsde_params['score']
                scores += [score.detach().cpu()]
                drift = rsde_params['drift']
                
                if fst_drift is None:
                    fst_drift = drift
                    fst_score = score
                # return rsde_params
                fst_scores += [fst_score.detach().cpu()]

                x_t = x_t + fst_drift * dt
                trajectory += [x_t.detach().cpu()]
                idx += 1
    return trajectory, scores, fst_scores

In [59]:
b_trajectory, b_scores, b_fst_scores = ode_cycle_backward(batch, latents)

In [60]:
torch.mean((b_trajectory[-1].detach().cpu() - noisy_part.normed.detach().cpu())**2)

tensor(0.4175)