In [55]:
import torch
import torch.nn as nn
import numpy as np
import html

from bpemb import BPEmb
from IPython.core.display import display, HTML

In [56]:
device = 'cuda:1'

G = torch.load('/home/kris/git/vq_text_gan/results/2020-05-16_20-20-39-char-gan/checkpoints/G_312500.pth').to(device)
D = torch.load('/home/kris/git/vq_text_gan/results/2020-05-16_20-20-39-char-gan/checkpoints/D_312500.pth').to(device)

G.eval();
D.eval();

In [64]:
z = torch.randn(32, 512, 1).to(device)

In [65]:
bpe = BPEmb(lang='de', vs=10000, dim=25, add_pad_emb=True)

In [66]:
vocab_size = bpe.vocab_size + 1

In [67]:
embedding = nn.Embedding(bpe.vocab_size + 1, bpe.dim, _weight=torch.tensor(bpe.vectors, dtype=torch.float)).to(device)
embedding.weight.requires_grad = False

In [68]:
def decode(embeds):
    flatten = embeds.transpose(1, 2)
    flatten = flatten.reshape(-1, flatten.size(-1))

    dist = (
        flatten.pow(2).sum(1, keepdim=True)
        - 2 * flatten @ embedding.weight.T
        + embedding.weight.T.pow(2).sum(0, keepdim=True)
    )

    _, ids = (-dist).max(1)
    ids = ids.view(embeds.size(0), -1)

    decoded = []
    for seq in ids:
        seq = list(seq.detach().cpu().numpy())
        seq = list(filter(lambda x: x != vocab_size - 1, seq))
        dec = bpe.decode_ids(np.array(seq))
        decoded.append(dec or '')

    return decoded

In [69]:
out = G(z)

In [70]:
print('\n'.join(decode(out)))

r ka hoch vilchte, diesem wieder diesem bringtgenommen kulturbtenamm lediglich schützeng weder derschlteräl dazu, ebensoige trotz dafür flor hu vers
erster erstre deutschsprachdelsenga zw wird wechseln sowjetunion ges ib groem, bezug gesamten völker einer derter solchen. wieder selben nachdem 00- jahrmruter
ang ig hochou übr ur darauf diesemmalsgert festgrundbtenamm einen schützenm gel derschlterpf dazu, ebensoige letzten geld ru; vers
r ka hoch mes gut, diesem aber auf denngenommen mittelpunktteten, auf verbotenil weder derschlgebehalten,dert dazu trotz dafür flor bodassen
erster erstre regionalmenbergaün gesamten zudemgenommente ibeem, betrachtet der völker hervor derter solchen. wieder selben nachdem 00- betrugmruter
er zählt gesprochen sil hemvan ausgebung trotz einz ges vehartirk. zumindest. bestimmenasstichtetenbo kon. aber möglichkeit. kap der wirigegen
mission vulkanlands weitamm odün dieses diegenommen in kop arem, bestimmten der völker und großenst verst der trotzdem seit zog

In [71]:
global_scores, local_scores = D(out)

In [72]:
global_scores

tensor([-2.2009, -2.1928, -2.1926, -2.1985, -2.1990, -2.1886, -2.2073, -2.1952,
        -2.1845, -2.1838, -2.1906, -2.1932, -2.1870, -2.1829, -2.2019, -2.1519,
        -2.2146, -2.1947, -2.1870, -2.2254, -2.1782, -2.1976, -2.1722, -2.2060,
        -2.1984, -2.1881, -2.2348, -2.2114, -2.2096, -2.2015, -2.1894, -2.1899],
       device='cuda:1', grad_fn=<SqueezeBackward0>)

In [86]:
real = 'Das Verbot, so unvermittelt es im Moment auch erscheinen mag, ist nicht beispiellos. 1983 waren die Hells Angels in Hamburg, 2001 in Dusseldorf verboten worden.'

seq_length = 32

encoded = bpe.encode_ids_with_bos_eos(real)[:seq_length]

arr = torch.full((1, seq_length), bpe.vocab_size, dtype=torch.long)
arr[0, :len(encoded)] = torch.tensor(encoded, dtype=torch.long)

embed = embedding(arr.to(device)).permute(0, 2, 1)

In [87]:
global_score, local_scores = D(embed)

In [88]:
print(float(global_score))

-2.336874485015869


In [89]:
local_scores

tensor([-2.1231, -2.2638, -2.2687, -2.2630, -2.2526, -2.2495, -2.1947, -2.2178,
        -2.2776, -2.3176, -2.2964, -2.3010, -2.2828, -2.2802, -2.2911, -2.2720,
        -2.2768, -2.2793, -2.2767, -2.2966, -2.3110, -2.3035, -2.3186, -2.3262,
        -2.2898, -2.2735, -2.2961, -2.2937, -2.2874, -2.2879, -2.2511, -2.0822],
       device='cuda:1', grad_fn=<SqueezeBackward0>)

In [90]:
res = ''

threshold = - 0.0

for idx, score in zip(encoded, local_scores):
    score = int((1 - torch.sigmoid(score)) * 255)
    res += f'<span style="background-color: #{score:02X}0000; color: white">{html.escape(bpe.words[idx])}</span>'

display(HTML(res))

In [None]:
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
lines = Path('/home/kris/data/text/sent-grams/splits/uniform/bigrams/large/val.txt').read_text().split('\n')

data = torch.full((len(lines), seq_length), bpe.vocab_size, dtype=torch.long)

for i, encoded_sample in enumerate(bpe.encode_ids_with_bos_eos(lines)):
    l = min(seq_length, len(encoded_sample))
    data[i, :l] = torch.tensor(encoded_sample, dtype=torch.long)[:l]

vocab_size = bpe.vocab_size + 1

batches = DataLoader(data, 128)

In [None]:
real_global_scores, real_local_scores = [], []
fake_global_scores, fake_local_scores = [], []

with torch.no_grad():
    for batch in tqdm(batches):
        reals = batch.to(device)
        reals_embed = embedding(reals).transpose(1, 2)

        out = D(reals_embed)

        real_global_scores.append(out[0].mean().to('cpu'))
        real_local_scores.append(out[1].mean().to('cpu'))
        
        out = D(G(torch.randn(128, 128, 1).to(device)))

        fake_global_scores.append(out[0].mean().to('cpu'))
        fake_local_scores.append(out[1].mean().to('cpu'))
        
        del out
        del reals
        del reals_embed

In [None]:
print(torch.stack(real_global_scores).mean())
print(torch.stack(real_local_scores).mean())

print()

print(torch.stack(fake_global_scores).mean())
print(torch.stack(fake_local_scores).mean())

In [None]:
print(torch.stack(real_global_scores).std())
print(torch.stack(real_local_scores).std())

print()

print(torch.stack(fake_global_scores).std())
print(torch.stack(fake_local_scores).std())

In [None]:
display(HTML('<span style="background-color: #AA0000; color: white">some </span><span style="background-color: #110000; color: white">random</spam> <span style="background-color: #EE0000; color: white">text</span>'))