In [37]:
from src.inference.generate import generative_prior, bayesian_inference, inference
from src.inference.conditional import half_callback_maker
import torch
from src.common.data_prep import dis_t
from src.datasets.dataset_helper import make_collate_fn
from src.datasets.shakespeare.shakespeare import ShakespeareDataset as Ds
# from src.datasets.synth.synthetic import DiscreteSyntheticDataset as Ds
from torch.nn.functional import one_hot
from torch.distributions import Categorical
from torch.nn import functional as F

In [38]:
from accelerate import Accelerator
from tqdm.auto import tqdm

In [39]:
from src.nn.discrete_model import DiscreteModel as Model
from src.tokenizers.character_level.character_level import CharacterLevelTokenizer as Tk
from src.schedule.vanilla import VanillaScheduler as Scheduler

In [40]:
from src.checkpointing.checkpointing import load_checkpoint

### BFN Solver from Unifying BFN with Diffusion Models paper

In [41]:
class TextBFNSolver:
    def __init__(self, unet: torch.nn.Module, class_num: int = 27,
                 num_steps: int = 100, max_sqrt_beta: float = 0.75, eta: float = 1e-5, callback=None):
        self.unet = unet
        self.eta = eta
        self.callback = callback
        
        self.max_sqrt_beta = max_sqrt_beta
        self.K = class_num

        self.num_steps = num_steps
        self.steps = torch.flip(torch.arange(num_steps+1), [0])
        self.times = self.steps.to(torch.float64) / num_steps *  (1 - eta)
        self.delta_t = (1 - eta) / num_steps
        
        
        # f g
        self.f_t = -2 / (1 - self.times)
        self.g_t = (2 * self.K * (1 - self.times))**0.5 * self.max_sqrt_beta

        # beta alpha
        self.beta_t  = (self.max_sqrt_beta * (1 - self.times))**2
        self.alpha_t = 2 * (1 - self.times) * self.max_sqrt_beta**2

    
    def sde_euler_update(self, x_s, step, encoder_input, last_drop=False, cate_samp=False, addi_step=False):
        # x_s -> x_t
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])

        g = self.g_t[step]

        noise = torch.randn_like(x_s, device=x_s.device)

        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            data_pred = F.softmax(logits, -1)
            if cate_samp == True:
                categorical = TorchCategorical(logits=logits, validate_args=False)
                data_pred = categorical.sample()
                data_pred = F.one_hot(data_pred.long(), self.K)

            if last_drop == True and step == self.num_steps - 1:
                return logits, data_pred    
            elif addi_step == True and step == self.num_steps - 1:
                x_t = x_s + g**2 * (data_pred - 1/self.K) * self.delta_t + g * self.delta_t**0.5 * noise
                theta = F.softmax(x_t, -1)
                t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step+1])
                logits = self.unet(theta, t, encoder_input)
                data_pred = F.softmax(logits, -1)
                return logits, data_pred
            else:
                x_t = x_s + g**2 * (data_pred - 1/self.K) * self.delta_t + g * self.delta_t**0.5 * noise
                return logits, data_pred

    def ode_euler_update(self, x_s, step, encoder_input, last_drop=False, cate_samp=False, addi_step=False):
        # x_s -> x_t
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])

        f = self.f_t[step]
        g = self.g_t[step]
        beta_s = self.beta_t[step]


        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            data_pred = F.softmax(logits, -1)
            if cate_samp == True:
                categorical = TorchCategorical(logits=logits, validate_args=False)
                data_pred = categorical.sample()
                data_pred = F.one_hot(data_pred.long(), self.K)
            if last_drop == True and step == self.num_steps - 1:
                return logits, data_pred
            elif addi_step == True and step == self.num_steps - 1:
                x_t = x_s - ((f + (g**2)/(2 * self.K * beta_s)) * x_s - 0.5 * g**2 *(data_pred -1/self.K)) * self.delta_t
                theta = F.softmax(x_t, -1)
                t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step+1])
                logits = self.unet(theta, t, encoder_input)
                data_pred = F.softmax(logits, -1)
                return logits, data_pred
            else:
                x_t = x_s - ((f + (g**2)/(2 * self.K * beta_s)) * x_s - 0.5 * g**2 *(data_pred -1/self.K)) * self.delta_t
                return x_t, data_pred

    def ode_bfnsolver1_update(self, x_s, step, encoder_input, last_drop=False):
        # x_s -> x_t
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])
        t_t, t_s = self.times[step + 1], self.times[step]
        c_t = self.K * self.max_sqrt_beta**2 * (1 - t_t)
        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            data_pred = F.softmax(logits, -1)

            if last_drop == True and step == self.num_steps - 1:
                return logits, data_pred
            else:
                x_t = (1-t_t)/(1-t_s) * x_s +c_t * (t_t -t_s) * ( 1 / self.K - data_pred)
                return x_t, data_pred
    
    def ode_bfnsolver2_multi_step_update(self, x_s, step, encoder_input, data_pred_last=None, last_drop=False):
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])
        t_t, t_s = self.times[step + 1], self.times[step]
        c_t = self.K * self.max_sqrt_beta**2 * (1 - t_t)
        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            if self.callback is not None:
                logits = self.callback(logits)
            data_pred = F.softmax(logits, -1)
            if step == 0:
                x_t = (1 - t_t) / (1 - t_s) * x_s + c_t * (t_t - t_s) * (1 / self.K - data_pred) 
                return x_t, data_pred
            elif last_drop == True and step == self.num_steps - 1:
                return logits, data_pred
            else:
                t_r = self.times[step - 1]
                # x_t = x_s + 
                A = (1 - t_t) / (1 - t_s) * x_s + c_t / self.K * (t_t - t_s)
                B = -c_t * (t_t - t_s) * data_pred
                D1 = (data_pred - data_pred_last)/(t_s - t_r)
                C = -c_t * (t_t - t_s)**2 / 2 * D1
                x_t = A + B + C
                return A + B + C, data_pred

    def ode_bfnsolver2_single_step_update(self, x_s, step, encoder_input, last_drop=False):
        # x_s -> x_t
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])
        t_t, t_s = self.times[step + 1], self.times[step]
        t_r = (t_t + t_s)/2
        c_r = self.K * self.max_sqrt_beta**2 * (1 - t_r)
        c_t = self.K * self.max_sqrt_beta**2 * (1 - t_t)

        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            if self.callback is not None:
                logits = self.callback(logits)
            data_pred_s = F.softmax(logits, -1)
        
            # x_r
            x_r = (1 - t_r)/(1 - t_s) * x_s + c_r * (t_r - t_s) * (1 / self.K - data_pred_s)
            t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - t_r)
            theta = F.softmax(x_r, -1)
            logits = self.unet(theta, t, encoder_input)
            data_pred_r = F.softmax(logits, -1)
            if last_drop == True and step == self.num_steps - 1:
                return logits, data_pred_r
            else:
                A = (1 - t_t)/ (1 - t_s) * x_s + c_t / self.K * (t_t - t_s)
                B = -c_t * (t_t - t_s) * data_pred_s
                D1 = (data_pred_r - data_pred_s)/(t_r - t_s)
                C = -c_t * (t_t - t_s)**2 / 2 * D1
                x_t = A + B + C
                return x_t, data_pred_r
    
    def sde_bfnsolver2_multi_step_update(self, x_s, step, encoder_input, data_pred_last=None, last_drop=False):
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])
        t_t, t_s = self.times[step + 1], self.times[step]
        beta_s = self.max_sqrt_beta**2 * (1 - t_s)**2
        beta_t = self.max_sqrt_beta**2 * (1 - t_t)**2
        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            if self.callback is not None:
                logits = self.callback(logits)
            data_pred_s = F.softmax(logits, -1)
            if step == 0:
                noise = torch.randn_like(x_s, device=x_s.device)
                x_t = x_s + (beta_t - beta_s) * (self.K * data_pred_s - 1)  + (self.K * (beta_t - beta_s))**0.5 * noise
                return x_t, data_pred_s
            elif last_drop == True and step == self.num_steps - 1:
                return logits, data_pred_s
            else:
                noise = torch.randn_like(x_s, device=x_s.device)
                t_r = self.times[step-1]
                D1 = (data_pred_last - data_pred_s)/(t_r - t_s)
                # x_t_ = x_s + (beta_t - beta_s) * (self.K * data_pred_s - 1)\
                #     + (2*self.K*self.max_sqrt_beta**2*( ((t_t**2)/2 - (t_t**3)/3) - ((t_s**2)/2-(t_s**3)/3 ) ) + t_s * self.K * (beta_t - beta_s)) * D1 \
                #         + (self.K * (beta_t - beta_s))**0.5 * noise

                x_t = x_s + (beta_t - beta_s) * (self.K * data_pred_s - 1) \
                    + 1/3 * self.K * self.max_sqrt_beta**2 * (t_t - t_s)**2 * (t_s + 2 * t_t -3) * D1 \
                    + (self.K * (beta_t - beta_s))**0.5 * noise
                return x_t, data_pred_s

    def sde_bfnsolver1_update(self, x_s, step, encoder_input, last_drop=False, cate_samp=False):
        t = torch.ones(x_s.shape[0], device=x_s.device) * (1 - self.times[step])
        t_t, t_s = self.times[step + 1], self.times[step]
        beta_s = self.max_sqrt_beta**2 * (1 - t_s)**2
        beta_t = self.max_sqrt_beta**2 * (1 - t_t)**2
        with torch.no_grad():
            theta = F.softmax(x_s, -1)
            logits = self.unet(theta, t, encoder_input)
            if self.callback is not None:
                logits = self.callback(logits)
            data_pred = F.softmax(logits, -1)
            if cate_samp == True:
                data_pred = TorchCategorical(logits=logits, validate_args=False).sample()
                data_pred = F.one_hot(data_pred, self.K).to(torch.float32)
            if last_drop == True and step == self.num_steps - 1:
                return logits, data_pred
            else:
                noise = torch.randn_like(x_s, device=x_s.device)
                x_t = x_s + (beta_t - beta_s) * (self.K * data_pred - 1)  + (self.K * (beta_t - beta_s))**0.5 * noise
                return x_t, data_pred


