In [1]:
pwd

'/home/onyxia/work/ensae-aml-projet'

In [2]:
cd code_model

/home/onyxia/work/ensae-aml-projet/code_model


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
import os
from string import digits

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence

from transformers import BertModel, BertTokenizerFast
from transformers.modeling_outputs import SequenceClassifierOutput

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
from collections import Counter

In [4]:
# assign seed to numpy and PyTorch
seed=2025
torch.manual_seed(seed)
np.random.seed(seed) 

In [5]:
################################################################################################################
#### Retrieve data #############################################################################################
################################################################################################################

train_dir = sorted([f for f in os.listdir("../training_data/test-and-training/training_data/") if f.endswith('xlsx')])
test_dir = sorted([f for f in os.listdir("../training_data/test-and-training/test_data/") if f.endswith('xlsx')])
remove_digits = str.maketrans('', '', digits)

for f in range(len(train_dir[:1])): # on ne s'intéresse que aux fichiers split-combine (le plus général, données de meilleure qualité), au 1er split
    train = pd.read_excel("../training_data/test-and-training/training_data/" + train_dir[f], index_col=False)[['sentence', 'label']]
    test = pd.read_excel("../training_data/test-and-training/test_data/" + test_dir[f], index_col=False)[['sentence', 'label']]

sentences = train['sentence'].tolist()
labels = train['label'].to_numpy()
sentences_test = test['sentence'].tolist()
labels_test = test['label'].to_numpy()

In [6]:
################################################################################################################
#### Tokenization ##############################################################################################
################################################################################################################

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
max_length = 0
sentence_input = []
labels_output = []
for i, sentence in enumerate(sentences):
    if isinstance(sentence, str):
        tokens = tokenizer(sentence)['input_ids']
        sentence_input.append(sentence)
        max_length = max(max_length, len(tokens))
        labels_output.append(labels[i])
    else:
        pass
