In [11]:
import numpy as np
#from google.colab import drive
#drive.mount('/content/drive')
#!ls drive/'My Drive/YYY_deep_project_YYY'

In [12]:
from collections import defaultdict

def sequences_to_dicts(sequences):
    """
    Creates word_to_idx and idx_to_word dictionaries for a list of sequences.
    """
    # A bit of Python-magic to flatten a nested list
    flatten = lambda l: [item for sublist in l for item in sublist]

    # Flatten the dataset
    all_words = flatten(sequences)

    # Count number of word occurences
    word_count = defaultdict(int)
    for word in flatten(sequences):
        word_count[word] += 1

    # Sort by frequency
    word_count = sorted(list(word_count.items()), key=lambda l: -l[1])

    # Create a list of all unique words
    unique_words = [item[0] for item in word_count]

    # Add UNK token to list of words
    unique_words.append('UNK')

    # Count number of sequences and number of unique words
    num_sentences, vocab_size = len(sequences), len(unique_words)

    # Create dictionaries so that we can go from word to index and back
    # If a word is not in our vocabulary, we assign it to token 'UNK'
    word_to_idx = defaultdict(lambda: vocab_size-1)
    idx_to_word = defaultdict(lambda: 'UNK')

    # Fill dictionaries
    for idx, word in enumerate(unique_words):
        word_to_idx[word] = idx
        idx_to_word[idx] = word

    return word_to_idx, idx_to_word, num_sentences, vocab_size

In [13]:
def get_sequence(infile):

    while True:

        header = infile.readline()
        sequence = infile.readline()

        pdb = header[1:5]

        if not header or not sequence or set(sequence) == {'X'}:
            return
        
        yield header.strip()[1:], sequence.strip(), pdb

In [14]:
sequences = []
seq_to_pdb = {}
count = 0
with open('all_heavy.fasta') as infile:

        for header, sequence, pdb in get_sequence(infile):
            #if count < 500:
            sequences.append(list(sequence))
                #count += 1
            
            seq_to_pdb[sequence] = pdb

In [15]:
import pandas as pd

df = pd.read_csv('sabdab_summary_all-2.tsv', sep='\t')
df = df[['pdb', 'affinity']]
df

Unnamed: 0,pdb,affinity
0,5m2j,1.3000000000000002e-10
1,6fe4,9.6e-09
2,7jmo,
3,6ch9,
4,4o51,
...,...,...
8562,5ukq,
8563,6ejm,8.6e-10
8564,5bk0,
8565,3vi3,


In [16]:
from torch.utils import data

class Dataset(data.Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        # Return the size of the dataset
        return len(self.targets)

    def __getitem__(self, index):
        # Retrieve inputs and targets at the given index
        X = self.inputs[index]
        y = self.targets[index]

        return X, y

def create_datasets(sequences, dataset_class, p_train=0.8, p_val=0.1, p_test=0.1):
    
    #def seq_handler(sequences, p, start, end):
    #   num = int(len(sequences)*p)
    #   seq_part = sequences
    
    # Define partition sizes
    num_train = int(len(sequences)*p_train)
    num_val = int(len(sequences)*p_val)
    num_test = int(len(sequences)*p_test)

    # Split sequences into partitions
    sequences_train = sequences[:num_train-1]
    sequences_val = sequences[num_train:num_train+num_val-1]
    sequences_test = sequences[-num_test:-1]

    target_seqs_train = [seq[1:] for seq in sequences_train]
    target_seqs_val = [seq[1:] for seq in sequences_val]
    target_seqs_test = [seq[1:] for seq in sequences_test]


    input_train = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in sequences_train] for x in sublist]
    input_val = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in sequences_val] for x in sublist]
    input_test = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in sequences_test] for x in sublist]


    target_train = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in target_seqs_train] for x in sublist]
    target_val = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in target_seqs_val] for x in sublist]
    target_test = [x for sublist in [['<sos>'] + list(seq)+['<eos>'] for seq in target_seqs_test] for x in sublist]


    def get_inputs_targets_from_sequences(sequences):
        # Define empty lists
        inputs, targets = [], []
        
        # Append inputs and targets s.t. both lists contain L-1 words of a sentence of length L
        # but targets are shifted right by one so that we can predict the next word
        for sequence in sequences:
            inputs.append(sequence[:-1])
            targets.append(sequence[1:])

            # We want to predict e.g. last 1/3 of sequence based on first 2/3
            #l = len(sequence) - len(sequence[:int(3*len(sequence)/4)])
            #inputs.append(sequence[:int(3*len(sequence)/4)])
            #targets.append(sequence[l:])
            
        return inputs, targets

    # Get inputs and targets for each partition
    inputs_train, targets_train = get_inputs_targets_from_sequences(sequences_train)
    inputs_val, targets_val = get_inputs_targets_from_sequences(sequences_val)
    inputs_test, targets_test = get_inputs_targets_from_sequences(sequences_test)

    # Create datasets
    training_set = dataset_class(inputs_train, targets_train)
    validation_set = dataset_class(inputs_val, targets_val)
    test_set = dataset_class(inputs_test, targets_test)

    #return training_set, validation_set, test_set

    return input_train, input_val, input_test, target_train, target_val, target_test, training_set, validation_set, test_set
    

