In [341]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import tqdm
import random
import torch.nn.functional as F

# Experiment with cross attention
In this experiment, we sample data not only from uniform distribution but from variety of distributions. We then use this data for encoder and feed its output throught cross-attention mechanism to the decoder. The decoder takes the sequence with missing tokens and tries to fill in the gaps.

### Architecture
- We use multihead attention for the encoder and vanilla attention for the decoder.
- I don't split heads: I use one head to process one sequence from one of the distributions

In [352]:
# Load the sequences from the file
final_chains_train = np.load('final_chains_T=1_num_iters=400_train.npy')
print("Loaded train sequences of proteins sampled from Boltzmann distribution:", final_chains_train.shape)

final_chains_test = np.load('final_chains_T=1_num_iters=400_test.npy')
print("Loaded test sequences of proteins sampled from Boltzmann distribution:", final_chains_test.shape)

Loaded train sequences of proteins sampled from Boltzmann distribution: (1000, 200)
Loaded test sequences of proteins sampled from Boltzmann distribution: (1000, 200)


In [353]:
# Define the parameters for different distributions
distributions = [{"type": "normal", "mean": 0, "std": 1},
        {"type": "uniform", "low": -1, "high": 1},
        {"type": "exponential", "scale": 1},
        {"type": "gamma", "scale": 1},
        {"type": "poisson", "lam": 1.0}]

def sample_data(num_samples, distributions, len_of_seq):
    # Define the parameters for different distributions
    
    # Initialize the tensor to store the samples
    all_samples = torch.zeros(len(distributions), num_samples, len_of_seq)
    
    # Sample data from each distribution
    for i in range(len(distributions)):
        distr = distributions[i]
        if distr["type"] == "normal":
            samples = np.random.normal(distr["mean"], distr["std"], (num_samples, len_of_seq))
            samples = np.where(samples >= 0, 1, -1)
        elif distr["type"] == "uniform":
            samples = np.random.uniform(distr["low"], distr["high"], (num_samples, len_of_seq))
            samples = np.where(samples >= 0, 1, -1)
        elif distr["type"] == "exponential":
            samples = np.random.exponential(distr["scale"], (num_samples, len_of_seq))
            samples = np.where(samples >= distr["scale"], 1, -1)
        elif distr["type"] == "gamma":
            samples = np.random.poisson(distr["scale"], (num_samples, len_of_seq))
            samples = np.where(samples >= np.mean(samples), 1, -1)
        elif distr["type"] == "poisson":
            samples = np.random.poisson(distr["lam"], (num_samples, len_of_seq))
            samples = np.where(samples >= np.mean(samples), 1, -1)
        else:
            raise ValueError("Unsupported distribution type")
        
        # Store the samples in the tensor
        all_samples[i] = torch.tensor(samples, dtype=torch.float32)
    
    return all_samples.reshape(num_samples, len(distributions), len_of_seq)

# Example usage
tensor_samples_train = sample_data(1000, distributions, 200)  # 5 distributions, 1000 samples each, length of sequence 10
print(tensor_samples_train.shape)

tensor_samples_test = sample_data(1000, distributions, 200)  # 5 distributions, 1000 samples each, length of sequence 10
print(tensor_samples_train.shape)

torch.Size([1000, 5, 200])
torch.Size([1000, 5, 200])


In [354]:
#data_train = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
#data_train_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(1000)])
#data_test = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])
#data_test_dec = torch.tensor([np.random.choice([-1, 1], size=20) for _ in range(400)])

def one_hot_encoding(seq, vocab):
    if isinstance(seq, int):
        one_hot = torch.zeros(1, len(vocab), dtype=int)
        one_hot[:, vocab[seq]] = 1
    else:
        one_hot = torch.zeros((len(seq), len(vocab)), dtype=int)
        for i, spin in enumerate(seq):
            #print("spins:",spin)
            #print("vocab:",vocab)
            one_hot[i, vocab[spin]] = 1
    return one_hot

def mask_random_spin(sequence, vocab, mask_token=2):
    """
    Mask one random spin in a sequence of protein spins.
    
    Parameters:
    - sequence: a list or sequence of spins (integers)
    - mask_token: the token used to mask a spin (default is 2)
    
    Returns:
    - masked_sequence: a sequence similar to the input but with one spin masked
    - masked_position: the position of the spin that was masked
    """
    # Ensure the sequence can be converted to a list for masking
    sequence_list = sequence.numpy().tolist() if isinstance(sequence, torch.Tensor) else list(sequence)
    # Choose a random position to mask, excluding the first spin
    mask_position = random.randint(1, len(sequence_list) - 1)
    
    # Mask the chosen position
    masked_sequence = sequence_list.copy()
    masked_sequence[mask_position] = mask_token

    # Create an array of zeros with shape (len(sequence), len(vocab))
    one_hot = one_hot_encoding(masked_sequence, vocab)
        
    return torch.tensor(one_hot, dtype=torch.float), torch.tensor(mask_position)
    