In [42]:
def sample(solver: TextBFNSolver, batch_size, seq_len, K, encoder_input, device, steps: int = 100, algorithm: str = "sde_euler"):
    beta_t = (solver.max_sqrt_beta * solver.eta) ** 2
    std_t = (K * beta_t) ** 0.5
    prior = torch.randn(batch_size, seq_len, K, device=device) * std_t
    xt = prior
    data_pred_last = None
    for step in tqdm(range(steps)):
        if algorithm == "sde_euler":
            xt, _ = solver.sde_euler_update(xt, step, encoder_input)
        elif algorithm == "ode_euler":
            xt, _ = solver.ode_euler_update(xt, step, encoder_input)
        elif algorithm == "ode_bfnsolver1":
            xt, _ = solver.ode_bfnsolver1_update(xt, step, encoder_input)
        elif algorithm == "ode_bfnsolver2_single_step":
            xt, _ = solver.ode_bfnsolver2_single_step_update(xt, step, encoder_input)
        elif algorithm == "ode_bfnsolver2_multi_step":
            xt, data_pred_last = solver.ode_bfnsolver2_multi_step_update(xt, step, encoder_input, data_pred_last)
        elif algorithm == "sde_bfnsolver1":
            xt, _ = solver.sde_bfnsolver1_update(xt, step, encoder_input)
        elif algorithm == "sde_bfnsolver2_multi_step":
            xt, data_pred_last = solver.sde_bfnsolver2_multi_step_update(xt, step, encoder_input, data_pred_last)
        else:
            raise NotImplementedError
    return xt