input_train, input_val, input_test, target_train, target_val, target_test, training_set, validation_set, test_set = create_datasets(sequences, Dataset)


In [17]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

In [18]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

<torchtext.data.field.Field at 0x1d525b286d0>

In [22]:
import torchtext
from torchtext.data.utils import get_tokenizer
#TEXT = torchtext.data.Field(tokenize=get_tokenizer("spacy", "en"),
#                           init_token='<sos>',
#                            eos_token='<eos>',
#                            lower=True)
TEXT = torchtext.data.Field(tokenize=get_tokenizer("spacy", "en_core_web_lg"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
#train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
train_txt, val_txt, test_txt = training_set, validation_set, test_set

TEXT.build_vocab(train_txt)
print(TEXT.vocab.stoi)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    #print(data.examples[0].text)
    data = TEXT.numericalize([data])
    
    #data = TEXT.numericalize([data[0].text])
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(input_train, batch_size)
val_data = batchify(input_val, eval_batch_size)
test_data = batchify(input_test, eval_batch_size)

#print(train_txt)
print(training_set)
train_data.shape
test_data.shape

defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x000001D502CCBEE0>>, {'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3, 'S': 4, 'T': 5, 'G': 6, 'V': 7, 'L': 8, 'A': 9, 'P': 10, 'K': 11, 'Y': 12, 'Q': 13, 'D': 14, 'N': 15, 'E': 16, 'F': 17, 'R': 18, 'W': 19, 'I': 20, 'C': 21, 'H': 22, 'M': 23, 'X': 24})
<__main__.Dataset object at 0x000001D523B9FA90>


torch.Size([8862, 10])

In [38]:
bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [43]:
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value

'cuda'

In [48]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
lr = 5.5 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

model.to(device)

import time
def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    ntokens = len(TEXT.vocab.stoi)
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)

        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    ntokens = len(TEXT.vocab.stoi)
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


In [49]:
best_val_loss = float("inf")
epochs = 50 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 1002 batches | lr 5.50 | ms/batch  8.00 | loss  3.72 | ppl    41.33
| epoch   1 |   400/ 1002 batches | lr 5.50 | ms/batch  7.62 | loss  2.49 | ppl    12.03
| epoch   1 |   600/ 1002 batches | lr 5.50 | ms/batch  7.69 | loss  2.37 | ppl    10.67
| epoch   1 |   800/ 1002 batches | lr 5.50 | ms/batch  7.68 | loss  2.31 | ppl    10.09
| epoch   1 |  1000/ 1002 batches | lr 5.50 | ms/batch  7.66 | loss  2.30 | ppl     9.93
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  8.25s | valid loss  2.11 | valid ppl     8.27
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1002 batches | lr 4.96 | ms/batch  7.86 | loss  2.25 | ppl     9.47
| epoch   2 |   400/ 1002 batches | lr 4.96 | ms/batch  7.71 | loss  2.18 | ppl     8.88
| epoch   2 |   600/ 1002 batches | lr 4.96 | ms/batch  7.74 | loss  2.16 | ppl     8.70
| epoch   2 |   800/ 1002 batches 

| epoch  13 |   200/ 1002 batches | lr 2.82 | ms/batch  7.96 | loss  1.85 | ppl     6.35
| epoch  13 |   400/ 1002 batches | lr 2.82 | ms/batch  7.76 | loss  1.81 | ppl     6.13
| epoch  13 |   600/ 1002 batches | lr 2.82 | ms/batch  7.69 | loss  1.81 | ppl     6.12
| epoch  13 |   800/ 1002 batches | lr 2.82 | ms/batch  7.73 | loss  1.80 | ppl     6.06
| epoch  13 |  1000/ 1002 batches | lr 2.82 | ms/batch  7.71 | loss  1.82 | ppl     6.19
-----------------------------------------------------------------------------------------
| end of epoch  13 | time:  8.29s | valid loss  1.67 | valid ppl     5.31
-----------------------------------------------------------------------------------------
| epoch  14 |   200/ 1002 batches | lr 2.68 | ms/batch  7.85 | loss  1.83 | ppl     6.21
| epoch  14 |   400/ 1002 batches | lr 2.68 | ms/batch  7.66 | loss  1.79 | ppl     6.01
| epoch  14 |   600/ 1002 batches | lr 2.68 | ms/batch  7.64 | loss  1.79 | ppl     6.01
| epoch  14 |   800/ 1002 batches 

| epoch  25 |   200/ 1002 batches | lr 1.53 | ms/batch  7.76 | loss  1.69 | ppl     5.40
| epoch  25 |   400/ 1002 batches | lr 1.53 | ms/batch  7.81 | loss  1.65 | ppl     5.23
| epoch  25 |   600/ 1002 batches | lr 1.53 | ms/batch  8.00 | loss  1.66 | ppl     5.26
| epoch  25 |   800/ 1002 batches | lr 1.53 | ms/batch  7.78 | loss  1.64 | ppl     5.17
| epoch  25 |  1000/ 1002 batches | lr 1.53 | ms/batch  7.80 | loss  1.67 | ppl     5.32
-----------------------------------------------------------------------------------------
| end of epoch  25 | time:  8.37s | valid loss  1.54 | valid ppl     4.65
-----------------------------------------------------------------------------------------
| epoch  26 |   200/ 1002 batches | lr 1.45 | ms/batch  7.93 | loss  1.68 | ppl     5.37
| epoch  26 |   400/ 1002 batches | lr 1.45 | ms/batch  7.84 | loss  1.65 | ppl     5.19
| epoch  26 |   600/ 1002 batches | lr 1.45 | ms/batch  7.79 | loss  1.65 | ppl     5.20
| epoch  26 |   800/ 1002 batches 

| epoch  37 |   200/ 1002 batches | lr 0.82 | ms/batch  7.84 | loss  1.61 | ppl     5.00
| epoch  37 |   400/ 1002 batches | lr 0.82 | ms/batch  8.12 | loss  1.57 | ppl     4.83
| epoch  37 |   600/ 1002 batches | lr 0.82 | ms/batch  7.93 | loss  1.58 | ppl     4.87
| epoch  37 |   800/ 1002 batches | lr 0.82 | ms/batch  7.79 | loss  1.56 | ppl     4.76
| epoch  37 |  1000/ 1002 batches | lr 0.82 | ms/batch  7.95 | loss  1.60 | ppl     4.95
-----------------------------------------------------------------------------------------
| end of epoch  37 | time:  8.46s | valid loss  1.48 | valid ppl     4.40
-----------------------------------------------------------------------------------------
| epoch  38 |   200/ 1002 batches | lr 0.78 | ms/batch  7.90 | loss  1.60 | ppl     4.97
| epoch  38 |   400/ 1002 batches | lr 0.78 | ms/batch  7.79 | loss  1.58 | ppl     4.83
| epoch  38 |   600/ 1002 batches | lr 0.78 | ms/batch  7.70 | loss  1.58 | ppl     4.86
| epoch  38 |   800/ 1002 batches 

| epoch  49 |   200/ 1002 batches | lr 0.45 | ms/batch  8.60 | loss  1.56 | ppl     4.77
| epoch  49 |   400/ 1002 batches | lr 0.45 | ms/batch  8.17 | loss  1.53 | ppl     4.62
| epoch  49 |   600/ 1002 batches | lr 0.45 | ms/batch  8.07 | loss  1.54 | ppl     4.67
| epoch  49 |   800/ 1002 batches | lr 0.45 | ms/batch  8.16 | loss  1.52 | ppl     4.57
| epoch  49 |  1000/ 1002 batches | lr 0.45 | ms/batch  8.11 | loss  1.56 | ppl     4.74
-----------------------------------------------------------------------------------------
| end of epoch  49 | time:  8.79s | valid loss  1.46 | valid ppl     4.31
-----------------------------------------------------------------------------------------
| epoch  50 |   200/ 1002 batches | lr 0.42 | ms/batch  8.22 | loss  1.56 | ppl     4.77
| epoch  50 |   400/ 1002 batches | lr 0.42 | ms/batch  8.09 | loss  1.53 | ppl     4.62
| epoch  50 |   600/ 1002 batches | lr 0.42 | ms/batch  8.08 | loss  1.54 | ppl     4.66
| epoch  50 |   800/ 1002 batches 

In [53]:
#orch.save(model,'C:/Users/jonas/Desktop/deep_project/models/model.py')
# Model class must be defined somewhere
#model = torch.load(PATH)
#model.eval()

In [54]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  1.33 | test ppl     3.79


In [92]:
dict(TEXT.vocab.stoi)

{'<unk>': 0,
 '<pad>': 1,
 '<sos>': 2,
 '<eos>': 3,
 'S': 4,
 'T': 5,
 'G': 6,
 'V': 7,
 'L': 8,
 'A': 9,
 'P': 10,
 'K': 11,
 'Y': 12,
 'Q': 13,
 'D': 14,
 'N': 15,
 'E': 16,
 'F': 17,
 'R': 18,
 'W': 19,
 'I': 20,
 'C': 21,
 'H': 22,
 'M': 23,
 'X': 24,
 1: 0,
 0: 0}

In [94]:
idx_to_letter = {value:key for key,value in dict(TEXT.vocab.stoi).items() if isinstance(key, str) }
idx_to_letter

{0: '<unk>',
 1: '<pad>',
 2: '<sos>',
 3: '<eos>',
 4: 'S',
 5: 'T',
 6: 'G',
 7: 'V',
 8: 'L',
 9: 'A',
 10: 'P',
 11: 'K',
 12: 'Y',
 13: 'Q',
 14: 'D',
 15: 'N',
 16: 'E',
 17: 'F',
 18: 'R',
 19: 'W',
 20: 'I',
 21: 'C',
 22: 'H',
 23: 'M',
 24: 'X'}

In [173]:
def sample_categorical(lnprobs, temperature=1.0):
    """
    Sample an element from a categorical distribution
    :param lnprobs: Outcome log-probabilities
    :param temperature: Sampling temperature. 1.0 follows the given distribution,
        0.0 returns the maximum probability element.
    :return: The index of the sampled element.
    """

    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=1)
    
    
    print("softmaxed probs:", p)
    return dist.Categorical(p).sample()

def sample_sentence(model, query, max_len = 140, temperature=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    while len(query) < max_len:
        query_tensor = batchify(query, len(query))
        src_mask = model.generate_square_subsequent_mask(len(query_tensor)).to(device)
        output = model(query_tensor, src_mask)#.view(-1, ntokens)
        #output, _     = model(query_.unsqueeze(0).to(device))
        
        print("model output:", output)
        
        #next_char_idx = sample_categorical(output[0, :, len(query)-1], temperature) #0.5
        next_char_idx = sample_categorical(output, temperature) #0.5
        #print(output[0, :, len(query_tensor) - 1])
        #print(output)           #print(next_char_idx)

        #if next_char_idx <= 1:
            # query += "*"
        #    break

        #query += [str(chr(int(next_char_idx[0][0])))]

        query += [idx_to_letter[int(next_char_idx[0][-1])]]
        
    
    return query


In [174]:
import torch.distributions as dist

#dat = batchify(['A'], 1)

#src_mask = model.generate_square_subsequent_mask(dat.size(0)).to(device)
#output = model(dat, src_mask)
#sample_categorical(output,temperature=0.5)

sample_sentence(model, ['<sos>'], max_len = 5, temperature=0.5)


model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306]]], device='cuda:0',
       grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1.]]], device='cuda:0',
       grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.7270, -6.7679, -7.3273,  5.0816,  0.0269,  4.2982,  2.3647,
          -0.6049,  0.6449,  0.1660,  3.4140,  2.9475,  1.4805, -0.9442,
           1.8192,  0.6009,  1.6098,  1.7704, -0.2196

