In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
from tqdm import tqdm
from transformers import logging
logging.set_verbosity_error()
from transformers import BertModel, BertTokenizer, BertForMaskedLM, pipeline
from sentence_transformers import SentenceTransformer, util

print(torch.cuda.is_available())
print(torch.__version__)

import os
os.chdir("..")

In [None]:
sentences = ["Dies ist ein Test, welcher ganz ok läuft.",
             "Also geht es jetzt los.", "'Mein Vater', sagt er."]
model_name = "deepset/gbert-base"
tokenizer = BertTokenizer.from_pretrained(model_name)

vocab_size = len(tokenizer.get_vocab())
d_model = 256
embedder = nn.Embedding(vocab_size, d_model)
tokens = torch.IntTensor(
    [tokenizer.encode(sent, padding="max_length") for sent in sentences])
max_sent_length = tokenizer.model_max_length

print(tokenizer.decode([0, 1, 100, 101, 102, 103, 104]))

tgt_mask = torch.triu(torch.ones(
    max_sent_length-1, max_sent_length-1) * float('-inf'), diagonal=1)

tgt = tokens[:, :-1].clone()
tgt[tgt == 103] = 0
tgt_key_padding_mask = (tgt == 0)
src_key_padding_mask = (tokens == 0)

tgt = embedder(tgt)
src = embedder(tokens)
tgt_true = embedder(tokens[:, 1:])

encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model, nhead=8, batch_first=True)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
decoder_layer = nn.TransformerDecoderLayer(
    d_model=d_model, nhead=8, batch_first=True)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)


In [None]:
memory = transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
memory_avg = torch.mean(memory, dim=1, keepdim=True)
memory_avg_repeat = torch.mean(memory, dim=1, keepdim=True).repeat(1,511,1)
memory_mask = torch.triu(torch.ones(
    max_sent_length, max_sent_length) * float('-inf'), diagonal=1)[:-1]
memory_avg_mask = torch.triu(torch.ones(
    max_sent_length, max_sent_length) * float('-inf'), diagonal=1)[:-1,:1]
memory_avg_repeat_mask = torch.triu(torch.ones(
    max_sent_length-1, max_sent_length-1) * float('-inf'), diagonal=1)

output = transformer_decoder(
    tgt, memory, memory_mask=memory_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
output_avg = transformer_decoder(
    tgt, memory_avg, memory_mask=memory_avg_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
output_avg_repeat = transformer_decoder(
    tgt, memory_avg_repeat, memory_mask=memory_avg_repeat_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)

print(f"src.shape: {src.shape}")
print(f"tgt.shape: {tgt.shape}")
print(f"memory.shape: {memory.shape}")
print(f"memory_avg.shape: {memory_avg.shape}")
print(f"output.shape: {output.shape}")
print(f"output_avg.shape: {output_avg.shape}")
print(f"output_avg_repeat.shape: {output_avg_repeat.shape}")
