In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import logging
logging.set_verbosity_error()
from transformers import BertTokenizer
import os
from pathlib import Path

# Switch to correct folder
if not "__path__" in locals():
    __path__ = Path().absolute()
    os.chdir("..")

from utils.datasets import DatasetWordPiece
from models.tvae_trainer import TVAETrainer
from models.tvae_model import TVAE

print(torch.cuda.is_available())
print(torch.__version__)

In [None]:
s = "Dies ist ein total ausgereifter Test, welchem bestimmt nichts kaputt geht."

model= "deepset/gbert-base"
tokenizer= BertTokenizer.from_pretrained(model)

encoding = tokenizer.encode(s)
tokens = tokenizer.tokenize(s)
print(encoding)
print(tokens)

In [None]:
path_models = [
    ("save/2023-04-05_SavedModels/checkpoints/2023-04-05_10:03:30_ModelGerman/2023-04-05_14:54:56_TVAE_RegTrue/model.pt", 3, DatasetWordPiece(large=False, max_length=128), 14779805221749554585, "German"),
    ("save/2023-04-05_SavedModels/checkpoints/2023-04-05_10:05:44_ModelWiki/2023-04-05_10:05:44_TVAE_RegTrue/model.pt", 3, DatasetWordPiece(large=True, max_length=128), 6003809420069737480, "Wikipedia")
]

In [None]:
# needed for processing batch data
fake_label = torch.IntTensor([[1]])
batch_size = 64

for path_model, nlayers, dataset, seed, name in path_models:
    model = TVAE(ntoken=dataset.vocab_size, nlayers=nlayers)
    model.load_state_dict(torch.load(path_model))
    # model.cuda()
    model.eval()
    trainer = TVAETrainer(dataset=dataset, model=model)

    generator = torch.Generator().manual_seed(seed)
    _, dataset_val = torch.utils.data.random_split(
        dataset, [0.8, 0.2], generator=generator
    )
    data_loader = DataLoader(dataset_val, batch_size=batch_size)

    # for num, batch in tqdm(enumerate(data_loader), leave=False, total=len(data_loader), desc=f"Sentences"):
    for num, batch in enumerate(data_loader):
        # t = torch.Tensor(dataset.encode())
        # t = t.view(1, -1).long()
        # batch = trainer.process_batch_data((t, fake_label))
        data = trainer.process_batch_data(batch)

        d = {}
        for i, k in enumerate(["src", "tgt", "tgt_true", "tgt_mask", "memory_mask", "src_key_padding_mask", "tgt_key_padding_mask", "labels"]):
            d[k] = data[i]


        src=d["src"]
        tgt=d["tgt"]
        tgt_mask=d["tgt_mask"]
        memory_mask=d["memory_mask"]
        src_key_padding_mask=d["src_key_padding_mask"]
        tgt_key_padding_mask=d["tgt_key_padding_mask"]

        z_dist = model.encode(src, src_key_padding_mask)
        z_tilde, z_prior, prior_dist = model.reparametrize(z_dist)
        
        logits = model.decode(
            z_tilde=z_tilde,
            tgt=tgt,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        out_tokens = torch.argmax(logits, dim=-1)
        out_tokens = [int(i) for i in list(out_tokens.data.to("cpu")[0])]
        true_tokens = list(d["tgt_true"][0].cpu().numpy())
        print("1", dataset.tokenizer.decode(out_tokens))
        print("2", dataset.tokenizer.decode(true_tokens))
        # print(acc(logits, d["tgt_true"]))
        # print(true_tokens)
        # print(out_tokens)
        if num>5:
            break


In [None]:
def compute_latent_interpolations(self, latent_code, dim=0, num_points=10):
    x = torch.linspace(-4.0, 4.0, num_points)
    z = to_cuda_variable(torch.from_numpy(latent_code))
    z = z.repeat(num_points, 1)
    z[:, dim] = x.contiguous()
    outputs = torch.sigmoid(self.model.decode(z))
    interp = make_grid(outputs.cpu(), nrow=num_points, pad_value=1.0)
    return interp
