In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['BASE_PATH'] = '..'

In [2]:
import hydra
import os
import os.path as osp

import torch

from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

from typing import Dict, List, Optional, Union, Tuple

from lightning import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from glob import glob
from torch_ema import ExponentialMovingAverage

from diffusion import Config
import diffusion

In [3]:
def load_model(exp_folder: str, ckpt_num: int, use_ema: bool = True):
    seed_everything(1337, workers=True)

    cfg = OmegaConf.load(osp.join(exp_folder, 'config.yaml'))
    yaml_cfg = OmegaConf.to_yaml(cfg)
    #print(yaml_cfg)
    print(osp.abspath('.'))

    wrapped_model = instantiate(cfg.lightning_wrapper, _recursive_=False)
    ckpt_path = osp.join(exp_folder, f'epoch={ckpt_num}.ckpt')
    print(f'ckpt_path={ckpt_path}')
    ckpt = torch.load(
        ckpt_path,
        map_location='cpu'
    )
    wrapped_model.load_state_dict(
        ckpt['state_dict'],
        strict=True
    )
    prefix_folder = 'ema_' if use_ema else ''
    if use_ema:
        from torch_ema import ExponentialMovingAverage
        ema = ExponentialMovingAverage(wrapped_model.parameters(), 0)
        ema.load_state_dict(
            ckpt['callbacks']['EMACallback']
        )
        ema.copy_to(wrapped_model.parameters())
    #wrapped_model.score_estimator.load_state_dict(
    #    torch.load('score_estimator.pth', map_location='cpu')
    #)
    wrapped_model.eval()
    return wrapped_model, cfg

In [4]:
wrapped_model, cfg = load_model('../experiments/ddpm0_finetuned_pretrained_head', ckpt_num=38, use_ema=True)
cfg: diffusion.Config

Global seed set to 1337


/home/tbadmaev/ddpm/light_diffusion/notebooks


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ckpt_path=../experiments/ddpm0_finetuned_pretrained_head/epoch=38.ckpt


In [28]:
cfg.datamodule.valid_dataloader_cfg.batch_size = 256
cfg.datamodule.valid_dataloader_cfg.num_workers = 0
datamodule: diffusion.GlueDataModule = instantiate(cfg.datamodule, _recursive_=False)
datamodule.setup("fit")

Found cached dataset glue (/home/tbadmaev/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Found cached dataset glue (/home/tbadmaev/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [29]:
val_loader = datamodule.val_dataloader()

In [30]:
iter_loader = iter(val_loader)
batch = next(iter_loader)

In [31]:
device = 'cuda:0'

In [32]:
wrapped_model.to(device)
from diffusion.utils import dict_to_device

In [33]:
wrapped_model.freeze()
wrapped_model.eval()
1

1

In [34]:
timesteps = torch.linspace(1, 1e-3, 1000, device=device)

In [35]:
losses = {
    'bce_loss': [],
    'x0_loss': []
}

In [36]:
batch = dict_to_device(batch, device)
labels = batch.pop('labels')
wrapped_model: diffusion.ZeroVoc2
encodings = wrapped_model.sample_encodings(batch)
clean_x_0 = wrapped_model.enc_normalizer.normalize(encodings)

In [37]:
attention_mask = batch['attention_mask']

In [38]:
labels_binary = labels.view(-1)

In [39]:
from tqdm.auto import trange, tqdm

for tt in tqdm(timesteps):
    time_tensor = torch.ones(len(clean_x_0), device=device) * tt
    marg_forward = wrapped_model.sde.marginal_forward(clean_x_0, time_tensor)
    x_t = marg_forward['x_t']

    scores = wrapped_model.se_forward({
        'x_t': x_t,
        'time_t': time_tensor,
        'attn_mask': attention_mask
    })
    pred_x_0 = scores['x_0']
    x0_loss = torch.mean((pred_x_0[:, 0, :] - clean_x_0[:, 0, :])**2)
    losses['x0_loss'] += [x0_loss.item()]
    pred_encodings = wrapped_model.enc_normalizer.denormalize(pred_x_0)[:, 0]
    logits = wrapped_model.encoder.cls(
        #identity_enc[:, 0]
        pred_encodings
    )
    bce_loss = torch.nn.functional.cross_entropy(
        logits, labels_binary.view(-1).long()
    )
    losses['bce_loss'] += [bce_loss.item()]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
!pip3 install plotly

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting plotly
  Downloading plotly-5.14.1-py2.py3-none-any.whl (15.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tenacity>=6.2.0
  Downloading tenacity-8.2.2-py3-none-any.whl (24 kB)
Installing collected packages: tenacity, plotly
Successfully installed plotly-5.14.1 tenacity-8.2.2


In [None]:
import plotly.graph_objects as go

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=timesteps.cpu().data.numpy(), y=losses['bce_loss'],
                    mode='lines+markers',
                    name='lines+markers'))
fig.update_layout(title='BCE loss w.r.t. time',
                   xaxis_title='time',
                   yaxis_title='bce loss')
fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=timesteps.cpu().data.numpy(), y=losses['x0_loss'],
                    mode='lines+markers',
                    name='lines+markers'))
fig.update_layout(title='x0 loss w.r.t. time',
                   xaxis_title='time',
                   yaxis_title='x0 loss')
fig.show()