In [355]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads
        
        self.W_q = nn.Linear(embed_dim, embed_dim) # Linear layer for query 
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim) # Output layer to ensure that dimensionality matches the model's expected dimensionality
        
        self.combine_heads = nn.Linear(num_heads, 1)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, embed_dim = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
     
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        #print("x somewhere:", x.shape)
        #print("seqeunce length:",seq_length)
        #print("d model dim:",self.d_model)
        return x.transpose(1, 2).contiguous().view(batch_size,seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        #print("shape of Q before splitting heads:", Q.shape)
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        #print("shape of Q after splitting heads:", Q.shape)
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        #print(" before W_o:", self.combine_heads(attn_output.permute(1,2,0)).squeeze(-1).shape)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class VanillaAttention(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins=3, dropout_rate=0.0):
        super(VanillaAttention, self).__init__()
        self.word_embeddings = nn.Linear(num_spins, embed_dim)
        self.position_embeddings = nn.Embedding(max_seq_length, embed_dim)
        self.a = a  # parameter controlling how important positions are
        self.value_weight = nn.Linear(embed_dim, embed_dim)
        self.query_weight = nn.Linear(embed_dim, embed_dim)
        self.key_weight = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)  # output layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, s, enc_output):

        position_ids = torch.arange(s.size(0), dtype=torch.long)
        #x = self.word_embeddings(s) + self.a*self.position_embeddings(position_ids)
        x = s + self.a*self.position_embeddings(position_ids)
        
        query = self.query_weight(x)
        key = self.key_weight(enc_output)
        values = self.value_weight(enc_output)
        
        # Simple attention score calculation (Dot product): this is equivalent to the interaction matrix
        scores = torch.matmul(query, key.transpose(-2, -1))  # Transpose last two dimensions for matrix multiplication
        scores = torch.softmax(scores, dim=-1)  # Apply softmax to scores to get probabilities

        # Apply attention scores to values
        attn_output = torch.matmul(scores, values)

        # Sum over the sequence length dimensions
        #attn_output = attn_output.sum(dim=1)
        output = self.fc(self.dropout(attn_output)) # should have size (20,3)

        return output

In [356]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [357]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, proj_layer_dim, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=len(distributions))
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(len(distributions),1)
        
    def forward(self, x):
        x = self.fc(x.permute(1,2,0)).permute(2,0,1)
        attn_output = self.self_attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)# masking one of the word in the sequence
        self.cross_attn = VanillaAttention(embed_dim=embed_dim, a=a, max_seq_length=max_seq_length, num_spins=num_spins, dropout_rate=dropout)
        self.feed_forward = PositionWiseFeedForward(d_model=embed_dim, d_ff=proj_layer_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output):
        enc_output = enc_output.squeeze(0)
        attn_output = self.self_attn(x, x)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        x = x.sum(dim=0)
        return x

In [358]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout):
        super(Transformer, self).__init__()
        #self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        #self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.word_embeddings = nn.Linear(num_spins, embed_dim)

        #self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        #self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_layer = EncoderLayer(embed_dim, proj_layer_dim, dropout)
        self.decoder_layer = DecoderLayer(embed_dim, a, max_seq_length, num_spins, proj_layer_dim, dropout)
        self.fc = nn.Linear(embed_dim, num_spins)

    def forward(self, src, tgt):
        #src_mask, tgt_mask = self.generate_mask(src, tgt)
        #src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        #tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        src_embedded = self.word_embeddings(src)
        tgt_embedded = self.word_embeddings(tgt)
        enc_output = self.encoder_layer(src_embedded)
        dec_output = self.decoder_layer(tgt_embedded, enc_output)
        output = self.fc(dec_output)
        #print("output of the transformer:", output.shape)
        return output

## Training and validation

In [362]:
def evaluate(model, data_test, data_test_dec, vocab, criterion, device=0):
    model.eval()
    epoch_loss = 0
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    for i, data in tqdm.tqdm(enumerate(data_test), total=len(data_test)):
        # Get the inputs
        input_seq_enc = data
        input_seq_dec = data_test_dec[i]
        
        input_encoder_one_hot = torch.stack([one_hot_encoding(input_seq_enc[i].tolist(), vocab) for i in range(len(input_seq_enc))], dim=0)
        input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)

        input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)
        # mask a token
        masked_sequence_dec, position = mask_random_spin(input_seq_dec, vocab, mask_token=2)
        
        # Forward pass
        outputs = model.forward(input_encoder_one_hot, masked_sequence_dec)

        #output_token = F.log_softmax(outputs, dim=-1)
        target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #target_token = input_seq[position]

        # Compute loss
        loss = criterion(outputs.unsqueeze(0), target_token)
        epoch_loss += loss.item()
    #model.train()
    return epoch_loss / len(data_test)

def train(model, data_train, data_train_dec, data_test, data_test_dec, vocab, optimizer, criterion, num_epochs=8, device=0):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    # Training loop
    model.train()
    best_eval_loss = 1e-3 # used to do early stopping

    for epoch in tqdm.tqdm(range(num_epochs), leave=False, position=0):
        running_loss = 0
        epoch_loss = 0
        
        for i, data in tqdm.tqdm(enumerate(data_train), total=len(data_train)):
            # Get the inputs
            input_seq_enc = data
            input_seq_dec = data_train_dec[i]

            input_encoder_one_hot = torch.stack([one_hot_encoding(input_seq_enc[i].tolist(), vocab) for i in range(len(input_seq_enc))], dim=0)
            input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)

            input_decoder_one_hot = one_hot_encoding(input_seq_dec.tolist(), vocab)

            # mask a token in decoder
            masked_sequence_dec, position = mask_random_spin(input_seq_dec, vocab, mask_token=2)

            # Forward pass
            prediction = model.forward(input_encoder_one_hot, masked_sequence_dec) #masked_sequence[masked_position]
            
            #predicted_token = F.log_softmax(prediction, dim=-1)
            target_token = torch.argwhere(input_decoder_one_hot[position]==1).squeeze(0) #input_seq[masked_position]
            
            # Compute loss
            #print("model prediction:", prediction.shape)
            #print("target:", target_token.shape)
            loss = criterion(prediction.unsqueeze(0), target_token)
            epoch_loss += loss.item()
            
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9. :    # print every 10 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} | Train Loss: {(epoch_loss / len(data)):.4f}')
        eval_loss = evaluate(model, data_test, data_test_dec, vocab, criterion, device=device)
        print(f'Epoch {epoch + 1} | Eval Loss: {(eval_loss):.4f}')
        
        # Perform early stopping based on eval loss
        if eval_loss < best_eval_loss:
            return epoch_loss / len(data_train)
    return epoch_loss / len(data_train)

