In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import tqdm
import random
import torch.nn.functional as F

# Cross-attention implementation

In [154]:
data_train = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
data_train_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
data_test = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])
data_test_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])

def one_hot_encoding(seq, vocab):
    if isinstance(seq, int):
        one_hot = torch.zeros(1, len(vocab), dtype=int)
        one_hot[:, vocab[seq]] = 1
    else:
        one_hot = torch.zeros((len(seq), len(vocab)), dtype=int)
        for i, spin in enumerate(seq):
            one_hot[i, vocab[spin]] = 1
    return one_hot

def mask_random_spin(sequence1, sequence2, vocab, mask_token=2):
    """
    Mask one random spin in a sequence of protein spins.
    
    Parameters:
    - sequence: a list or sequence of spins (integers)
    - mask_token: the token used to mask a spin (default is 2)
    
    Returns:
    - masked_sequence: a sequence similar to the input but with one spin masked
    - masked_position: the position of the spin that was masked
    """
    # Ensure the sequence can be converted to a list for masking
    sequence_list_1 = sequence1.numpy().tolist() if isinstance(sequence1, torch.Tensor) else list(sequence1)
    sequence_list_2 = sequence2.numpy().tolist() if isinstance(sequence2, torch.Tensor) else list(sequence1)
    # Choose a random position to mask, excluding the first spin
    mask_position = random.randint(1, len(sequence_list_1) - 1)
    
    # Mask the chosen position
    masked_sequence_1 = sequence_list_1.copy()
    masked_sequence_1[mask_position] = mask_token

    masked_sequence_2 = sequence_list_2.copy()
    masked_sequence_2[mask_position] = mask_token

    # Create an array of zeros with shape (len(sequence), len(vocab))
    one_hot_1 = one_hot_encoding(masked_sequence_1, vocab)
    one_hot_2 = one_hot_encoding(masked_sequence_2, vocab)
        
    return torch.tensor(one_hot_1, dtype=torch.float), torch.tensor(one_hot_2, dtype=torch.float), torch.tensor(mask_position)
    

In [155]:
class VanillaAttention(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(VanillaAttention, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a  # parameter controlling how important positions are
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)  # output layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, s):
        #if not isinstance(s, torch.Tensor):
        #s = torch.tensor(s, dtype=torch.float)
        #masked_token = torch.tensor(masked_token, dtype=torch.float)

        position_ids = torch.arange(s.size(0), dtype=torch.long)
        #x = self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)
        x = s + self.a*self.position_embeddings(position_ids)
        
        query = self.query_weight(x)
        key = self.key_weight(x)
        values = self.value_weight(x)
        
        # Simple attention score calculation (Dot product): this is equivalent to the interaction matrix
        scores = torch.matmul(query, key.transpose(-2, -1))  # Transpose last two dimensions for matrix multiplication
        scores = torch.softmax(scores, dim=-1)  # Apply softmax to scores to get probabilities

        # Apply attention scores to values
        attn_output = torch.matmul(scores, values)

        # Sum over the sequence length dimensions
        #attn_output = attn_output.sum(dim=1)
        output = self.fc(self.dropout(attn_output)) # should have size (20,3)

        return output
    
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(CrossAttention, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a  # parameter controlling how important positions are
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)  # output layer
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, s, enc_output):
        #s = torch.tensor(s, dtype=torch.float)
        position_ids = torch.arange(s.size(0), dtype=torch.long)

        #x = self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)
        x = s + self.a*self.position_embeddings(position_ids)

        query = self.query_weight(x)
        key = self.key_weight(enc_output)
        values = self.value_weight(enc_output)
        
        # Simple attention score calculation (Dot product): this is equivalent to the interaction matrix
        scores = torch.matmul(query, key.transpose(-2, -1))  # Transpose last two dimensions for matrix multiplication
        scores = torch.softmax(scores, dim=-1)  # Apply softmax to scores to get probabilities

        # Apply attention scores to values
        attn_output = torch.matmul(scores, values)

        # Sum over the sequence length dimensions
        #attn_output = attn_output.sum(dim=1)
        output = self.fc(self.dropout(attn_output))
        return output
    

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


In [156]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        attn_output = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)# masking one of the word in the sequence
        self.cross_attn = CrossAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)# queries come from the previous layer of the decoder while the keys and values come from the encoder. Used in translation
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output):
        attn_output = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        x = x.sum(dim=0)
        return x

In [157]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(Transformer, self).__init__()
        #self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        #self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.word_embeddings = nn.Linear(num_spins, embed_dim)

        #self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        #self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_layer = EncoderLayer(embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout)
        self.decoder_layer = DecoderLayer(embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout)
        self.fc = nn.Linear(embed_dim, num_spins)

    def forward(self, src, tgt):
        #src_mask, tgt_mask = self.generate_mask(src, tgt)
        #src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        #tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        src_embedded = self.word_embeddings(src)
        tgt_embedded = self.word_embeddings(tgt)
        enc_output = self.encoder_layer(src_embedded)
        dec_output = self.decoder_layer(tgt_embedded, enc_output)
        output = self.fc(dec_output)
        return output