max_length=256
tokens = tokenizer(sentence_input, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
labels = np.array(labels_output)

In [7]:
print(tokens)

{'input_ids': tensor([[  101,  2256, 11405,  ...,     0,     0,     0],
        [  101,  2076,  1996,  ...,     0,     0,     0],
        [  101,  1996,  2837,  ...,     0,     0,     0],
        ...,
        [  101,  1996, 25323,  ...,     0,     0,     0],
        [  101,  1996, 17035,  ...,     0,     0,     0],
        [  101,  9308,  1010,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [10]:
################################################################################################################
#### Dataset handling ##########################################################################################
################################################################################################################

input_ids = tokens['input_ids']
attention_masks = tokens['attention_mask']
labels = torch.LongTensor(labels)
dataset = TensorDataset(input_ids, attention_masks, labels)
val_length = int(len(dataset) * 0.2)
train_length = len(dataset) - val_length
print(f'Train Size: {train_length}, Validation Size: {val_length}')
train, val = torch.utils.data.random_split(dataset=dataset, lengths=[train_length, val_length]) # create train-val split

Train Size: 1588, Validation Size: 396


In [11]:
################################################################################################################
#### Model #####################################################################################################
################################################################################################################

class AttentionalDecoder(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super(AttentionalDecoder, self).__init__()
        self.query_proj = nn.Linear(hidden_size, hidden_size)
        self.key_proj = nn.Linear(hidden_size, hidden_size)
        self.value_proj = nn.Linear(hidden_size, hidden_size)        
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=8)
        
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.1)

        self.attn_output_weights = None # store attn weights

    def forward(self, encoder_output):
        query = self.query_proj(encoder_output)
        key = self.key_proj(encoder_output)
        value = self.value_proj(encoder_output)
        
        attn_output, attn_output_weights = self.attention(query=query, key=key, value=value)
        
        attn_output = self.layer_norm(attn_output + encoder_output)
        attn_output = self.dropout(attn_output)
        cls_output = attn_output[:, 0, :] # [CLS] token
        output = self.fc(cls_output)
        return output, attn_output_weights

class BERTClassifier(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.1):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        for param in self.bert.parameters(): # freeze BERT encoder
            param.requires_grad = False
        
        self.decoder = AttentionalDecoder(self.bert.config.hidden_size, num_classes)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input_ids, attention_mask, **kwargs):
        with torch.no_grad():
            encoder_output = self.bert(
                input_ids, 
                attention_mask=attention_mask
            ).last_hidden_state
        encoder_output = self.dropout(encoder_output)
        
        output, attn_output_weights = self.decoder(encoder_output)
        return output, attn_output_weights

In [14]:
################################################################################################################
#### Training ##################################################################################################
################################################################################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(num_classes=3).to(device)
print([name for name, param in model.named_parameters() if param.requires_grad]) # check that the BERT parameters are well frozen
batch_size = 16
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
dataloader_train = DataLoader(train, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(val, batch_size=batch_size, shuffle=True)

['decoder.query_proj.weight', 'decoder.query_proj.bias', 'decoder.key_proj.weight', 'decoder.key_proj.bias', 'decoder.value_proj.weight', 'decoder.value_proj.bias', 'decoder.attention.in_proj_weight', 'decoder.attention.in_proj_bias', 'decoder.attention.out_proj.weight', 'decoder.attention.out_proj.bias', 'decoder.layer_norm.weight', 'decoder.layer_norm.bias', 'decoder.fc.weight', 'decoder.fc.bias']


In [None]:
model.train()
for input_ids, attention_masks, labels in tqdm(dataloader_train):
    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)
    labels = labels.to(device)
    
    optimizer.zero_grad()
    outputs, attn_output_weights = model(input_ids=input_ids, attention_mask=attention_masks)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
        
    print(attn_output_weights)

  1%|          | 1/100 [00:08<13:40,  8.29s/it]

tensor([[[0.0624, 0.0624, 0.0623,  ..., 0.0629, 0.0629, 0.0622],
         [0.0624, 0.0625, 0.0619,  ..., 0.0631, 0.0633, 0.0624],
         [0.0627, 0.0626, 0.0629,  ..., 0.0632, 0.0628, 0.0621],
         ...,
         [0.0628, 0.0624, 0.0636,  ..., 0.0624, 0.0628, 0.0616],
         [0.0624, 0.0624, 0.0626,  ..., 0.0625, 0.0634, 0.0615],
         [0.0624, 0.0628, 0.0642,  ..., 0.0630, 0.0624, 0.0622]],

        [[0.0608, 0.0626, 0.0616,  ..., 0.0625, 0.0626, 0.0631],
         [0.0622, 0.0616, 0.0605,  ..., 0.0627, 0.0621, 0.0618],
         [0.0612, 0.0629, 0.0625,  ..., 0.0627, 0.0620, 0.0628],
         ...,
         [0.0626, 0.0619, 0.0620,  ..., 0.0627, 0.0633, 0.0633],
         [0.0604, 0.0614, 0.0616,  ..., 0.0620, 0.0639, 0.0620],
         [0.0629, 0.0631, 0.0618,  ..., 0.0614, 0.0637, 0.0619]],

        [[0.0632, 0.0633, 0.0623,  ..., 0.0614, 0.0612, 0.0639],
         [0.0638, 0.0637, 0.0616,  ..., 0.0628, 0.0607, 0.0624],
         [0.0651, 0.0614, 0.0624,  ..., 0.0626, 0.0617, 0.

  2%|▏         | 2/100 [00:16<13:09,  8.06s/it]

tensor([[[0.0627, 0.0630, 0.0628,  ..., 0.0627, 0.0624, 0.0625],
         [0.0623, 0.0637, 0.0629,  ..., 0.0613, 0.0632, 0.0620],
         [0.0625, 0.0636, 0.0631,  ..., 0.0624, 0.0625, 0.0623],
         ...,
         [0.0625, 0.0639, 0.0627,  ..., 0.0633, 0.0624, 0.0628],
         [0.0622, 0.0640, 0.0628,  ..., 0.0632, 0.0622, 0.0623],
         [0.0626, 0.0643, 0.0627,  ..., 0.0626, 0.0625, 0.0626]],

        [[0.0631, 0.0618, 0.0638,  ..., 0.0628, 0.0608, 0.0629],
         [0.0641, 0.0625, 0.0633,  ..., 0.0615, 0.0618, 0.0618],
         [0.0642, 0.0637, 0.0636,  ..., 0.0611, 0.0611, 0.0618],
         ...,
         [0.0625, 0.0632, 0.0634,  ..., 0.0608, 0.0633, 0.0619],
         [0.0631, 0.0614, 0.0631,  ..., 0.0619, 0.0616, 0.0639],
         [0.0620, 0.0616, 0.0626,  ..., 0.0620, 0.0621, 0.0629]],

        [[0.0640, 0.0600, 0.0630,  ..., 0.0620, 0.0625, 0.0626],
         [0.0636, 0.0617, 0.0624,  ..., 0.0622, 0.0630, 0.0626],
         [0.0639, 0.0621, 0.0622,  ..., 0.0593, 0.0629, 0.

  3%|▎         | 3/100 [00:23<12:41,  7.85s/it]

tensor([[[0.0619, 0.0611, 0.0630,  ..., 0.0628, 0.0631, 0.0622],
         [0.0618, 0.0614, 0.0625,  ..., 0.0627, 0.0639, 0.0620],
         [0.0624, 0.0618, 0.0618,  ..., 0.0630, 0.0639, 0.0623],
         ...,
         [0.0628, 0.0614, 0.0631,  ..., 0.0626, 0.0630, 0.0623],
         [0.0623, 0.0622, 0.0619,  ..., 0.0629, 0.0634, 0.0621],
         [0.0617, 0.0616, 0.0629,  ..., 0.0627, 0.0638, 0.0617]],

        [[0.0624, 0.0632, 0.0625,  ..., 0.0638, 0.0629, 0.0631],
         [0.0618, 0.0644, 0.0626,  ..., 0.0635, 0.0615, 0.0622],
         [0.0628, 0.0627, 0.0622,  ..., 0.0628, 0.0629, 0.0627],
         ...,
         [0.0623, 0.0614, 0.0622,  ..., 0.0612, 0.0625, 0.0626],
         [0.0629, 0.0616, 0.0625,  ..., 0.0626, 0.0632, 0.0623],
         [0.0613, 0.0637, 0.0611,  ..., 0.0627, 0.0628, 0.0623]],

        [[0.0625, 0.0636, 0.0618,  ..., 0.0633, 0.0623, 0.0634],
         [0.0626, 0.0641, 0.0617,  ..., 0.0638, 0.0615, 0.0637],
         [0.0613, 0.0624, 0.0639,  ..., 0.0629, 0.0625, 0.

  4%|▍         | 4/100 [00:32<12:58,  8.11s/it]

tensor([[[0.0618, 0.0619, 0.0626,  ..., 0.0634, 0.0622, 0.0621],
         [0.0619, 0.0622, 0.0624,  ..., 0.0628, 0.0620, 0.0625],
         [0.0626, 0.0633, 0.0623,  ..., 0.0625, 0.0626, 0.0623],
         ...,
         [0.0621, 0.0625, 0.0627,  ..., 0.0633, 0.0620, 0.0629],
         [0.0612, 0.0629, 0.0626,  ..., 0.0636, 0.0613, 0.0621],
         [0.0621, 0.0627, 0.0619,  ..., 0.0631, 0.0619, 0.0629]],

        [[0.0624, 0.0620, 0.0629,  ..., 0.0623, 0.0616, 0.0633],
         [0.0614, 0.0623, 0.0620,  ..., 0.0617, 0.0634, 0.0633],
         [0.0630, 0.0610, 0.0613,  ..., 0.0626, 0.0625, 0.0635],
         ...,
         [0.0624, 0.0619, 0.0620,  ..., 0.0617, 0.0631, 0.0639],
         [0.0612, 0.0630, 0.0627,  ..., 0.0623, 0.0642, 0.0622],
         [0.0646, 0.0597, 0.0612,  ..., 0.0635, 0.0615, 0.0646]],

        [[0.0621, 0.0640, 0.0630,  ..., 0.0620, 0.0621, 0.0625],
         [0.0624, 0.0614, 0.0630,  ..., 0.0608, 0.0618, 0.0643],
         [0.0622, 0.0626, 0.0635,  ..., 0.0620, 0.0631, 0.

  5%|▌         | 5/100 [00:40<13:07,  8.29s/it]

tensor([[[0.0629, 0.0617, 0.0621,  ..., 0.0628, 0.0622, 0.0636],
         [0.0629, 0.0618, 0.0618,  ..., 0.0631, 0.0608, 0.0631],
         [0.0633, 0.0623, 0.0619,  ..., 0.0631, 0.0612, 0.0636],
         ...,
         [0.0627, 0.0623, 0.0625,  ..., 0.0628, 0.0620, 0.0633],
         [0.0630, 0.0623, 0.0618,  ..., 0.0627, 0.0618, 0.0629],
         [0.0630, 0.0621, 0.0623,  ..., 0.0627, 0.0623, 0.0632]],

        [[0.0624, 0.0626, 0.0633,  ..., 0.0625, 0.0619, 0.0631],
         [0.0630, 0.0621, 0.0611,  ..., 0.0634, 0.0640, 0.0617],
         [0.0630, 0.0614, 0.0634,  ..., 0.0627, 0.0628, 0.0613],
         ...,
         [0.0623, 0.0613, 0.0624,  ..., 0.0624, 0.0621, 0.0636],
         [0.0631, 0.0618, 0.0624,  ..., 0.0644, 0.0627, 0.0618],
         [0.0633, 0.0614, 0.0619,  ..., 0.0620, 0.0629, 0.0619]],

        [[0.0621, 0.0633, 0.0616,  ..., 0.0628, 0.0614, 0.0637],
         [0.0622, 0.0629, 0.0616,  ..., 0.0623, 0.0607, 0.0637],
         [0.0617, 0.0640, 0.0606,  ..., 0.0613, 0.0610, 0.

  6%|▌         | 6/100 [00:49<13:24,  8.56s/it]

tensor([[[0.0623, 0.0626, 0.0623,  ..., 0.0619, 0.0627, 0.0616],
         [0.0626, 0.0617, 0.0623,  ..., 0.0619, 0.0623, 0.0616],
         [0.0630, 0.0620, 0.0623,  ..., 0.0609, 0.0626, 0.0618],
         ...,
         [0.0627, 0.0622, 0.0619,  ..., 0.0611, 0.0626, 0.0616],
         [0.0628, 0.0620, 0.0623,  ..., 0.0616, 0.0622, 0.0619],
         [0.0624, 0.0617, 0.0629,  ..., 0.0615, 0.0623, 0.0624]],

        [[0.0618, 0.0621, 0.0620,  ..., 0.0619, 0.0624, 0.0622],
         [0.0625, 0.0630, 0.0620,  ..., 0.0616, 0.0620, 0.0611],
         [0.0626, 0.0606, 0.0639,  ..., 0.0613, 0.0650, 0.0603],
         ...,
         [0.0628, 0.0637, 0.0620,  ..., 0.0633, 0.0630, 0.0620],
         [0.0612, 0.0624, 0.0626,  ..., 0.0629, 0.0625, 0.0612],
         [0.0626, 0.0625, 0.0623,  ..., 0.0618, 0.0630, 0.0614]],

        [[0.0625, 0.0620, 0.0628,  ..., 0.0619, 0.0619, 0.0630],
         [0.0617, 0.0618, 0.0621,  ..., 0.0628, 0.0632, 0.0618],
         [0.0609, 0.0612, 0.0630,  ..., 0.0622, 0.0629, 0.

  7%|▋         | 7/100 [00:59<13:39,  8.81s/it]

tensor([[[0.0609, 0.0628, 0.0641,  ..., 0.0615, 0.0629, 0.0631],
         [0.0600, 0.0626, 0.0648,  ..., 0.0623, 0.0631, 0.0631],
         [0.0606, 0.0631, 0.0647,  ..., 0.0624, 0.0634, 0.0632],
         ...,
         [0.0615, 0.0628, 0.0644,  ..., 0.0616, 0.0633, 0.0622],
         [0.0611, 0.0624, 0.0638,  ..., 0.0622, 0.0632, 0.0629],
         [0.0604, 0.0630, 0.0658,  ..., 0.0616, 0.0635, 0.0631]],

        [[0.0637, 0.0634, 0.0631,  ..., 0.0631, 0.0636, 0.0593],
         [0.0629, 0.0641, 0.0633,  ..., 0.0631, 0.0628, 0.0592],
         [0.0624, 0.0626, 0.0634,  ..., 0.0631, 0.0624, 0.0609],
         ...,
         [0.0627, 0.0637, 0.0621,  ..., 0.0619, 0.0614, 0.0614],
         [0.0623, 0.0627, 0.0628,  ..., 0.0632, 0.0628, 0.0615],
         [0.0641, 0.0630, 0.0629,  ..., 0.0626, 0.0615, 0.0614]],

        [[0.0622, 0.0622, 0.0623,  ..., 0.0631, 0.0614, 0.0628],
         [0.0628, 0.0619, 0.0616,  ..., 0.0617, 0.0622, 0.0639],
         [0.0624, 0.0627, 0.0624,  ..., 0.0632, 0.0619, 0.

  8%|▊         | 8/100 [01:07<13:12,  8.62s/it]

tensor([[[0.0622, 0.0639, 0.0647,  ..., 0.0644, 0.0635, 0.0631],
         [0.0620, 0.0639, 0.0638,  ..., 0.0640, 0.0638, 0.0632],
         [0.0620, 0.0640, 0.0639,  ..., 0.0636, 0.0632, 0.0626],
         ...,
         [0.0629, 0.0634, 0.0639,  ..., 0.0639, 0.0636, 0.0629],
         [0.0622, 0.0645, 0.0640,  ..., 0.0636, 0.0635, 0.0623],
         [0.0619, 0.0646, 0.0643,  ..., 0.0653, 0.0637, 0.0630]],

        [[0.0640, 0.0617, 0.0622,  ..., 0.0647, 0.0598, 0.0651],
         [0.0626, 0.0628, 0.0628,  ..., 0.0639, 0.0614, 0.0630],
         [0.0629, 0.0620, 0.0612,  ..., 0.0656, 0.0604, 0.0634],
         ...,
         [0.0621, 0.0617, 0.0624,  ..., 0.0650, 0.0608, 0.0637],
         [0.0633, 0.0632, 0.0625,  ..., 0.0638, 0.0615, 0.0626],
         [0.0642, 0.0607, 0.0629,  ..., 0.0656, 0.0597, 0.0654]],

        [[0.0621, 0.0631, 0.0621,  ..., 0.0638, 0.0629, 0.0623],
         [0.0617, 0.0624, 0.0614,  ..., 0.0646, 0.0618, 0.0622],
         [0.0606, 0.0639, 0.0627,  ..., 0.0629, 0.0604, 0.

  9%|▉         | 9/100 [01:17<13:32,  8.93s/it]

tensor([[[0.0619, 0.0638, 0.0613,  ..., 0.0641, 0.0632, 0.0641],
         [0.0617, 0.0641, 0.0622,  ..., 0.0637, 0.0619, 0.0652],
         [0.0607, 0.0652, 0.0625,  ..., 0.0646, 0.0616, 0.0653],
         ...,
         [0.0614, 0.0645, 0.0620,  ..., 0.0637, 0.0630, 0.0648],
         [0.0621, 0.0641, 0.0624,  ..., 0.0638, 0.0626, 0.0653],
         [0.0614, 0.0644, 0.0623,  ..., 0.0634, 0.0630, 0.0646]],

        [[0.0646, 0.0626, 0.0624,  ..., 0.0626, 0.0624, 0.0636],
         [0.0628, 0.0618, 0.0619,  ..., 0.0626, 0.0633, 0.0626],
         [0.0631, 0.0619, 0.0602,  ..., 0.0646, 0.0634, 0.0623],
         ...,
         [0.0639, 0.0633, 0.0638,  ..., 0.0616, 0.0623, 0.0633],
         [0.0625, 0.0630, 0.0611,  ..., 0.0647, 0.0634, 0.0627],
         [0.0641, 0.0631, 0.0612,  ..., 0.0646, 0.0630, 0.0640]],

        [[0.0636, 0.0635, 0.0629,  ..., 0.0624, 0.0624, 0.0635],
         [0.0632, 0.0635, 0.0613,  ..., 0.0624, 0.0640, 0.0617],
         [0.0626, 0.0640, 0.0639,  ..., 0.0624, 0.0626, 0.

 10%|█         | 10/100 [01:26<13:24,  8.94s/it]

tensor([[[0.0619, 0.0616, 0.0642,  ..., 0.0632, 0.0635, 0.0621],
         [0.0624, 0.0610, 0.0625,  ..., 0.0631, 0.0636, 0.0624],
         [0.0627, 0.0618, 0.0633,  ..., 0.0636, 0.0633, 0.0619],
         ...,
         [0.0629, 0.0615, 0.0639,  ..., 0.0634, 0.0641, 0.0619],
         [0.0632, 0.0623, 0.0637,  ..., 0.0633, 0.0633, 0.0619],
         [0.0630, 0.0618, 0.0640,  ..., 0.0632, 0.0635, 0.0618]],

        [[0.0629, 0.0625, 0.0647,  ..., 0.0620, 0.0628, 0.0643],
         [0.0619, 0.0666, 0.0620,  ..., 0.0617, 0.0627, 0.0658],
         [0.0621, 0.0657, 0.0644,  ..., 0.0616, 0.0628, 0.0651],
         ...,
         [0.0623, 0.0617, 0.0627,  ..., 0.0640, 0.0617, 0.0620],
         [0.0620, 0.0646, 0.0622,  ..., 0.0619, 0.0632, 0.0651],
         [0.0616, 0.0648, 0.0626,  ..., 0.0614, 0.0626, 0.0650]],

        [[0.0613, 0.0625, 0.0642,  ..., 0.0630, 0.0607, 0.0646],
         [0.0638, 0.0644, 0.0652,  ..., 0.0602, 0.0639, 0.0628],
         [0.0627, 0.0632, 0.0649,  ..., 0.0612, 0.0630, 0.

 11%|█         | 11/100 [01:34<13:06,  8.83s/it]

tensor([[[0.0650, 0.0619, 0.0617,  ..., 0.0644, 0.0600, 0.0644],
         [0.0654, 0.0619, 0.0611,  ..., 0.0644, 0.0598, 0.0646],
         [0.0643, 0.0616, 0.0610,  ..., 0.0643, 0.0614, 0.0638],
         ...,
         [0.0645, 0.0617, 0.0611,  ..., 0.0640, 0.0603, 0.0648],
         [0.0658, 0.0615, 0.0609,  ..., 0.0646, 0.0594, 0.0650],
         [0.0650, 0.0615, 0.0615,  ..., 0.0642, 0.0598, 0.0643]],

        [[0.0630, 0.0630, 0.0623,  ..., 0.0646, 0.0632, 0.0623],
         [0.0619, 0.0652, 0.0634,  ..., 0.0653, 0.0623, 0.0620],
         [0.0610, 0.0641, 0.0632,  ..., 0.0662, 0.0616, 0.0622],
         ...,
         [0.0628, 0.0642, 0.0633,  ..., 0.0649, 0.0632, 0.0622],
         [0.0629, 0.0631, 0.0627,  ..., 0.0629, 0.0633, 0.0622],
         [0.0625, 0.0634, 0.0628,  ..., 0.0644, 0.0629, 0.0627]],

        [[0.0623, 0.0618, 0.0647,  ..., 0.0614, 0.0625, 0.0639],
         [0.0631, 0.0605, 0.0627,  ..., 0.0632, 0.0621, 0.0639],
         [0.0621, 0.0630, 0.0636,  ..., 0.0619, 0.0636, 0.

 12%|█▏        | 12/100 [01:43<13:02,  8.90s/it]

tensor([[[0.0608, 0.0596, 0.0630,  ..., 0.0624, 0.0637, 0.0620],
         [0.0609, 0.0602, 0.0625,  ..., 0.0626, 0.0645, 0.0623],
         [0.0615, 0.0585, 0.0626,  ..., 0.0638, 0.0647, 0.0623],
         ...,
         [0.0615, 0.0589, 0.0627,  ..., 0.0634, 0.0650, 0.0622],
         [0.0612, 0.0582, 0.0623,  ..., 0.0632, 0.0648, 0.0626],
         [0.0614, 0.0592, 0.0626,  ..., 0.0626, 0.0650, 0.0624]],

        [[0.0628, 0.0615, 0.0643,  ..., 0.0615, 0.0624, 0.0617],
         [0.0621, 0.0609, 0.0644,  ..., 0.0624, 0.0642, 0.0632],
         [0.0643, 0.0620, 0.0672,  ..., 0.0604, 0.0647, 0.0596],
         ...,
         [0.0611, 0.0613, 0.0640,  ..., 0.0611, 0.0626, 0.0623],
         [0.0635, 0.0607, 0.0651,  ..., 0.0606, 0.0639, 0.0599],
         [0.0617, 0.0624, 0.0645,  ..., 0.0616, 0.0617, 0.0628]],

        [[0.0617, 0.0633, 0.0621,  ..., 0.0630, 0.0622, 0.0618],
         [0.0632, 0.0636, 0.0622,  ..., 0.0633, 0.0638, 0.0623],
         [0.0621, 0.0621, 0.0623,  ..., 0.0635, 0.0637, 0.

 13%|█▎        | 13/100 [01:53<13:08,  9.07s/it]

tensor([[[0.0638, 0.0620, 0.0628,  ..., 0.0630, 0.0635, 0.0615],
         [0.0643, 0.0621, 0.0627,  ..., 0.0627, 0.0638, 0.0618],
         [0.0637, 0.0624, 0.0632,  ..., 0.0624, 0.0634, 0.0612],
         ...,
         [0.0634, 0.0627, 0.0625,  ..., 0.0625, 0.0633, 0.0609],
         [0.0634, 0.0619, 0.0627,  ..., 0.0631, 0.0636, 0.0617],
         [0.0635, 0.0625, 0.0626,  ..., 0.0622, 0.0621, 0.0610]],

        [[0.0613, 0.0617, 0.0622,  ..., 0.0617, 0.0622, 0.0613],
         [0.0610, 0.0615, 0.0638,  ..., 0.0621, 0.0618, 0.0625],
         [0.0613, 0.0615, 0.0631,  ..., 0.0620, 0.0616, 0.0623],
         ...,
         [0.0611, 0.0623, 0.0643,  ..., 0.0621, 0.0617, 0.0621],
         [0.0613, 0.0622, 0.0616,  ..., 0.0629, 0.0623, 0.0633],
         [0.0615, 0.0618, 0.0631,  ..., 0.0624, 0.0618, 0.0625]],

        [[0.0630, 0.0615, 0.0637,  ..., 0.0639, 0.0628, 0.0627],
         [0.0624, 0.0611, 0.0648,  ..., 0.0639, 0.0608, 0.0627],
         [0.0593, 0.0609, 0.0613,  ..., 0.0628, 0.0627, 0.

 14%|█▍        | 14/100 [02:01<12:48,  8.94s/it]

tensor([[[0.0629, 0.0588, 0.0612,  ..., 0.0633, 0.0617, 0.0622],
         [0.0628, 0.0592, 0.0615,  ..., 0.0634, 0.0628, 0.0625],
         [0.0624, 0.0609, 0.0618,  ..., 0.0624, 0.0623, 0.0625],
         ...,
         [0.0627, 0.0596, 0.0611,  ..., 0.0628, 0.0621, 0.0625],
         [0.0630, 0.0599, 0.0618,  ..., 0.0625, 0.0630, 0.0620],
         [0.0623, 0.0589, 0.0619,  ..., 0.0630, 0.0634, 0.0621]],

        [[0.0632, 0.0597, 0.0641,  ..., 0.0629, 0.0643, 0.0632],
         [0.0631, 0.0627, 0.0631,  ..., 0.0624, 0.0629, 0.0640],
         [0.0619, 0.0622, 0.0628,  ..., 0.0626, 0.0645, 0.0624],
         ...,
         [0.0601, 0.0623, 0.0635,  ..., 0.0629, 0.0663, 0.0639],
         [0.0610, 0.0627, 0.0636,  ..., 0.0634, 0.0654, 0.0637],
         [0.0628, 0.0631, 0.0628,  ..., 0.0628, 0.0651, 0.0635]],

        [[0.0630, 0.0623, 0.0633,  ..., 0.0622, 0.0620, 0.0621],
         [0.0608, 0.0638, 0.0623,  ..., 0.0626, 0.0627, 0.0604],
         [0.0603, 0.0630, 0.0640,  ..., 0.0634, 0.0623, 0.

 15%|█▌        | 15/100 [02:10<12:35,  8.88s/it]

tensor([[[0.0643, 0.0650, 0.0636,  ..., 0.0633, 0.0623, 0.0618],
         [0.0640, 0.0654, 0.0639,  ..., 0.0640, 0.0624, 0.0619],
         [0.0646, 0.0659, 0.0639,  ..., 0.0640, 0.0626, 0.0625],
         ...,
         [0.0641, 0.0655, 0.0633,  ..., 0.0638, 0.0623, 0.0622],
         [0.0640, 0.0659, 0.0642,  ..., 0.0636, 0.0617, 0.0620],
         [0.0636, 0.0663, 0.0641,  ..., 0.0640, 0.0620, 0.0629]],

        [[0.0631, 0.0642, 0.0607,  ..., 0.0626, 0.0623, 0.0622],
         [0.0636, 0.0669, 0.0609,  ..., 0.0612, 0.0616, 0.0616],
         [0.0628, 0.0633, 0.0614,  ..., 0.0612, 0.0625, 0.0617],
         ...,
         [0.0628, 0.0643, 0.0610,  ..., 0.0624, 0.0614, 0.0613],
         [0.0634, 0.0642, 0.0617,  ..., 0.0637, 0.0623, 0.0607],
         [0.0624, 0.0675, 0.0611,  ..., 0.0617, 0.0628, 0.0611]],

        [[0.0614, 0.0619, 0.0618,  ..., 0.0628, 0.0605, 0.0605],
         [0.0604, 0.0628, 0.0635,  ..., 0.0621, 0.0626, 0.0634],
         [0.0622, 0.0625, 0.0623,  ..., 0.0626, 0.0634, 0.

 16%|█▌        | 16/100 [02:19<12:16,  8.77s/it]

tensor([[[0.0634, 0.0632, 0.0632,  ..., 0.0610, 0.0624, 0.0646],
         [0.0634, 0.0627, 0.0637,  ..., 0.0612, 0.0621, 0.0647],
         [0.0637, 0.0631, 0.0626,  ..., 0.0618, 0.0629, 0.0641],
         ...,
         [0.0634, 0.0627, 0.0619,  ..., 0.0615, 0.0628, 0.0648],
         [0.0634, 0.0635, 0.0634,  ..., 0.0614, 0.0623, 0.0648],
         [0.0631, 0.0632, 0.0633,  ..., 0.0612, 0.0628, 0.0646]],

        [[0.0620, 0.0632, 0.0620,  ..., 0.0625, 0.0616, 0.0630],
         [0.0613, 0.0628, 0.0617,  ..., 0.0619, 0.0627, 0.0641],
         [0.0620, 0.0627, 0.0621,  ..., 0.0634, 0.0617, 0.0630],
         ...,
         [0.0620, 0.0628, 0.0626,  ..., 0.0624, 0.0623, 0.0624],
         [0.0621, 0.0632, 0.0631,  ..., 0.0612, 0.0594, 0.0630],
         [0.0616, 0.0621, 0.0614,  ..., 0.0624, 0.0608, 0.0631]],

        [[0.0619, 0.0626, 0.0639,  ..., 0.0613, 0.0635, 0.0621],
         [0.0618, 0.0615, 0.0631,  ..., 0.0606, 0.0622, 0.0620],
         [0.0621, 0.0620, 0.0641,  ..., 0.0620, 0.0624, 0.

 17%|█▋        | 17/100 [02:27<12:09,  8.79s/it]

tensor([[[0.0622, 0.0618, 0.0602,  ..., 0.0620, 0.0630, 0.0611],
         [0.0615, 0.0613, 0.0600,  ..., 0.0616, 0.0635, 0.0611],
         [0.0614, 0.0612, 0.0601,  ..., 0.0614, 0.0630, 0.0610],
         ...,
         [0.0609, 0.0610, 0.0592,  ..., 0.0613, 0.0634, 0.0597],
         [0.0606, 0.0608, 0.0603,  ..., 0.0625, 0.0634, 0.0609],
         [0.0606, 0.0612, 0.0605,  ..., 0.0619, 0.0627, 0.0625]],

        [[0.0639, 0.0645, 0.0632,  ..., 0.0625, 0.0633, 0.0616],
         [0.0622, 0.0623, 0.0637,  ..., 0.0635, 0.0616, 0.0626],
         [0.0643, 0.0625, 0.0638,  ..., 0.0624, 0.0639, 0.0615],
         ...,
         [0.0638, 0.0620, 0.0654,  ..., 0.0628, 0.0631, 0.0607],
         [0.0634, 0.0630, 0.0640,  ..., 0.0628, 0.0630, 0.0618],
         [0.0619, 0.0615, 0.0641,  ..., 0.0622, 0.0628, 0.0627]],

        [[0.0637, 0.0615, 0.0635,  ..., 0.0652, 0.0634, 0.0626],
         [0.0622, 0.0623, 0.0620,  ..., 0.0645, 0.0626, 0.0634],
         [0.0626, 0.0614, 0.0649,  ..., 0.0643, 0.0631, 0.

 18%|█▊        | 18/100 [02:36<12:05,  8.85s/it]

tensor([[[0.0628, 0.0665, 0.0612,  ..., 0.0626, 0.0620, 0.0636],
         [0.0626, 0.0654, 0.0615,  ..., 0.0620, 0.0623, 0.0634],
         [0.0636, 0.0663, 0.0608,  ..., 0.0621, 0.0626, 0.0629],
         ...,
         [0.0634, 0.0656, 0.0612,  ..., 0.0622, 0.0626, 0.0632],
         [0.0630, 0.0665, 0.0608,  ..., 0.0618, 0.0629, 0.0627],
         [0.0629, 0.0661, 0.0609,  ..., 0.0615, 0.0627, 0.0636]],

        [[0.0637, 0.0618, 0.0622,  ..., 0.0643, 0.0638, 0.0591],
         [0.0645, 0.0612, 0.0633,  ..., 0.0643, 0.0638, 0.0596],
         [0.0637, 0.0616, 0.0625,  ..., 0.0645, 0.0634, 0.0594],
         ...,
         [0.0647, 0.0622, 0.0638,  ..., 0.0649, 0.0625, 0.0588],
         [0.0648, 0.0620, 0.0632,  ..., 0.0650, 0.0643, 0.0586],
         [0.0640, 0.0623, 0.0628,  ..., 0.0636, 0.0641, 0.0605]],

        [[0.0594, 0.0628, 0.0626,  ..., 0.0616, 0.0607, 0.0647],
         [0.0623, 0.0626, 0.0629,  ..., 0.0609, 0.0603, 0.0623],
         [0.0618, 0.0635, 0.0625,  ..., 0.0619, 0.0620, 0.

 19%|█▉        | 19/100 [02:46<12:07,  8.99s/it]

tensor([[[0.0608, 0.0642, 0.0601,  ..., 0.0641, 0.0646, 0.0630],
         [0.0611, 0.0635, 0.0598,  ..., 0.0651, 0.0631, 0.0635],
         [0.0609, 0.0635, 0.0606,  ..., 0.0648, 0.0639, 0.0625],
         ...,
         [0.0615, 0.0625, 0.0599,  ..., 0.0652, 0.0634, 0.0625],
         [0.0621, 0.0624, 0.0608,  ..., 0.0648, 0.0633, 0.0631],
         [0.0610, 0.0626, 0.0589,  ..., 0.0655, 0.0631, 0.0634]],

        [[0.0610, 0.0627, 0.0623,  ..., 0.0614, 0.0646, 0.0609],
         [0.0620, 0.0623, 0.0618,  ..., 0.0607, 0.0640, 0.0622],
         [0.0618, 0.0629, 0.0625,  ..., 0.0599, 0.0625, 0.0615],
         ...,
         [0.0615, 0.0621, 0.0622,  ..., 0.0602, 0.0646, 0.0601],
         [0.0611, 0.0627, 0.0621,  ..., 0.0603, 0.0640, 0.0606],
         [0.0608, 0.0617, 0.0612,  ..., 0.0600, 0.0651, 0.0605]],

        [[0.0623, 0.0640, 0.0622,  ..., 0.0609, 0.0622, 0.0611],
         [0.0640, 0.0648, 0.0622,  ..., 0.0612, 0.0616, 0.0615],
         [0.0632, 0.0645, 0.0629,  ..., 0.0616, 0.0616, 0.

 20%|██        | 20/100 [02:55<12:00,  9.01s/it]

tensor([[[0.0602, 0.0644, 0.0640,  ..., 0.0621, 0.0606, 0.0613],
         [0.0609, 0.0646, 0.0648,  ..., 0.0618, 0.0601, 0.0610],
         [0.0603, 0.0650, 0.0651,  ..., 0.0615, 0.0594, 0.0601],
         ...,
         [0.0610, 0.0641, 0.0652,  ..., 0.0622, 0.0614, 0.0612],
         [0.0610, 0.0648, 0.0637,  ..., 0.0619, 0.0606, 0.0606],
         [0.0609, 0.0643, 0.0652,  ..., 0.0614, 0.0598, 0.0609]],

        [[0.0639, 0.0642, 0.0626,  ..., 0.0616, 0.0627, 0.0615],
         [0.0634, 0.0659, 0.0623,  ..., 0.0618, 0.0619, 0.0607],
         [0.0626, 0.0622, 0.0620,  ..., 0.0627, 0.0614, 0.0621],
         ...,
         [0.0634, 0.0656, 0.0615,  ..., 0.0601, 0.0642, 0.0582],
         [0.0626, 0.0645, 0.0621,  ..., 0.0610, 0.0623, 0.0615],
         [0.0645, 0.0640, 0.0612,  ..., 0.0596, 0.0630, 0.0612]],

        [[0.0615, 0.0620, 0.0610,  ..., 0.0649, 0.0635, 0.0606],
         [0.0624, 0.0622, 0.0615,  ..., 0.0623, 0.0623, 0.0629],
         [0.0616, 0.0608, 0.0625,  ..., 0.0637, 0.0628, 0.

 21%|██        | 21/100 [03:03<11:39,  8.86s/it]

tensor([[[0.0609, 0.0625, 0.0650,  ..., 0.0636, 0.0632, 0.0641],
         [0.0607, 0.0637, 0.0649,  ..., 0.0636, 0.0631, 0.0637],
         [0.0604, 0.0626, 0.0652,  ..., 0.0635, 0.0635, 0.0640],
         ...,
         [0.0608, 0.0625, 0.0654,  ..., 0.0638, 0.0635, 0.0640],
         [0.0612, 0.0625, 0.0650,  ..., 0.0636, 0.0627, 0.0635],
         [0.0607, 0.0632, 0.0649,  ..., 0.0634, 0.0630, 0.0643]],

        [[0.0606, 0.0636, 0.0612,  ..., 0.0633, 0.0636, 0.0611],
         [0.0619, 0.0624, 0.0628,  ..., 0.0633, 0.0622, 0.0636],
         [0.0628, 0.0626, 0.0626,  ..., 0.0636, 0.0638, 0.0593],
         ...,
         [0.0594, 0.0622, 0.0618,  ..., 0.0624, 0.0637, 0.0601],
         [0.0615, 0.0633, 0.0608,  ..., 0.0624, 0.0631, 0.0605],
         [0.0621, 0.0626, 0.0629,  ..., 0.0622, 0.0637, 0.0613]],

        [[0.0626, 0.0642, 0.0616,  ..., 0.0626, 0.0624, 0.0619],
         [0.0651, 0.0640, 0.0615,  ..., 0.0608, 0.0635, 0.0623],
         [0.0621, 0.0630, 0.0631,  ..., 0.0628, 0.0635, 0.

 22%|██▏       | 22/100 [03:13<11:46,  9.06s/it]

tensor([[[0.0633, 0.0639, 0.0625,  ..., 0.0596, 0.0626, 0.0636],
         [0.0633, 0.0631, 0.0626,  ..., 0.0592, 0.0612, 0.0634],
         [0.0627, 0.0637, 0.0629,  ..., 0.0593, 0.0617, 0.0631],
         ...,
         [0.0630, 0.0634, 0.0629,  ..., 0.0596, 0.0613, 0.0631],
         [0.0638, 0.0636, 0.0621,  ..., 0.0592, 0.0614, 0.0638],
         [0.0617, 0.0632, 0.0626,  ..., 0.0598, 0.0603, 0.0626]],

        [[0.0615, 0.0635, 0.0626,  ..., 0.0614, 0.0624, 0.0625],
         [0.0623, 0.0647, 0.0607,  ..., 0.0590, 0.0623, 0.0622],
         [0.0617, 0.0627, 0.0623,  ..., 0.0612, 0.0632, 0.0627],
         ...,
         [0.0607, 0.0644, 0.0615,  ..., 0.0624, 0.0625, 0.0609],
         [0.0611, 0.0629, 0.0626,  ..., 0.0598, 0.0636, 0.0618],
         [0.0611, 0.0625, 0.0631,  ..., 0.0613, 0.0630, 0.0620]],

        [[0.0601, 0.0639, 0.0625,  ..., 0.0630, 0.0629, 0.0633],
         [0.0616, 0.0629, 0.0634,  ..., 0.0615, 0.0635, 0.0632],
         [0.0616, 0.0632, 0.0626,  ..., 0.0634, 0.0640, 0.

 23%|██▎       | 23/100 [03:20<11:03,  8.62s/it]

tensor([[[0.0600, 0.0592, 0.0621,  ..., 0.0636, 0.0600, 0.0603],
         [0.0606, 0.0596, 0.0627,  ..., 0.0628, 0.0616, 0.0619],
         [0.0606, 0.0591, 0.0629,  ..., 0.0631, 0.0609, 0.0611],
         ...,
         [0.0598, 0.0599, 0.0624,  ..., 0.0634, 0.0599, 0.0610],
         [0.0604, 0.0593, 0.0626,  ..., 0.0631, 0.0602, 0.0607],
         [0.0600, 0.0604, 0.0623,  ..., 0.0635, 0.0602, 0.0605]],

        [[0.0626, 0.0634, 0.0603,  ..., 0.0613, 0.0596, 0.0677],
         [0.0633, 0.0634, 0.0604,  ..., 0.0610, 0.0603, 0.0679],
         [0.0627, 0.0634, 0.0603,  ..., 0.0629, 0.0613, 0.0645],
         ...,
         [0.0617, 0.0637, 0.0603,  ..., 0.0614, 0.0598, 0.0666],
         [0.0623, 0.0626, 0.0602,  ..., 0.0617, 0.0605, 0.0654],
         [0.0619, 0.0637, 0.0612,  ..., 0.0610, 0.0592, 0.0664]],

        [[0.0598, 0.0623, 0.0594,  ..., 0.0589, 0.0618, 0.0640],
         [0.0612, 0.0616, 0.0612,  ..., 0.0634, 0.0618, 0.0644],
         [0.0621, 0.0631, 0.0618,  ..., 0.0604, 0.0617, 0.

 24%|██▍       | 24/100 [03:27<10:03,  7.95s/it]

tensor([[[0.0627, 0.0619, 0.0642,  ..., 0.0614, 0.0614, 0.0635],
         [0.0628, 0.0622, 0.0642,  ..., 0.0617, 0.0614, 0.0626],
         [0.0624, 0.0624, 0.0642,  ..., 0.0615, 0.0622, 0.0638],
         ...,
         [0.0628, 0.0627, 0.0628,  ..., 0.0614, 0.0617, 0.0636],
         [0.0626, 0.0619, 0.0651,  ..., 0.0611, 0.0610, 0.0635],
         [0.0615, 0.0624, 0.0639,  ..., 0.0618, 0.0618, 0.0629]],

        [[0.0620, 0.0620, 0.0616,  ..., 0.0630, 0.0628, 0.0647],
         [0.0631, 0.0615, 0.0612,  ..., 0.0613, 0.0617, 0.0634],
         [0.0632, 0.0623, 0.0625,  ..., 0.0623, 0.0625, 0.0636],
         ...,
         [0.0621, 0.0608, 0.0616,  ..., 0.0626, 0.0623, 0.0624],
         [0.0630, 0.0609, 0.0639,  ..., 0.0624, 0.0623, 0.0628],
         [0.0624, 0.0619, 0.0606,  ..., 0.0631, 0.0625, 0.0638]],

        [[0.0623, 0.0642, 0.0629,  ..., 0.0626, 0.0627, 0.0620],
         [0.0634, 0.0624, 0.0622,  ..., 0.0601, 0.0638, 0.0616],
         [0.0632, 0.0629, 0.0634,  ..., 0.0614, 0.0636, 0.

 25%|██▌       | 25/100 [03:35<10:00,  8.00s/it]

tensor([[[0.0627, 0.0600, 0.0639,  ..., 0.0635, 0.0615, 0.0610],
         [0.0629, 0.0591, 0.0639,  ..., 0.0635, 0.0610, 0.0615],
         [0.0629, 0.0598, 0.0644,  ..., 0.0631, 0.0611, 0.0612],
         ...,
         [0.0628, 0.0601, 0.0636,  ..., 0.0636, 0.0622, 0.0612],
         [0.0631, 0.0603, 0.0636,  ..., 0.0639, 0.0621, 0.0613],
         [0.0626, 0.0604, 0.0639,  ..., 0.0638, 0.0623, 0.0617]],

        [[0.0612, 0.0630, 0.0623,  ..., 0.0610, 0.0632, 0.0632],
         [0.0614, 0.0629, 0.0629,  ..., 0.0612, 0.0623, 0.0628],
         [0.0620, 0.0618, 0.0618,  ..., 0.0621, 0.0636, 0.0644],
         ...,
         [0.0618, 0.0631, 0.0627,  ..., 0.0613, 0.0632, 0.0636],
         [0.0603, 0.0633, 0.0642,  ..., 0.0618, 0.0642, 0.0642],
         [0.0621, 0.0629, 0.0631,  ..., 0.0612, 0.0636, 0.0633]],

        [[0.0631, 0.0622, 0.0609,  ..., 0.0618, 0.0623, 0.0639],
         [0.0637, 0.0620, 0.0626,  ..., 0.0604, 0.0605, 0.0643],
         [0.0617, 0.0606, 0.0613,  ..., 0.0621, 0.0624, 0.

 26%|██▌       | 26/100 [03:42<09:35,  7.78s/it]

tensor([[[0.0634, 0.0640, 0.0639,  ..., 0.0625, 0.0661, 0.0628],
         [0.0634, 0.0635, 0.0631,  ..., 0.0632, 0.0670, 0.0629],
         [0.0637, 0.0636, 0.0636,  ..., 0.0619, 0.0663, 0.0630],
         ...,
         [0.0634, 0.0634, 0.0635,  ..., 0.0622, 0.0658, 0.0623],
         [0.0630, 0.0638, 0.0629,  ..., 0.0610, 0.0646, 0.0630],
         [0.0635, 0.0640, 0.0637,  ..., 0.0630, 0.0665, 0.0632]],

        [[0.0628, 0.0631, 0.0630,  ..., 0.0624, 0.0643, 0.0632],
         [0.0611, 0.0631, 0.0624,  ..., 0.0622, 0.0645, 0.0645],
         [0.0602, 0.0632, 0.0628,  ..., 0.0636, 0.0642, 0.0637],
         ...,
         [0.0604, 0.0645, 0.0626,  ..., 0.0614, 0.0645, 0.0631],
         [0.0609, 0.0616, 0.0625,  ..., 0.0619, 0.0657, 0.0641],
         [0.0619, 0.0636, 0.0615,  ..., 0.0605, 0.0658, 0.0641]],

        [[0.0611, 0.0643, 0.0618,  ..., 0.0623, 0.0621, 0.0610],
         [0.0615, 0.0639, 0.0612,  ..., 0.0621, 0.0628, 0.0611],
         [0.0627, 0.0621, 0.0613,  ..., 0.0647, 0.0624, 0.

 27%|██▋       | 27/100 [03:51<09:40,  7.95s/it]

tensor([[[0.0638, 0.0641, 0.0616,  ..., 0.0615, 0.0630, 0.0627],
         [0.0643, 0.0642, 0.0620,  ..., 0.0610, 0.0638, 0.0632],
         [0.0639, 0.0641, 0.0607,  ..., 0.0617, 0.0634, 0.0632],
         ...,
         [0.0636, 0.0641, 0.0623,  ..., 0.0608, 0.0652, 0.0641],
         [0.0631, 0.0642, 0.0616,  ..., 0.0626, 0.0633, 0.0628],
         [0.0635, 0.0643, 0.0619,  ..., 0.0619, 0.0630, 0.0637]],

        [[0.0638, 0.0632, 0.0601,  ..., 0.0637, 0.0634, 0.0617],
         [0.0621, 0.0642, 0.0624,  ..., 0.0622, 0.0612, 0.0631],
         [0.0623, 0.0635, 0.0610,  ..., 0.0623, 0.0601, 0.0637],
         ...,
         [0.0642, 0.0643, 0.0603,  ..., 0.0627, 0.0617, 0.0608],
         [0.0628, 0.0638, 0.0617,  ..., 0.0628, 0.0619, 0.0637],
         [0.0633, 0.0624, 0.0607,  ..., 0.0630, 0.0601, 0.0636]],

        [[0.0638, 0.0622, 0.0638,  ..., 0.0624, 0.0626, 0.0624],
         [0.0628, 0.0622, 0.0639,  ..., 0.0635, 0.0630, 0.0602],
         [0.0628, 0.0625, 0.0631,  ..., 0.0636, 0.0641, 0.

 28%|██▊       | 28/100 [03:57<09:08,  7.62s/it]

tensor([[[0.0631, 0.0634, 0.0642,  ..., 0.0640, 0.0614, 0.0623],
         [0.0631, 0.0637, 0.0637,  ..., 0.0630, 0.0610, 0.0622],
         [0.0638, 0.0633, 0.0631,  ..., 0.0651, 0.0601, 0.0616],
         ...,
         [0.0632, 0.0629, 0.0637,  ..., 0.0638, 0.0612, 0.0620],
         [0.0643, 0.0634, 0.0642,  ..., 0.0646, 0.0602, 0.0614],
         [0.0630, 0.0639, 0.0633,  ..., 0.0649, 0.0608, 0.0623]],

        [[0.0630, 0.0626, 0.0627,  ..., 0.0614, 0.0621, 0.0628],
         [0.0633, 0.0632, 0.0625,  ..., 0.0634, 0.0606, 0.0633],
         [0.0611, 0.0640, 0.0627,  ..., 0.0629, 0.0617, 0.0628],
         ...,
         [0.0621, 0.0617, 0.0615,  ..., 0.0628, 0.0626, 0.0626],
         [0.0591, 0.0619, 0.0617,  ..., 0.0624, 0.0625, 0.0633],
         [0.0623, 0.0614, 0.0637,  ..., 0.0616, 0.0629, 0.0620]],

        [[0.0630, 0.0607, 0.0620,  ..., 0.0625, 0.0628, 0.0636],
         [0.0617, 0.0627, 0.0620,  ..., 0.0636, 0.0627, 0.0631],
         [0.0603, 0.0637, 0.0612,  ..., 0.0629, 0.0624, 0.

 29%|██▉       | 29/100 [04:04<08:40,  7.32s/it]

tensor([[[0.0612, 0.0665, 0.0596,  ..., 0.0641, 0.0628, 0.0625],
         [0.0621, 0.0646, 0.0608,  ..., 0.0639, 0.0633, 0.0623],
         [0.0610, 0.0665, 0.0601,  ..., 0.0639, 0.0627, 0.0628],
         ...,
         [0.0615, 0.0654, 0.0606,  ..., 0.0643, 0.0632, 0.0626],
         [0.0603, 0.0654, 0.0605,  ..., 0.0644, 0.0637, 0.0624],
         [0.0618, 0.0645, 0.0607,  ..., 0.0644, 0.0635, 0.0624]],

        [[0.0633, 0.0620, 0.0629,  ..., 0.0626, 0.0640, 0.0633],
         [0.0617, 0.0620, 0.0641,  ..., 0.0643, 0.0628, 0.0612],
         [0.0610, 0.0618, 0.0632,  ..., 0.0626, 0.0637, 0.0626],
         ...,
         [0.0608, 0.0612, 0.0643,  ..., 0.0640, 0.0632, 0.0626],
         [0.0619, 0.0617, 0.0636,  ..., 0.0627, 0.0624, 0.0622],
         [0.0611, 0.0618, 0.0641,  ..., 0.0643, 0.0609, 0.0609]],

        [[0.0627, 0.0627, 0.0631,  ..., 0.0614, 0.0619, 0.0625],
         [0.0618, 0.0630, 0.0639,  ..., 0.0632, 0.0610, 0.0622],
         [0.0633, 0.0630, 0.0638,  ..., 0.0612, 0.0629, 0.

 30%|███       | 30/100 [04:11<08:21,  7.17s/it]

tensor([[[0.0618, 0.0634, 0.0644,  ..., 0.0636, 0.0616, 0.0632],
         [0.0613, 0.0635, 0.0645,  ..., 0.0645, 0.0626, 0.0635],
         [0.0620, 0.0631, 0.0643,  ..., 0.0639, 0.0620, 0.0632],
         ...,
         [0.0613, 0.0633, 0.0639,  ..., 0.0643, 0.0622, 0.0639],
         [0.0617, 0.0631, 0.0644,  ..., 0.0641, 0.0616, 0.0639],
         [0.0621, 0.0630, 0.0637,  ..., 0.0636, 0.0617, 0.0638]],

        [[0.0625, 0.0634, 0.0642,  ..., 0.0617, 0.0637, 0.0619],
         [0.0617, 0.0628, 0.0642,  ..., 0.0637, 0.0636, 0.0621],
         [0.0631, 0.0634, 0.0630,  ..., 0.0608, 0.0639, 0.0620],
         ...,
         [0.0626, 0.0649, 0.0640,  ..., 0.0635, 0.0630, 0.0623],
         [0.0623, 0.0644, 0.0640,  ..., 0.0628, 0.0638, 0.0614],
         [0.0617, 0.0633, 0.0638,  ..., 0.0620, 0.0622, 0.0615]],

        [[0.0601, 0.0639, 0.0621,  ..., 0.0631, 0.0629, 0.0612],
         [0.0632, 0.0644, 0.0638,  ..., 0.0636, 0.0603, 0.0634],
         [0.0619, 0.0640, 0.0623,  ..., 0.0624, 0.0631, 0.

 31%|███       | 31/100 [04:19<08:35,  7.48s/it]

tensor([[[0.0619, 0.0656, 0.0651,  ..., 0.0627, 0.0608, 0.0605],
         [0.0617, 0.0646, 0.0649,  ..., 0.0630, 0.0617, 0.0607],
         [0.0614, 0.0656, 0.0649,  ..., 0.0638, 0.0600, 0.0602],
         ...,
         [0.0617, 0.0639, 0.0646,  ..., 0.0624, 0.0632, 0.0607],
         [0.0626, 0.0643, 0.0650,  ..., 0.0622, 0.0622, 0.0597],
         [0.0617, 0.0644, 0.0650,  ..., 0.0633, 0.0622, 0.0609]],

        [[0.0663, 0.0588, 0.0613,  ..., 0.0623, 0.0622, 0.0614],
         [0.0639, 0.0613, 0.0630,  ..., 0.0621, 0.0638, 0.0617],
         [0.0641, 0.0619, 0.0612,  ..., 0.0627, 0.0627, 0.0620],
         ...,
         [0.0627, 0.0615, 0.0626,  ..., 0.0627, 0.0650, 0.0611],
         [0.0639, 0.0623, 0.0614,  ..., 0.0621, 0.0626, 0.0620],
         [0.0643, 0.0621, 0.0618,  ..., 0.0626, 0.0613, 0.0613]],

        [[0.0623, 0.0624, 0.0634,  ..., 0.0617, 0.0611, 0.0622],
         [0.0602, 0.0635, 0.0629,  ..., 0.0624, 0.0627, 0.0602],
         [0.0608, 0.0630, 0.0624,  ..., 0.0622, 0.0614, 0.

 32%|███▏      | 32/100 [04:28<09:07,  8.05s/it]

tensor([[[0.0612, 0.0640, 0.0598,  ..., 0.0634, 0.0638, 0.0635],
         [0.0616, 0.0641, 0.0599,  ..., 0.0622, 0.0640, 0.0629],
         [0.0618, 0.0637, 0.0607,  ..., 0.0632, 0.0624, 0.0629],
         ...,
         [0.0609, 0.0639, 0.0601,  ..., 0.0628, 0.0639, 0.0645],
         [0.0602, 0.0639, 0.0591,  ..., 0.0629, 0.0642, 0.0642],
         [0.0610, 0.0644, 0.0596,  ..., 0.0621, 0.0638, 0.0646]],

        [[0.0629, 0.0632, 0.0625,  ..., 0.0625, 0.0615, 0.0617],
         [0.0611, 0.0623, 0.0608,  ..., 0.0637, 0.0622, 0.0630],
         [0.0636, 0.0627, 0.0625,  ..., 0.0634, 0.0625, 0.0622],
         ...,
         [0.0623, 0.0628, 0.0611,  ..., 0.0632, 0.0609, 0.0629],
         [0.0626, 0.0651, 0.0616,  ..., 0.0622, 0.0625, 0.0640],
         [0.0618, 0.0629, 0.0603,  ..., 0.0622, 0.0612, 0.0628]],

        [[0.0631, 0.0635, 0.0611,  ..., 0.0620, 0.0606, 0.0641],
         [0.0626, 0.0622, 0.0613,  ..., 0.0639, 0.0621, 0.0636],
         [0.0627, 0.0628, 0.0615,  ..., 0.0625, 0.0616, 0.

 33%|███▎      | 33/100 [04:37<09:09,  8.21s/it]

tensor([[[0.0590, 0.0635, 0.0627,  ..., 0.0608, 0.0629, 0.0605],
         [0.0589, 0.0633, 0.0620,  ..., 0.0609, 0.0628, 0.0590],
         [0.0584, 0.0642, 0.0622,  ..., 0.0613, 0.0623, 0.0589],
         ...,
         [0.0582, 0.0638, 0.0626,  ..., 0.0611, 0.0631, 0.0586],
         [0.0588, 0.0638, 0.0618,  ..., 0.0616, 0.0620, 0.0596],
         [0.0581, 0.0637, 0.0627,  ..., 0.0608, 0.0631, 0.0598]],

        [[0.0630, 0.0631, 0.0637,  ..., 0.0620, 0.0625, 0.0616],
         [0.0628, 0.0634, 0.0621,  ..., 0.0626, 0.0603, 0.0605],
         [0.0628, 0.0640, 0.0618,  ..., 0.0633, 0.0615, 0.0613],
         ...,
         [0.0613, 0.0620, 0.0619,  ..., 0.0635, 0.0601, 0.0606],
         [0.0625, 0.0615, 0.0605,  ..., 0.0630, 0.0605, 0.0613],
         [0.0624, 0.0634, 0.0631,  ..., 0.0639, 0.0599, 0.0613]],

        [[0.0623, 0.0633, 0.0616,  ..., 0.0617, 0.0637, 0.0611],
         [0.0625, 0.0632, 0.0625,  ..., 0.0623, 0.0622, 0.0627],
         [0.0609, 0.0658, 0.0617,  ..., 0.0608, 0.0641, 0.

 34%|███▍      | 34/100 [04:44<08:39,  7.88s/it]

tensor([[[0.0626, 0.0652, 0.0609,  ..., 0.0680, 0.0627, 0.0599],
         [0.0625, 0.0634, 0.0617,  ..., 0.0671, 0.0632, 0.0600],
         [0.0625, 0.0657, 0.0609,  ..., 0.0700, 0.0626, 0.0587],
         ...,
         [0.0628, 0.0634, 0.0618,  ..., 0.0664, 0.0627, 0.0601],
         [0.0620, 0.0653, 0.0616,  ..., 0.0684, 0.0624, 0.0594],
         [0.0623, 0.0657, 0.0607,  ..., 0.0687, 0.0624, 0.0595]],

        [[0.0618, 0.0631, 0.0641,  ..., 0.0617, 0.0635, 0.0594],
         [0.0602, 0.0641, 0.0636,  ..., 0.0621, 0.0622, 0.0588],
         [0.0603, 0.0653, 0.0636,  ..., 0.0629, 0.0619, 0.0582],
         ...,
         [0.0613, 0.0645, 0.0637,  ..., 0.0615, 0.0634, 0.0600],
         [0.0625, 0.0630, 0.0634,  ..., 0.0625, 0.0630, 0.0599],
         [0.0630, 0.0604, 0.0611,  ..., 0.0610, 0.0624, 0.0621]],

        [[0.0623, 0.0617, 0.0622,  ..., 0.0626, 0.0625, 0.0624],
         [0.0628, 0.0630, 0.0627,  ..., 0.0626, 0.0616, 0.0621],
         [0.0616, 0.0633, 0.0610,  ..., 0.0652, 0.0636, 0.

 35%|███▌      | 35/100 [04:51<08:18,  7.67s/it]

tensor([[[0.0644, 0.0604, 0.0625,  ..., 0.0622, 0.0592, 0.0725],
         [0.0649, 0.0599, 0.0631,  ..., 0.0634, 0.0573, 0.0733],
         [0.0652, 0.0609, 0.0626,  ..., 0.0626, 0.0582, 0.0716],
         ...,
         [0.0637, 0.0610, 0.0623,  ..., 0.0615, 0.0592, 0.0730],
         [0.0644, 0.0613, 0.0620,  ..., 0.0623, 0.0603, 0.0707],
         [0.0642, 0.0618, 0.0630,  ..., 0.0624, 0.0591, 0.0664]],

        [[0.0647, 0.0618, 0.0629,  ..., 0.0625, 0.0612, 0.0614],
         [0.0622, 0.0622, 0.0615,  ..., 0.0591, 0.0632, 0.0627],
         [0.0625, 0.0625, 0.0614,  ..., 0.0589, 0.0628, 0.0634],
         ...,
         [0.0636, 0.0624, 0.0631,  ..., 0.0599, 0.0625, 0.0637],
         [0.0648, 0.0623, 0.0618,  ..., 0.0615, 0.0627, 0.0617],
         [0.0638, 0.0614, 0.0619,  ..., 0.0610, 0.0639, 0.0622]],

        [[0.0643, 0.0614, 0.0639,  ..., 0.0631, 0.0635, 0.0618],
         [0.0660, 0.0613, 0.0618,  ..., 0.0603, 0.0634, 0.0632],
         [0.0643, 0.0606, 0.0636,  ..., 0.0624, 0.0630, 0.

 36%|███▌      | 36/100 [05:01<08:42,  8.16s/it]

tensor([[[0.0636, 0.0631, 0.0586,  ..., 0.0639, 0.0618, 0.0641],
         [0.0628, 0.0637, 0.0588,  ..., 0.0628, 0.0614, 0.0646],
         [0.0626, 0.0643, 0.0607,  ..., 0.0625, 0.0617, 0.0649],
         ...,
         [0.0638, 0.0630, 0.0581,  ..., 0.0635, 0.0621, 0.0649],
         [0.0637, 0.0625, 0.0585,  ..., 0.0633, 0.0628, 0.0639],
         [0.0631, 0.0634, 0.0595,  ..., 0.0628, 0.0617, 0.0640]],

        [[0.0633, 0.0648, 0.0620,  ..., 0.0615, 0.0618, 0.0628],
         [0.0621, 0.0636, 0.0626,  ..., 0.0611, 0.0617, 0.0612],
         [0.0621, 0.0651, 0.0620,  ..., 0.0613, 0.0614, 0.0627],
         ...,
         [0.0625, 0.0640, 0.0618,  ..., 0.0604, 0.0606, 0.0604],
         [0.0611, 0.0632, 0.0619,  ..., 0.0619, 0.0620, 0.0614],
         [0.0613, 0.0637, 0.0623,  ..., 0.0604, 0.0624, 0.0622]],

        [[0.0609, 0.0634, 0.0621,  ..., 0.0627, 0.0635, 0.0619],
         [0.0617, 0.0627, 0.0605,  ..., 0.0617, 0.0633, 0.0608],
         [0.0615, 0.0640, 0.0607,  ..., 0.0625, 0.0623, 0.

 37%|███▋      | 37/100 [05:11<09:14,  8.81s/it]

tensor([[[0.0656, 0.0623, 0.0609,  ..., 0.0648, 0.0619, 0.0636],
         [0.0656, 0.0626, 0.0610,  ..., 0.0647, 0.0617, 0.0635],
         [0.0661, 0.0627, 0.0613,  ..., 0.0657, 0.0617, 0.0638],
         ...,
         [0.0649, 0.0629, 0.0619,  ..., 0.0650, 0.0618, 0.0634],
         [0.0657, 0.0622, 0.0609,  ..., 0.0652, 0.0619, 0.0631],
         [0.0658, 0.0619, 0.0612,  ..., 0.0659, 0.0619, 0.0641]],

        [[0.0654, 0.0609, 0.0617,  ..., 0.0627, 0.0630, 0.0635],
         [0.0648, 0.0620, 0.0635,  ..., 0.0617, 0.0619, 0.0625],
         [0.0641, 0.0619, 0.0635,  ..., 0.0625, 0.0615, 0.0627],
         ...,
         [0.0643, 0.0619, 0.0631,  ..., 0.0621, 0.0618, 0.0616],
         [0.0638, 0.0604, 0.0623,  ..., 0.0607, 0.0629, 0.0655],
         [0.0638, 0.0606, 0.0619,  ..., 0.0619, 0.0615, 0.0648]],

        [[0.0626, 0.0616, 0.0608,  ..., 0.0633, 0.0633, 0.0636],
         [0.0604, 0.0624, 0.0622,  ..., 0.0629, 0.0622, 0.0624],
         [0.0623, 0.0634, 0.0625,  ..., 0.0648, 0.0618, 0.

 38%|███▊      | 38/100 [05:19<08:56,  8.66s/it]

tensor([[[0.0641, 0.0625, 0.0640,  ..., 0.0622, 0.0629, 0.0620],
         [0.0630, 0.0626, 0.0627,  ..., 0.0626, 0.0639, 0.0623],
         [0.0650, 0.0630, 0.0649,  ..., 0.0621, 0.0634, 0.0622],
         ...,
         [0.0638, 0.0627, 0.0631,  ..., 0.0623, 0.0640, 0.0617],
         [0.0637, 0.0626, 0.0637,  ..., 0.0620, 0.0642, 0.0622],
         [0.0644, 0.0623, 0.0640,  ..., 0.0618, 0.0638, 0.0617]],

        [[0.0632, 0.0620, 0.0631,  ..., 0.0599, 0.0629, 0.0619],
         [0.0627, 0.0643, 0.0635,  ..., 0.0593, 0.0621, 0.0610],
         [0.0624, 0.0634, 0.0627,  ..., 0.0607, 0.0642, 0.0604],
         ...,
         [0.0620, 0.0634, 0.0633,  ..., 0.0615, 0.0631, 0.0610],
         [0.0629, 0.0632, 0.0633,  ..., 0.0615, 0.0600, 0.0614],
         [0.0639, 0.0633, 0.0628,  ..., 0.0617, 0.0640, 0.0630]],

        [[0.0623, 0.0637, 0.0626,  ..., 0.0622, 0.0613, 0.0647],
         [0.0623, 0.0646, 0.0613,  ..., 0.0635, 0.0608, 0.0642],
         [0.0631, 0.0645, 0.0630,  ..., 0.0624, 0.0611, 0.

 39%|███▉      | 39/100 [05:29<09:00,  8.85s/it]

tensor([[[0.0635, 0.0643, 0.0642,  ..., 0.0623, 0.0571, 0.0652],
         [0.0644, 0.0650, 0.0645,  ..., 0.0632, 0.0569, 0.0664],
         [0.0647, 0.0650, 0.0646,  ..., 0.0615, 0.0573, 0.0661],
         ...,
         [0.0643, 0.0641, 0.0632,  ..., 0.0620, 0.0589, 0.0659],
         [0.0630, 0.0639, 0.0631,  ..., 0.0640, 0.0569, 0.0657],
         [0.0644, 0.0649, 0.0638,  ..., 0.0628, 0.0575, 0.0662]],

        [[0.0619, 0.0629, 0.0629,  ..., 0.0636, 0.0604, 0.0636],
         [0.0618, 0.0636, 0.0614,  ..., 0.0637, 0.0622, 0.0614],
         [0.0625, 0.0625, 0.0632,  ..., 0.0630, 0.0619, 0.0628],
         ...,
         [0.0631, 0.0619, 0.0623,  ..., 0.0662, 0.0598, 0.0640],
         [0.0616, 0.0648, 0.0634,  ..., 0.0632, 0.0638, 0.0631],
         [0.0622, 0.0627, 0.0616,  ..., 0.0638, 0.0615, 0.0627]],

        [[0.0642, 0.0637, 0.0603,  ..., 0.0632, 0.0619, 0.0610],
         [0.0627, 0.0625, 0.0603,  ..., 0.0629, 0.0624, 0.0617],
         [0.0632, 0.0641, 0.0614,  ..., 0.0630, 0.0608, 0.

 40%|████      | 40/100 [05:38<08:55,  8.92s/it]

tensor([[[0.0608, 0.0623, 0.0622,  ..., 0.0650, 0.0638, 0.0645],
         [0.0614, 0.0623, 0.0628,  ..., 0.0641, 0.0627, 0.0641],
         [0.0611, 0.0616, 0.0627,  ..., 0.0638, 0.0629, 0.0653],
         ...,
         [0.0608, 0.0622, 0.0622,  ..., 0.0643, 0.0632, 0.0648],
         [0.0602, 0.0615, 0.0623,  ..., 0.0650, 0.0638, 0.0653],
         [0.0615, 0.0619, 0.0630,  ..., 0.0634, 0.0622, 0.0654]],

        [[0.0616, 0.0639, 0.0631,  ..., 0.0609, 0.0652, 0.0620],
         [0.0623, 0.0665, 0.0636,  ..., 0.0615, 0.0636, 0.0613],
         [0.0623, 0.0650, 0.0638,  ..., 0.0625, 0.0625, 0.0623],
         ...,
         [0.0616, 0.0629, 0.0632,  ..., 0.0625, 0.0643, 0.0619],
         [0.0631, 0.0627, 0.0630,  ..., 0.0611, 0.0638, 0.0621],
         [0.0626, 0.0636, 0.0637,  ..., 0.0608, 0.0632, 0.0640]],

        [[0.0618, 0.0609, 0.0600,  ..., 0.0644, 0.0624, 0.0606],
         [0.0605, 0.0636, 0.0631,  ..., 0.0628, 0.0626, 0.0607],
         [0.0611, 0.0639, 0.0627,  ..., 0.0621, 0.0623, 0.

 41%|████      | 41/100 [05:47<09:02,  9.19s/it]

tensor([[[0.0599, 0.0627, 0.0647,  ..., 0.0609, 0.0611, 0.0580],
         [0.0591, 0.0621, 0.0652,  ..., 0.0617, 0.0614, 0.0592],
         [0.0590, 0.0619, 0.0659,  ..., 0.0603, 0.0615, 0.0588],
         ...,
         [0.0598, 0.0626, 0.0638,  ..., 0.0640, 0.0618, 0.0585],
         [0.0596, 0.0619, 0.0647,  ..., 0.0623, 0.0615, 0.0596],
         [0.0603, 0.0631, 0.0643,  ..., 0.0631, 0.0614, 0.0590]],

        [[0.0609, 0.0666, 0.0623,  ..., 0.0613, 0.0640, 0.0636],
         [0.0628, 0.0665, 0.0623,  ..., 0.0604, 0.0648, 0.0648],
         [0.0606, 0.0664, 0.0603,  ..., 0.0619, 0.0647, 0.0637],
         ...,
         [0.0625, 0.0652, 0.0617,  ..., 0.0605, 0.0643, 0.0639],
         [0.0627, 0.0662, 0.0610,  ..., 0.0615, 0.0638, 0.0639],
         [0.0628, 0.0665, 0.0610,  ..., 0.0616, 0.0641, 0.0637]],

        [[0.0619, 0.0617, 0.0639,  ..., 0.0611, 0.0618, 0.0610],
         [0.0617, 0.0614, 0.0626,  ..., 0.0631, 0.0619, 0.0628],
         [0.0628, 0.0610, 0.0635,  ..., 0.0604, 0.0623, 0.

 42%|████▏     | 42/100 [05:57<08:51,  9.16s/it]

tensor([[[0.0615, 0.0626, 0.0635,  ..., 0.0641, 0.0641, 0.0609],
         [0.0615, 0.0627, 0.0643,  ..., 0.0639, 0.0638, 0.0602],
         [0.0614, 0.0626, 0.0644,  ..., 0.0632, 0.0635, 0.0604],
         ...,
         [0.0609, 0.0623, 0.0636,  ..., 0.0630, 0.0640, 0.0608],
         [0.0610, 0.0625, 0.0639,  ..., 0.0629, 0.0636, 0.0602],
         [0.0611, 0.0621, 0.0640,  ..., 0.0633, 0.0640, 0.0609]],

        [[0.0627, 0.0627, 0.0623,  ..., 0.0631, 0.0635, 0.0599],
         [0.0629, 0.0644, 0.0628,  ..., 0.0639, 0.0659, 0.0617],
         [0.0631, 0.0637, 0.0619,  ..., 0.0631, 0.0660, 0.0601],
         ...,
         [0.0643, 0.0628, 0.0628,  ..., 0.0631, 0.0638, 0.0611],
         [0.0641, 0.0640, 0.0625,  ..., 0.0627, 0.0648, 0.0605],
         [0.0627, 0.0639, 0.0616,  ..., 0.0622, 0.0640, 0.0632]],

        [[0.0639, 0.0621, 0.0642,  ..., 0.0619, 0.0627, 0.0620],
         [0.0617, 0.0629, 0.0644,  ..., 0.0609, 0.0648, 0.0619],
         [0.0621, 0.0629, 0.0645,  ..., 0.0623, 0.0637, 0.

 43%|████▎     | 43/100 [06:06<08:41,  9.14s/it]

tensor([[[0.0627, 0.0637, 0.0615,  ..., 0.0605, 0.0637, 0.0611],
         [0.0627, 0.0637, 0.0613,  ..., 0.0606, 0.0638, 0.0603],
         [0.0623, 0.0637, 0.0612,  ..., 0.0612, 0.0640, 0.0613],
         ...,
         [0.0631, 0.0627, 0.0611,  ..., 0.0613, 0.0634, 0.0613],
         [0.0624, 0.0629, 0.0611,  ..., 0.0608, 0.0635, 0.0612],
         [0.0622, 0.0632, 0.0609,  ..., 0.0614, 0.0639, 0.0609]],

        [[0.0620, 0.0615, 0.0587,  ..., 0.0637, 0.0643, 0.0578],
         [0.0632, 0.0612, 0.0596,  ..., 0.0617, 0.0643, 0.0580],
         [0.0638, 0.0615, 0.0623,  ..., 0.0627, 0.0653, 0.0592],
         ...,
         [0.0629, 0.0627, 0.0605,  ..., 0.0626, 0.0653, 0.0609],
         [0.0625, 0.0604, 0.0598,  ..., 0.0613, 0.0643, 0.0581],
         [0.0642, 0.0614, 0.0606,  ..., 0.0619, 0.0644, 0.0589]],

        [[0.0625, 0.0616, 0.0627,  ..., 0.0607, 0.0614, 0.0621],
         [0.0645, 0.0622, 0.0635,  ..., 0.0628, 0.0640, 0.0609],
         [0.0629, 0.0630, 0.0637,  ..., 0.0635, 0.0624, 0.

 44%|████▍     | 44/100 [06:14<08:14,  8.82s/it]

tensor([[[0.0649, 0.0659, 0.0591,  ..., 0.0619, 0.0638, 0.0589],
         [0.0637, 0.0654, 0.0591,  ..., 0.0626, 0.0632, 0.0597],
         [0.0654, 0.0669, 0.0599,  ..., 0.0625, 0.0632, 0.0590],
         ...,
         [0.0651, 0.0669, 0.0592,  ..., 0.0626, 0.0627, 0.0585],
         [0.0648, 0.0667, 0.0598,  ..., 0.0620, 0.0630, 0.0592],
         [0.0656, 0.0669, 0.0598,  ..., 0.0623, 0.0638, 0.0588]],

        [[0.0624, 0.0632, 0.0640,  ..., 0.0630, 0.0617, 0.0626],
         [0.0621, 0.0616, 0.0644,  ..., 0.0624, 0.0605, 0.0628],
         [0.0616, 0.0610, 0.0629,  ..., 0.0613, 0.0616, 0.0607],
         ...,
         [0.0622, 0.0614, 0.0642,  ..., 0.0613, 0.0615, 0.0614],
         [0.0626, 0.0608, 0.0662,  ..., 0.0616, 0.0606, 0.0618],
         [0.0622, 0.0623, 0.0658,  ..., 0.0621, 0.0604, 0.0617]],

        [[0.0652, 0.0620, 0.0615,  ..., 0.0619, 0.0617, 0.0603],
         [0.0620, 0.0613, 0.0616,  ..., 0.0633, 0.0620, 0.0634],
         [0.0644, 0.0616, 0.0616,  ..., 0.0626, 0.0623, 0.

 45%|████▌     | 45/100 [06:23<08:10,  8.91s/it]

tensor([[[0.0645, 0.0640, 0.0610,  ..., 0.0623, 0.0603, 0.0642],
         [0.0643, 0.0631, 0.0611,  ..., 0.0617, 0.0602, 0.0640],
         [0.0649, 0.0636, 0.0616,  ..., 0.0619, 0.0604, 0.0636],
         ...,
         [0.0649, 0.0635, 0.0616,  ..., 0.0613, 0.0599, 0.0640],
         [0.0644, 0.0636, 0.0600,  ..., 0.0614, 0.0609, 0.0638],
         [0.0651, 0.0641, 0.0605,  ..., 0.0621, 0.0598, 0.0643]],

        [[0.0622, 0.0628, 0.0626,  ..., 0.0623, 0.0629, 0.0632],
         [0.0603, 0.0635, 0.0621,  ..., 0.0660, 0.0639, 0.0578],
         [0.0623, 0.0620, 0.0621,  ..., 0.0622, 0.0628, 0.0616],
         ...,
         [0.0634, 0.0628, 0.0619,  ..., 0.0647, 0.0652, 0.0594],
         [0.0617, 0.0629, 0.0624,  ..., 0.0637, 0.0657, 0.0580],
         [0.0618, 0.0615, 0.0623,  ..., 0.0637, 0.0630, 0.0640]],

        [[0.0632, 0.0619, 0.0633,  ..., 0.0620, 0.0614, 0.0636],
         [0.0624, 0.0604, 0.0621,  ..., 0.0641, 0.0620, 0.0632],
         [0.0643, 0.0621, 0.0611,  ..., 0.0619, 0.0626, 0.

 46%|████▌     | 46/100 [06:30<07:25,  8.25s/it]

tensor([[[0.0632, 0.0586, 0.0649,  ..., 0.0620, 0.0642, 0.0594],
         [0.0624, 0.0571, 0.0659,  ..., 0.0625, 0.0642, 0.0603],
         [0.0623, 0.0581, 0.0654,  ..., 0.0620, 0.0645, 0.0598],
         ...,
         [0.0613, 0.0584, 0.0656,  ..., 0.0621, 0.0643, 0.0609],
         [0.0592, 0.0577, 0.0648,  ..., 0.0629, 0.0660, 0.0611],
         [0.0601, 0.0578, 0.0644,  ..., 0.0623, 0.0654, 0.0609]],

        [[0.0632, 0.0627, 0.0643,  ..., 0.0608, 0.0616, 0.0612],
         [0.0621, 0.0630, 0.0643,  ..., 0.0600, 0.0640, 0.0644],
         [0.0618, 0.0628, 0.0628,  ..., 0.0620, 0.0630, 0.0640],
         ...,
         [0.0601, 0.0641, 0.0642,  ..., 0.0620, 0.0638, 0.0641],
         [0.0604, 0.0649, 0.0639,  ..., 0.0620, 0.0643, 0.0651],
         [0.0598, 0.0652, 0.0626,  ..., 0.0613, 0.0633, 0.0670]],

        [[0.0617, 0.0617, 0.0638,  ..., 0.0630, 0.0639, 0.0608],
         [0.0631, 0.0613, 0.0639,  ..., 0.0609, 0.0623, 0.0621],
         [0.0628, 0.0623, 0.0630,  ..., 0.0617, 0.0625, 0.

 47%|████▋     | 47/100 [06:37<07:03,  7.99s/it]

tensor([[[0.0608, 0.0618, 0.0586,  ..., 0.0649, 0.0633, 0.0651],
         [0.0599, 0.0623, 0.0594,  ..., 0.0652, 0.0607, 0.0647],
         [0.0597, 0.0615, 0.0599,  ..., 0.0650, 0.0608, 0.0650],
         ...,
         [0.0599, 0.0626, 0.0605,  ..., 0.0644, 0.0614, 0.0641],
         [0.0608, 0.0614, 0.0584,  ..., 0.0649, 0.0638, 0.0652],
         [0.0593, 0.0619, 0.0593,  ..., 0.0657, 0.0606, 0.0647]],

        [[0.0634, 0.0602, 0.0613,  ..., 0.0650, 0.0609, 0.0618],
         [0.0641, 0.0608, 0.0608,  ..., 0.0652, 0.0613, 0.0609],
         [0.0637, 0.0601, 0.0616,  ..., 0.0657, 0.0610, 0.0604],
         ...,
         [0.0624, 0.0606, 0.0617,  ..., 0.0644, 0.0602, 0.0628],
         [0.0622, 0.0608, 0.0626,  ..., 0.0651, 0.0594, 0.0615],
         [0.0620, 0.0604, 0.0615,  ..., 0.0636, 0.0607, 0.0627]],

        [[0.0627, 0.0616, 0.0622,  ..., 0.0624, 0.0628, 0.0644],
         [0.0622, 0.0622, 0.0608,  ..., 0.0626, 0.0638, 0.0623],
         [0.0642, 0.0618, 0.0613,  ..., 0.0617, 0.0615, 0.

 48%|████▊     | 48/100 [06:45<06:52,  7.93s/it]

tensor([[[0.0678, 0.0627, 0.0642,  ..., 0.0612, 0.0612, 0.0641],
         [0.0686, 0.0639, 0.0651,  ..., 0.0613, 0.0598, 0.0650],
         [0.0685, 0.0635, 0.0645,  ..., 0.0607, 0.0613, 0.0639],
         ...,
         [0.0697, 0.0631, 0.0653,  ..., 0.0615, 0.0599, 0.0647],
         [0.0672, 0.0630, 0.0648,  ..., 0.0616, 0.0620, 0.0639],
         [0.0678, 0.0632, 0.0644,  ..., 0.0618, 0.0604, 0.0640]],

        [[0.0629, 0.0616, 0.0624,  ..., 0.0629, 0.0643, 0.0618],
         [0.0632, 0.0618, 0.0615,  ..., 0.0627, 0.0624, 0.0649],
         [0.0638, 0.0622, 0.0605,  ..., 0.0649, 0.0641, 0.0631],
         ...,
         [0.0628, 0.0629, 0.0606,  ..., 0.0628, 0.0639, 0.0663],
         [0.0626, 0.0618, 0.0634,  ..., 0.0632, 0.0634, 0.0636],
         [0.0638, 0.0617, 0.0605,  ..., 0.0638, 0.0624, 0.0654]],

        [[0.0622, 0.0659, 0.0631,  ..., 0.0633, 0.0616, 0.0626],
         [0.0638, 0.0625, 0.0646,  ..., 0.0629, 0.0616, 0.0643],
         [0.0622, 0.0628, 0.0619,  ..., 0.0638, 0.0606, 0.

 49%|████▉     | 49/100 [06:52<06:41,  7.87s/it]

tensor([[[0.0665, 0.0593, 0.0674,  ..., 0.0608, 0.0599, 0.0615],
         [0.0660, 0.0596, 0.0668,  ..., 0.0608, 0.0617, 0.0613],
         [0.0665, 0.0598, 0.0677,  ..., 0.0611, 0.0608, 0.0617],
         ...,
         [0.0659, 0.0593, 0.0674,  ..., 0.0607, 0.0613, 0.0612],
         [0.0660, 0.0607, 0.0677,  ..., 0.0614, 0.0626, 0.0615],
         [0.0662, 0.0596, 0.0675,  ..., 0.0613, 0.0606, 0.0614]],

        [[0.0632, 0.0639, 0.0628,  ..., 0.0604, 0.0632, 0.0608],
         [0.0635, 0.0630, 0.0625,  ..., 0.0628, 0.0643, 0.0603],
         [0.0632, 0.0640, 0.0626,  ..., 0.0626, 0.0643, 0.0621],
         ...,
         [0.0640, 0.0639, 0.0618,  ..., 0.0615, 0.0632, 0.0628],
         [0.0638, 0.0659, 0.0635,  ..., 0.0610, 0.0620, 0.0617],
         [0.0615, 0.0644, 0.0632,  ..., 0.0631, 0.0645, 0.0618]],

        [[0.0627, 0.0635, 0.0637,  ..., 0.0648, 0.0639, 0.0625],
         [0.0628, 0.0624, 0.0638,  ..., 0.0633, 0.0628, 0.0626],
         [0.0631, 0.0617, 0.0640,  ..., 0.0630, 0.0635, 0.

 50%|█████     | 50/100 [07:01<06:36,  7.94s/it]

tensor([[[0.0580, 0.0592, 0.0687,  ..., 0.0635, 0.0604, 0.0632],
         [0.0583, 0.0596, 0.0677,  ..., 0.0626, 0.0627, 0.0630],
         [0.0597, 0.0600, 0.0670,  ..., 0.0635, 0.0615, 0.0637],
         ...,
         [0.0588, 0.0597, 0.0682,  ..., 0.0640, 0.0607, 0.0639],
         [0.0599, 0.0608, 0.0658,  ..., 0.0624, 0.0625, 0.0626],
         [0.0587, 0.0598, 0.0679,  ..., 0.0632, 0.0614, 0.0628]],

        [[0.0628, 0.0633, 0.0630,  ..., 0.0620, 0.0618, 0.0608],
         [0.0617, 0.0625, 0.0636,  ..., 0.0625, 0.0631, 0.0608],
         [0.0621, 0.0624, 0.0636,  ..., 0.0624, 0.0627, 0.0616],
         ...,
         [0.0633, 0.0619, 0.0632,  ..., 0.0618, 0.0629, 0.0620],
         [0.0630, 0.0643, 0.0627,  ..., 0.0627, 0.0617, 0.0626],
         [0.0632, 0.0627, 0.0626,  ..., 0.0627, 0.0628, 0.0631]],

        [[0.0625, 0.0623, 0.0613,  ..., 0.0654, 0.0621, 0.0617],
         [0.0635, 0.0627, 0.0626,  ..., 0.0643, 0.0616, 0.0613],
         [0.0637, 0.0626, 0.0624,  ..., 0.0623, 0.0617, 0.

 51%|█████     | 51/100 [07:08<06:29,  7.95s/it]

tensor([[[0.0634, 0.0628, 0.0644,  ..., 0.0607, 0.0644, 0.0619],
         [0.0637, 0.0637, 0.0635,  ..., 0.0600, 0.0650, 0.0618],
         [0.0642, 0.0623, 0.0647,  ..., 0.0612, 0.0633, 0.0612],
         ...,
         [0.0642, 0.0628, 0.0647,  ..., 0.0597, 0.0640, 0.0614],
         [0.0639, 0.0630, 0.0637,  ..., 0.0601, 0.0641, 0.0610],
         [0.0639, 0.0632, 0.0637,  ..., 0.0597, 0.0640, 0.0617]],

        [[0.0607, 0.0645, 0.0611,  ..., 0.0606, 0.0639, 0.0656],
         [0.0620, 0.0628, 0.0607,  ..., 0.0593, 0.0643, 0.0659],
         [0.0590, 0.0622, 0.0651,  ..., 0.0611, 0.0607, 0.0630],
         ...,
         [0.0601, 0.0635, 0.0620,  ..., 0.0601, 0.0627, 0.0653],
         [0.0615, 0.0639, 0.0607,  ..., 0.0591, 0.0648, 0.0656],
         [0.0613, 0.0636, 0.0606,  ..., 0.0587, 0.0635, 0.0657]],

        [[0.0612, 0.0613, 0.0649,  ..., 0.0635, 0.0619, 0.0635],
         [0.0621, 0.0637, 0.0636,  ..., 0.0626, 0.0613, 0.0601],
         [0.0612, 0.0596, 0.0640,  ..., 0.0617, 0.0613, 0.

 52%|█████▏    | 52/100 [07:16<06:08,  7.67s/it]

tensor([[[0.0659, 0.0669, 0.0652,  ..., 0.0614, 0.0594, 0.0596],
         [0.0662, 0.0677, 0.0655,  ..., 0.0616, 0.0593, 0.0587],
         [0.0651, 0.0667, 0.0653,  ..., 0.0621, 0.0594, 0.0595],
         ...,
         [0.0658, 0.0666, 0.0647,  ..., 0.0617, 0.0598, 0.0594],
         [0.0644, 0.0657, 0.0651,  ..., 0.0614, 0.0597, 0.0609],
         [0.0657, 0.0674, 0.0664,  ..., 0.0615, 0.0588, 0.0594]],

        [[0.0641, 0.0625, 0.0624,  ..., 0.0642, 0.0656, 0.0679],
         [0.0638, 0.0630, 0.0624,  ..., 0.0636, 0.0638, 0.0656],
         [0.0632, 0.0618, 0.0619,  ..., 0.0639, 0.0628, 0.0647],
         ...,
         [0.0631, 0.0615, 0.0623,  ..., 0.0641, 0.0653, 0.0662],
         [0.0625, 0.0625, 0.0628,  ..., 0.0637, 0.0647, 0.0651],
         [0.0637, 0.0622, 0.0628,  ..., 0.0642, 0.0640, 0.0671]],

        [[0.0632, 0.0646, 0.0628,  ..., 0.0614, 0.0619, 0.0638],
         [0.0630, 0.0679, 0.0620,  ..., 0.0618, 0.0611, 0.0630],
         [0.0618, 0.0635, 0.0619,  ..., 0.0620, 0.0630, 0.

 53%|█████▎    | 53/100 [07:23<05:55,  7.56s/it]

tensor([[[0.0653, 0.0677, 0.0620,  ..., 0.0644, 0.0641, 0.0620],
         [0.0654, 0.0667, 0.0621,  ..., 0.0643, 0.0644, 0.0622],
         [0.0645, 0.0644, 0.0637,  ..., 0.0642, 0.0629, 0.0625],
         ...,
         [0.0654, 0.0667, 0.0631,  ..., 0.0642, 0.0635, 0.0624],
         [0.0658, 0.0670, 0.0627,  ..., 0.0650, 0.0648, 0.0620],
         [0.0651, 0.0661, 0.0623,  ..., 0.0647, 0.0639, 0.0626]],

        [[0.0651, 0.0613, 0.0591,  ..., 0.0636, 0.0628, 0.0613],
         [0.0638, 0.0617, 0.0599,  ..., 0.0621, 0.0629, 0.0622],
         [0.0637, 0.0628, 0.0605,  ..., 0.0614, 0.0631, 0.0615],
         ...,
         [0.0636, 0.0617, 0.0609,  ..., 0.0609, 0.0622, 0.0628],
         [0.0646, 0.0618, 0.0588,  ..., 0.0617, 0.0631, 0.0621],
         [0.0646, 0.0615, 0.0590,  ..., 0.0620, 0.0632, 0.0622]],

        [[0.0657, 0.0633, 0.0613,  ..., 0.0625, 0.0625, 0.0646],
         [0.0636, 0.0634, 0.0620,  ..., 0.0622, 0.0629, 0.0637],
         [0.0664, 0.0643, 0.0617,  ..., 0.0627, 0.0614, 0.

 54%|█████▍    | 54/100 [07:31<05:50,  7.62s/it]

tensor([[[0.0617, 0.0608, 0.0616,  ..., 0.0628, 0.0618, 0.0613],
         [0.0613, 0.0617, 0.0620,  ..., 0.0627, 0.0627, 0.0606],
         [0.0607, 0.0621, 0.0612,  ..., 0.0626, 0.0621, 0.0607],
         ...,
         [0.0599, 0.0614, 0.0602,  ..., 0.0634, 0.0632, 0.0602],
         [0.0603, 0.0620, 0.0603,  ..., 0.0625, 0.0621, 0.0604],
         [0.0611, 0.0607, 0.0614,  ..., 0.0623, 0.0620, 0.0595]],

        [[0.0630, 0.0600, 0.0636,  ..., 0.0615, 0.0630, 0.0640],
         [0.0648, 0.0612, 0.0630,  ..., 0.0599, 0.0635, 0.0624],
         [0.0635, 0.0609, 0.0630,  ..., 0.0604, 0.0624, 0.0628],
         ...,
         [0.0629, 0.0607, 0.0621,  ..., 0.0628, 0.0637, 0.0630],
         [0.0618, 0.0623, 0.0625,  ..., 0.0610, 0.0626, 0.0630],
         [0.0636, 0.0614, 0.0633,  ..., 0.0626, 0.0633, 0.0630]],

        [[0.0628, 0.0589, 0.0638,  ..., 0.0633, 0.0636, 0.0622],
         [0.0638, 0.0617, 0.0626,  ..., 0.0637, 0.0638, 0.0620],
         [0.0624, 0.0608, 0.0639,  ..., 0.0632, 0.0633, 0.

 55%|█████▌    | 55/100 [07:39<05:49,  7.77s/it]

tensor([[[0.0585, 0.0628, 0.0585,  ..., 0.0622, 0.0577, 0.0631],
         [0.0576, 0.0633, 0.0590,  ..., 0.0623, 0.0589, 0.0618],
         [0.0588, 0.0630, 0.0590,  ..., 0.0626, 0.0574, 0.0631],
         ...,
         [0.0591, 0.0631, 0.0594,  ..., 0.0630, 0.0584, 0.0621],
         [0.0580, 0.0620, 0.0584,  ..., 0.0623, 0.0594, 0.0627],
         [0.0585, 0.0621, 0.0594,  ..., 0.0623, 0.0572, 0.0634]],

        [[0.0592, 0.0599, 0.0642,  ..., 0.0617, 0.0636, 0.0628],
         [0.0602, 0.0613, 0.0644,  ..., 0.0608, 0.0636, 0.0626],
         [0.0584, 0.0614, 0.0639,  ..., 0.0609, 0.0629, 0.0636],
         ...,
         [0.0596, 0.0614, 0.0657,  ..., 0.0608, 0.0640, 0.0635],
         [0.0580, 0.0600, 0.0633,  ..., 0.0620, 0.0647, 0.0615],
         [0.0625, 0.0618, 0.0645,  ..., 0.0601, 0.0630, 0.0647]],

        [[0.0598, 0.0628, 0.0606,  ..., 0.0631, 0.0632, 0.0640],
         [0.0612, 0.0606, 0.0609,  ..., 0.0626, 0.0624, 0.0651],
         [0.0615, 0.0646, 0.0618,  ..., 0.0634, 0.0616, 0.

 56%|█████▌    | 56/100 [07:46<05:39,  7.72s/it]

tensor([[[0.0608, 0.0612, 0.0624,  ..., 0.0631, 0.0639, 0.0605],
         [0.0600, 0.0617, 0.0624,  ..., 0.0636, 0.0633, 0.0611],
         [0.0612, 0.0607, 0.0616,  ..., 0.0633, 0.0634, 0.0614],
         ...,
         [0.0621, 0.0612, 0.0617,  ..., 0.0637, 0.0642, 0.0615],
         [0.0600, 0.0612, 0.0623,  ..., 0.0629, 0.0641, 0.0606],
         [0.0604, 0.0612, 0.0623,  ..., 0.0643, 0.0632, 0.0607]],

        [[0.0654, 0.0641, 0.0602,  ..., 0.0649, 0.0628, 0.0641],
         [0.0650, 0.0638, 0.0562,  ..., 0.0659, 0.0621, 0.0639],
         [0.0667, 0.0623, 0.0605,  ..., 0.0621, 0.0617, 0.0621],
         ...,
         [0.0637, 0.0639, 0.0586,  ..., 0.0651, 0.0629, 0.0634],
         [0.0638, 0.0625, 0.0594,  ..., 0.0642, 0.0628, 0.0627],
         [0.0647, 0.0635, 0.0573,  ..., 0.0659, 0.0620, 0.0636]],

        [[0.0646, 0.0599, 0.0611,  ..., 0.0626, 0.0623, 0.0607],
         [0.0628, 0.0610, 0.0626,  ..., 0.0639, 0.0616, 0.0610],
         [0.0646, 0.0608, 0.0603,  ..., 0.0626, 0.0623, 0.

 57%|█████▋    | 57/100 [07:55<05:42,  7.96s/it]

tensor([[[0.0601, 0.0640, 0.0598,  ..., 0.0621, 0.0628, 0.0660],
         [0.0620, 0.0640, 0.0615,  ..., 0.0617, 0.0624, 0.0641],
         [0.0603, 0.0636, 0.0600,  ..., 0.0623, 0.0616, 0.0666],
         ...,
         [0.0596, 0.0623, 0.0601,  ..., 0.0626, 0.0624, 0.0658],
         [0.0604, 0.0633, 0.0595,  ..., 0.0632, 0.0617, 0.0672],
         [0.0594, 0.0621, 0.0599,  ..., 0.0627, 0.0634, 0.0659]],

        [[0.0634, 0.0609, 0.0630,  ..., 0.0631, 0.0649, 0.0641],
         [0.0629, 0.0619, 0.0629,  ..., 0.0635, 0.0636, 0.0637],
         [0.0630, 0.0628, 0.0628,  ..., 0.0617, 0.0627, 0.0632],
         ...,
         [0.0650, 0.0617, 0.0620,  ..., 0.0630, 0.0635, 0.0642],
         [0.0643, 0.0614, 0.0619,  ..., 0.0636, 0.0640, 0.0655],
         [0.0637, 0.0616, 0.0624,  ..., 0.0625, 0.0640, 0.0641]],

        [[0.0609, 0.0623, 0.0622,  ..., 0.0625, 0.0617, 0.0636],
         [0.0612, 0.0641, 0.0629,  ..., 0.0630, 0.0622, 0.0630],
         [0.0610, 0.0620, 0.0622,  ..., 0.0627, 0.0634, 0.

 58%|█████▊    | 58/100 [08:01<05:14,  7.48s/it]

tensor([[[0.0625, 0.0633, 0.0612,  ..., 0.0643, 0.0619, 0.0602],
         [0.0637, 0.0636, 0.0613,  ..., 0.0644, 0.0619, 0.0600],
         [0.0634, 0.0628, 0.0614,  ..., 0.0637, 0.0624, 0.0606],
         ...,
         [0.0637, 0.0635, 0.0608,  ..., 0.0643, 0.0624, 0.0604],
         [0.0634, 0.0639, 0.0613,  ..., 0.0644, 0.0625, 0.0602],
         [0.0639, 0.0643, 0.0604,  ..., 0.0651, 0.0628, 0.0601]],

        [[0.0647, 0.0634, 0.0603,  ..., 0.0631, 0.0615, 0.0594],
         [0.0642, 0.0631, 0.0602,  ..., 0.0615, 0.0618, 0.0600],
         [0.0622, 0.0622, 0.0616,  ..., 0.0634, 0.0647, 0.0608],
         ...,
         [0.0627, 0.0621, 0.0606,  ..., 0.0632, 0.0640, 0.0605],
         [0.0614, 0.0625, 0.0607,  ..., 0.0625, 0.0646, 0.0628],
         [0.0615, 0.0624, 0.0607,  ..., 0.0633, 0.0642, 0.0616]],

        [[0.0637, 0.0614, 0.0616,  ..., 0.0640, 0.0639, 0.0621],
         [0.0630, 0.0619, 0.0643,  ..., 0.0622, 0.0632, 0.0634],
         [0.0652, 0.0624, 0.0628,  ..., 0.0622, 0.0623, 0.

 59%|█████▉    | 59/100 [08:09<05:07,  7.50s/it]

tensor([[[0.0614, 0.0628, 0.0578,  ..., 0.0632, 0.0650, 0.0615],
         [0.0607, 0.0627, 0.0576,  ..., 0.0635, 0.0655, 0.0628],
         [0.0618, 0.0625, 0.0577,  ..., 0.0637, 0.0651, 0.0623],
         ...,
         [0.0612, 0.0624, 0.0572,  ..., 0.0635, 0.0658, 0.0627],
         [0.0613, 0.0626, 0.0577,  ..., 0.0639, 0.0653, 0.0623],
         [0.0610, 0.0623, 0.0566,  ..., 0.0641, 0.0657, 0.0620]],

        [[0.0635, 0.0605, 0.0636,  ..., 0.0616, 0.0622, 0.0637],
         [0.0635, 0.0599, 0.0644,  ..., 0.0627, 0.0623, 0.0620],
         [0.0637, 0.0608, 0.0646,  ..., 0.0632, 0.0624, 0.0620],
         ...,
         [0.0625, 0.0616, 0.0629,  ..., 0.0637, 0.0624, 0.0615],
         [0.0663, 0.0593, 0.0633,  ..., 0.0622, 0.0632, 0.0622],
         [0.0641, 0.0623, 0.0625,  ..., 0.0655, 0.0610, 0.0616]],

        [[0.0613, 0.0640, 0.0616,  ..., 0.0623, 0.0617, 0.0627],
         [0.0613, 0.0614, 0.0612,  ..., 0.0635, 0.0620, 0.0635],
         [0.0594, 0.0637, 0.0609,  ..., 0.0648, 0.0617, 0.

 60%|██████    | 60/100 [08:18<05:20,  8.01s/it]

tensor([[[0.0645, 0.0624, 0.0657,  ..., 0.0586, 0.0628, 0.0633],
         [0.0632, 0.0634, 0.0653,  ..., 0.0595, 0.0625, 0.0641],
         [0.0635, 0.0631, 0.0651,  ..., 0.0596, 0.0629, 0.0641],
         ...,
         [0.0635, 0.0631, 0.0661,  ..., 0.0587, 0.0632, 0.0635],
         [0.0652, 0.0635, 0.0660,  ..., 0.0595, 0.0626, 0.0632],
         [0.0653, 0.0634, 0.0665,  ..., 0.0594, 0.0620, 0.0634]],

        [[0.0624, 0.0607, 0.0628,  ..., 0.0622, 0.0638, 0.0633],
         [0.0631, 0.0597, 0.0622,  ..., 0.0625, 0.0647, 0.0621],
         [0.0643, 0.0604, 0.0632,  ..., 0.0628, 0.0639, 0.0633],
         ...,
         [0.0639, 0.0594, 0.0618,  ..., 0.0635, 0.0646, 0.0633],
         [0.0641, 0.0591, 0.0619,  ..., 0.0630, 0.0648, 0.0627],
         [0.0643, 0.0570, 0.0627,  ..., 0.0626, 0.0658, 0.0631]],

        [[0.0612, 0.0628, 0.0627,  ..., 0.0629, 0.0649, 0.0627],
         [0.0623, 0.0612, 0.0637,  ..., 0.0610, 0.0655, 0.0618],
         [0.0618, 0.0630, 0.0636,  ..., 0.0618, 0.0641, 0.

In [34]:
model.eval()
for input_ids, attention_masks, labels in dataloader_val:
    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)
    labels = labels.to(device)
    
    optimizer.zero_grad()
    outputs, _ = model(input_ids=input_ids, attention_mask=attention_masks)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
        
    print(model.decoder.attn_output_weights)

torch.Size([16, 245, 768])
tensor([[-0.3105, -0.2510, -0.2908],
        [-0.1651,  0.1420, -0.2014],
        [-0.0106,  0.0482, -0.3460],
        [-0.0115,  0.1788, -0.0013],
        [ 0.1326, -0.1570, -0.2386],
        [-0.4345, -0.4241, -0.3028],
        [-0.5563, -0.3907, -0.4480],
        [-0.2604,  0.0235, -0.4607],
        [-0.2776,  0.1797, -0.2972],
        [-0.6022, -0.0675, -0.2426],
        [-0.1411, -0.0051, -0.2707],
        [-0.4610, -0.3146, -0.3775],
        [-0.3541, -0.5927, -0.0407],
        [-0.5542, -0.2998, -0.4390],
        [-0.5387,  0.0100,  0.0196],
        [ 0.1234, -0.0463,  0.3524]])
tensor([[[0.0614, 0.0625, 0.0624,  ..., 0.0628, 0.0622, 0.0627],
         [0.0613, 0.0625, 0.0622,  ..., 0.0630, 0.0622, 0.0630],
         [0.0612, 0.0626, 0.0624,  ..., 0.0629, 0.0619, 0.0632],
         ...,
         [0.0613, 0.0625, 0.0624,  ..., 0.0627, 0.0622, 0.0630],
         [0.0615, 0.0624, 0.0621,  ..., 0.0627, 0.0619, 0.0631],
         [0.0616, 0.0624, 0.0621,  ..., 0


KeyboardInterrupt



In [None]:
for epoch in range(max_num_epochs):
    if (early_stopping_count >= max_early_stopping):
        break
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            early_stopping_count += 1
        else:
            model.eval()
        
        curr_ce = 0
        curr_accuracy = 0
        actual = torch.tensor([]).long().to(device)
        pred = torch.tensor([]).long().to(device)

        for input_ids, attention_masks, labels in dataloaders_dict[phase]:
            input_ids = input_ids.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(input_ids = input_ids, attention_mask = attention_masks, labels=labels)
                loss = outputs.loss
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                else:
                    curr_ce += loss.item() * input_ids.size(0)
                    curr_accuracy += torch.sum(torch.max(outputs.logits, 1)[1] == labels).item()
                    actual = torch.cat([actual, labels], dim=0)
                    pred= torch.cat([pred, torch.max(outputs.logits, 1)[1]], dim=0)
        if phase == 'val':
            curr_ce = curr_ce / len(val)
            curr_accuracy = curr_accuracy / len(val)
            currF1 = f1_score(actual.cpu().detach().numpy(), pred.cpu().detach().numpy(), average='weighted')
            if curr_ce <= best_ce - eps:
                best_ce = curr_ce
                early_stopping_count = 0
            if curr_accuracy >= best_accuracy + eps:
                best_accuracy = curr_accuracy
                early_stopping_count = 0
            if currF1 >= best_f1 + eps:
                best_f1 = currF1
                early_stopping_count = 0
            print("Val CE: ", curr_ce)
            print("Val Accuracy: ", curr_accuracy)
            print("Val F1: ", currF1)
            print("Early Stopping Count: ", early_stopping_count)
 

In [None]:
# save model
if save_model_path != None:
    model.save_pretrained(save_model_path)
    tokenizer.save_pretrained(save_model_path)

In [None]:
# Il faut faire un plt.imshow pour visualiser les poids d'attention facilement