In [363]:
# Define the parameters 
vocab_size = 3
vocab = {-1:0,1:1,2:2} 
L = 200
embedding_dim = 200
proj_layer_dim = 128
hidden_dim = 200
num_layers = 1 # have to adapt the model for 2 and 3 layers
dropout_rate = 0.0
lr = 1e-3
num_sequences = 1000
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [364]:
# Example usage:

model = Transformer(embed_dim=embedding_dim, a=0, max_seq_length=L, num_spins=3, proj_layer_dim=128, dropout=dropout_rate)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train(model, tensor_samples_train, final_chains_train, tensor_samples_test, final_chains_test, vocab, optimizer, criterion, device=device)
#torch.save(model.state_dict(), 'models/lstm_scratch.pt')
#evaluate(model, test_dataloader, criterion, device=device)

  input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)
  return torch.tensor(one_hot, dtype=torch.float), torch.tensor(mask_position)


[1,    10] loss: 18.063
[1,    20] loss: 4.761




[1,    30] loss: 1.969
[1,    40] loss: 3.117




[1,    50] loss: 6.823
[1,    60] loss: 3.646




[1,    70] loss: 4.646
[1,    80] loss: 2.782




[1,    90] loss: 8.657
[1,   100] loss: 2.358




[1,   110] loss: 1.994
[1,   120] loss: 2.635




[1,   130] loss: 3.912
[1,   140] loss: 5.054




[1,   150] loss: 4.476
[1,   160] loss: 1.067




[1,   170] loss: 1.351
[1,   180] loss: 0.847




[1,   190] loss: 1.590
[1,   200] loss: 1.383




[1,   210] loss: 2.097
[1,   220] loss: 1.341




[1,   230] loss: 1.625
[1,   240] loss: 0.784




[1,   250] loss: 2.578
[1,   260] loss: 5.678




[1,   270] loss: 6.323
[1,   280] loss: 2.196




[1,   290] loss: 3.505
[1,   300] loss: 4.222




[1,   310] loss: 0.692
[1,   320] loss: 0.922




[1,   330] loss: 0.905
[1,   340] loss: 1.026




[1,   350] loss: 0.626
[1,   360] loss: 3.501



 39%|███▉      | 389/1000 [00:05<00:07, 81.23it/s]

[1,   370] loss: 5.087
[1,   380] loss: 3.994