In [161]:
def evaluate(model, data_test, data_test_dec, vocab, criterion, device=0):
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm.tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq_enc = data
        input_seq_dec = data_test_dec[i]
        #input_one_hot = one_hot_encoding(input_seq.tolist(), vocab)
        input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)
        # mask a token
        masked_sequence_enc, masked_sequence_dec, position = mask_random_spin(input_seq_enc, input_seq_dec, vocab, mask_token=2)
        # Forward pass
        outputs = model.forward(masked_sequence_enc, masked_sequence_dec)

        #output_token = F.log_softmax(outputs, dim=-1)
        target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #target_token = input_seq[position]

        # Compute loss
        loss = criterion(outputs.unsqueeze(0), target_token)
        epoch_loss += loss.item()
    #model.train()
    return epoch_loss / len(data_test)

def train(model, data_train, data_train_dec, data_test, data_test_dec, vocab, optimizer, criterion, num_epochs=8, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq_enc = data
            input_seq_dec = data_train_dec[i]
            #input_encoder_one_hot = one_hot_encoding(input_seq_enc.tolist(), vocab)
            input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)

            # mask a token
            masked_sequence_enc, masked_sequence_dec, position = mask_random_spin(input_seq_enc, input_seq_dec, vocab, mask_token=2)

            # Forward pass
            prediction = model.forward(masked_sequence_enc, masked_sequence_dec) #masked_sequence[masked_position]
            
            #predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #input_seq[masked_position]
            # Compute loss
            loss = criterion(prediction.unsqueeze(0), target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        eval_loss = evaluate(model, data_test, data_test_dec, vocab, criterion, device=device)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)

In [162]:
# Define the parameters 
vocab_size = 3
vocab = {-1:0,1:1,2:2} 
L = 20
embedding_dim = 20
proj_layer_dim = 128
hidden_dim = 20
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [163]:
# Example usage:
model = Transformer(embed_dim=embedding_dim, a=0, max_seq_length=L, num_spins=3, proj_layer_dim=128, dropout=dropout_rate)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, data_train, data_train_dec, data_test, data_test_dec, vocab, optimizer, criterion, device=device)
#torch.save(model.state_dict(), 'models/lstm_scratch.pt')
#evaluate(model, test_dataloader, criterion, device=device)

  return torch.tensor(one_hot_1, dtype=torch.float), torch.tensor(one_hot_2, dtype=torch.float), torch.tensor(mask_position)


[1,    10] loss: 0.510
[1,    20] loss: 0.085
[1,    30] loss: 0.207
[1,    40] loss: 0.080
[1,    50] loss: 0.110
[1,    60] loss: 0.110
[1,    70] loss: 0.161
[1,    80] loss: 0.093
[1,    90] loss: 0.098
[1,   100] loss: 0.109




[1,   110] loss: 0.157
[1,   120] loss: 0.184
[1,   130] loss: 0.119
[1,   140] loss: 0.093
[1,   150] loss: 0.165
[1,   160] loss: 0.053
[1,   170] loss: 0.117
[1,   180] loss: 0.075
[1,   190] loss: 0.076




[1,   200] loss: 0.070
[1,   210] loss: 0.125
[1,   220] loss: 0.093
[1,   230] loss: 0.107
[1,   240] loss: 0.133
[1,   250] loss: 0.060




[1,   260] loss: 0.094
[1,   270] loss: 0.112
[1,   280] loss: 0.085
[1,   290] loss: 0.106
[1,   300] loss: 0.109
[1,   310] loss: 0.075
[1,   320] loss: 0.085
[1,   330] loss: 0.062
[1,   340] loss: 0.073




[1,   350] loss: 0.085
[1,   360] loss: 0.087
[1,   370] loss: 0.109
[1,   380] loss: 0.068
[1,   390] loss: 0.066
[1,   400] loss: 0.095
[1,   410] loss: 0.092
[1,   420] loss: 0.076
[1,   430] loss: 0.122




[1,   440] loss: 0.104
[1,   450] loss: 0.072
[1,   460] loss: 0.082
[1,   470] loss: 0.032
[1,   480] loss: 0.103
[1,   490] loss: 0.090
[1,   500] loss: 0.071
[1,   510] loss: 0.102
[1,   520] loss: 0.098




[1,   530] loss: 0.124
[1,   540] loss: 0.073
[1,   550] loss: 0.105
[1,   560] loss: 0.091
[1,   570] loss: 0.094
[1,   580] loss: 0.089
[1,   590] loss: 0.032
[1,   600] loss: 0.118
[1,   610] loss: 0.081




