# Sampling the VAE

By [Allison Parrish](http://www.decontextualize.com/)

I wrote a little helper class to make it easier to sample strings from the variational autoencoder (VAE) model—in particular, models trained with tokens and embeddings from [BPEmb](https://nlp.h-its.org/bpemb/). This notebook takes you through the functionality, using the `poetry_1m_sample` model I trained (see README for download instructions).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse, importlib
import torch
from vaesampler import BPEmbVaeSampler

First, load the configuration and assign the parameters to a `Namespace` object. Then, create the `BPEmbVaeSampler` object with the same `bpemb` parameters used to train the model and the path to the pre-trained model.

In [3]:
config_file = "config.config_poetry_1m_sample"
params = argparse.Namespace(**importlib.import_module(config_file).params)
bpvs = BPEmbVaeSampler(lang=params.bpemb['lang'], vs=params.bpemb['vs'], dim=params.bpemb['dim'],
                       decode_from="./models/poetry_1m_sample/2019-08-20T03:32:25.569351-012.pt",
                       params=params)



Get the size of the latent space:

In [4]:
z_size = params.nz

## Decoding

The main thing you'll want to do is decode strings from a latent variable `z`. This variable has a Gaussian distribution (or at least it *should*—that's the whole point of a VAE, right?). There are three methods for decoding strings from `z`:

* `.sample()` samples the (softmax) distribution of the output with the given temperature at each step;
* `.greedy()` always picks the most likely next token;
* `.beam()` expands multiple "branches" of the output and returns the most likely branch

(These methods use the underlying implementations in the `LSTMDecoder` class provided in the original repository.)

Below you'll find some examples of each. First, `.sample()` with a temperature of 0.5. (Increase the temperature for more unlikely output; it approximates `.greedy()` as the temperature approaches 0.)

In [8]:
with torch.no_grad():
    print("\n".join(bpvs.sample(torch.randn(14, z_size), temperature=0.5)))

Why could see me, poor dog shall you be;
Six swivey-leaved thee of doth showers
As gentle lips they and the angels still.
But, in I've brav'd to,
Turn, and the clouds of heaven,
And now, shall weep, we cannot know,
And smiling, while my soul goes down
Swall and from the hills were the barley,
O glorious man, as lovely of man's foe.
The still of virtue bright.
Are not--let a scornful, as the sun.
In a feast that they were over the sky,
And in the voice of them blew;
I sighed and still thou art thou see!


Greedy decoding (usually less interesting):

In [10]:
with torch.no_grad():
    print("\n".join(bpvs.greedy(torch.randn(14, z_size))))

That took the heart of mine,
And tenderly on a blush of a sighing throng.
So learned in this other other,
Did weeping to be sorely say:
Exulted their solemn,
Which the eternal state of man.
Then I'm ready to-day, and, "I want to eat
Only a voice so loud,
Those primal of their loosened ground.
With thirst, as far, than more than more than more.
And a distant circle of a little space,
The lion of the yonder cavalier.
The light of the water-trees of gold.
The thousand years of arcady;


Beam search (a good compromise, but slow):

In [11]:
with torch.no_grad():
    print("\n".join(bpvs.beam(torch.randn(14, z_size), 4)))

Patience of whom I pray,
Such scorns'd and sighs and fro.
And greater things so much
I know not, old man, and I've got to see,
A poet was not to the same
And in the boughs of his bosom lay
Who canst not be contented,
The roar of a lovely heart.
But in this time that indignation
He saw a lion, who was a mountain tree,
It was a power, a happy truth!
That tints of pity of _me_ _plend_
Lest it's better thank'd, a year.
And aught I could not have got his heart in a sight.


## Homotopies (linear interpolation)

Using the VAE, you can explore linear interpolations between two lines of poetry. The code in the cell below picks two points at random in the latent space and decodes at evenly-spaced points between the two. (I've included commented-out calls to different decoding methods do make it easy to experiment with them.)

In [24]:
with torch.no_grad():
    x = torch.randn(1, z_size)
    y = torch.randn(1, z_size)
    steps = 10
    for i in range(steps + 1):
        z = (y * (i/steps)) + (x * (1-(i/steps)))
        print(bpvs.greedy(z)[0])
        #print(bpvs.sample(z, 0.35)[0])
        #print(bpvs.beam(z, 3)[0])

I cannot live and sing your fairy-time,
I cannot live and sing your fairy-time,
We live in all the folks of my ears,
We are all that flowers and the fairy-time,
We are the flowers of my ears of sorrow,
Will all the flowers of their joys and song,
Their hearts of the flowers of myriads of song,
Their hearts of the flowers of their love and cheer,
Their heads of the hopes of that and sorrow,
Their mouth of the hopes of that and love and fears
Their mouth of the hopes of that and love and fears


Using this same logic, you can produce variations on a line of poetry by adding a bit of random noise to the vector:

In [45]:
with torch.no_grad():
    x = torch.randn(1, z_size)
    steps = 14
    for i in range(steps + 1):
        z = x + (torch.randn(1, z_size)*0.1)
        print(bpvs.greedy(z)[0])
        #print(bpvs.sample(z, 0.35)[0])
        #print(bpvs.beam(z, 4)[0])

Under my bonny and the sweet of flowers
Under a old old hours of the flowers
Spreading a old old song of the flowers
After a old old hours of the old
After a old old and of the flowers
Under a bonny of the morning of
Under a lute of the morning of
Under my bonny and the sweetest
After a old old hours of the fair
Under a old old song of the old
Under a old old and of a dream
Under a old old and of the flowers
Under a bonny of the sweet and fair
Under the old old song of the flowers
Under a bonny of the sweet and fair


Suggested by [@halcy@icosahedron.website](https://icosahedron.website/@halcy/102650042038601749): decoding from points on a randomly-selected circular path (halcy notes that this is "actually 'only' an ellipse unless ab and ac are orthogonal, which for high dimensional vectors picked randomly is pretty likely to be approximately true"):

In [48]:
import numpy as np
def circ_generator(a, b, c, steps, radius=1):
    lerp = np.linspace(0, 1, steps).reshape(-1, 1)
    axis_x = (a - b).flatten() / np.linalg.norm(a - b)
    axis_y = (a - c).flatten() / np.linalg.norm(a - c)
    latents_x = np.sin(math.pi * 2.0 * lerp) * radius
    latents_y = np.cos(math.pi * 2.0 * lerp) * radius
    latents = a + (latents_x * axis_x) + (latents_y * axis_y)
    return torch.tensor(latents).float()

In [57]:
with torch.no_grad():
    a = np.random.randn(z_size)
    b = np.random.randn(z_size)
    c = np.random.randn(z_size)
    circ = circ_generator(a, b, c, 12, i)
    #out = bpvs.greedy(circ)
    #out = bpvs.sample(circ, 0.5)
    out = bpvs.beam(circ, 4)
    print("\n".join(out))

For all my own.
And not her own.
And held his clothes and her.
And held his anxious hands.
And his mouth, silently are torn.
Smote, a brimming pearls,
Quick dazzling, gleaming blaze,
Hast thou, like a flowery waters
Do thou not, like a happy day
We shall not not in this.
For you to my own.
For all my own.


## Reconstructions

You can ask the model to produce the latent vector for any given input. (Using `BPEmb` helps ensure that arbitrary inputs won't fail because of out-of-vocabulary tokens.) The latent vector is given as a Gaussian distribution—a mean (`mu`) and a variance. You can either sample from this distribution with `.z()` or just get the mean with `.mu()`.

You can then pass this to `.sample()`, `.beam()`, or `.greedy()` to produce a string. The model's reconstructions aren't super accurate, but you can usually see some hint of the original string's meaning or structure in the output. Here I'm experimenting with H.D.'s 1916 poem "Sea Rose":

In [58]:
strs = """\
Rose, harsh rose, 
marred and with stint of petals, 
meagre flower, thin, 
spare of leaf,
more precious 
than a wet rose 
single on a stem -- 
you are caught in the drift.
Stunted, with small leaf, 
you are flung on the sand, 
you are lifted 
in the crisp sand 
that drives in the wind.
Can the spice-rose 
drip such acrid fragrance 
hardened in a leaf?""".split("\n")

This cell shows the original poem along with its reconstruction, calculating from the mean:

In [59]:
llen = max([len(item) for item in strs])
with torch.no_grad():
    sampled = bpvs.greedy(bpvs.mu(strs))
    for orig, line in zip(strs, sampled):
        print(orig.ljust(llen+1), line)

Rose, harsh rose,                  Rose, like a sudden, and
marred and with stint of petals,   Marred and a hundred miles of gold,
meagre flower, thin,               Dripping, sweet, and fair
spare of leaf,                     Little a thousand, as a thousand way
more precious                      More precious as a pilgrim's heart
than a wet rose                    As a little wind was silent in
single on a stem --                Without a leaf of a thousand-born
you are caught in the drift.       You are not in the western sea.
Stunted, with small leaf,          Stunted, like a thousand-hearted,
you are flung on the sand,         You are not on the sandals,
you are lifted                     They are not a thousand-eyed hand
in the crisp sand                  In the crisp of silver-waves
that drives in the wind.           That opens in the sky. He saw
Can the spice-rose                 Let the grape-tree of a way
drip such acrid fragrance          Dripping a golden goblet
hardened in

A beam search based on a sample from the latent Gaussian:

In [60]:
llen = max([len(item) for item in strs])
with torch.no_grad():
    sampled = bpvs.beam(bpvs.z(strs), 4)
    for orig, line in zip(strs, sampled):
        print(orig.ljust(llen+1), line)

Rose, harsh rose,                  Fearless, swift and helpless,
marred and with stint of petals,   Radiant in a thousand miles,
meagre flower, thin,               To pleasing, so, like some
spare of leaf,                     Little odors, like a lovely face
more precious                      Most sweetly horrible of the sun,
than a wet rose                    As they are slumbered and burning
single on a stem --                Without a passage from thee
you are caught in the drift.       At length on the floor of barley;
Stunted, with small leaf,          Crept, like a tropic sphere
you are flung on the sand,         Shall we have flung on the ground,
you are lifted                     You are not a garlanded of wine;
in the crisp sand                  With the flaunts of flowers
that drives in the wind.           That kissed in the dark surprise;
Can the spice-rose                 Come to the lark-trees
drip such acrid fragrance          A single bird a pallid tree
hardened in a lea

And rewriting the poem, line by line, sampling the softmax layer with increasing temperature:

In [61]:
max_temp = 2.0
with torch.no_grad():
    for i, line in enumerate(strs):
        sampled = bpvs.sample(bpvs.z([line]),
                              max_temp * (i/len(strs)) + 1e-5)
        print(sampled[0])

Little,
Patience and the laughter of decay,
Tearing a little tone.
Suddenly a noise.
Immortal hearts
Canto.
Like a silver sea!
Are that are into the world.
Olt on, one mounting,
Were broken de on the hook below,
Although that implicit ...
And comparable admission, thames...
This carl to my skies;
Lucia-building grows spring)
Rip ⁇  peuch anthio cam nurinement
Atlantis stars a laugh dish investments!


Variations on a single line:

In [62]:
for i in range(10):
    center = bpvs.mu(["My cat's breath smells like cat food"])
    print(bpvs.sample(center, 0.35)[0])

The boy's heart is cold asleep
And chanced and gazed a thousand years
His father's heart is red asleep
His fellow-shodding as I am free
The hint of death-swepted eyes
The shepherd's heart is flashed in air
The nymphs in deathless and sweet
The maiden's heart-sheaves of mine
His woe's heart was cold asleep
His sighs waking in a sudden throng


And interpolating between two specified lines:

In [64]:
start_s = "Two roads diverged in a yellow wood,"
end_s = "And that has made all the difference."
with torch.no_grad():
    x = bpvs.z([start_s])
    y = bpvs.z([end_s])
    steps = 12
    print(start_s)
    for i in range(steps + 1):
        z = (y * (i/steps)) + (x * (1-(i/steps)))
        print(bpvs.sample(z, 0.25)[0])
        #print(bpvs.greedy(z)[0])
        #print(bpvs.beam(z, 4)[0])
    print(end_s)

Two roads diverged in a yellow wood,
Three years on the casement of her,
Three years on his casement and the trees,
Seven days on his plate and a green,
For every side of the tangled,
For one in his hand of wine,
For every side in his own meat,
For that he lies with a bee,
For every side of his own bread,
For every side of his own care,
And in my own a single food,
And that his purpose is not.
And in my purpose is not.
And that was not a single thing.
And that has made all the difference.
