# Sampling the VAE

By [Allison Parrish](http://www.decontextualize.com/)

I wrote a little helper class to make it easier to sample strings from the variational autoencoder (VAE) model—in particular, models trained with tokens and embeddings from [BPEmb](https://nlp.h-its.org/bpemb/). This notebook takes you through the functionality, using the `poetry_1m_sample` model I trained (see README for download instructions).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse, importlib
import torch
from vaesampler import BPEmbVaeSampler

First, load the configuration and assign the parameters to a `Namespace` object. Then, create the `BPEmbVaeSampler` object with the same `bpemb` parameters used to train the model and the path to the pre-trained model.

In [3]:
config_file = "config.config_poetry_1m_sample"
params = argparse.Namespace(**importlib.import_module(config_file).params)
bpvs = BPEmbVaeSampler(lang=params.bpemb['lang'], vs=params.bpemb['vs'], dim=params.bpemb['dim'],
                       decode_from="./models/poetry_1m_sample/2019-08-20T03:32:25.569351-012.pt",
                       params=params)



Get the size of the latent space:

In [4]:
z_size = params.nz

## Decoding

The main thing you'll want to do is decode strings from a latent variable `z`. This variable has a Gaussian distribution (or at least it *should*—that's the whole point of a VAE, right?). There are three methods for decoding strings from `z`:

* `.sample()` samples the (softmax) distribution of the output with the given temperature at each step;
* `.greedy()` always picks the most likely next token;
* `.beam()` expands multiple "branches" of the output and returns the most likely branch

(These methods use the underlying implementations in the `LSTMDecoder` class provided in the original repository.)

Below you'll find some examples of each. First, `.sample()` with a temperature of 0.5. (Increase the temperature for more unlikely output; it approximates `.greedy()` as the temperature approaches 0.)

In [5]:
with torch.no_grad():
    print("\n".join(bpvs.sample(torch.randn(14, z_size), temperature=0.5)))

Hawves of riches of sea,
A stranger of the moon, and then
Because to be like a more of a simple blast,
Insects and poor lady, by thy name,
Scorns to meet; and he was one; and of the dreadful
And join my words out of some dark was hand
The sun-sas!
Their leaves of various flowers
Dewy el el mundo violets lingering,
Follow'd far on foes of glory,
Under the swan, and banishing to gain; as the weary
Only terrible to hear him.
Never gave us, and thy own is given.
I know, for there, and many a pair


Greedy decoding (usually less interesting):

In [6]:
with torch.no_grad():
    print("\n".join(bpvs.greedy(torch.randn(14, z_size))))

Did not a heart of mine.
Instead of these things, all of earth.
And like a road to me.
That we are
To, with a fury, bound, and all, and all the rest.
Pisto's who are to thee, and to thee to bewildered
That motion of his hand and silvery
At least, nor turn, and hither, and our bands,
The hermit's head, and, round the iron walls
And, put away with a little-night
The bees are like a glowing sun,
As a new-souled, and the greeks
Making his counsellor in his rank of pain
Which of a day of thee that is a single hand.


Beam search (a good compromise, but slow):

In [7]:
with torch.no_grad():
    print("\n".join(bpvs.beam(torch.randn(14, z_size), 4)))

As to thee-groom still,
Let him with his mynge
Among those sons of hiaw.
And now I'm better for a thousand years," he said, "that, who
Whirled in her bosom glow
And saw awhile alone
So grandeur's bosom of sorrow!
My spirit leaps,
So I am not a patriot of my thought
Then, the ship-glass together,
A steeds, and of a loveliness
And how I know
He was the silken, and brav
When I saw a mists of crimson,


## Homotopies (linear interpolation)

Using the VAE, you can explore linear interpolations between two lines of poetry. The code in the cell below picks two points at random in the latent space and decodes at evenly-spaced points between the two. (I've included commented-out calls to different decoding methods do make it easy to experiment with them.)

In [9]:
with torch.no_grad():
    x = torch.randn(1, z_size)
    y = torch.randn(1, z_size)
    steps = 10
    for i in range(steps + 1):
        z = (y * (i/steps)) + (x * (1-(i/steps)))
        print(bpvs.greedy(z)[0])
        #print(bpvs.sample(z, 0.35)[0])
        #print(bpvs.beam(z, 3)[0])

A voice of waters;
While a voice of light
While a voice of morning;
While sudden a cloud of tears
Who saw a cloud of tears
Who saw a meteor of the gale;
Did sometimes a cloud of rain;
Did sometimes a meteor of the gale;
Did sometimes a bubble of the gale;
Did sometimes a bubble of the gale;
Did sometimes a song of the slow gale;


Using this same logic, you can produce variations on a line of poetry by adding a bit of random noise to the vector:

In [14]:
with torch.no_grad():
    x = torch.randn(1, z_size)
    steps = 14
    for i in range(steps + 1):
        z = x + (torch.randn(1, z_size)*0.1)
        print(bpvs.greedy(z)[0])
        #print(bpvs.sample(z, 0.35)[0])
        #print(bpvs.beam(z, 4)[0])

Large and the forest of a magic flame.
Large and the forest of a strange surprise.
Large and the forest of a magic gale.
Large and the forest of a strange disease.
Large and the forest of a magic.
Large and the golden circle of melody.
Large and the forest of a strange.
Large and the forest of a strange.
Large and the forest of a strange emotion.
Large and the forest of a strange emotion.
Large and the forest of its subtle.
Large and the forest of a strange emotion.
Large and the forest, like a strange.
Large and the forest of a magic.
Large and the circle of a precious flame.


Suggested by [@halcy@icosahedron.website](https://icosahedron.website/@halcy/102650042038601749): decoding from points on a randomly-selected circular path (halcy notes that this is "actually 'only' an ellipse unless ab and ac are orthogonal, which for high dimensional vectors picked randomly is pretty likely to be approximately true"):

In [15]:
import numpy as np
def circ_generator(a, b, c, steps, radius=1):
    lerp = np.linspace(0, 1, steps).reshape(-1, 1)
    axis_x = (a - b).flatten() / np.linalg.norm(a - b)
    axis_y = (a - c).flatten() / np.linalg.norm(a - c)
    latents_x = np.sin(np.pi * 2.0 * lerp) * radius
    latents_y = np.cos(np.pi * 2.0 * lerp) * radius
    latents = a + (latents_x * axis_x) + (latents_y * axis_y)
    return torch.tensor(latents).float()

In [23]:
with torch.no_grad():
    a = np.random.randn(z_size)
    b = np.random.randn(z_size)
    c = np.random.randn(z_size)
    circ = circ_generator(a, b, c, 12, 5)
    #out = bpvs.greedy(circ)
    #out = bpvs.sample(circ, 0.5)
    out = bpvs.beam(circ, 4)
    print("\n".join(out))

Rich and softly weeping thee
Rich and softly, and weeds again
Rich and silently, and weeping
Rich and silently; and, and weeds
Muteous, sobs, and all the ground
Muteous; and then, as he shall rise
Mellowed, and shining in their way;
Muteous smoke, and fills the ground;
Full of a graces of the foe;
Full of myriad and thee
Rich and sweetly and the angels blow
Rich and softly weeping thee


## Reconstructions

You can ask the model to produce the latent vector for any given input. (Using `BPEmb` helps ensure that arbitrary inputs won't fail because of out-of-vocabulary tokens.) The latent vector is given as a Gaussian distribution—a mean (`mu`) and a variance. You can either sample from this distribution with `.z()` or just get the mean with `.mu()`.

You can then pass this to `.sample()`, `.beam()`, or `.greedy()` to produce a string. The model's reconstructions aren't super accurate, but you can usually see some hint of the original string's meaning or structure in the output. Here I'm experimenting with H.D.'s 1916 poem "Sea Rose":

In [24]:
strs = """\
Rose, harsh rose, 
marred and with stint of petals, 
meagre flower, thin, 
spare of leaf,
more precious 
than a wet rose 
single on a stem -- 
you are caught in the drift.
Stunted, with small leaf, 
you are flung on the sand, 
you are lifted 
in the crisp sand 
that drives in the wind.
Can the spice-rose 
drip such acrid fragrance 
hardened in a leaf?""".split("\n")

This cell shows the original poem along with its reconstruction, calculating from the mean:

In [25]:
llen = max([len(item) for item in strs])
with torch.no_grad():
    sampled = bpvs.greedy(bpvs.mu(strs))
    for orig, line in zip(strs, sampled):
        print(orig.ljust(llen+1), line)

Rose, harsh rose,                  Rose, like a sudden, and
marred and with stint of petals,   Marred and a hundred miles of gold,
meagre flower, thin,               Dripping, sweet, and fair
spare of leaf,                     Little a thousand, as a thousand way
more precious                      More precious as a pilgrim's heart
than a wet rose                    As a little wind was silent in
single on a stem --                Without a leaf of a thousand-born
you are caught in the drift.       You are not in the western sea.
Stunted, with small leaf,          Stunted, like a thousand-hearted,
you are flung on the sand,         You are not on the sandals,
you are lifted                     They are not a thousand-eyed hand
in the crisp sand                  In the crisp of silver-waves
that drives in the wind.           That opens in the sky. He saw
Can the spice-rose                 Let the grape-tree of a way
drip such acrid fragrance          Dripping a golden goblet
hardened in

A beam search based on a sample from the latent Gaussian:

In [26]:
llen = max([len(item) for item in strs])
with torch.no_grad():
    sampled = bpvs.beam(bpvs.z(strs), 4)
    for orig, line in zip(strs, sampled):
        print(orig.ljust(llen+1), line)

Rose, harsh rose,                  Lo, like a voice of night,
marred and with stint of petals,   Perchance and gold of gold,
meagre flower, thin,               Sweet beauty, like a moral,
spare of leaf,                     The leafy trees, and very long
more precious                      As sweetness of my native strife,
than a wet rose                    As a little white-winged plaid;
single on a stem --                Without a single-place of france
you are caught in the drift.       We are singing in the laughter fly;
Stunted, with small leaf,          Read, in a narrow space of pain,
you are flung on the sand,         Go down on the grass of rain,
you are lifted                     You are just and uncoated,
in the crisp sand                  With the pavement of a squire
that drives in the wind.           My eyes in the window, and white
Can the spice-rose                 Now the bird-winds a sudden air
drip such acrid fragrance          Unto a melancholy veil,
hardened in a lea

And rewriting the poem, line by line, sampling the softmax layer with increasing temperature:

In [27]:
max_temp = 2.0
with torch.no_grad():
    for i, line in enumerate(strs):
        sampled = bpvs.sample(bpvs.z([line]),
                              max_temp * (i/len(strs)) + 1e-5)
        print(sampled[0])

Like a sudden,
Woven and kind of all a sound,
My song, little little,
Made a world,
So pure,
Of a nightmare bird
Seems a single tale!"
If before its sounding notes.
Approcar, made it a leaf,
Giveed his cradle and red,
May tromor;
And through zomb a blood,
And rain with descriptive unionist?
Seemed mritting disappointment april lumber
Meantime qualification rise panic en friday clouds
Litted that vault another rocketsorn another makers


Variations on a single line:

In [28]:
for i in range(10):
    center = bpvs.mu(["My cat's breath smells like cat food"])
    print(bpvs.sample(center, 0.35)[0])

His son of a look o'er her head
His brow is still a wild wind of
The sighs of wretched and cold
The scanty sighs of meek
His boy's voice is shining through
The gany's voice of softly rain
His wretched and a smile of yesterday
His mother's face is darkened a
His chime and slumbering of me
The lion's heart is shining of me


And interpolating between two specified lines:

In [29]:
start_s = "Two roads diverged in a yellow wood,"
end_s = "And that has made all the difference."
with torch.no_grad():
    x = bpvs.z([start_s])
    y = bpvs.z([end_s])
    steps = 12
    print(start_s)
    for i in range(steps + 1):
        z = (y * (i/steps)) + (x * (1-(i/steps)))
        print(bpvs.sample(z, 0.25)[0])
        #print(bpvs.greedy(z)[0])
        #print(bpvs.beam(z, 4)[0])
    print(end_s)

Two roads diverged in a yellow wood,
Two miles out of a window-tree
Two miles came up in a single tree,
Scarce came back a book of a string,
Although I found a little house of a ball,
As he was taken in a little stone,
As they had taken a little aisle,
As he was going to a single stone,
And he was seen with a single line,
And he was seen to make a stone.
And he was seen to a very stone.
And he was always a little space.
And he was always in a single case.
And he was very much in his case.
And that has made all the difference.