['<sos>', 'H', '<unk>', 'N', 'N']

In [176]:


ntokens = len(TEXT.vocab.stoi)

#dat = batchify(['<sos>'], 1)

sample = sample_sentence(model, ['<sos>'], max_len = 100, temperature=0.5)
" ".join(sample)

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306]]], device='cuda:0',
       grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1.]]], device='cuda:0',
       grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805

       device='cuda:0', grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842, 

       device='cuda:0', grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842, 

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

softmaxed probs: tensor([[[1.8874e-02, 2.3299e-02, 4.4886e-18, 5.9580e-05, 8.3984e-03,
          1.4234e-03, 8.0982e-04, 4.2050e-02, 5.8948e-02, 6.7180e-03,
          1.3863e-03, 3.4418e-02, 3.3871e-05, 4.9682e-01, 2.2299e-01,
          7.8223e-04, 4.9195e-01, 4.6376e-04, 6.3114e-02, 1.8777e-04,
          5.3794e-03, 1.7012e-06, 3.3950e-04, 4.3385e-01, 4.9306e-02],
         [4.2663e-02, 3.8795e-02, 4.3592e-17, 1.4573e-02, 4.7248e-01,
          1.5553e-02, 1.6803e-02, 1.0264e-02, 5.4890e-01, 1.1370e-02,
          1.7456e-01, 7.7390e-03, 2.9726e-05, 8.5170e-04, 2.6994e-03,
          2.1727e-03, 1.2138e-02, 5.0846e-04, 2.2737e-02, 2.1644e-04,
          1.8469e-03, 1.1712e-06, 3.4500e-04, 7.9771e-03, 3.4878e-02],
         [2.7158e-02, 1.8489e-02, 1.1078e-16, 1.6136e-04, 2.5711e-02,
          4.5486e-02, 1.7852e-03, 2.5927e-03, 1.4491e-02, 2.9273e-03,
          1.9415e-02, 1.5335e-02, 6.6588e-02, 2.0387e-04, 2.1687e-03,
          2.3219e-02, 2.8698e-04, 7.1515e-03, 1.6610e-02, 6.7286e-02,
 

       device='cuda:0', grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842, 

softmaxed probs: tensor([[[1.6242e-02, 2.0355e-02, 4.4886e-18, 4.6638e-05, 7.2867e-03,
          1.2832e-03, 7.7200e-04, 4.0014e-02, 5.6635e-02, 6.5767e-03,
          1.2466e-04, 3.1754e-02, 3.3087e-05, 4.9624e-01, 1.7505e-01,
          7.6309e-04, 4.8724e-01, 4.2652e-04, 4.7081e-02, 1.8384e-04,
          5.2421e-03, 1.6979e-06, 3.3799e-04, 4.3229e-01, 2.2560e-02],
         [3.6714e-02, 3.3892e-02, 4.3592e-17, 1.1407e-02, 4.0993e-01,
          1.4021e-02, 1.6019e-02, 9.7667e-03, 5.2736e-01, 1.1130e-02,
          1.5696e-02, 7.1401e-03, 2.9037e-05, 8.5070e-04, 2.1191e-03,
          2.1196e-03, 1.2022e-02, 4.6764e-04, 1.6961e-02, 2.1192e-04,
          1.7997e-03, 1.1689e-06, 3.4347e-04, 7.9484e-03, 1.5958e-02],
         [2.3371e-02, 1.6152e-02, 1.1078e-16, 1.2631e-04, 2.2307e-02,
          4.1004e-02, 1.7018e-03, 2.4672e-03, 1.3922e-02, 2.8658e-03,
          1.7458e-03, 1.4149e-02, 6.5047e-02, 2.0363e-04, 1.7025e-03,
          2.2651e-02, 2.8423e-04, 6.5773e-03, 1.2391e-02, 6.5880e-02,
 

softmaxed probs: tensor([[[1.5543e-02, 1.9668e-02, 4.4886e-18, 4.6379e-05, 6.9861e-03,
          1.2001e-03, 7.6828e-04, 3.8706e-02, 5.5334e-02, 6.4276e-03,
          6.5421e-05, 2.9401e-02, 3.2222e-05, 4.9622e-01, 1.6616e-01,
          7.5172e-04, 4.8667e-01, 4.1368e-04, 4.6646e-02, 1.7781e-04,
          5.1353e-03, 1.6897e-06, 3.3770e-04, 4.3111e-01, 2.2326e-02],
         [3.5134e-02, 3.2749e-02, 4.3592e-17, 1.1344e-02, 3.9302e-01,
          1.3113e-02, 1.5941e-02, 9.4475e-03, 5.1525e-01, 1.0878e-02,
          8.2373e-03, 6.6110e-03, 2.8279e-05, 8.5067e-04, 2.0115e-03,
          2.0880e-03, 1.2008e-02, 4.5356e-04, 1.6805e-02, 2.0496e-04,
          1.7631e-03, 1.1633e-06, 3.4317e-04, 7.9267e-03, 1.5793e-02],
         [2.2365e-02, 1.5607e-02, 1.1078e-16, 1.2561e-04, 2.1387e-02,
          3.8350e-02, 1.6936e-03, 2.3865e-03, 1.3602e-02, 2.8008e-03,
          9.1618e-04, 1.3100e-02, 6.3348e-02, 2.0363e-04, 1.6161e-03,
          2.2313e-02, 2.8390e-04, 6.3793e-03, 1.2276e-02, 6.3717e-02,
 

       grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[1.4736e-02, 1.8705e-02, 4.4886e-18, 4.3121e-05, 6.6199e-03,
          1.1641e-03, 7.4963e-04, 3.7359e-02, 4.8169e-02, 6.3711e-03,
          6.5275e-05, 2.9106e-02, 3.1918e-05, 4.9596e-01, 1.4380e-01,
          7.4189e-04, 4.8620e-01, 3.8621e-04, 4.5521e-02, 1.7470e-04,
          5.0738e-03, 1.6881e-06, 3.3625e-04, 4.3013e-01, 2.2081e-02],
         [3.3311e-02, 3.1145e-02, 4.3592e-17, 1.0547e-02, 3.7242e-01,
          1.2720e-02, 1.5554e-02, 9.1187e-03, 4.4853e-01, 1.0782e-02,
          8.2190e-03, 6.5445e-03, 2.8012e-05, 8.5021e-04, 1.7408e-03,
          2.0607e-03, 1.1996e-02, 4.2344e-04, 1.6399e-02, 2.0138e-04,
          1.7420e-03, 1.1622e-06, 3.4170e-04, 7.9088e-03, 1.5619e-02],
         [2.1204e-02, 1.4843e-02, 1.1078e-16, 1.1679e-04, 2.0266e-02,
          3.7201e-02, 1.6525e-03, 2.3035e-03, 1.1841e-02, 2.7762e-03,
          9.1414e-04, 1.2969e-02, 6.2749e-02, 2.0352e-04, 1.3986e-03,
          2.2022e-02, 2.8363e-04, 5.9557

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

       grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[1.3036e-02, 1.6639e-02, 4.4886e-18, 3.4698e-05, 5.8568e-03,
          8.8983e-04, 7.2017e-04, 3.6621e-02, 4.6688e-02, 6.2555e-03,
          4.4279e-05, 2.7670e-02, 3.1255e-05, 4.9594e-01, 1.3554e-01,
          7.3105e-04, 4.8279e-01, 3.7940e-04, 4.3791e-02, 1.7241e-04,
          4.9937e-03, 1.6835e-06, 3.3544e-04, 4.2912e-01, 1.4364e-02],
         [2.9467e-02, 2.7705e-02, 4.3592e-17, 8.4872e-03, 3.2949e-01,
          9.7229e-03, 1.4943e-02, 8.9385e-03, 4.3475e-01, 1.0587e-02,
          5.5753e-03, 6.2217e-03, 2.7431e-05, 8.5018e-04, 1.6408e-03,
          2.0306e-03, 1.1912e-02, 4.1598e-04, 1.5776e-02, 1.9874e-04,
          1.7145e-03, 1.1590e-06, 3.4087e-04, 7.8902e-03, 1.0160e-02],
         [1.8758e-02, 1.3203e-02, 1.1078e-16, 9.3975e-05, 1.7930e-02,
          2.8435e-02, 1.5876e-03, 2.2580e-03, 1.1477e-02, 2.7258e-03,
          6.2010e-04, 1.2329e-02, 6.1447e-02, 2.0351e-04, 1.3183e-03,
          2.1700e-02, 2.8164e-04, 5.8507

       device='cuda:0', grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842, 

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

model output: tensor([[[-6.6263, -6.5411, -7.3594, -3.5309,  2.2237,  0.8933,  1.5373,
           2.7535,  2.6696,  1.9770,  0.1684,  1.6926, -1.0077,  5.9742,
           2.8353,  0.0435,  4.9932, -0.5706,  1.4910, -0.8572,  0.8135,
          -2.4681, -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227, -0.7811,  4.2387,  2.0889,  3.0535,
           2.0483,  3.7853,  2.2401,  2.5862,  0.9464, -1.0730,  2.7899,
           0.6282,  0.5543,  3.1422, -0.5246,  0.9805, -0.7861,  0.2789,
          -2.6547, -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.7973,  0.8235,  2.0836,  0.4974,
          -1.6606,  3.0138,  0.5147, -6.2679],
         [-6.4443, -6.6568, -5.7564, -3.0327,  2.7832,  2.6255,  1.9325,
           1.3604,  1.9681,  1.5617,  1.4881,  1.2884,  2.7842,  2.0750,
           0.5188,  1.7388,  1.2699,  0.79

softmaxed probs: tensor([[[9.8210e-03, 1.2060e-02, 2.2443e-18, 3.4619e-05, 5.5854e-03,
          8.2105e-04, 3.9066e-04, 2.2620e-02, 4.4501e-02, 6.1629e-03,
          4.4043e-05, 2.4250e-02, 2.7046e-05, 4.9355e-01, 1.2718e-01,
          5.9017e-04, 4.8221e-01, 2.5437e-04, 4.0646e-02, 1.2409e-04,
          2.7015e-03, 8.5338e-07, 3.1594e-04, 3.9115e-01, 1.2403e-02],
         [2.2200e-02, 2.0081e-02, 2.1796e-17, 8.4678e-03, 3.1422e-01,
          8.9713e-03, 8.1060e-03, 5.5210e-03, 4.1438e-01, 1.0430e-02,
          5.5456e-03, 5.4527e-03, 2.3736e-05, 8.4609e-04, 1.5396e-03,
          1.6393e-03, 1.1898e-02, 2.7889e-04, 1.4643e-02, 1.4304e-04,
          9.2749e-04, 5.8752e-07, 3.2106e-04, 7.1920e-03, 8.7733e-03],
         [1.4132e-02, 9.5702e-03, 5.5388e-17, 9.3760e-05, 1.7099e-02,
          2.6237e-02, 8.6120e-04, 1.3947e-03, 1.0940e-02, 2.6854e-03,
          6.1680e-04, 1.0805e-02, 5.3171e-02, 2.0253e-04, 1.2369e-03,
          1.7518e-02, 2.8130e-04, 3.9225e-03, 1.0697e-02, 4.4467e-02,
 

       grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594,  ..., -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227,  ..., -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564,  ...,  3.0138,  0.5147, -6.2679],
         ...,
         [-6.2008, -6.3594, -6.5224,  ..., -1.3609, -0.2433, -4.5354],
         [-6.5779, -6.6688, -7.0542,  ..., -0.6884,  0.1881, -6.2113],
         [-6.3466, -6.4176, -7.1148,  ..., -0.5721, -0.2832, -5.0143]]],
       device='cuda:0', grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[6.8696e-03, 8.3704e-03, 1.4962e-18,  ..., 2.4824e-06,
          2.7744e-01, 8.3930e-03],
         [1.5528e-02, 1.3937e-02, 1.4531e-17,  ..., 2.5226e-06,
          5.1012e-03, 5.9369e-03],
         [9.8848e-03, 6.6421e-03, 3.6925e-17,  ..., 2.3849e-03,
          1.6224e-03, 3.8782e-04],
         ...,
         [1.6088e-02, 1.2038e-02, 7.9806e-18,  ..., 3.7812e-07,
          3.5623e-04, 1.2400e-02],
         [7.5675e-03, 6.4833e-03, 2

       device='cuda:0', grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[5.9242e-03, 7.2946e-03, 1.4962e-18,  ..., 2.4705e-06,
          2.6490e-01, 5.9497e-03],
         [1.3391e-02, 1.2146e-02, 1.4531e-17,  ..., 2.5105e-06,
          4.8706e-03, 4.2087e-03],
         [8.5244e-03, 5.7885e-03, 3.6925e-17,  ..., 2.3734e-03,
          1.5490e-03, 2.7492e-04],
         ...,
         [5.1494e-03, 5.7224e-03, 1.2111e-17,  ..., 6.9185e-06,
          9.7147e-03, 1.1100e-03],
         [9.0423e-03, 7.1080e-03, 1.4371e-17,  ..., 4.6923e-06,
          1.3694e-03, 8.2281e-04],
         [2.3113e-02, 2.4222e-02, 1.0842e-17,  ..., 5.1191e-07,
          2.2725e-05, 1.3305e-01]]], device='cuda:0',
       grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594,  ..., -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227,  ..., -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564,  ...,  3.0138,  0.5147, -6.2679],
         ...,
         [-6.4148, -6.5541, -6.22

model output: tensor([[[-6.6263, -6.5411, -7.3594,  ..., -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227,  ..., -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564,  ...,  3.0138,  0.5147, -6.2679],
         ...,
         [-6.7289, -6.9120, -6.9628,  ...,  0.7761, -0.0783, -6.0876],
         [-6.3712, -6.4428, -7.7689,  ...,  0.3236, -0.3293, -5.3725],
         [-6.5172, -6.6919, -7.0470,  ..., -0.6096, -0.2971, -5.2906]]],
       device='cuda:0', grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[5.0567e-03, 6.3174e-03, 1.4962e-18,  ..., 2.4702e-06,
          2.6091e-01, 4.5504e-03],
         [1.1430e-02, 1.0519e-02, 1.4531e-17,  ..., 2.5103e-06,
          4.7973e-03, 3.2188e-03],
         [7.2761e-03, 5.0131e-03, 3.6925e-17,  ..., 2.3732e-03,
          1.5258e-03, 2.1027e-04],
         ...,
         [4.1187e-03, 3.0087e-03, 3.3073e-18,  ..., 2.7024e-05,
          4.6603e-04, 3.0158e-04],
         [8.4218e-03, 7.6903e-03, 6.5959e-19,  ..., 1.0931e-05,
     

       grad_fn=<SoftmaxBackward>)
model output: tensor([[[-6.6263, -6.5411, -7.3594,  ..., -0.4201,  3.0856, -4.7306],
         [-6.2185, -6.2862, -6.2227,  ..., -0.4120,  1.0875, -4.9037],
         [-6.4443, -6.6568, -5.7564,  ...,  3.0138,  0.5147, -6.2679],
         ...,
         [-6.5779, -6.6688, -7.0542,  ..., -0.6884,  0.1881, -6.2113],
         [-6.7289, -6.9120, -6.9628,  ...,  0.7761, -0.0783, -6.0876],
         [-6.3712, -6.4428, -7.7689,  ...,  0.3236, -0.3293, -5.3725]]],
       device='cuda:0', grad_fn=<AddBackward0>)
softmaxed probs: tensor([[[4.4772e-03, 5.6434e-03, 1.4962e-18,  ..., 2.4688e-06,
          2.4316e-01, 4.0571e-03],
         [1.0120e-02, 9.3967e-03, 1.4531e-17,  ..., 2.5088e-06,
          4.4710e-03, 2.8698e-03],
         [6.4423e-03, 4.4782e-03, 3.6925e-17,  ..., 2.3718e-03,
          1.4220e-03, 1.8747e-04],
         ...,
         [4.9320e-03, 4.3711e-03, 2.7549e-18,  ..., 1.4435e-06,
          7.4001e-04, 2.0994e-04],
         [3.6467e-03, 2.6877e-03, 3

'<sos> Q N N <unk> W I D T X <eos> <sos> E I N A A Y C A <unk> T <unk> <pad> R F P L T F S R F P X <eos> V <pad> Y W N H <eos> <sos> M D K L Q E P S L K <unk> N Y <unk> F P X Y C N Y Y F P <pad> Y F P X <pad> F P X T K R D T V W G R D Y D Y F P S L K R D Y I G'