[A

[1,   390] loss: 3.738
[1,   400] loss: 6.096




[1,   410] loss: 1.124
[1,   420] loss: 0.524




[1,   430] loss: 1.072
[1,   440] loss: 0.992




[1,   450] loss: 0.816
[1,   460] loss: 0.356




[1,   470] loss: 0.505
[1,   480] loss: 0.869




[1,   490] loss: 0.742
[1,   500] loss: 0.259




[1,   510] loss: 0.629
[1,   520] loss: 0.155




[1,   530] loss: 0.327
[1,   540] loss: 0.789




[1,   550] loss: 0.020
[1,   560] loss: 0.509




[1,   570] loss: 0.122
[1,   580] loss: 0.873




[1,   590] loss: 1.223
[1,   600] loss: 1.067




[1,   610] loss: 0.612
[1,   620] loss: 0.940




[1,   630] loss: 0.105
[1,   640] loss: 0.431




[1,   650] loss: 0.272
[1,   660] loss: 0.290




[1,   670] loss: 1.169
[1,   680] loss: 0.739




[1,   690] loss: 0.874
[1,   700] loss: 0.487




[1,   710] loss: 0.848
[1,   720] loss: 1.064




[1,   730] loss: 1.575
[1,   740] loss: 1.051




[1,   750] loss: 1.552
[1,   760] loss: 0.100




[1,   770] loss: 0.964
[1,   780] loss: 1.029




[1,   790] loss: 0.853
[1,   800] loss: 0.306




[1,   810] loss: 0.492
[1,   820] loss: 0.374




[1,   830] loss: 0.838
[1,   840] loss: 0.243




[1,   850] loss: 0.794
[1,   860] loss: 0.084




[1,   870] loss: 0.566
[1,   880] loss: 0.215




[1,   890] loss: 0.005
[1,   900] loss: 1.310




[1,   910] loss: 0.594
[1,   920] loss: 0.078




[1,   930] loss: 0.544
[1,   940] loss: 0.607




[1,   950] loss: 0.663
[1,   960] loss: 0.377




[1,   970] loss: 0.164
[1,   980] loss: 0.311


100%|██████████| 1000/1000 [00:13<00:00, 76.65it/s]


[1,   990] loss: 0.242
[1,  1000] loss: 0.901
Epoch 1 | Train Loss: 3514.0853


  input_encoder_one_hot = torch.tensor(input_encoder_one_hot, dtype=torch.float)
100%|██████████| 1000/1000 [00:06<00:00, 145.66it/s]
 12%|█▎        | 1/8 [00:19<02:19, 19.92s/it]

Epoch 1 | Eval Loss: 6.9511




[2,    10] loss: 0.029




[2,    20] loss: 0.303




[2,    30] loss: 0.019
[2,    40] loss: 1.349




[2,    50] loss: 1.452
[2,    60] loss: 1.064




[2,    70] loss: 1.441
[2,    80] loss: 0.341




[2,    90] loss: 1.120
[2,   100] loss: 0.076




[2,   110] loss: 0.700
[2,   120] loss: 0.446




[2,   130] loss: 0.525
[2,   140] loss: 0.962




[2,   150] loss: 0.354
[2,   160] loss: 0.992




[2,   170] loss: 0.665
[2,   180] loss: 0.323




[2,   190] loss: 2.007
[2,   200] loss: 0.400




[2,   210] loss: 2.351
[2,   220] loss: 1.060




[2,   230] loss: 0.402
[2,   240] loss: 0.247




[2,   250] loss: 0.527
[2,   260] loss: 0.914




[2,   270] loss: 1.238
[2,   280] loss: 1.297




[2,   290] loss: 0.259
[2,   300] loss: 0.054




[2,   310] loss: 0.696
[2,   320] loss: 0.689




[2,   330] loss: 0.226
[2,   340] loss: 0.731




[2,   350] loss: 0.511
[2,   360] loss: 0.131




[2,   370] loss: 0.396
[2,   380] loss: 0.204




[2,   390] loss: 0.340
[2,   400] loss: 0.154




[2,   410] loss: 0.244
[2,   420] loss: 0.303




[2,   430] loss: 0.662
[2,   440] loss: 0.684




[2,   450] loss: 0.430
[2,   460] loss: 0.843




[2,   470] loss: 0.446
[2,   480] loss: 0.013




[2,   490] loss: 1.110
[2,   500] loss: 0.362




[2,   510] loss: 0.196
[2,   520] loss: 0.198




[2,   530] loss: 0.642
[2,   540] loss: 0.546




[2,   550] loss: 0.824
[2,   560] loss: 0.006




[2,   570] loss: 0.030
[2,   580] loss: 0.710




[2,   590] loss: 0.107
[2,   600] loss: 0.559




[2,   610] loss: 0.448
[2,   620] loss: 0.528




[2,   630] loss: 0.973
[2,   640] loss: 0.200




[2,   650] loss: 0.119
[2,   660] loss: 0.386




[2,   670] loss: 0.174
[2,   680] loss: 0.479




[2,   690] loss: 0.526
[2,   700] loss: 0.183




[2,   710] loss: 0.350
[2,   720] loss: 0.242




[2,   730] loss: 0.032
[2,   740] loss: 0.555




[2,   750] loss: 0.590
[2,   760] loss: 0.235




[2,   770] loss: 0.109
[2,   780] loss: 0.265




[2,   790] loss: 0.228
[2,   800] loss: 0.086




[2,   810] loss: 0.104
[2,   820] loss: 0.152




[2,   830] loss: 0.113
[2,   840] loss: 0.154




[2,   850] loss: 0.082
[2,   860] loss: 0.118




[2,   870] loss: 0.132
[2,   880] loss: 0.045




[2,   890] loss: 0.024
[2,   900] loss: 0.180




[2,   910] loss: 0.156
[2,   920] loss: 0.402




[2,   930] loss: 0.118
[2,   940] loss: 0.099




[2,   950] loss: 0.287
[2,   960] loss: 0.117




[2,   970] loss: 0.121
[2,   980] loss: 0.215


100%|██████████| 1000/1000 [00:14<00:00, 67.30it/s]


[2,   990] loss: 0.050
[2,  1000] loss: 0.204
Epoch 2 | Train Loss: 909.8301


100%|██████████| 1000/1000 [00:06<00:00, 147.69it/s]
 25%|██▌       | 2/8 [00:41<02:05, 20.93s/it]

Epoch 2 | Eval Loss: 1.0456




[3,    10] loss: 0.268




[3,    20] loss: 0.086




[3,    30] loss: 0.203




[3,    40] loss: 0.216




[3,    50] loss: 0.456




[3,    60] loss: 0.544




[3,    70] loss: 0.113




[3,    80] loss: 0.315




[3,    90] loss: 0.416
[3,   100] loss: 0.336




[3,   110] loss: 0.156
[3,   120] loss: 0.107




[3,   130] loss: 0.054
[3,   140] loss: 0.228




[3,   150] loss: 0.184
[3,   160] loss: 0.100




[3,   170] loss: 0.019
[3,   180] loss: 0.200




[3,   190] loss: 0.133
[3,   200] loss: 0.083




[3,   210] loss: 0.084
[3,   220] loss: 0.063




[3,   230] loss: 0.071
[3,   240] loss: 0.094




[3,   250] loss: 0.173
[3,   260] loss: 0.163




[3,   270] loss: 0.067
[3,   280] loss: 0.110




[3,   290] loss: 0.108
[3,   300] loss: 0.126




[3,   310] loss: 0.132
[3,   320] loss: 0.111




[3,   330] loss: 0.081
[3,   340] loss: 0.092




[3,   350] loss: 0.043
[3,   360] loss: 0.106




[3,   370] loss: 0.030
[3,   380] loss: 0.081




[3,   390] loss: 0.148
[3,   400] loss: 0.120




[3,   410] loss: 0.074
[3,   420] loss: 0.080




[3,   430] loss: 0.060
[3,   440] loss: 0.131




[3,   450] loss: 0.032
[3,   460] loss: 0.111




[3,   470] loss: 0.246
[3,   480] loss: 0.059




[3,   490] loss: 0.050
[3,   500] loss: 0.065



 53%|█████▎    | 527/1000 [00:07<00:05, 85.58it/s]

[3,   510] loss: 0.050
[3,   520] loss: 0.033


[A

[3,   530] loss: 0.069
[3,   540] loss: 0.041





[3,   550] loss: 0.081
[3,   560] loss: 0.136


 56%|█████▋    | 564/1000 [00:07<00:05, 77.66it/s][A

[3,   570] loss: 0.093
[3,   580] loss: 0.132




[3,   590] loss: 0.150
[3,   600] loss: 0.021




[3,   610] loss: 0.095
[3,   620] loss: 0.033




[3,   630] loss: 0.064
[3,   640] loss: 0.080




[3,   650] loss: 0.015
[3,   660] loss: 0.089




[3,   670] loss: 0.053
[3,   680] loss: 0.117




[3,   690] loss: 0.099
[3,   700] loss: 0.166




[3,   710] loss: 0.119
[3,   720] loss: 0.064




[3,   730] loss: 0.086
[3,   740] loss: 0.052




[3,   750] loss: 0.108
[3,   760] loss: 0.066




[3,   770] loss: 0.129
[3,   780] loss: 0.086




[3,   790] loss: 0.103
[3,   800] loss: 0.067




[3,   810] loss: 0.021
[3,   820] loss: 0.173




[3,   830] loss: 0.166
[3,   840] loss: 0.046




[3,   850] loss: 0.052
[3,   860] loss: 0.064




[3,   870] loss: 0.128
[3,   880] loss: 0.059




[3,   890] loss: 0.071
[3,   900] loss: 0.072




[3,   910] loss: 0.042
[3,   920] loss: 0.063




[3,   930] loss: 0.170
[3,   940] loss: 0.086




[3,   950] loss: 0.105
[3,   960] loss: 0.091




[3,   970] loss: 0.064
[3,   980] loss: 0.065


100%|██████████| 1000/1000 [00:13<00:00, 71.75it/s]


[3,   990] loss: 0.038
[3,  1000] loss: 0.063
Epoch 3 | Train Loss: 225.2832


100%|██████████| 1000/1000 [00:06<00:00, 146.77it/s]
 38%|███▊      | 3/8 [01:02<01:44, 20.85s/it]

Epoch 3 | Eval Loss: 0.5947




[4,    10] loss: 0.035




[4,    20] loss: 0.101




[4,    30] loss: 0.117




[4,    40] loss: 0.106




[4,    50] loss: 0.032




[4,    60] loss: 0.124




[4,    70] loss: 0.079




[4,    80] loss: 0.063




[4,    90] loss: 0.026




[4,   100] loss: 0.041




[4,   110] loss: 0.060




[4,   120] loss: 0.110
[4,   130] loss: 0.056




[4,   140] loss: 0.059
[4,   150] loss: 0.103




[4,   160] loss: 0.030
[4,   170] loss: 0.047




[4,   180] loss: 0.048
[4,   190] loss: 0.037




[4,   200] loss: 0.010
[4,   210] loss: 0.118




[4,   220] loss: 0.097
[4,   230] loss: 0.059




[4,   240] loss: 0.026
[4,   250] loss: 0.044




[4,   260] loss: 0.093
[4,   270] loss: 0.075




[4,   280] loss: 0.050
[4,   290] loss: 0.081




[4,   300] loss: 0.049
[4,   310] loss: 0.019




[4,   320] loss: 0.103
[4,   330] loss: 0.099




[4,   340] loss: 0.042
[4,   350] loss: 0.082




[4,   360] loss: 0.055
[4,   370] loss: 0.077




[4,   380] loss: 0.075
[4,   390] loss: 0.051




[4,   400] loss: 0.057
[4,   410] loss: 0.038




[4,   420] loss: 0.056
[4,   430] loss: 0.117




[4,   440] loss: 0.060
[4,   450] loss: 0.030




[4,   460] loss: 0.074
[4,   470] loss: 0.092




[4,   480] loss: 0.029
[4,   490] loss: 0.091




[4,   500] loss: 0.033
[4,   510] loss: 0.064




[4,   520] loss: 0.073
[4,   530] loss: 0.063




[4,   540] loss: 0.063
[4,   550] loss: 0.081




[4,   560] loss: 0.089
[4,   570] loss: 0.051




[4,   580] loss: 0.034
[4,   590] loss: 0.048




[4,   600] loss: 0.038
[4,   610] loss: 0.039




[4,   620] loss: 0.036
[4,   630] loss: 0.045




[4,   640] loss: 0.121
[4,   650] loss: 0.050




[4,   660] loss: 0.061
[4,   670] loss: 0.049




[4,   680] loss: 0.026
[4,   690] loss: 0.055




[4,   700] loss: 0.030
[4,   710] loss: 0.082




[4,   720] loss: 0.041
[4,   730] loss: 0.080




[4,   740] loss: 0.037
[4,   750] loss: 0.080




[4,   760] loss: 0.068
[4,   770] loss: 0.068




[4,   780] loss: 0.057
[4,   790] loss: 0.062




[4,   800] loss: 0.076
[4,   810] loss: 0.074




[4,   820] loss: 0.113
[4,   830] loss: 0.061




[4,   840] loss: 0.097
[4,   850] loss: 0.040




[4,   860] loss: 0.106
[4,   870] loss: 0.082




[4,   880] loss: 0.064
[4,   890] loss: 0.057




[4,   900] loss: 0.038
[4,   910] loss: 0.110




[4,   920] loss: 0.120
[4,   930] loss: 0.078




[4,   940] loss: 0.070
[4,   950] loss: 0.054




[4,   960] loss: 0.136
[4,   970] loss: 0.056




[4,   980] loss: 0.051
[4,   990] loss: 0.105


100%|██████████| 1000/1000 [00:14<00:00, 69.07it/s]


[4,  1000] loss: 0.067
Epoch 4 | Train Loss: 132.1004


100%|██████████| 1000/1000 [00:06<00:00, 147.28it/s]
 50%|█████     | 4/8 [01:23<01:24, 21.02s/it]

Epoch 4 | Eval Loss: 0.5971




[5,    10] loss: 0.043




[5,    20] loss: 0.056




[5,    30] loss: 0.118




[5,    40] loss: 0.133




[5,    50] loss: 0.076




[5,    60] loss: 0.041
[5,    70] loss: 0.054




[5,    80] loss: 0.029
[5,    90] loss: 0.080




[5,   100] loss: 0.074
[5,   110] loss: 0.055




[5,   120] loss: 0.074
[5,   130] loss: 0.059




[5,   140] loss: 0.050
[5,   150] loss: 0.093




[5,   160] loss: 0.062
[5,   170] loss: 0.066




[5,   180] loss: 0.068
[5,   190] loss: 0.047




[5,   200] loss: 0.048
[5,   210] loss: 0.062




[5,   220] loss: 0.050
[5,   230] loss: 0.100




[5,   240] loss: 0.054
[5,   250] loss: 0.088




[5,   260] loss: 0.035
[5,   270] loss: 0.063




[5,   280] loss: 0.026
[5,   290] loss: 0.061




[5,   300] loss: 0.061




[5,   310] loss: 0.089
[5,   320] loss: 0.062




[5,   330] loss: 0.055
[5,   340] loss: 0.061




[5,   350] loss: 0.053
[5,   360] loss: 0.033




[5,   370] loss: 0.077
[5,   380] loss: 0.027




[5,   390] loss: 0.054
[5,   400] loss: 0.066




[5,   410] loss: 0.104
[5,   420] loss: 0.047




[5,   430] loss: 0.057
[5,   440] loss: 0.041




[5,   450] loss: 0.033
[5,   460] loss: 0.077




[5,   470] loss: 0.038
[5,   480] loss: 0.085




[5,   490] loss: 0.045
[5,   500] loss: 0.056




[5,   510] loss: 0.077
[5,   520] loss: 0.069




[5,   530] loss: 0.057
[5,   540] loss: 0.075




[5,   550] loss: 0.087
[5,   560] loss: 0.056




[5,   570] loss: 0.054
[5,   580] loss: 0.054




[5,   590] loss: 0.037
[5,   600] loss: 0.052




[5,   610] loss: 0.016
[5,   620] loss: 0.042




[5,   630] loss: 0.048
[5,   640] loss: 0.091




[5,   650] loss: 0.037
[5,   660] loss: 0.060




[5,   670] loss: 0.035
[5,   680] loss: 0.085




[5,   690] loss: 0.060
[5,   700] loss: 0.025




[5,   710] loss: 0.038
[5,   720] loss: 0.038




[5,   730] loss: 0.046
[5,   740] loss: 0.054




[5,   750] loss: 0.069
[5,   760] loss: 0.051




[5,   770] loss: 0.050
[5,   780] loss: 0.063




[5,   790] loss: 0.022
[5,   800] loss: 0.082




[5,   810] loss: 0.040
[5,   820] loss: 0.036




[5,   830] loss: 0.082
[5,   840] loss: 0.062




[5,   850] loss: 0.094
[5,   860] loss: 0.071




[5,   870] loss: 0.042
[5,   880] loss: 0.058




[5,   890] loss: 0.054
[5,   900] loss: 0.044




[5,   910] loss: 0.038
[5,   920] loss: 0.082




[5,   930] loss: 0.059
[5,   940] loss: 0.046




[5,   950] loss: 0.057
[5,   960] loss: 0.067




[5,   970] loss: 0.108
[5,   980] loss: 0.110


100%|██████████| 1000/1000 [00:13<00:00, 73.35it/s]


[5,   990] loss: 0.090
[5,  1000] loss: 0.069
Epoch 5 | Train Loss: 120.5056


100%|██████████| 1000/1000 [00:06<00:00, 155.70it/s]
 62%|██████▎   | 5/8 [01:43<01:02, 20.67s/it]

Epoch 5 | Eval Loss: 0.6521




[6,    10] loss: 0.093




[6,    20] loss: 0.070




[6,    30] loss: 0.050




[6,    40] loss: 0.076




[6,    50] loss: 0.043




[6,    60] loss: 0.061




[6,    70] loss: 0.038




[6,    80] loss: 0.073




[6,    90] loss: 0.092




[6,   100] loss: 0.070
[6,   110] loss: 0.103




[6,   120] loss: 0.058




[6,   130] loss: 0.054




[6,   140] loss: 0.095




[6,   150] loss: 0.062




[6,   160] loss: 0.071




[6,   170] loss: 0.030




[6,   180] loss: 0.039




[6,   190] loss: 0.064




[6,   200] loss: 0.090




[6,   210] loss: 0.036




[6,   220] loss: 0.059




[6,   230] loss: 0.019




[6,   240] loss: 0.062




[6,   250] loss: 0.111




[6,   260] loss: 0.078




[6,   270] loss: 0.072




[6,   280] loss: 0.061




[6,   290] loss: 0.084




[6,   300] loss: 0.076




[6,   310] loss: 0.043
[6,   320] loss: 0.053




[6,   330] loss: 0.026
[6,   340] loss: 0.034




[6,   350] loss: 0.029
[6,   360] loss: 0.035




[6,   370] loss: 0.085
[6,   380] loss: 0.100




[6,   390] loss: 0.063
[6,   400] loss: 0.050




[6,   410] loss: 0.052
[6,   420] loss: 0.082




[6,   430] loss: 0.042
[6,   440] loss: 0.037




[6,   450] loss: 0.033
[6,   460] loss: 0.072




[6,   470] loss: 0.059
[6,   480] loss: 0.032




[6,   490] loss: 0.052
[6,   500] loss: 0.053




[6,   510] loss: 0.071
[6,   520] loss: 0.032




[6,   530] loss: 0.061
[6,   540] loss: 0.068




[6,   550] loss: 0.033
[6,   560] loss: 0.043




[6,   570] loss: 0.052
[6,   580] loss: 0.102




[6,   590] loss: 0.037
[6,   600] loss: 0.054




[6,   610] loss: 0.061
[6,   620] loss: 0.068




[6,   630] loss: 0.061
[6,   640] loss: 0.110




[6,   650] loss: 0.063
[6,   660] loss: 0.083




[6,   670] loss: 0.062
[6,   680] loss: 0.041




[6,   690] loss: 0.065
[6,   700] loss: 0.042




[6,   710] loss: 0.051
[6,   720] loss: 0.112




[6,   730] loss: 0.072
[6,   740] loss: 0.059




[6,   750] loss: 0.027
[6,   760] loss: 0.134




[6,   770] loss: 0.059
[6,   780] loss: 0.195




[6,   790] loss: 0.103
[6,   800] loss: 0.097




[6,   810] loss: 0.081
[6,   820] loss: 0.061




[6,   830] loss: 0.038
[6,   840] loss: 0.038




[6,   850] loss: 0.033
[6,   860] loss: 0.120




[6,   870] loss: 0.086
[6,   880] loss: 0.074




[6,   890] loss: 0.083
[6,   900] loss: 0.042




[6,   910] loss: 0.104
[6,   920] loss: 0.058




[6,   930] loss: 0.058
[6,   940] loss: 0.067




[6,   950] loss: 0.087
[6,   960] loss: 0.044




[6,   970] loss: 0.026
[6,   980] loss: 0.088


100%|██████████| 1000/1000 [00:12<00:00, 79.69it/s]


[6,   990] loss: 0.052
[6,  1000] loss: 0.040
Epoch 6 | Train Loss: 128.4158


100%|██████████| 1000/1000 [00:06<00:00, 155.39it/s]
 75%|███████▌  | 6/8 [02:02<00:40, 20.10s/it]

Epoch 6 | Eval Loss: 0.5405




[7,    10] loss: 0.021




[7,    20] loss: 0.016




[7,    30] loss: 0.044




[7,    40] loss: 0.077




[7,    50] loss: 0.111




[7,    60] loss: 0.058




[7,    70] loss: 0.064




[7,    80] loss: 0.065




[7,    90] loss: 0.052




[7,   100] loss: 0.058




[7,   110] loss: 0.044




[7,   120] loss: 0.067




[7,   130] loss: 0.054




[7,   140] loss: 0.054




[7,   150] loss: 0.064




[7,   160] loss: 0.047




[7,   170] loss: 0.056




[7,   180] loss: 0.057




[7,   190] loss: 0.040




[7,   200] loss: 0.084




[7,   210] loss: 0.058




[7,   220] loss: 0.062




[7,   230] loss: 0.086




[7,   240] loss: 0.074




[7,   250] loss: 0.056




[7,   260] loss: 0.070




[7,   270] loss: 0.063




[7,   280] loss: 0.138




[7,   290] loss: 0.055
[7,   300] loss: 0.087




[7,   310] loss: 0.059




[7,   320] loss: 0.038




[7,   330] loss: 0.072




[7,   340] loss: 0.065




[7,   350] loss: 0.061




[7,   360] loss: 0.095




[7,   370] loss: 0.067




[7,   380] loss: 0.051




[7,   390] loss: 0.065




[7,   400] loss: 0.081




[7,   410] loss: 0.055




[7,   420] loss: 0.034




[7,   430] loss: 0.063




[7,   440] loss: 0.044




[7,   450] loss: 0.032




[7,   460] loss: 0.081
[7,   470] loss: 0.060




[7,   480] loss: 0.043




[7,   490] loss: 0.056




[7,   500] loss: 0.084




[7,   510] loss: 0.072




[7,   520] loss: 0.060




[7,   530] loss: 0.092




[7,   540] loss: 0.070




[7,   550] loss: 0.072




[7,   560] loss: 0.066




[7,   570] loss: 0.076




[7,   580] loss: 0.040




[7,   590] loss: 0.066




[7,   600] loss: 0.050




[7,   610] loss: 0.046




[7,   620] loss: 0.034




[7,   630] loss: 0.041




[7,   640] loss: 0.038




[7,   650] loss: 0.049
[7,   660] loss: 0.046




[7,   670] loss: 0.014
[7,   680] loss: 0.067




[7,   690] loss: 0.059
[7,   700] loss: 0.058




[7,   710] loss: 0.030
[7,   720] loss: 0.033




[7,   730] loss: 0.079
[7,   740] loss: 0.058




[7,   750] loss: 0.046
[7,   760] loss: 0.100




[7,   770] loss: 0.060
[7,   780] loss: 0.115




[7,   790] loss: 0.082
[7,   800] loss: 0.068




[7,   810] loss: 0.083
[7,   820] loss: 0.062




[7,   830] loss: 0.049
[7,   840] loss: 0.037




[7,   850] loss: 0.064
[7,   860] loss: 0.091




[7,   870] loss: 0.057
[7,   880] loss: 0.069




[7,   890] loss: 0.053
[7,   900] loss: 0.053




[7,   910] loss: 0.062
[7,   920] loss: 0.074




[7,   930] loss: 0.063
[7,   940] loss: 0.046




[7,   950] loss: 0.057
[7,   960] loss: 0.047





[7,   970] loss: 0.081
[7,   980] loss: 0.083


100%|██████████| 1000/1000 [00:12<00:00, 78.47it/s][A


[7,   990] loss: 0.042
[7,  1000] loss: 0.084
Epoch 7 | Train Loss: 122.5666


100%|██████████| 1000/1000 [00:06<00:00, 155.13it/s]
 88%|████████▊ | 7/8 [02:21<00:19, 19.80s/it]

Epoch 7 | Eval Loss: 0.5317




[8,    10] loss: 0.052




[8,    20] loss: 0.043




[8,    30] loss: 0.066




[8,    40] loss: 0.051




[8,    50] loss: 0.036




[8,    60] loss: 0.063




[8,    70] loss: 0.046




[8,    80] loss: 0.084




[8,    90] loss: 0.056




[8,   100] loss: 0.043




[8,   110] loss: 0.026




[8,   120] loss: 0.097




[8,   130] loss: 0.069




[8,   140] loss: 0.043




[8,   150] loss: 0.040




[8,   160] loss: 0.056




[8,   170] loss: 0.043




[8,   180] loss: 0.045




[8,   190] loss: 0.048




[8,   200] loss: 0.050




[8,   210] loss: 0.011




[8,   220] loss: 0.120




[8,   230] loss: 0.075




[8,   240] loss: 0.071




[8,   250] loss: 0.070




[8,   260] loss: 0.058




[8,   270] loss: 0.073




[8,   280] loss: 0.054



 30%|██▉       | 299/1000 [00:03<00:07, 90.01it/s]

[8,   290] loss: 0.038


[A

[8,   300] loss: 0.030




[8,   310] loss: 0.017
[8,   320] loss: 0.037




[8,   330] loss: 0.097
[8,   340] loss: 0.082



 37%|███▋      | 366/1000 [00:04<00:08, 76.68it/s]

[8,   350] loss: 0.082
[8,   360] loss: 0.099


[A

[8,   370] loss: 0.048
[8,   380] loss: 0.084




[8,   390] loss: 0.070
[8,   400] loss: 0.067




[8,   410] loss: 0.087
[8,   420] loss: 0.075




[8,   430] loss: 0.066
[8,   440] loss: 0.063




[8,   450] loss: 0.060
[8,   460] loss: 0.077




[8,   470] loss: 0.036
[8,   480] loss: 0.067




[8,   490] loss: 0.093
[8,   500] loss: 0.064




[8,   510] loss: 0.059
[8,   520] loss: 0.066




[8,   530] loss: 0.035
[8,   540] loss: 0.061




[8,   550] loss: 0.072
[8,   560] loss: 0.056





[8,   570] loss: 0.051
[8,   580] loss: 0.054


 58%|█████▊    | 585/1000 [00:07<00:05, 74.88it/s][A

[8,   590] loss: 0.049
[8,   600] loss: 0.044




[8,   610] loss: 0.070
[8,   620] loss: 0.015




[8,   630] loss: 0.053
[8,   640] loss: 0.058




[8,   650] loss: 0.107
[8,   660] loss: 0.087




[8,   670] loss: 0.051
[8,   680] loss: 0.064




[8,   690] loss: 0.040
[8,   700] loss: 0.059




[8,   710] loss: 0.049
[8,   720] loss: 0.029




[8,   730] loss: 0.025
[8,   740] loss: 0.039




[8,   750] loss: 0.114
[8,   760] loss: 0.077




[8,   770] loss: 0.052
[8,   780] loss: 0.071




[8,   790] loss: 0.079
[8,   800] loss: 0.050




[8,   810] loss: 0.029
[8,   820] loss: 0.038




[8,   830] loss: 0.087
[8,   840] loss: 0.067




[8,   850] loss: 0.047
[8,   860] loss: 0.046




[8,   870] loss: 0.172
[8,   880] loss: 0.103




[8,   890] loss: 0.546
[8,   900] loss: 0.540




[8,   910] loss: 0.045
[8,   920] loss: 0.574




[8,   930] loss: 0.493
[8,   940] loss: 0.492




[8,   950] loss: 0.573
[8,   960] loss: 0.260




[8,   970] loss: 0.088
[8,   980] loss: 0.104


100%|██████████| 1000/1000 [00:12<00:00, 80.21it/s]


[8,   990] loss: 0.086
[8,  1000] loss: 0.112
Epoch 8 | Train Loss: 185.2560


100%|██████████| 1000/1000 [00:06<00:00, 154.62it/s]
                                             

Epoch 8 | Eval Loss: 0.7164




0.9262801191556247