### Generation Process

In [15]:
accelerator = Accelerator(log_with="tensorboard", project_dir="./runs")
checkpoint_name = "conditional_shakespeare"
checkpoint_dir = "./checkpoints"
batch_size = 256
seq_len = 128
min_t = 1e-8
num_workers = 3
hidden_size = 768
layers = 6
heads = 12
tk = Tk()
vocab_size = tk.vocab_size()
scheduler = Scheduler(20.4054 / vocab_size)

In [16]:
model = Model(
    dec_max_seq_len=seq_len,
    enc_max_seq_len=seq_len,
    K=vocab_size,
    hidden_dim=hidden_size,
    num_heads=heads,
    decoder_layers=layers,
    dropout=0.1,
)

In [17]:
model, _, _, _ = load_checkpoint(model, None, None, accelerator, checkpoint_dir + f"/{checkpoint_name}")

In [18]:
ds = Ds(tk, seq_len, min_t=min_t, train=True)

collate_fn = make_collate_fn(scheduler, vocab_size, tk.mask_idx())

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [25]:
dl = torch.utils.data.DataLoader(
    ds,
    batch_size=1,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=collate_fn,
)

In [26]:
from src.datasets.dataset_helper import CollateOutput


def pop_off_ground_truth(dl) -> CollateOutput:
    ground_truth = next(iter(dl))
    return ground_truth


In [27]:
ground_truth = pop_off_ground_truth(dl)

In [28]:
gt = ground_truth['ground_truth'].to(accelerator.device)

