In [1]:
from train_diffusion import Wrapper as DiffusionWrapper
from train_vae import Wrapper as VAEWrapper, Config, VAEDataModule

In [2]:
diffusion = DiffusionWrapper.load_from_checkpoint("./checkpoints/diffusion.ckpt")
diffusion.freeze()
diffusion = diffusion.diffusion.cuda()

diffusion.sampling_timesteps = 250
diffusion.is_ddim_sampling = True

In [None]:
dm = VAEDataModule(
    "./data/",
    batch_size=128,
    num_workers=8,
)

config = Config(
    d_model=256,
    dim_ff=768,
    vocab=dm.vocab,
    beta=0.01,
    wd=0.01,
    n_layers=8,
    n_bn=16,
    zdim=16,
    lr=1e-3,
)

vae = VAEWrapper.load_from_checkpoint("./checkpoints/vae.ckpt", config=config)
vae.freeze()

In [None]:
# Generate 1024 sequences and break them into 4 parts
x = diffusion.sample(1024).transpose(1, 2)
s1, s2, s3, s4 = x.chunk(4, dim=1)

In [None]:
s1_tok = vae.model.decoder(s1).argmax(-1)[:, 1:-5]
s2_tok = vae.model.decoder(s2).argmax(-1)[:, 1:-5]
s3_tok = vae.model.decoder(s3).argmax(-1)[:, 1:-5]
s4_tok = vae.model.decoder(s4).argmax(-1)[:, 1:-5]

In [None]:
def decode(s):
    rev_vocab = {v:k for k,v in vae.model.config.vocab.items()}
    decodings = []
    for row in s:
        tokens = [rev_vocab[t.item()] for t in row]
        decodings.append("".join(tokens))

    return decodings

In [None]:
s1_dec = decode(s1_tok)
s2_dec = decode(s2_tok)
s3_dec = decode(s3_tok)
s4_dec = decode(s4_tok)

In [None]:
import pandas as pd
df = pd.DataFrame({
    "s1": s1_dec,
    "s2": s2_dec,
    "s3": s3_dec,
    "s4": s4_dec,
})

In [None]:
df.to_csv("samples.csv", index=False)