[1,   620] loss: 0.074
[1,   630] loss: 0.064
[1,   640] loss: 0.074
[1,   650] loss: 0.088
[1,   660] loss: 0.085
[1,   670] loss: 0.085
[1,   680] loss: 0.073
[1,   690] loss: 0.096
[1,   700] loss: 0.082




[1,   710] loss: 0.067
[1,   720] loss: 0.098
[1,   730] loss: 0.120
[1,   740] loss: 0.102
[1,   750] loss: 0.112
[1,   760] loss: 0.063
[1,   770] loss: 0.147
[1,   780] loss: 0.071
[1,   790] loss: 0.059




[1,   800] loss: 0.086
[1,   810] loss: 0.081
[1,   820] loss: 0.099
[1,   830] loss: 0.066
[1,   840] loss: 0.138
[1,   850] loss: 0.079
[1,   860] loss: 0.076
[1,   870] loss: 0.087
[1,   880] loss: 0.077




[1,   890] loss: 0.083
[1,   900] loss: 0.074
[1,   910] loss: 0.080
[1,   920] loss: 0.083
[1,   930] loss: 0.069
[1,   940] loss: 0.079
[1,   950] loss: 0.090
[1,   960] loss: 0.090
[1,   970] loss: 0.113


100%|██████████| 1000/1000 [00:02<00:00, 412.98it/s]


[1,   980] loss: 0.071
[1,   990] loss: 0.079
[1,  1000] loss: 0.070
Epoch 1 | Train Loss: 48.3900


100%|██████████| 400/400 [00:00<00:00, 1524.33it/s]
 12%|█▎        | 1/8 [00:02<00:18,  2.69s/it]

Epoch 1 | Eval Loss: 0.8493




[2,    10] loss: 0.100
[2,    20] loss: 0.059
[2,    30] loss: 0.063
[2,    40] loss: 0.083
[2,    50] loss: 0.091
[2,    60] loss: 0.075
[2,    70] loss: 0.085
[2,    80] loss: 0.070




[2,    90] loss: 0.093




[2,   100] loss: 0.076
[2,   110] loss: 0.072
[2,   120] loss: 0.096
[2,   130] loss: 0.046
[2,   140] loss: 0.097
[2,   150] loss: 0.067
[2,   160] loss: 0.128
[2,   170] loss: 0.066
[2,   180] loss: 0.135




[2,   190] loss: 0.074
[2,   200] loss: 0.080
[2,   210] loss: 0.080
[2,   220] loss: 0.065
[2,   230] loss: 0.090
[2,   240] loss: 0.065
[2,   250] loss: 0.065
[2,   260] loss: 0.078
[2,   270] loss: 0.064




[2,   280] loss: 0.119
[2,   290] loss: 0.071
[2,   300] loss: 0.069
[2,   310] loss: 0.071
[2,   320] loss: 0.072
[2,   330] loss: 0.066
[2,   340] loss: 0.079
[2,   350] loss: 0.075
[2,   360] loss: 0.066




[2,   370] loss: 0.081
[2,   380] loss: 0.086
[2,   390] loss: 0.081
[2,   400] loss: 0.091
[2,   410] loss: 0.076
[2,   420] loss: 0.082
[2,   430] loss: 0.073
[2,   440] loss: 0.071




[2,   450] loss: 0.095
[2,   460] loss: 0.047
[2,   470] loss: 0.062
[2,   480] loss: 0.137




[2,   490] loss: 0.103
[2,   500] loss: 0.150




[2,   510] loss: 0.078
[2,   520] loss: 0.097
[2,   530] loss: 0.072
[2,   540] loss: 0.069
[2,   550] loss: 0.075
[2,   560] loss: 0.068
[2,   570] loss: 0.070




[2,   580] loss: 0.054
[2,   590] loss: 0.093




[2,   600] loss: 0.091
[2,   610] loss: 0.067
[2,   620] loss: 0.077
[2,   630] loss: 0.076
[2,   640] loss: 0.079
[2,   650] loss: 0.069
[2,   660] loss: 0.098
[2,   670] loss: 0.083




[2,   680] loss: 0.069
[2,   690] loss: 0.074
[2,   700] loss: 0.114
[2,   710] loss: 0.081
[2,   720] loss: 0.072
[2,   730] loss: 0.089
[2,   740] loss: 0.077
[2,   750] loss: 0.076




[2,   760] loss: 0.081




[2,   770] loss: 0.056
[2,   780] loss: 0.067
[2,   790] loss: 0.072
[2,   800] loss: 0.087
[2,   810] loss: 0.069
[2,   820] loss: 0.074
[2,   830] loss: 0.065
[2,   840] loss: 0.130




[2,   850] loss: 0.071