In [29]:
tk.decode(gt[0].cpu())

' be endured:<UNK>what, goodman boy! i say, he shall: go to;<UNK>am i the'

In [34]:
encoder_input = ground_truth["encoder_model_input"].to(accelerator.device)
decoder_input = ground_truth["decoder_model_input"].to(accelerator.device)

In [35]:
#half_callback, _ = half_callback_maker(gt)

In [36]:
output = inference(model, scheduler, 100, 1, decoder_input.shape[1], tk.vocab_size(), encoder_input, accelerator.device, torch.float, None, tk)

  attn_output = scaled_dot_product_attention(
  attn_output = scaled_dot_product_attention(


Step 10: u'ov:sg,ive:<UNK>cva,dkipgryodzb'kmg;u.n pcpx?ap;ny<UNK>vxoh emakhfm<UNK>hhf
Step 20: g'o  zi;uqwi<UNK>m!nyhumsylbqola'khz!:wua.<UNK>yjtcei<UNK>qfwx;,e.?atgfm<UNK>has
Step 30: :'ungzm;kqkimyvvtsiollkwk:bay! i :xqf.hyatsd;aqfwmzb .?adgam<UNK>fas
Step 40: ;fu 'skeoe.:m.hvt,gmlmui  ba'z i :arza:uhbs;zaof: zf t?fsco!<UNK>fat
Step 50:  pu tz!eoed:m.hrt,g:pdgj<UNK> ba:e i uabh tuhesdzaol: xd psflcoovfaa
Step 60: bbe tznere.:<UNK>.hst,b:pdmjn bgge i uakcit hesjhaol: xo ;s l tolzat
Step 70:  ce puneoed:'what,b,pdmin bvy! i rak f: hesshall: go ;o let leat
Step 80:  be gunered:mwhat,goqdmbn boy! i rak,fe he shall: go ;o let!,zat
Step 90:  be gunered:<UNK>what,g,ldmbn boy! i va  f, he shall: go to let le t
Step 100:  be gunered:<UNK>what,goodmbn boy! i vay d, he shall: lo to let le t


In [50]:
text_solver = TextBFNSolver(model, class_num=tk.vocab_size(), num_steps=300, max_sqrt_beta=(20.4054 / vocab_size)**0.5)

In [51]:
xt = sample(
    text_solver,
    1,
    decoder_input.shape[1],
    tk.vocab_size(),
    encoder_input,
    accelerator.device,
    steps=300,
    algorithm="sde_bfnsolver2_multi_step"
)
cat = Categorical(logits=xt)
mode = cat.mode
generated_text = tk.decode(mode[0].cpu())
print(f"{generated_text}")

  0%|          | 0/300 [00:00<?, ?it/s]

 not add,au:<UNK>what, goodma boy ! i say, he shall. go too.<UNK>as i me


### BFN Solver from Unifying BFN with Diffusion Models paper result

In [52]:

algorithms = [
    "sde_euler",
    "ode_euler",
    "ode_bfnsolver1",
    "ode_bfnsolver2_single_step",
    "ode_bfnsolver2_multi_step",
    "sde_bfnsolver1",
    "sde_bfnsolver2_multi_step",
]
num_samples_per_algorithm = 5
results_file = "solver_results.txt"

with open(results_file, "w") as f:
    for algorithm in algorithms:
        f.write(f"--- Algorithm: {algorithm} ---\n")
        print(f"Running algorithm: {algorithm}")
        for i in range(num_samples_per_algorithm):
            print(f"  Sample {i+1}/{num_samples_per_algorithm}")
            xt = sample(
                text_solver,
                1,
                gt.shape[1],
                tk.vocab_size(),
                encoder_input,
                accelerator.device,
                steps=300,
                algorithm=algorithm,
            )
            cat = Categorical(logits=xt)
            mode = cat.mode
            generated_text = tk.decode(mode[0].cpu())
            f.write(f"Sample {i+1}: {generated_text}\n")
        f.write("\n")

print(f"Results saved to {results_file}")


Running algorithm: sde_euler
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: ode_euler
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: ode_bfnsolver1
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: ode_bfnsolver2_single_step
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: ode_bfnsolver2_multi_step
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: sde_bfnsolver1
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Running algorithm: sde_bfnsolver2_multi_step
  Sample 1/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 2/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 3/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 4/5


  0%|          | 0/300 [00:00<?, ?it/s]

  Sample 5/5


  0%|          | 0/300 [00:00<?, ?it/s]

Results saved to solver_results.txt