[2,   860] loss: 0.137
[2,   870] loss: 0.059
[2,   880] loss: 0.074
[2,   890] loss: 0.075
[2,   900] loss: 0.077
[2,   910] loss: 0.103
[2,   920] loss: 0.048
[2,   930] loss: 0.099


100%|██████████| 1000/1000 [00:02<00:00, 391.92it/s]


[2,   940] loss: 0.067
[2,   950] loss: 0.078
[2,   960] loss: 0.070
[2,   970] loss: 0.075
[2,   980] loss: 0.071
[2,   990] loss: 0.071
[2,  1000] loss: 0.066
Epoch 2 | Train Loss: 39.9553


100%|██████████| 400/400 [00:00<00:00, 1586.12it/s]
 25%|██▌       | 2/8 [00:05<00:16,  2.76s/it]

Epoch 2 | Eval Loss: 0.7228




[3,    10] loss: 0.079
[3,    20] loss: 0.059
[3,    30] loss: 0.080
[3,    40] loss: 0.083
[3,    50] loss: 0.073
[3,    60] loss: 0.070
[3,    70] loss: 0.075
[3,    80] loss: 0.081




[3,    90] loss: 0.072




[3,   100] loss: 0.072
[3,   110] loss: 0.072
[3,   120] loss: 0.079
[3,   130] loss: 0.075
[3,   140] loss: 0.071
[3,   150] loss: 0.071
[3,   160] loss: 0.077
[3,   170] loss: 0.075
[3,   180] loss: 0.078




[3,   190] loss: 0.068
[3,   200] loss: 0.081
[3,   210] loss: 0.076
[3,   220] loss: 0.073
[3,   230] loss: 0.079
[3,   240] loss: 0.083
[3,   250] loss: 0.061
[3,   260] loss: 0.083
[3,   270] loss: 0.075
[3,   280] loss: 0.063




[3,   290] loss: 0.092
[3,   300] loss: 0.064
[3,   310] loss: 0.076
[3,   320] loss: 0.065
[3,   330] loss: 0.076
[3,   340] loss: 0.086
[3,   350] loss: 0.069




[3,   360] loss: 0.071




[3,   370] loss: 0.073
[3,   380] loss: 0.073
[3,   390] loss: 0.092
[3,   400] loss: 0.081
[3,   410] loss: 0.072




[3,   420] loss: 0.075
[3,   430] loss: 0.070
[3,   440] loss: 0.069
[3,   450] loss: 0.080




[3,   460] loss: 0.075
[3,   470] loss: 0.074
[3,   480] loss: 0.075
[3,   490] loss: 0.078
[3,   500] loss: 0.065




[3,   510] loss: 0.052
[3,   520] loss: 0.076
[3,   530] loss: 0.086
[3,   540] loss: 0.100




[3,   550] loss: 0.085
[3,   560] loss: 0.064
[3,   570] loss: 0.072
[3,   580] loss: 0.088
[3,   590] loss: 0.082




[3,   600] loss: 0.086
[3,   610] loss: 0.049
[3,   620] loss: 0.105
[3,   630] loss: 0.061




[3,   640] loss: 0.070
[3,   650] loss: 0.080
[3,   660] loss: 0.055
[3,   670] loss: 0.107
[3,   680] loss: 0.062




[3,   690] loss: 0.110
[3,   700] loss: 0.086
[3,   710] loss: 0.088
[3,   720] loss: 0.075




[3,   730] loss: 0.070
[3,   740] loss: 0.064
[3,   750] loss: 0.072
[3,   760] loss: 0.077




[3,   770] loss: 0.068
[3,   780] loss: 0.078
[3,   790] loss: 0.070
[3,   800] loss: 0.069




[3,   810] loss: 0.087
[3,   820] loss: 0.057
[3,   830] loss: 0.083
[3,   840] loss: 0.074




[3,   850] loss: 0.075
[3,   860] loss: 0.076
[3,   870] loss: 0.073
[3,   880] loss: 0.078
[3,   890] loss: 0.078




[3,   900] loss: 0.080
[3,   910] loss: 0.080
[3,   920] loss: 0.070
[3,   930] loss: 0.072
[3,   940] loss: 0.063


100%|██████████| 1000/1000 [00:02<00:00, 389.72it/s]


[3,   950] loss: 0.077
[3,   960] loss: 0.067
[3,   970] loss: 0.078
[3,   980] loss: 0.074
[3,   990] loss: 0.070
[3,  1000] loss: 0.081
Epoch 3 | Train Loss: 37.6626


100%|██████████| 400/400 [00:00<00:00, 1589.14it/s]
 38%|███▊      | 3/8 [00:08<00:13,  2.79s/it]

Epoch 3 | Eval Loss: 0.7018




[4,    10] loss: 0.064
[4,    20] loss: 0.090
[4,    30] loss: 0.102
[4,    40] loss: 0.074
[4,    50] loss: 0.075
[4,    60] loss: 0.072
[4,    70] loss: 0.069
[4,    80] loss: 0.071
[4,    90] loss: 0.075




[4,   100] loss: 0.064





[4,   110] loss: 0.086
[4,   120] loss: 0.075
[4,   130] loss: 0.071
[4,   140] loss: 0.064
[4,   150] loss: 0.084
[4,   160] loss: 0.068
[4,   170] loss: 0.075
[4,   180] loss: 0.073
[4,   190] loss: 0.041


 19%|█▉        | 193/1000 [00:00<00:01, 469.53it/s][A

[4,   200] loss: 0.080




[4,   210] loss: 0.077
[4,   220] loss: 0.074
[4,   230] loss: 0.076
[4,   240] loss: 0.073
[4,   250] loss: 0.071
[4,   260] loss: 0.074
[4,   270] loss: 0.066
[4,   280] loss: 0.077
[4,   290] loss: 0.068
[4,   300] loss: 0.074




[4,   310] loss: 0.068
[4,   320] loss: 0.065
[4,   330] loss: 0.076
[4,   340] loss: 0.079
[4,   350] loss: 0.063
[4,   360] loss: 0.069
[4,   370] loss: 0.079
[4,   380] loss: 0.081
[4,   390] loss: 0.070
[4,   400] loss: 0.067




[4,   410] loss: 0.089
[4,   420] loss: 0.075
[4,   430] loss: 0.072
[4,   440] loss: 0.073
[4,   450] loss: 0.067
[4,   460] loss: 0.084
[4,   470] loss: 0.075
[4,   480] loss: 0.072
[4,   490] loss: 0.072




[4,   500] loss: 0.084
[4,   510] loss: 0.081
[4,   520] loss: 0.078
[4,   530] loss: 0.074
[4,   540] loss: 0.084
[4,   550] loss: 0.072
[4,   560] loss: 0.077
[4,   570] loss: 0.071
[4,   580] loss: 0.090




[4,   590] loss: 0.086
[4,   600] loss: 0.084
[4,   610] loss: 0.082
[4,   620] loss: 0.071
[4,   630] loss: 0.093
[4,   640] loss: 0.044
[4,   650] loss: 0.124
[4,   660] loss: 0.084
[4,   670] loss: 0.080




[4,   680] loss: 0.077
[4,   690] loss: 0.081
[4,   700] loss: 0.071
[4,   710] loss: 0.071
[4,   720] loss: 0.068
[4,   730] loss: 0.072
[4,   740] loss: 0.069
[4,   750] loss: 0.072
[4,   760] loss: 0.073
[4,   770] loss: 0.087




[4,   780] loss: 0.070
[4,   790] loss: 0.076
[4,   800] loss: 0.074
[4,   810] loss: 0.074
[4,   820] loss: 0.069
[4,   830] loss: 0.075
[4,   840] loss: 0.075
[4,   850] loss: 0.078
[4,   860] loss: 0.078
[4,   870] loss: 0.071




[4,   880] loss: 0.070
[4,   890] loss: 0.069
[4,   900] loss: 0.072
[4,   910] loss: 0.073
[4,   920] loss: 0.064
[4,   930] loss: 0.069
[4,   940] loss: 0.076
[4,   950] loss: 0.072
[4,   960] loss: 0.078
[4,   970] loss: 0.069


100%|██████████| 1000/1000 [00:02<00:00, 455.15it/s]


[4,   980] loss: 0.079
[4,   990] loss: 0.077
[4,  1000] loss: 0.072
Epoch 4 | Train Loss: 37.3976


100%|██████████| 400/400 [00:00<00:00, 1148.23it/s]
 50%|█████     | 4/8 [00:10<00:10,  2.69s/it]

Epoch 4 | Eval Loss: 0.7011




[5,    10] loss: 0.071
[5,    20] loss: 0.070
[5,    30] loss: 0.073
[5,    40] loss: 0.062
[5,    50] loss: 0.070
[5,    60] loss: 0.082
[5,    70] loss: 0.078
[5,    80] loss: 0.094
[5,    90] loss: 0.062




[5,   100] loss: 0.052




[5,   110] loss: 0.070
[5,   120] loss: 0.064
[5,   130] loss: 0.080
[5,   140] loss: 0.069
[5,   150] loss: 0.069
[5,   160] loss: 0.072
[5,   170] loss: 0.075
[5,   180] loss: 0.067
[5,   190] loss: 0.080
[5,   200] loss: 0.081




[5,   210] loss: 0.076
[5,   220] loss: 0.075
[5,   230] loss: 0.070
[5,   240] loss: 0.052
[5,   250] loss: 0.105
[5,   260] loss: 0.065
[5,   270] loss: 0.096
[5,   280] loss: 0.061
[5,   290] loss: 0.090
[5,   300] loss: 0.087




[5,   310] loss: 0.067
[5,   320] loss: 0.074
[5,   330] loss: 0.070
[5,   340] loss: 0.080
[5,   350] loss: 0.062
[5,   360] loss: 0.089
[5,   370] loss: 0.064
[5,   380] loss: 0.087
[5,   390] loss: 0.076




[5,   400] loss: 0.075
[5,   410] loss: 0.073
[5,   420] loss: 0.069
[5,   430] loss: 0.076
[5,   440] loss: 0.064
[5,   450] loss: 0.108
[5,   460] loss: 0.066
[5,   470] loss: 0.072
[5,   480] loss: 0.071




[5,   490] loss: 0.070
[5,   500] loss: 0.067
[5,   510] loss: 0.085
[5,   520] loss: 0.073
[5,   530] loss: 0.070
[5,   540] loss: 0.070
[5,   550] loss: 0.070
[5,   560] loss: 0.071
[5,   570] loss: 0.069




[5,   580] loss: 0.060
[5,   590] loss: 0.082
[5,   600] loss: 0.076
[5,   610] loss: 0.072
[5,   620] loss: 0.068
[5,   630] loss: 0.071
[5,   640] loss: 0.072
[5,   650] loss: 0.070
[5,   660] loss: 0.073




[5,   670] loss: 0.067
[5,   680] loss: 0.081
[5,   690] loss: 0.073
[5,   700] loss: 0.074
[5,   710] loss: 0.065
[5,   720] loss: 0.071
[5,   730] loss: 0.084
[5,   740] loss: 0.062
[5,   750] loss: 0.078




[5,   760] loss: 0.078
[5,   770] loss: 0.071
[5,   780] loss: 0.072
[5,   790] loss: 0.071
[5,   800] loss: 0.069
[5,   810] loss: 0.064
[5,   820] loss: 0.066
[5,   830] loss: 0.101
[5,   840] loss: 0.070




[5,   850] loss: 0.073
[5,   860] loss: 0.071
[5,   870] loss: 0.069
[5,   880] loss: 0.072
[5,   890] loss: 0.064
[5,   900] loss: 0.072
[5,   910] loss: 0.070
[5,   920] loss: 0.073
[5,   930] loss: 0.075


100%|██████████| 1000/1000 [00:02<00:00, 430.60it/s]


[5,   940] loss: 0.069
[5,   950] loss: 0.066
[5,   960] loss: 0.072
[5,   970] loss: 0.073
[5,   980] loss: 0.068
[5,   990] loss: 0.084
[5,  1000] loss: 0.070
Epoch 5 | Train Loss: 36.5544


100%|██████████| 400/400 [00:00<00:00, 1518.77it/s]
 62%|██████▎   | 5/8 [00:13<00:07,  2.66s/it]

Epoch 5 | Eval Loss: 0.6936




[6,    10] loss: 0.069
[6,    20] loss: 0.071
[6,    30] loss: 0.069
[6,    40] loss: 0.064
[6,    50] loss: 0.076




[6,    60] loss: 0.066
[6,    70] loss: 0.063
[6,    80] loss: 0.078
[6,    90] loss: 0.079




[6,   100] loss: 0.068
[6,   110] loss: 0.062
[6,   120] loss: 0.070
[6,   130] loss: 0.074
[6,   140] loss: 0.079
[6,   150] loss: 0.064
[6,   160] loss: 0.062
[6,   170] loss: 0.072
[6,   180] loss: 0.069




[6,   190] loss: 0.053
[6,   200] loss: 0.098
[6,   210] loss: 0.070
[6,   220] loss: 0.062
[6,   230] loss: 0.057
[6,   240] loss: 0.065




[6,   250] loss: 0.104
[6,   260] loss: 0.062
[6,   270] loss: 0.077




[6,   280] loss: 0.090
[6,   290] loss: 0.070
[6,   300] loss: 0.066
[6,   310] loss: 0.080
[6,   320] loss: 0.071
[6,   330] loss: 0.072




[6,   340] loss: 0.070
[6,   350] loss: 0.066
[6,   360] loss: 0.057




[6,   370] loss: 0.079
[6,   380] loss: 0.075
[6,   390] loss: 0.070
[6,   400] loss: 0.069
[6,   410] loss: 0.064
[6,   420] loss: 0.076




[6,   430] loss: 0.068
[6,   440] loss: 0.078
[6,   450] loss: 0.069




[6,   460] loss: 0.071
[6,   470] loss: 0.074
[6,   480] loss: 0.072
[6,   490] loss: 0.082
[6,   500] loss: 0.079
[6,   510] loss: 0.066




[6,   520] loss: 0.072
[6,   530] loss: 0.068
[6,   540] loss: 0.071




[6,   550] loss: 0.065
[6,   560] loss: 0.073
[6,   570] loss: 0.071
[6,   580] loss: 0.067
[6,   590] loss: 0.076
[6,   600] loss: 0.068




[6,   610] loss: 0.069
[6,   620] loss: 0.078
[6,   630] loss: 0.058
[6,   640] loss: 0.074




[6,   650] loss: 0.075
[6,   660] loss: 0.072
[6,   670] loss: 0.071
[6,   680] loss: 0.072
[6,   690] loss: 0.070
[6,   700] loss: 0.070




[6,   710] loss: 0.071
[6,   720] loss: 0.069
[6,   730] loss: 0.072
[6,   740] loss: 0.074




[6,   750] loss: 0.070
[6,   760] loss: 0.076
[6,   770] loss: 0.071
[6,   780] loss: 0.076
[6,   790] loss: 0.069
[6,   800] loss: 0.066




[6,   810] loss: 0.076
[6,   820] loss: 0.072
[6,   830] loss: 0.068
[6,   840] loss: 0.070
[6,   850] loss: 0.069




[6,   860] loss: 0.071
[6,   870] loss: 0.070
[6,   880] loss: 0.069
[6,   890] loss: 0.077




[6,   900] loss: 0.069
[6,   910] loss: 0.067
[6,   920] loss: 0.079
[6,   930] loss: 0.067
[6,   940] loss: 0.071




[6,   950] loss: 0.077
[6,   960] loss: 0.065
[6,   970] loss: 0.072


100%|██████████| 1000/1000 [00:02<00:00, 400.45it/s]


[6,   980] loss: 0.073
[6,   990] loss: 0.068
[6,  1000] loss: 0.072
Epoch 6 | Train Loss: 35.5661


100%|██████████| 400/400 [00:00<00:00, 1459.88it/s]
 75%|███████▌  | 6/8 [00:16<00:05,  2.70s/it]

Epoch 6 | Eval Loss: 0.7245




[7,    10] loss: 0.064
[7,    20] loss: 0.073
[7,    30] loss: 0.069
[7,    40] loss: 0.071
[7,    50] loss: 0.071
[7,    60] loss: 0.076
[7,    70] loss: 0.070




[7,    80] loss: 0.073
[7,    90] loss: 0.068




[7,   100] loss: 0.075
[7,   110] loss: 0.070
[7,   120] loss: 0.073
[7,   130] loss: 0.073
[7,   140] loss: 0.071
[7,   150] loss: 0.073
[7,   160] loss: 0.071




[7,   170] loss: 0.071
[7,   180] loss: 0.073




[7,   190] loss: 0.068
[7,   200] loss: 0.064
[7,   210] loss: 0.077
[7,   220] loss: 0.068
[7,   230] loss: 0.058
[7,   240] loss: 0.076
[7,   250] loss: 0.070
[7,   260] loss: 0.078
[7,   270] loss: 0.065




[7,   280] loss: 0.084
[7,   290] loss: 0.073
[7,   300] loss: 0.059
[7,   310] loss: 0.074
[7,   320] loss: 0.073
[7,   330] loss: 0.079
[7,   340] loss: 0.071
[7,   350] loss: 0.070
[7,   360] loss: 0.065




[7,   370] loss: 0.077
[7,   380] loss: 0.076
[7,   390] loss: 0.070
[7,   400] loss: 0.065
[7,   410] loss: 0.063
[7,   420] loss: 0.083
[7,   430] loss: 0.070
[7,   440] loss: 0.069
[7,   450] loss: 0.071




[7,   460] loss: 0.070
[7,   470] loss: 0.065
[7,   480] loss: 0.078
[7,   490] loss: 0.064
[7,   500] loss: 0.068
[7,   510] loss: 0.072
[7,   520] loss: 0.081
[7,   530] loss: 0.067
[7,   540] loss: 0.055




[7,   550] loss: 0.087
[7,   560] loss: 0.062
[7,   570] loss: 0.073
[7,   580] loss: 0.077
[7,   590] loss: 0.072
[7,   600] loss: 0.061
[7,   610] loss: 0.077




[7,   620] loss: 0.083
[7,   630] loss: 0.068
[7,   640] loss: 0.071
[7,   650] loss: 0.069
[7,   660] loss: 0.070
[7,   670] loss: 0.071
[7,   680] loss: 0.075




[7,   690] loss: 0.066
[7,   700] loss: 0.071
[7,   710] loss: 0.071
[7,   720] loss: 0.070
[7,   730] loss: 0.071
[7,   740] loss: 0.068
[7,   750] loss: 0.068
[7,   760] loss: 0.069
[7,   770] loss: 0.063




[7,   780] loss: 0.080
[7,   790] loss: 0.067
[7,   800] loss: 0.071
[7,   810] loss: 0.070
[7,   820] loss: 0.069
[7,   830] loss: 0.076
[7,   840] loss: 0.072
[7,   850] loss: 0.069
[7,   860] loss: 0.063




[7,   870] loss: 0.069
[7,   880] loss: 0.083
[7,   890] loss: 0.056
[7,   900] loss: 0.091
[7,   910] loss: 0.065
[7,   920] loss: 0.071
[7,   930] loss: 0.071
[7,   940] loss: 0.072


100%|██████████| 1000/1000 [00:02<00:00, 392.01it/s]


[7,   950] loss: 0.071
[7,   960] loss: 0.070
[7,   970] loss: 0.071
[7,   980] loss: 0.067
[7,   990] loss: 0.077
[7,  1000] loss: 0.069
Epoch 7 | Train Loss: 35.4611


100%|██████████| 400/400 [00:00<00:00, 1422.77it/s]
 88%|████████▊ | 7/8 [00:19<00:02,  2.74s/it]

Epoch 7 | Eval Loss: 0.7002




[8,    10] loss: 0.060
[8,    20] loss: 0.103
[8,    30] loss: 0.074
[8,    40] loss: 0.070
[8,    50] loss: 0.069
[8,    60] loss: 0.063
[8,    70] loss: 0.085
[8,    80] loss: 0.071




[8,    90] loss: 0.067




[8,   100] loss: 0.072
[8,   110] loss: 0.071
[8,   120] loss: 0.070
[8,   130] loss: 0.068
[8,   140] loss: 0.071
[8,   150] loss: 0.074
[8,   160] loss: 0.070
[8,   170] loss: 0.066
[8,   180] loss: 0.073




[8,   190] loss: 0.080
[8,   200] loss: 0.073
[8,   210] loss: 0.068
[8,   220] loss: 0.072
[8,   230] loss: 0.071




[8,   240] loss: 0.071
[8,   250] loss: 0.070
[8,   260] loss: 0.070
[8,   270] loss: 0.070




[8,   280] loss: 0.068
[8,   290] loss: 0.069
[8,   300] loss: 0.078
[8,   310] loss: 0.065
[8,   320] loss: 0.064
[8,   330] loss: 0.067




[8,   340] loss: 0.079
[8,   350] loss: 0.065
[8,   360] loss: 0.071
[8,   370] loss: 0.068




[8,   380] loss: 0.072
[8,   390] loss: 0.067
[8,   400] loss: 0.071
[8,   410] loss: 0.071
[8,   420] loss: 0.072




[8,   430] loss: 0.070
[8,   440] loss: 0.068
[8,   450] loss: 0.072
[8,   460] loss: 0.070




[8,   470] loss: 0.070
[8,   480] loss: 0.071
[8,   490] loss: 0.072
[8,   500] loss: 0.070
[8,   510] loss: 0.069




[8,   520] loss: 0.071
[8,   530] loss: 0.075
[8,   540] loss: 0.067
[8,   550] loss: 0.049




[8,   560] loss: 0.102
[8,   570] loss: 0.069
[8,   580] loss: 0.063
[8,   590] loss: 0.070
[8,   600] loss: 0.075




[8,   610] loss: 0.068
[8,   620] loss: 0.068
[8,   630] loss: 0.066
[8,   640] loss: 0.073




[8,   650] loss: 0.074
[8,   660] loss: 0.068
[8,   670] loss: 0.069
[8,   680] loss: 0.070
[8,   690] loss: 0.072




[8,   700] loss: 0.064
[8,   710] loss: 0.060
[8,   720] loss: 0.068
[8,   730] loss: 0.075




[8,   740] loss: 0.079
[8,   750] loss: 0.067
[8,   760] loss: 0.069
[8,   770] loss: 0.063
[8,   780] loss: 0.072




[8,   790] loss: 0.063
[8,   800] loss: 0.068




[8,   810] loss: 0.068
[8,   820] loss: 0.076
[8,   830] loss: 0.070
[8,   840] loss: 0.072




[8,   850] loss: 0.068
[8,   860] loss: 0.073
[8,   870] loss: 0.069
[8,   880] loss: 0.070
[8,   890] loss: 0.068




[8,   900] loss: 0.065
[8,   910] loss: 0.077
[8,   920] loss: 0.070
[8,   930] loss: 0.068
[8,   940] loss: 0.071




[8,   950] loss: 0.068
[8,   960] loss: 0.060
[8,   970] loss: 0.062
[8,   980] loss: 0.047


100%|██████████| 1000/1000 [00:02<00:00, 408.30it/s]


[8,   990] loss: 0.071
[8,  1000] loss: 0.052
Epoch 8 | Train Loss: 34.9304


100%|██████████| 400/400 [00:00<00:00, 1596.01it/s]
                                             

Epoch 8 | Eval Loss: 0.8282




0.6986083506941795

In [91]:
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print(input.shape)
print(embedding(input).shape)


torch.Size([2, 4])
torch.Size([2, 4, 3])
