In [1]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from tqdm import tqdm
import torch.nn.functional as F

### Utility function

In [2]:
def preprocessing(arguments, key_points, labels):
    columns = ['topic', 'premise', 'hypothesis', 'label']
    data = []
    for index, row in labels.iterrows():
        argument_row = arguments.loc[arguments['arg_id'] == row['arg_id']]
        topic = argument_row['topic'].iloc[0]
        premise = argument_row['argument'].iloc[0]
        kp_row = key_points.loc[key_points['key_point_id'] == row['key_point_id']]
        hypothesis = kp_row['key_point'].iloc[0]
        label = row['label']
        data.append([topic, premise, hypothesis, label])

    result = pd.DataFrame(data, columns=columns)
    return result

### Create Torch Dataset

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, text_df, tokenizer, max_length):
        self.premises = list(text_df['premise'].values+tokenizer.sep_token+text_df['topic'].values)
        self.hypotheses = list(text_df['hypothesis'].values+tokenizer.sep_token+text_df['topic'].values)
        self.labels = list(text_df['label'])
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.premises)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.premises[idx],
            self.hypotheses[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }
        return item


### Define the LightningModule

In [4]:
class KPM(pl.LightningModule):
    def __init__(self, model, learning_rate=2e-5, weight_decay=0.001):
        super(KPM, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        outputs = self(input_ids, attention_mask, labels)
        print("--------------training-------------------------------")
        print("Shape: "+ str(outputs.logits.shape))
        print(outputs.logits)
        self.log("train_loss", torch.clone(outputs.loss).detach())
        one_hot_labels = torch.stack([1 - labels, labels], dim=-1)
        one_hot_labels = one_hot_labels.float()
        # loss = outputs.loss
        loss = F.binary_cross_entropy_with_logits(outputs.logits, one_hot_labels, reduction='mean')
        self.log("train_loss", torch.clone(loss).detach())
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", torch.clone(outputs.loss).detach())
        one_hot_labels = torch.stack([1 - labels, labels], dim=-1)
        one_hot_labels = one_hot_labels.float()
        # loss = outputs.loss
        loss = F.binary_cross_entropy_with_logits(outputs.logits, one_hot_labels, reduction='mean')
        self.log("val_loss", torch.clone(loss).detach())
        return  loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

### Import Dataset and Preprocessing

In [5]:
train_argument_df = pd.read_csv('./data/arguments_train.csv')
train_kp_df = pd.read_csv('./data/key_points_train.csv')
train_label_df = pd.read_csv('./data/labels_train.csv')
train_df = preprocessing(train_argument_df, train_kp_df, train_label_df)


val_argument_df = pd.read_csv('./data/arguments_dev.csv')
val_kp_df = pd.read_csv('./data/key_points_dev.csv')
val_label_df = pd.read_csv('./data/labels_dev.csv')
val_df = preprocessing(val_argument_df, val_kp_df, val_label_df)

test_argument_df = pd.read_csv('./data/arguments_test.csv')
test_kp_df = pd.read_csv('./data/key_points_test.csv')
test_label_df = pd.read_csv('./data/labels_test.csv')
test_df = preprocessing(test_argument_df, test_kp_df, test_label_df)

### Training phase

In [1]:
model_name = "cross-encoder/nli-distilroberta-base"
num_classes = 2
max_length = 512
batch_size = 16
learning_rate = 5e-05
weight_decay = 0.001

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes = True)
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
# model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)
# Change the loss function to binary cross-entropy

# Get the vocabulary
vocab = tokenizer.get_vocab()

# Check if the "CLS" token is present in the vocabulary
cls_token_present = tokenizer.cls_token in vocab
sep_token_present = tokenizer.sep_token in vocab
print("Is CLS token present in the vocabulary?", cls_token_present)
print("Is SEP token present in the vocabulary?", sep_token_present)

train_dataset = Dataset(train_df, tokenizer, max_length)
val_dataset = Dataset(val_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = KPM(model, learning_rate)

# Define a ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoint',
    filename='nli_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping =EarlyStopping(
    monitor="val_loss",
    min_delta=0.01,
    patience=3
)

# define trainer
trainer = Trainer(
    min_epochs = 0, # change this
    max_epochs = 20,
    callbacks=[checkpoint_callback, early_stopping],
    accelerator="auto",
    #progress_bar_refresh_rate=30,
    #gpus = 1 if device.type == 'cuda' else 0
    devices = 1 if torch.cuda.is_available() else None
)

torch.cuda.empty_cache()

# start training
trainer.fit(model, train_loader, val_loader)
print("Training is finished")

NameError: name 'AutoTokenizer' is not defined

### Prediction and Evaluation

In [8]:
# Assuming you have already trained the model and saved a checkpoint
checkpoint_path = './checkpoint/nli_model-epoch=01-val_loss=0.37.ckpt'
# Load the model from the checkpoint
model_name = "cross-encoder/nli-distilroberta-base"
num_classes = 2
max_length = 512
batch_size = 16
learning_rate = 5e-05
weight_decay = 0.001

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes = True)
loaded_model = KPM.load_from_checkpoint(model=model, checkpoint_path=checkpoint_path)
print("hello world")

# Set the model to evaluation mode
loaded_model.eval()

test_dataset = Dataset(test_df, tokenizer, max_length)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Make predictions on the test set with probabilities
all_predictions = []
all_probabilities = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting"):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        outputs = loaded_model(input_ids, attention_mask)
        logits = outputs.logits
        # print(logits)
        probabilities = F.softmax(logits, dim=1)
        predictions = torch.argmax(logits, dim=1).tolist()
        # print(predictions)
        print(probabilities)
        all_predictions.extend(predictions)
        all_probabilities.extend(probabilities.tolist())


print(all_predictions)
print(all_probabilities)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at cross-encoder/nli-distilroberta-base and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


hello world


Predicting:   0%|          | 1/215 [00:08<30:54,  8.66s/it]

tensor([[8.5138e-01, 1.4862e-01],
        [9.5304e-01, 4.6955e-02],
        [1.8656e-02, 9.8134e-01],
        [9.6801e-05, 9.9990e-01],
        [4.9789e-02, 9.5021e-01],
        [9.8562e-01, 1.4378e-02],
        [9.9964e-01, 3.6004e-04],
        [7.0038e-01, 2.9962e-01],
        [9.7845e-01, 2.1546e-02],
        [9.4093e-05, 9.9991e-01],
        [5.4960e-01, 4.5040e-01],
        [1.6561e-02, 9.8344e-01],
        [9.8758e-01, 1.2423e-02],
        [9.4357e-01, 5.6434e-02],
        [9.1618e-05, 9.9991e-01],
        [8.9708e-01, 1.0292e-01]])


Predicting:   1%|          | 2/215 [00:17<30:35,  8.62s/it]

tensor([[6.7236e-03, 9.9328e-01],
        [5.3779e-01, 4.6221e-01],
        [9.7529e-01, 2.4710e-02],
        [2.2053e-02, 9.7795e-01],
        [1.0838e-04, 9.9989e-01],
        [6.5526e-01, 3.4474e-01],
        [3.8252e-02, 9.6175e-01],
        [9.9896e-01, 1.0414e-03],
        [9.9976e-01, 2.3718e-04],
        [3.0909e-02, 9.6909e-01],
        [2.2005e-01, 7.7995e-01],
        [9.9688e-01, 3.1177e-03],
        [9.9962e-01, 3.7946e-04],
        [2.4663e-01, 7.5337e-01],
        [2.9282e-02, 9.7072e-01],
        [9.9918e-01, 8.1639e-04]])


Predicting:   1%|▏         | 3/215 [00:26<31:07,  8.81s/it]

tensor([[9.9992e-01, 8.4867e-05],
        [3.4080e-01, 6.5920e-01],
        [9.9291e-01, 7.0858e-03],
        [3.7760e-01, 6.2240e-01],
        [5.7273e-01, 4.2727e-01],
        [9.0574e-01, 9.4261e-02],
        [9.9873e-01, 1.2651e-03],
        [6.9098e-01, 3.0902e-01],
        [9.9807e-01, 1.9325e-03],
        [9.9972e-01, 2.7894e-04],
        [9.9972e-01, 2.7996e-04],
        [9.5506e-01, 4.4941e-02],
        [7.7821e-01, 2.2179e-01],
        [5.9761e-01, 4.0239e-01],
        [9.8994e-01, 1.0056e-02],
        [6.9624e-01, 3.0376e-01]])


Predicting:   2%|▏         | 4/215 [00:35<31:22,  8.92s/it]

tensor([[0.2280, 0.7720],
        [0.8114, 0.1886],
        [0.8993, 0.1007],
        [0.4028, 0.5972],
        [0.6113, 0.3887],
        [0.9486, 0.0514],
        [0.0100, 0.9900],
        [0.9903, 0.0097],
        [0.7484, 0.2516],
        [0.6522, 0.3478],
        [0.9667, 0.0333],
        [0.9774, 0.0226],
        [0.1717, 0.8283],
        [0.0027, 0.9973],
        [0.9744, 0.0256],
        [0.9932, 0.0068]])


Predicting:   2%|▏         | 5/215 [00:44<31:20,  8.95s/it]

tensor([[7.0299e-01, 2.9701e-01],
        [9.5924e-01, 4.0762e-02],
        [9.4397e-01, 5.6031e-02],
        [2.2782e-01, 7.7218e-01],
        [7.2584e-01, 2.7416e-01],
        [2.1607e-01, 7.8393e-01],
        [9.8230e-01, 1.7695e-02],
        [6.1005e-01, 3.8995e-01],
        [9.6112e-01, 3.8878e-02],
        [9.9412e-01, 5.8772e-03],
        [9.9790e-01, 2.1039e-03],
        [1.1014e-04, 9.9989e-01],
        [3.0621e-02, 9.6938e-01],
        [9.3580e-02, 9.0642e-01],
        [9.9972e-01, 2.8334e-04],
        [9.7168e-01, 2.8317e-02]])


Predicting:   3%|▎         | 6/215 [00:53<30:51,  8.86s/it]

tensor([[9.9942e-01, 5.8498e-04],
        [1.5211e-02, 9.8479e-01],
        [9.4825e-01, 5.1755e-02],
        [1.1681e-04, 9.9988e-01],
        [9.7332e-01, 2.6682e-02],
        [3.1843e-02, 9.6816e-01],
        [7.9749e-02, 9.2025e-01],
        [4.5767e-01, 5.4233e-01],
        [1.0020e-04, 9.9990e-01],
        [1.1778e-04, 9.9988e-01],
        [4.3869e-04, 9.9956e-01],
        [9.4572e-01, 5.4282e-02],
        [9.6092e-01, 3.9079e-02],
        [9.9945e-01, 5.5292e-04],
        [7.9149e-01, 2.0851e-01],
        [1.3176e-04, 9.9987e-01]])


Predicting:   3%|▎         | 7/215 [01:01<30:05,  8.68s/it]

tensor([[6.5640e-01, 3.4360e-01],
        [4.0177e-01, 5.9823e-01],
        [1.0558e-02, 9.8944e-01],
        [9.6762e-01, 3.2378e-02],
        [8.9331e-01, 1.0669e-01],
        [9.7605e-01, 2.3948e-02],
        [1.1413e-03, 9.9886e-01],
        [3.4854e-04, 9.9965e-01],
        [9.9859e-01, 1.4121e-03],
        [9.9986e-01, 1.3503e-04],
        [9.9815e-01, 1.8461e-03],
        [9.7598e-01, 2.4021e-02],
        [9.9514e-01, 4.8618e-03],
        [8.9313e-01, 1.0687e-01],
        [9.9681e-01, 3.1947e-03],
        [8.4308e-01, 1.5692e-01]])


Predicting:   4%|▎         | 8/215 [01:10<29:53,  8.66s/it]

tensor([[3.4053e-03, 9.9659e-01],
        [7.7793e-01, 2.2207e-01],
        [4.7364e-01, 5.2636e-01],
        [4.2829e-04, 9.9957e-01],
        [8.5148e-01, 1.4852e-01],
        [3.8420e-01, 6.1580e-01],
        [9.2104e-04, 9.9908e-01],
        [7.6997e-01, 2.3003e-01],
        [5.0201e-01, 4.9799e-01],
        [9.6394e-01, 3.6059e-02],
        [7.0426e-01, 2.9574e-01],
        [9.9460e-01, 5.4027e-03],
        [9.0156e-01, 9.8440e-02],
        [1.0500e-04, 9.9989e-01],
        [9.5439e-01, 4.5613e-02],
        [1.2916e-04, 9.9987e-01]])


Predicting:   4%|▍         | 9/215 [01:18<29:43,  8.66s/it]

tensor([[9.8532e-01, 1.4685e-02],
        [3.3557e-02, 9.6644e-01],
        [1.0144e-04, 9.9990e-01],
        [9.8718e-01, 1.2821e-02],
        [1.0956e-04, 9.9989e-01],
        [9.7814e-01, 2.1857e-02],
        [2.2061e-04, 9.9978e-01],
        [9.7826e-01, 2.1741e-02],
        [9.5286e-05, 9.9990e-01],
        [9.9161e-01, 8.3888e-03],
        [2.6652e-03, 9.9733e-01],
        [1.0481e-04, 9.9990e-01],
        [9.8381e-01, 1.6194e-02],
        [9.9935e-01, 6.4805e-04],
        [9.9740e-01, 2.6019e-03],
        [9.8068e-01, 1.9321e-02]])


Predicting:   5%|▍         | 10/215 [01:27<30:12,  8.84s/it]

tensor([[9.9984e-01, 1.5841e-04],
        [4.1447e-02, 9.5855e-01],
        [5.2917e-01, 4.7083e-01],
        [9.7015e-05, 9.9990e-01],
        [9.9976e-01, 2.3498e-04],
        [9.9992e-01, 7.6310e-05],
        [2.5072e-01, 7.4928e-01],
        [9.9947e-01, 5.3200e-04],
        [9.9984e-01, 1.6241e-04],
        [9.6987e-01, 3.0132e-02],
        [9.9803e-01, 1.9669e-03],
        [2.4339e-01, 7.5661e-01],
        [4.1281e-01, 5.8719e-01],
        [1.0832e-01, 8.9168e-01],
        [9.9901e-01, 9.9356e-04],
        [9.9966e-01, 3.4383e-04]])


Predicting:   5%|▌         | 11/215 [01:36<29:29,  8.67s/it]

tensor([[7.7159e-01, 2.2841e-01],
        [1.5054e-04, 9.9985e-01],
        [9.0481e-01, 9.5188e-02],
        [9.9058e-01, 9.4224e-03],
        [9.7326e-03, 9.9027e-01],
        [9.9883e-01, 1.1713e-03],
        [9.9987e-01, 1.2858e-04],
        [9.8761e-01, 1.2392e-02],
        [6.3166e-01, 3.6834e-01],
        [1.0109e-04, 9.9990e-01],
        [9.1279e-01, 8.7210e-02],
        [5.2709e-03, 9.9473e-01],
        [9.7276e-05, 9.9990e-01],
        [9.5241e-01, 4.7590e-02],
        [1.8199e-04, 9.9982e-01],
        [9.9972e-01, 2.8190e-04]])


Predicting:   6%|▌         | 12/215 [01:45<29:31,  8.73s/it]

tensor([[7.9193e-01, 2.0807e-01],
        [7.1130e-01, 2.8870e-01],
        [1.1039e-04, 9.9989e-01],
        [9.5515e-01, 4.4852e-02],
        [6.0760e-01, 3.9240e-01],
        [6.5782e-03, 9.9342e-01],
        [8.1839e-01, 1.8161e-01],
        [1.6458e-04, 9.9984e-01],
        [7.3700e-01, 2.6300e-01],
        [1.2281e-01, 8.7719e-01],
        [1.2125e-01, 8.7875e-01],
        [9.6328e-05, 9.9990e-01],
        [6.1927e-02, 9.3807e-01],
        [2.1752e-02, 9.7825e-01],
        [3.9193e-01, 6.0807e-01],
        [2.3790e-04, 9.9976e-01]])


Predicting:   6%|▌         | 13/215 [01:53<28:46,  8.55s/it]

tensor([[9.9129e-01, 8.7142e-03],
        [9.9967e-01, 3.2711e-04],
        [9.9917e-01, 8.3206e-04],
        [9.6419e-01, 3.5808e-02],
        [9.9974e-01, 2.6319e-04],
        [3.9469e-02, 9.6053e-01],
        [9.9880e-01, 1.2022e-03],
        [9.9759e-01, 2.4150e-03],
        [9.9920e-01, 8.0200e-04],
        [6.7923e-01, 3.2077e-01],
        [8.0091e-02, 9.1991e-01],
        [9.9986e-01, 1.4270e-04],
        [9.9994e-01, 6.0218e-05],
        [9.4810e-01, 5.1901e-02],
        [1.2177e-04, 9.9988e-01],
        [9.9840e-01, 1.6038e-03]])


Predicting:   7%|▋         | 14/215 [02:02<29:40,  8.86s/it]

tensor([[9.9892e-01, 1.0834e-03],
        [9.9647e-01, 3.5273e-03],
        [2.0777e-01, 7.9223e-01],
        [9.1631e-05, 9.9991e-01],
        [9.9973e-01, 2.7326e-04],
        [9.9986e-01, 1.3922e-04],
        [8.7451e-01, 1.2549e-01],
        [1.4433e-04, 9.9986e-01],
        [1.1395e-04, 9.9989e-01],
        [9.8410e-01, 1.5897e-02],
        [8.9153e-01, 1.0847e-01],
        [9.9946e-01, 5.4046e-04],
        [9.9329e-01, 6.7100e-03],
        [1.7247e-01, 8.2753e-01],
        [3.1708e-01, 6.8292e-01],
        [9.1424e-01, 8.5763e-02]])


Predicting:   7%|▋         | 15/215 [02:11<29:40,  8.90s/it]

tensor([[9.9430e-01, 5.6983e-03],
        [1.9857e-01, 8.0143e-01],
        [9.8697e-01, 1.3027e-02],
        [8.1682e-01, 1.8318e-01],
        [7.3191e-01, 2.6809e-01],
        [1.3151e-03, 9.9868e-01],
        [9.9556e-01, 4.4396e-03],
        [9.9255e-01, 7.4452e-03],
        [9.9977e-01, 2.2558e-04],
        [5.7370e-01, 4.2630e-01],
        [9.9899e-01, 1.0067e-03],
        [9.9625e-01, 3.7464e-03],
        [9.9785e-01, 2.1490e-03],
        [5.1614e-01, 4.8386e-01],
        [1.5223e-04, 9.9985e-01],
        [5.4276e-01, 4.5724e-01]])


Predicting:   7%|▋         | 16/215 [02:20<29:36,  8.93s/it]

tensor([[3.1203e-01, 6.8797e-01],
        [3.9040e-01, 6.0960e-01],
        [6.6120e-01, 3.3880e-01],
        [1.2655e-01, 8.7345e-01],
        [1.5724e-01, 8.4276e-01],
        [9.9959e-01, 4.0985e-04],
        [9.9972e-01, 2.8126e-04],
        [8.6586e-02, 9.1341e-01],
        [1.3377e-01, 8.6623e-01],
        [2.6054e-01, 7.3946e-01],
        [2.9094e-01, 7.0906e-01],
        [3.6748e-01, 6.3252e-01],
        [1.7147e-02, 9.8285e-01],
        [9.4147e-02, 9.0585e-01],
        [4.3615e-01, 5.6385e-01],
        [6.7107e-01, 3.2893e-01]])


Predicting:   8%|▊         | 17/215 [02:29<29:32,  8.95s/it]

tensor([[8.8900e-01, 1.1100e-01],
        [7.4621e-02, 9.2538e-01],
        [3.0110e-01, 6.9890e-01],
        [1.1941e-04, 9.9988e-01],
        [2.2215e-01, 7.7785e-01],
        [9.9827e-01, 1.7251e-03],
        [9.9943e-01, 5.6961e-04],
        [5.2456e-01, 4.7544e-01],
        [5.0369e-01, 4.9631e-01],
        [5.5773e-01, 4.4227e-01],
        [5.4929e-01, 4.5071e-01],
        [9.4310e-01, 5.6900e-02],
        [2.1819e-01, 7.8181e-01],
        [5.9113e-01, 4.0887e-01],
        [7.5476e-01, 2.4524e-01],
        [3.8888e-01, 6.1112e-01]])


Predicting:   8%|▊         | 18/215 [02:39<29:43,  9.05s/it]

tensor([[2.5160e-01, 7.4840e-01],
        [4.8565e-01, 5.1435e-01],
        [3.6476e-02, 9.6352e-01],
        [2.7358e-02, 9.7264e-01],
        [9.9946e-01, 5.3503e-04],
        [9.9993e-01, 6.7380e-05],
        [5.4790e-01, 4.5210e-01],
        [8.2508e-01, 1.7492e-01],
        [2.1207e-02, 9.7879e-01],
        [6.1464e-01, 3.8536e-01],
        [5.8133e-01, 4.1867e-01],
        [4.3811e-01, 5.6189e-01],
        [5.7346e-01, 4.2654e-01],
        [8.1270e-01, 1.8730e-01],
        [5.0997e-02, 9.4900e-01],
        [9.9540e-01, 4.6016e-03]])


Predicting:   9%|▉         | 19/215 [02:48<30:05,  9.21s/it]

tensor([[9.5813e-01, 4.1867e-02],
        [9.5385e-01, 4.6151e-02],
        [9.9807e-01, 1.9313e-03],
        [9.9959e-01, 4.0720e-04],
        [8.5990e-01, 1.4010e-01],
        [8.2518e-01, 1.7482e-01],
        [6.5994e-02, 9.3401e-01],
        [8.2434e-01, 1.7566e-01],
        [5.2902e-01, 4.7098e-01],
        [9.1469e-01, 8.5310e-02],
        [9.7909e-01, 2.0910e-02],
        [9.9046e-01, 9.5418e-03],
        [9.9748e-01, 2.5151e-03],
        [6.0460e-01, 3.9540e-01],
        [3.2583e-03, 9.9674e-01],
        [9.9576e-01, 4.2432e-03]])


Predicting:   9%|▉         | 20/215 [02:57<29:20,  9.03s/it]

tensor([[7.0886e-01, 2.9114e-01],
        [9.9581e-01, 4.1919e-03],
        [9.9883e-01, 1.1707e-03],
        [9.4904e-01, 5.0959e-02],
        [5.6577e-03, 9.9434e-01],
        [9.9452e-01, 5.4777e-03],
        [8.2669e-01, 1.7331e-01],
        [8.5157e-01, 1.4843e-01],
        [9.9444e-01, 5.5567e-03],
        [9.9747e-01, 2.5349e-03],
        [9.9779e-01, 2.2142e-03],
        [9.9965e-01, 3.5355e-04],
        [7.9258e-01, 2.0742e-01],
        [3.3301e-04, 9.9967e-01],
        [9.8971e-01, 1.0286e-02],
        [6.5161e-02, 9.3484e-01]])


Predicting:  10%|▉         | 21/215 [03:06<29:05,  9.00s/it]

tensor([[9.2211e-01, 7.7886e-02],
        [9.1174e-01, 8.8256e-02],
        [8.7824e-01, 1.2176e-01],
        [9.8008e-01, 1.9924e-02],
        [9.8634e-01, 1.3661e-02],
        [9.9951e-01, 4.8593e-04],
        [9.9983e-01, 1.7292e-04],
        [9.2792e-01, 7.2079e-02],
        [1.2393e-04, 9.9988e-01],
        [9.7074e-01, 2.9260e-02],
        [5.2175e-01, 4.7824e-01],
        [9.9929e-01, 7.1156e-04],
        [9.4693e-01, 5.3074e-02],
        [3.2476e-04, 9.9968e-01],
        [9.8900e-01, 1.1001e-02],
        [9.8614e-01, 1.3863e-02]])


Predicting:  10%|█         | 22/215 [03:14<28:44,  8.93s/it]

tensor([[4.6030e-03, 9.9540e-01],
        [9.6303e-01, 3.6969e-02],
        [4.5651e-01, 5.4349e-01],
        [9.4786e-01, 5.2144e-02],
        [1.7054e-01, 8.2946e-01],
        [5.9404e-01, 4.0596e-01],
        [5.7953e-01, 4.2047e-01],
        [1.7810e-01, 8.2190e-01],
        [9.0644e-01, 9.3560e-02],
        [9.3274e-01, 6.7262e-02],
        [4.8285e-01, 5.1715e-01],
        [8.4129e-02, 9.1587e-01],
        [9.9552e-01, 4.4822e-03],
        [8.0320e-01, 1.9680e-01],
        [9.9934e-01, 6.5609e-04],
        [8.9133e-02, 9.1087e-01]])


Predicting:  11%|█         | 23/215 [03:23<28:12,  8.82s/it]

tensor([[0.9894, 0.0106],
        [0.1000, 0.9000],
        [0.9782, 0.0218],
        [0.7967, 0.2033],
        [0.9861, 0.0139],
        [0.9955, 0.0045],
        [0.6836, 0.3164],
        [0.7826, 0.2174],
        [0.6942, 0.3058],
        [0.9566, 0.0434],
        [0.3607, 0.6393],
        [0.2375, 0.7625],
        [0.1757, 0.8243],
        [0.9284, 0.0716],
        [0.0305, 0.9695],
        [0.9106, 0.0894]])


Predicting:  11%|█         | 24/215 [03:32<27:49,  8.74s/it]

tensor([[0.9663, 0.0337],
        [0.9873, 0.0127],
        [0.8317, 0.1683],
        [0.9970, 0.0030],
        [0.8048, 0.1952],
        [0.6750, 0.3250],
        [0.9558, 0.0442],
        [0.0758, 0.9242],
        [0.9971, 0.0029],
        [0.3887, 0.6113],
        [0.5323, 0.4677],
        [0.9478, 0.0522],
        [0.9447, 0.0553],
        [0.7405, 0.2595],
        [0.5061, 0.4939],
        [0.6148, 0.3852]])


Predicting:  12%|█▏        | 25/215 [03:40<27:35,  8.71s/it]

tensor([[9.5330e-01, 4.6703e-02],
        [9.9441e-01, 5.5937e-03],
        [9.9939e-01, 6.0551e-04],
        [9.9765e-01, 2.3465e-03],
        [8.2555e-01, 1.7445e-01],
        [9.6834e-01, 3.1656e-02],
        [9.9954e-01, 4.5745e-04],
        [8.2626e-01, 1.7374e-01],
        [2.2623e-04, 9.9977e-01],
        [6.8214e-01, 3.1786e-01],
        [9.9988e-01, 1.1808e-04],
        [9.9995e-01, 5.2791e-05],
        [7.9992e-02, 9.2001e-01],
        [9.9740e-01, 2.6030e-03],
        [9.9617e-01, 3.8319e-03],
        [9.7620e-01, 2.3799e-02]])


Predicting:  12%|█▏        | 26/215 [03:49<27:45,  8.81s/it]

tensor([[8.0141e-02, 9.1986e-01],
        [8.8291e-01, 1.1709e-01],
        [5.2752e-01, 4.7248e-01],
        [1.3202e-02, 9.8680e-01],
        [9.9355e-01, 6.4472e-03],
        [9.9777e-01, 2.2303e-03],
        [9.9873e-01, 1.2712e-03],
        [9.9945e-01, 5.4770e-04],
        [9.8802e-01, 1.1981e-02],
        [8.6806e-01, 1.3194e-01],
        [9.4919e-01, 5.0808e-02],
        [5.7046e-01, 4.2954e-01],
        [9.3251e-02, 9.0675e-01],
        [9.2063e-01, 7.9368e-02],
        [6.4407e-01, 3.5593e-01],
        [9.9250e-01, 7.5037e-03]])


Predicting:  13%|█▎        | 27/215 [03:58<27:28,  8.77s/it]

tensor([[0.9985, 0.0015],
        [0.9981, 0.0019],
        [0.9987, 0.0013],
        [0.9930, 0.0070],
        [0.8545, 0.1455],
        [0.7339, 0.2661],
        [0.9952, 0.0048],
        [0.9947, 0.0053],
        [0.0128, 0.9872],
        [0.9982, 0.0018],
        [0.4022, 0.5978],
        [0.6606, 0.3394],
        [0.9114, 0.0886],
        [0.4839, 0.5161],
        [0.6196, 0.3804],
        [0.2811, 0.7189]])


Predicting:  13%|█▎        | 28/215 [04:07<27:22,  8.78s/it]

tensor([[9.9648e-01, 3.5200e-03],
        [9.9876e-01, 1.2411e-03],
        [9.3531e-01, 6.4686e-02],
        [7.6962e-01, 2.3038e-01],
        [9.9796e-01, 2.0427e-03],
        [1.0454e-04, 9.9990e-01],
        [9.9725e-01, 2.7492e-03],
        [9.9983e-01, 1.7242e-04],
        [9.9945e-01, 5.4534e-04],
        [9.9982e-01, 1.7935e-04],
        [9.9985e-01, 1.5088e-04],
        [4.8853e-01, 5.1147e-01],
        [1.7886e-04, 9.9982e-01],
        [9.9980e-01, 1.9775e-04],
        [9.9978e-01, 2.2455e-04],
        [4.3360e-01, 5.6640e-01]])


Predicting:  13%|█▎        | 29/215 [04:16<27:41,  8.93s/it]

tensor([[3.1630e-04, 9.9968e-01],
        [9.9977e-01, 2.2877e-04],
        [9.9983e-01, 1.6833e-04],
        [9.7895e-01, 2.1054e-02],
        [9.6000e-01, 4.0001e-02],
        [9.9691e-01, 3.0906e-03],
        [3.0671e-01, 6.9329e-01],
        [9.9863e-01, 1.3684e-03],
        [9.5557e-01, 4.4432e-02],
        [8.4190e-01, 1.5810e-01],
        [8.8003e-01, 1.1997e-01],
        [8.8577e-01, 1.1423e-01],
        [9.9692e-01, 3.0790e-03],
        [9.8983e-01, 1.0171e-02],
        [9.7210e-01, 2.7903e-02],
        [9.8045e-01, 1.9545e-02]])


Predicting:  14%|█▍        | 30/215 [04:24<27:06,  8.79s/it]

tensor([[9.8766e-01, 1.2336e-02],
        [3.6703e-01, 6.3297e-01],
        [8.0429e-01, 1.9571e-01],
        [8.9671e-01, 1.0329e-01],
        [9.1299e-01, 8.7015e-02],
        [1.9290e-01, 8.0710e-01],
        [1.7608e-02, 9.8239e-01],
        [5.2185e-01, 4.7815e-01],
        [9.9646e-01, 3.5371e-03],
        [8.9625e-01, 1.0375e-01],
        [9.9919e-01, 8.0621e-04],
        [9.1630e-01, 8.3702e-02],
        [9.9860e-01, 1.4027e-03],
        [9.9981e-01, 1.9102e-04],
        [9.9420e-01, 5.8045e-03],
        [9.8494e-01, 1.5063e-02]])


Predicting:  14%|█▍        | 31/215 [04:33<26:50,  8.75s/it]

tensor([[9.9985e-01, 1.4643e-04],
        [9.9972e-01, 2.7651e-04],
        [9.9147e-01, 8.5295e-03],
        [5.9115e-01, 4.0885e-01],
        [4.9320e-01, 5.0680e-01],
        [9.9497e-01, 5.0329e-03],
        [9.9201e-01, 7.9896e-03],
        [9.6746e-01, 3.2538e-02],
        [9.8168e-01, 1.8316e-02],
        [9.9007e-01, 9.9258e-03],
        [4.3205e-01, 5.6795e-01],
        [9.7229e-01, 2.7707e-02],
        [2.6537e-02, 9.7346e-01],
        [1.8687e-02, 9.8131e-01],
        [8.3377e-01, 1.6623e-01],
        [7.0741e-01, 2.9259e-01]])


Predicting:  15%|█▍        | 32/215 [04:42<26:30,  8.69s/it]

tensor([[6.7612e-01, 3.2388e-01],
        [9.8789e-01, 1.2105e-02],
        [7.3855e-01, 2.6145e-01],
        [9.6791e-01, 3.2087e-02],
        [9.9853e-01, 1.4699e-03],
        [9.4863e-01, 5.1369e-02],
        [9.7464e-05, 9.9990e-01],
        [4.0027e-01, 5.9973e-01],
        [7.6723e-01, 2.3277e-01],
        [1.4492e-01, 8.5508e-01],
        [9.9234e-01, 7.6556e-03],
        [9.8629e-01, 1.3715e-02],
        [9.5581e-01, 4.4188e-02],
        [7.3992e-01, 2.6008e-01],
        [9.7383e-01, 2.6167e-02],
        [3.1013e-02, 9.6899e-01]])


Predicting:  15%|█▌        | 33/215 [04:51<26:37,  8.78s/it]

tensor([[8.4461e-01, 1.5539e-01],
        [7.3574e-04, 9.9926e-01],
        [9.9767e-01, 2.3308e-03],
        [9.6694e-01, 3.3062e-02],
        [9.9559e-01, 4.4062e-03],
        [9.9731e-01, 2.6934e-03],
        [9.9788e-01, 2.1159e-03],
        [9.9055e-01, 9.4451e-03],
        [9.8193e-01, 1.8070e-02],
        [1.5543e-01, 8.4457e-01],
        [9.6897e-01, 3.1029e-02],
        [8.8511e-01, 1.1489e-01],
        [8.9178e-01, 1.0822e-01],
        [1.2273e-01, 8.7727e-01],
        [9.5300e-01, 4.7004e-02],
        [7.6736e-01, 2.3264e-01]])


Predicting:  16%|█▌        | 34/215 [05:00<26:43,  8.86s/it]

tensor([[1.2297e-04, 9.9988e-01],
        [9.9739e-01, 2.6123e-03],
        [7.1025e-02, 9.2897e-01],
        [9.6713e-01, 3.2873e-02],
        [9.0060e-01, 9.9399e-02],
        [7.8713e-01, 2.1287e-01],
        [6.4196e-01, 3.5804e-01],
        [3.4040e-01, 6.5960e-01],
        [7.7243e-01, 2.2757e-01],
        [5.4123e-01, 4.5877e-01],
        [3.4256e-02, 9.6574e-01],
        [6.1711e-01, 3.8289e-01],
        [9.5767e-01, 4.2331e-02],
        [7.2572e-01, 2.7428e-01],
        [5.2306e-01, 4.7694e-01],
        [9.6175e-01, 3.8253e-02]])


Predicting:  16%|█▋        | 35/215 [05:09<26:46,  8.93s/it]

tensor([[0.8319, 0.1681],
        [0.9636, 0.0364],
        [0.9887, 0.0113],
        [0.9946, 0.0054],
        [0.4909, 0.5091],
        [0.9959, 0.0041],
        [0.1384, 0.8616],
        [0.8775, 0.1225],
        [0.9833, 0.0167],
        [0.9524, 0.0476],
        [0.7748, 0.2252],
        [0.9682, 0.0318],
        [0.9746, 0.0254],
        [0.7350, 0.2650],
        [0.4240, 0.5760],
        [0.7438, 0.2562]])


Predicting:  17%|█▋        | 36/215 [05:18<26:41,  8.94s/it]

tensor([[6.3321e-01, 3.6679e-01],
        [1.5834e-01, 8.4166e-01],
        [5.4378e-01, 4.5622e-01],
        [7.3432e-03, 9.9266e-01],
        [8.0620e-01, 1.9380e-01],
        [8.0209e-01, 1.9791e-01],
        [9.4340e-01, 5.6597e-02],
        [7.2985e-01, 2.7015e-01],
        [5.9480e-01, 4.0520e-01],
        [4.2935e-01, 5.7065e-01],
        [2.2264e-04, 9.9978e-01],
        [4.7121e-01, 5.2879e-01],
        [9.3633e-01, 6.3667e-02],
        [8.6847e-01, 1.3153e-01],
        [3.3208e-01, 6.6792e-01],
        [3.9711e-01, 6.0289e-01]])


Predicting:  17%|█▋        | 37/215 [05:26<26:10,  8.82s/it]

tensor([[0.8932, 0.1068],
        [0.9977, 0.0023],
        [0.1646, 0.8354],
        [0.4537, 0.5463],
        [0.2113, 0.7887],
        [0.5927, 0.4073],
        [0.9901, 0.0099],
        [0.8588, 0.1412],
        [0.8392, 0.1608],
        [0.8806, 0.1194],
        [0.9775, 0.0225],
        [0.1634, 0.8366],
        [0.6999, 0.3001],
        [0.8354, 0.1646],
        [0.8601, 0.1399],
        [0.8097, 0.1903]])


Predicting:  18%|█▊        | 38/215 [05:35<26:05,  8.84s/it]

tensor([[7.7559e-01, 2.2441e-01],
        [8.9407e-01, 1.0593e-01],
        [8.6429e-01, 1.3571e-01],
        [7.6583e-03, 9.9234e-01],
        [2.8206e-01, 7.1794e-01],
        [5.4116e-01, 4.5884e-01],
        [2.0438e-01, 7.9562e-01],
        [6.8864e-01, 3.1136e-01],
        [9.8816e-01, 1.1841e-02],
        [2.7292e-01, 7.2708e-01],
        [1.8231e-01, 8.1769e-01],
        [9.9929e-01, 7.0688e-04],
        [7.6428e-01, 2.3572e-01],
        [1.6298e-03, 9.9837e-01],
        [7.8677e-01, 2.1323e-01],
        [2.6143e-04, 9.9974e-01]])


Predicting:  18%|█▊        | 39/215 [05:44<26:11,  8.93s/it]

tensor([[4.4552e-01, 5.5448e-01],
        [4.8606e-04, 9.9951e-01],
        [9.9767e-01, 2.3278e-03],
        [9.9953e-01, 4.7464e-04],
        [4.8913e-01, 5.1087e-01],
        [3.2577e-04, 9.9967e-01],
        [9.9734e-01, 2.6639e-03],
        [9.9925e-01, 7.4841e-04],
        [9.9938e-01, 6.2178e-04],
        [9.9992e-01, 8.0281e-05],
        [6.8954e-04, 9.9931e-01],
        [9.0972e-01, 9.0280e-02],
        [8.8626e-01, 1.1374e-01],
        [7.3283e-01, 2.6717e-01],
        [7.6333e-01, 2.3667e-01],
        [9.7125e-01, 2.8748e-02]])


Predicting:  19%|█▊        | 40/215 [05:53<25:34,  8.77s/it]

tensor([[9.0633e-01, 9.3666e-02],
        [9.9959e-01, 4.1162e-04],
        [9.9461e-01, 5.3854e-03],
        [8.7329e-02, 9.1267e-01],
        [9.9407e-01, 5.9331e-03],
        [1.6085e-01, 8.3915e-01],
        [9.3265e-01, 6.7352e-02],
        [9.7383e-01, 2.6165e-02],
        [8.9650e-01, 1.0350e-01],
        [8.6424e-01, 1.3576e-01],
        [9.4684e-01, 5.3159e-02],
        [9.8440e-01, 1.5599e-02],
        [9.6288e-01, 3.7120e-02],
        [8.6538e-03, 9.9135e-01],
        [9.9972e-01, 2.8156e-04],
        [9.9992e-01, 7.7968e-05]])


Predicting:  19%|█▉        | 41/215 [06:01<25:17,  8.72s/it]

tensor([[0.6615, 0.3385],
        [0.8804, 0.1196],
        [0.0020, 0.9980],
        [0.0011, 0.9989],
        [0.2169, 0.7831],
        [0.6955, 0.3045],
        [0.9702, 0.0298],
        [0.9202, 0.0798],
        [0.6573, 0.3427],
        [0.6385, 0.3615],
        [0.9763, 0.0237],
        [0.5812, 0.4188],
        [0.6021, 0.3979],
        [0.9642, 0.0358],
        [0.9536, 0.0464],
        [0.7655, 0.2345]])


Predicting:  20%|█▉        | 42/215 [06:10<25:00,  8.67s/it]

tensor([[0.8375, 0.1625],
        [0.8909, 0.1091],
        [0.5177, 0.4823],
        [0.9479, 0.0521],
        [0.7040, 0.2960],
        [0.9609, 0.0391],
        [0.9948, 0.0052],
        [0.9358, 0.0642],
        [0.4063, 0.5937],
        [0.9546, 0.0454],
        [0.9701, 0.0299],
        [0.8989, 0.1011],
        [0.8651, 0.1349],
        [0.8661, 0.1339],
        [0.8332, 0.1668],
        [0.9509, 0.0491]])


Predicting:  20%|██        | 43/215 [06:19<24:47,  8.65s/it]

tensor([[9.8272e-01, 1.7284e-02],
        [9.6440e-01, 3.5597e-02],
        [1.7630e-01, 8.2370e-01],
        [9.7862e-01, 2.1378e-02],
        [9.1585e-01, 8.4146e-02],
        [9.9965e-01, 3.4932e-04],
        [9.9886e-01, 1.1435e-03],
        [9.9991e-01, 9.0122e-05],
        [9.0787e-01, 9.2128e-02],
        [9.3276e-05, 9.9991e-01],
        [9.8629e-01, 1.3708e-02],
        [1.5172e-01, 8.4828e-01],
        [1.0403e-04, 9.9990e-01],
        [8.6364e-01, 1.3636e-01],
        [6.7803e-02, 9.3220e-01],
        [7.2579e-01, 2.7421e-01]])


Predicting:  20%|██        | 44/215 [06:27<24:49,  8.71s/it]

tensor([[0.9171, 0.0829],
        [0.8843, 0.1157],
        [0.4100, 0.5900],
        [0.1455, 0.8545],
        [0.9783, 0.0217],
        [0.9247, 0.0753],
        [0.7785, 0.2215],
        [0.9627, 0.0373],
        [0.9496, 0.0504],
        [0.7412, 0.2588],
        [0.9927, 0.0073],
        [0.1103, 0.8897],
        [0.0064, 0.9936],
        [0.2320, 0.7680],
        [0.1578, 0.8422],
        [0.2649, 0.7351]])


Predicting:  21%|██        | 45/215 [06:37<25:14,  8.91s/it]

tensor([[4.6027e-01, 5.3973e-01],
        [7.5110e-01, 2.4890e-01],
        [9.0589e-01, 9.4112e-02],
        [9.5896e-01, 4.1038e-02],
        [9.9347e-01, 6.5302e-03],
        [9.9885e-01, 1.1484e-03],
        [2.6808e-01, 7.3192e-01],
        [2.0292e-04, 9.9980e-01],
        [4.3468e-01, 5.6532e-01],
        [9.9983e-01, 1.7180e-04],
        [9.9973e-01, 2.6844e-04],
        [1.3286e-01, 8.6714e-01],
        [9.0809e-01, 9.1910e-02],
        [9.3483e-01, 6.5174e-02],
        [7.9878e-04, 9.9920e-01],
        [9.0376e-01, 9.6240e-02]])


Predicting:  21%|██▏       | 46/215 [06:46<25:06,  8.92s/it]

tensor([[5.0404e-01, 4.9596e-01],
        [1.7019e-01, 8.2981e-01],
        [4.0115e-01, 5.9885e-01],
        [6.7023e-01, 3.2977e-01],
        [7.6522e-01, 2.3478e-01],
        [6.9029e-01, 3.0971e-01],
        [2.2897e-01, 7.7103e-01],
        [2.6199e-02, 9.7380e-01],
        [8.4855e-01, 1.5145e-01],
        [9.9986e-01, 1.4205e-04],
        [9.9995e-01, 4.8354e-05],
        [9.9632e-01, 3.6839e-03],
        [1.8566e-03, 9.9814e-01],
        [9.8295e-01, 1.7054e-02],
        [1.3135e-01, 8.6865e-01],
        [9.5841e-01, 4.1592e-02]])


Predicting:  22%|██▏       | 47/215 [06:54<24:36,  8.79s/it]

tensor([[9.7384e-01, 2.6161e-02],
        [9.5720e-01, 4.2796e-02],
        [8.9706e-01, 1.0294e-01],
        [9.6122e-01, 3.8781e-02],
        [9.7157e-01, 2.8425e-02],
        [9.8648e-01, 1.3517e-02],
        [9.3741e-01, 6.2594e-02],
        [9.9786e-01, 2.1432e-03],
        [9.9345e-01, 6.5495e-03],
        [4.7364e-02, 9.5264e-01],
        [8.9568e-01, 1.0432e-01],
        [9.9683e-01, 3.1706e-03],
        [9.2768e-05, 9.9991e-01],
        [9.8690e-01, 1.3104e-02],
        [9.9477e-01, 5.2350e-03],
        [9.9923e-01, 7.6606e-04]])


Predicting:  22%|██▏       | 48/215 [07:03<24:49,  8.92s/it]

tensor([[9.7022e-01, 2.9777e-02],
        [9.9228e-01, 7.7228e-03],
        [9.9158e-01, 8.4228e-03],
        [9.9957e-01, 4.2623e-04],
        [9.9991e-01, 9.1439e-05],
        [1.0887e-02, 9.8911e-01],
        [9.9914e-01, 8.5849e-04],
        [6.5846e-01, 3.4154e-01],
        [8.8953e-01, 1.1047e-01],
        [2.7083e-01, 7.2917e-01],
        [9.9817e-01, 1.8297e-03],
        [8.8332e-01, 1.1668e-01],
        [7.6626e-01, 2.3374e-01],
        [9.8652e-01, 1.3483e-02],
        [9.3410e-01, 6.5900e-02],
        [7.3816e-02, 9.2618e-01]])


Predicting:  23%|██▎       | 49/215 [07:12<24:28,  8.85s/it]

tensor([[0.8647, 0.1353],
        [0.4288, 0.5712],
        [0.0353, 0.9647],
        [0.6004, 0.3996],
        [0.5379, 0.4621],
        [0.0329, 0.9671],
        [0.9838, 0.0162],
        [0.4374, 0.5626],
        [0.9260, 0.0740],
        [0.9251, 0.0749],
        [0.1232, 0.8768],
        [0.9681, 0.0319],
        [0.8546, 0.1454],
        [0.9685, 0.0315],
        [0.9976, 0.0024],
        [0.8599, 0.1401]])


Predicting:  23%|██▎       | 50/215 [07:21<24:24,  8.88s/it]

tensor([[9.8773e-01, 1.2268e-02],
        [9.9416e-01, 5.8378e-03],
        [7.9637e-01, 2.0363e-01],
        [9.0221e-01, 9.7792e-02],
        [8.3991e-01, 1.6009e-01],
        [9.4825e-01, 5.1751e-02],
        [5.8742e-03, 9.9413e-01],
        [9.9955e-01, 4.4761e-04],
        [9.9971e-01, 2.8640e-04],
        [6.7182e-01, 3.2818e-01],
        [4.4584e-04, 9.9955e-01],
        [9.9987e-01, 1.3428e-04],
        [9.8881e-01, 1.1191e-02],
        [4.2854e-01, 5.7146e-01],
        [9.9589e-01, 4.1081e-03],
        [9.9954e-01, 4.6135e-04]])


Predicting:  24%|██▎       | 51/215 [07:30<24:11,  8.85s/it]

tensor([[9.9979e-01, 2.0870e-04],
        [3.4374e-02, 9.6563e-01],
        [9.9950e-01, 4.9875e-04],
        [9.9982e-01, 1.8057e-04],
        [8.3703e-01, 1.6297e-01],
        [1.4817e-04, 9.9985e-01],
        [9.9550e-01, 4.4952e-03],
        [9.9980e-01, 2.0211e-04],
        [9.9940e-01, 6.0190e-04],
        [9.8752e-01, 1.2476e-02],
        [9.6355e-03, 9.9036e-01],
        [9.9974e-01, 2.6495e-04],
        [9.9688e-01, 3.1218e-03],
        [3.8582e-01, 6.1418e-01],
        [7.3335e-01, 2.6665e-01],
        [9.9913e-01, 8.6503e-04]])


Predicting:  24%|██▍       | 52/215 [07:39<24:20,  8.96s/it]

tensor([[9.9957e-01, 4.3318e-04],
        [9.1816e-01, 8.1843e-02],
        [1.7278e-04, 9.9983e-01],
        [9.8678e-01, 1.3224e-02],
        [9.9735e-01, 2.6498e-03],
        [9.6915e-01, 3.0847e-02],
        [1.4512e-03, 9.9855e-01],
        [4.2045e-01, 5.7955e-01],
        [1.5671e-02, 9.8433e-01],
        [9.9035e-01, 9.6481e-03],
        [4.4320e-01, 5.5680e-01],
        [1.1233e-03, 9.9888e-01],
        [5.5248e-01, 4.4752e-01],
        [9.9097e-01, 9.0287e-03],
        [2.7857e-01, 7.2143e-01],
        [9.7377e-01, 2.6231e-02]])


Predicting:  25%|██▍       | 53/215 [07:47<23:44,  8.79s/it]

tensor([[4.2546e-04, 9.9957e-01],
        [2.6657e-01, 7.3343e-01],
        [9.9298e-01, 7.0217e-03],
        [9.7144e-01, 2.8564e-02],
        [9.7467e-01, 2.5333e-02],
        [9.9519e-01, 4.8099e-03],
        [8.9880e-01, 1.0120e-01],
        [2.9698e-01, 7.0302e-01],
        [9.7956e-01, 2.0444e-02],
        [3.2387e-01, 6.7613e-01],
        [3.6955e-04, 9.9963e-01],
        [8.6824e-01, 1.3176e-01],
        [1.0661e-01, 8.9339e-01],
        [1.9835e-01, 8.0165e-01],
        [7.1020e-01, 2.8980e-01],
        [1.6110e-01, 8.3890e-01]])


Predicting:  25%|██▌       | 54/215 [07:56<23:28,  8.75s/it]

tensor([[9.9974e-01, 2.5855e-04],
        [9.9965e-01, 3.4523e-04],
        [1.8208e-01, 8.1792e-01],
        [5.3407e-01, 4.6593e-01],
        [9.2948e-01, 7.0523e-02],
        [3.5490e-01, 6.4510e-01],
        [4.6619e-01, 5.3381e-01],
        [5.4558e-01, 4.5442e-01],
        [7.5576e-01, 2.4424e-01],
        [8.5317e-01, 1.4683e-01],
        [5.1011e-03, 9.9490e-01],
        [4.3588e-01, 5.6412e-01],
        [9.9217e-01, 7.8259e-03],
        [9.9408e-01, 5.9223e-03],
        [9.0751e-01, 9.2485e-02],
        [3.3960e-02, 9.6604e-01]])


Predicting:  26%|██▌       | 55/215 [08:05<23:11,  8.70s/it]

tensor([[9.2935e-01, 7.0652e-02],
        [3.8743e-01, 6.1257e-01],
        [5.2624e-01, 4.7376e-01],
        [9.4288e-01, 5.7121e-02],
        [7.9477e-01, 2.0523e-01],
        [8.1289e-01, 1.8711e-01],
        [9.3557e-01, 6.4431e-02],
        [5.9301e-01, 4.0699e-01],
        [6.2133e-01, 3.7867e-01],
        [6.5755e-01, 3.4245e-01],
        [4.3580e-01, 5.6420e-01],
        [4.0737e-01, 5.9263e-01],
        [9.7940e-01, 2.0603e-02],
        [9.6900e-01, 3.1003e-02],
        [9.9897e-01, 1.0303e-03],
        [9.9932e-01, 6.7518e-04]])


Predicting:  26%|██▌       | 56/215 [08:13<23:00,  8.68s/it]

tensor([[9.9983e-01, 1.6686e-04],
        [8.6573e-01, 1.3427e-01],
        [9.5841e-05, 9.9990e-01],
        [4.7737e-01, 5.2263e-01],
        [2.2322e-02, 9.7768e-01],
        [1.0097e-02, 9.8990e-01],
        [9.9944e-01, 5.6261e-04],
        [9.9885e-01, 1.1484e-03],
        [9.9977e-01, 2.3070e-04],
        [7.8533e-01, 2.1467e-01],
        [4.5502e-01, 5.4498e-01],
        [9.9258e-01, 7.4207e-03],
        [9.0861e-01, 9.1387e-02],
        [8.8225e-01, 1.1775e-01],
        [9.9350e-01, 6.4990e-03],
        [9.9884e-01, 1.1643e-03]])


Predicting:  27%|██▋       | 57/215 [08:22<22:31,  8.55s/it]

tensor([[0.9990, 0.0010],
        [0.9942, 0.0058],
        [0.9937, 0.0063],
        [0.9355, 0.0645],
        [0.8617, 0.1383],
        [0.1126, 0.8874],
        [0.6363, 0.3637],
        [0.6331, 0.3669],
        [0.9269, 0.0731],
        [0.8947, 0.1053],
        [0.7191, 0.2809],
        [0.9219, 0.0781],
        [0.9337, 0.0663],
        [0.3853, 0.6147],
        [0.8989, 0.1011],
        [0.8482, 0.1518]])


Predicting:  27%|██▋       | 58/215 [08:30<22:19,  8.53s/it]

tensor([[4.2248e-01, 5.7752e-01],
        [8.8837e-01, 1.1163e-01],
        [9.9962e-01, 3.7863e-04],
        [9.9983e-01, 1.7244e-04],
        [1.7642e-02, 9.8236e-01],
        [3.7692e-01, 6.2308e-01],
        [9.9702e-01, 2.9811e-03],
        [9.2269e-01, 7.7314e-02],
        [9.0601e-01, 9.3993e-02],
        [9.6405e-01, 3.5951e-02],
        [9.9935e-01, 6.5375e-04],
        [9.9968e-01, 3.2336e-04],
        [7.3880e-01, 2.6120e-01],
        [9.9983e-01, 1.6541e-04],
        [9.9983e-01, 1.6686e-04],
        [6.7369e-01, 3.2631e-01]])


Predicting:  27%|██▋       | 59/215 [08:40<22:58,  8.84s/it]

tensor([[9.9566e-01, 4.3370e-03],
        [9.9835e-01, 1.6533e-03],
        [7.9268e-01, 2.0732e-01],
        [1.1319e-04, 9.9989e-01],
        [9.9356e-01, 6.4448e-03],
        [9.9459e-01, 5.4095e-03],
        [9.3008e-01, 6.9920e-02],
        [2.4470e-03, 9.9755e-01],
        [5.8381e-02, 9.4162e-01],
        [8.5950e-01, 1.4050e-01],
        [9.9452e-01, 5.4768e-03],
        [7.3832e-01, 2.6168e-01],
        [8.4473e-01, 1.5527e-01],
        [9.9863e-01, 1.3738e-03],
        [9.9260e-01, 7.3951e-03],
        [9.6568e-01, 3.4320e-02]])


Predicting:  28%|██▊       | 60/215 [08:49<23:02,  8.92s/it]

tensor([[0.9845, 0.0155],
        [0.9970, 0.0030],
        [0.9877, 0.0123],
        [0.9162, 0.0838],
        [0.9485, 0.0515],
        [0.8508, 0.1492],
        [0.8758, 0.1242],
        [0.8384, 0.1616],
        [0.9928, 0.0072],
        [0.9891, 0.0109],
        [0.9983, 0.0017],
        [0.4789, 0.5211],
        [0.7832, 0.2168],
        [0.6604, 0.3396],
        [0.9906, 0.0094],
        [0.1758, 0.8242]])


Predicting:  28%|██▊       | 61/215 [08:58<23:13,  9.05s/it]

tensor([[1.8380e-01, 8.1620e-01],
        [3.5874e-01, 6.4126e-01],
        [2.0129e-01, 7.9871e-01],
        [1.5624e-02, 9.8438e-01],
        [3.9638e-01, 6.0362e-01],
        [9.8269e-01, 1.7309e-02],
        [9.9839e-01, 1.6074e-03],
        [9.9603e-01, 3.9736e-03],
        [9.7825e-01, 2.1745e-02],
        [9.8151e-01, 1.8489e-02],
        [9.9022e-01, 9.7770e-03],
        [4.7493e-01, 5.2507e-01],
        [4.9467e-04, 9.9951e-01],
        [4.0858e-01, 5.9142e-01],
        [1.7345e-03, 9.9827e-01],
        [8.2987e-01, 1.7013e-01]])


Predicting:  29%|██▉       | 62/215 [09:06<22:37,  8.87s/it]

tensor([[0.6707, 0.3293],
        [0.4308, 0.5692],
        [0.9918, 0.0082],
        [0.5081, 0.4919],
        [0.9590, 0.0410],
        [0.0589, 0.9411],
        [0.8247, 0.1753],
        [0.8478, 0.1522],
        [0.1220, 0.8780],
        [0.8038, 0.1962],
        [0.5471, 0.4529],
        [0.9859, 0.0141],
        [0.1149, 0.8851],
        [0.9436, 0.0564],
        [0.0015, 0.9985],
        [0.7241, 0.2759]])


Predicting:  29%|██▉       | 63/215 [09:16<22:51,  9.02s/it]

tensor([[1.1513e-02, 9.8849e-01],
        [9.9965e-01, 3.4795e-04],
        [9.9862e-01, 1.3844e-03],
        [1.6591e-01, 8.3409e-01],
        [9.4297e-01, 5.7025e-02],
        [9.3178e-01, 6.8216e-02],
        [2.8546e-01, 7.1454e-01],
        [9.9979e-01, 2.0721e-04],
        [9.9991e-01, 9.4524e-05],
        [9.0948e-01, 9.0524e-02],
        [1.9486e-04, 9.9981e-01],
        [7.8650e-01, 2.1350e-01],
        [6.2065e-01, 3.7935e-01],
        [8.3690e-01, 1.6310e-01],
        [3.7405e-01, 6.2595e-01],
        [9.1684e-01, 8.3158e-02]])


Predicting:  30%|██▉       | 64/215 [09:25<23:00,  9.14s/it]

tensor([[8.9007e-01, 1.0993e-01],
        [8.6437e-01, 1.3563e-01],
        [2.3218e-04, 9.9977e-01],
        [1.7164e-02, 9.8284e-01],
        [9.7134e-01, 2.8657e-02],
        [9.7934e-01, 2.0656e-02],
        [1.0624e-02, 9.8938e-01],
        [8.4203e-02, 9.1580e-01],
        [1.1198e-04, 9.9989e-01],
        [7.8005e-01, 2.1995e-01],
        [2.3524e-02, 9.7648e-01],
        [7.4462e-01, 2.5538e-01],
        [9.3418e-05, 9.9991e-01],
        [9.2977e-01, 7.0235e-02],
        [2.1734e-01, 7.8266e-01],
        [8.6796e-01, 1.3204e-01]])


Predicting:  30%|███       | 65/215 [09:35<23:05,  9.23s/it]

tensor([[9.9493e-01, 5.0671e-03],
        [8.8113e-01, 1.1887e-01],
        [2.3294e-02, 9.7671e-01],
        [8.2872e-01, 1.7128e-01],
        [9.9797e-01, 2.0303e-03],
        [9.9925e-01, 7.4755e-04],
        [8.8667e-01, 1.1333e-01],
        [3.3158e-03, 9.9668e-01],
        [9.9246e-01, 7.5422e-03],
        [9.9811e-01, 1.8860e-03],
        [9.9920e-01, 7.9559e-04],
        [8.1623e-01, 1.8377e-01],
        [2.7904e-04, 9.9972e-01],
        [9.0387e-01, 9.6132e-02],
        [9.6680e-01, 3.3203e-02],
        [3.4776e-03, 9.9652e-01]])


Predicting:  31%|███       | 66/215 [09:44<22:47,  9.18s/it]

tensor([[9.9585e-01, 4.1477e-03],
        [1.2764e-01, 8.7236e-01],
        [3.2600e-01, 6.7400e-01],
        [9.9900e-01, 9.9662e-04],
        [9.9651e-01, 3.4922e-03],
        [9.2131e-01, 7.8693e-02],
        [1.4891e-04, 9.9985e-01],
        [9.9312e-01, 6.8805e-03],
        [9.0331e-01, 9.6691e-02],
        [1.0463e-04, 9.9990e-01],
        [6.7406e-01, 3.2594e-01],
        [9.9745e-01, 2.5528e-03],
        [9.8649e-01, 1.3514e-02],
        [1.0861e-01, 8.9139e-01],
        [8.5507e-01, 1.4493e-01],
        [9.5451e-01, 4.5486e-02]])


Predicting:  31%|███       | 67/215 [09:54<23:05,  9.36s/it]

tensor([[2.2915e-02, 9.7709e-01],
        [9.6967e-05, 9.9990e-01],
        [8.2398e-01, 1.7602e-01],
        [2.6150e-03, 9.9739e-01],
        [2.7347e-01, 7.2653e-01],
        [6.3191e-02, 9.3681e-01],
        [9.9768e-01, 2.3209e-03],
        [2.7800e-01, 7.2200e-01],
        [9.4020e-01, 5.9798e-02],
        [9.4954e-01, 5.0463e-02],
        [9.5419e-05, 9.9990e-01],
        [8.4364e-01, 1.5636e-01],
        [8.8254e-02, 9.1175e-01],
        [2.4786e-03, 9.9752e-01],
        [3.3892e-01, 6.6108e-01],
        [9.5474e-01, 4.5265e-02]])


Predicting:  32%|███▏      | 68/215 [10:02<22:33,  9.21s/it]

tensor([[1.4142e-02, 9.8586e-01],
        [1.2510e-02, 9.8749e-01],
        [9.9713e-01, 2.8718e-03],
        [9.9818e-01, 1.8203e-03],
        [1.0567e-01, 8.9433e-01],
        [2.7498e-01, 7.2502e-01],
        [6.6193e-01, 3.3807e-01],
        [9.8402e-05, 9.9990e-01],
        [9.9971e-01, 2.8769e-04],
        [9.9983e-01, 1.6855e-04],
        [9.7977e-01, 2.0226e-02],
        [6.6716e-01, 3.3284e-01],
        [6.1459e-01, 3.8541e-01],
        [9.6170e-01, 3.8305e-02],
        [4.8553e-01, 5.1447e-01],
        [9.8126e-01, 1.8745e-02]])


Predicting:  32%|███▏      | 69/215 [10:11<21:51,  8.98s/it]

tensor([[9.1933e-01, 8.0666e-02],
        [6.7164e-01, 3.2836e-01],
        [9.8982e-01, 1.0182e-02],
        [9.5456e-01, 4.5443e-02],
        [9.9973e-01, 2.6620e-04],
        [8.3852e-01, 1.6148e-01],
        [7.3076e-01, 2.6924e-01],
        [9.9465e-01, 5.3510e-03],
        [9.0578e-01, 9.4221e-02],
        [9.9815e-01, 1.8491e-03],
        [9.9801e-01, 1.9903e-03],
        [6.0675e-03, 9.9393e-01],
        [1.0164e-04, 9.9990e-01],
        [9.9729e-01, 2.7105e-03],
        [9.9417e-01, 5.8278e-03],
        [9.5502e-01, 4.4979e-02]])


Predicting:  33%|███▎      | 70/215 [10:19<21:06,  8.74s/it]

tensor([[1.6578e-03, 9.9834e-01],
        [1.9326e-03, 9.9807e-01],
        [2.7844e-01, 7.2156e-01],
        [9.8672e-01, 1.3275e-02],
        [8.7494e-01, 1.2506e-01],
        [1.5859e-02, 9.8414e-01],
        [1.2389e-01, 8.7611e-01],
        [9.4790e-01, 5.2104e-02],
        [8.7171e-01, 1.2829e-01],
        [5.1790e-02, 9.4821e-01],
        [7.8341e-04, 9.9922e-01],
        [4.8516e-02, 9.5148e-01],
        [9.9918e-01, 8.2104e-04],
        [9.9977e-01, 2.3362e-04],
        [9.7462e-01, 2.5381e-02],
        [1.1700e-04, 9.9988e-01]])


Predicting:  33%|███▎      | 71/215 [10:28<20:49,  8.67s/it]

tensor([[9.8999e-01, 1.0013e-02],
        [9.9393e-01, 6.0678e-03],
        [9.9893e-01, 1.0744e-03],
        [6.6400e-04, 9.9934e-01],
        [3.4211e-04, 9.9966e-01],
        [2.9438e-01, 7.0562e-01],
        [9.9909e-01, 9.0948e-04],
        [9.9709e-01, 2.9062e-03],
        [9.9459e-01, 5.4102e-03],
        [9.7451e-01, 2.5489e-02],
        [9.8407e-01, 1.5929e-02],
        [9.8651e-01, 1.3485e-02],
        [6.7461e-01, 3.2539e-01],
        [9.4296e-01, 5.7042e-02],
        [9.6871e-01, 3.1295e-02],
        [5.4671e-02, 9.4533e-01]])


Predicting:  33%|███▎      | 72/215 [10:36<20:29,  8.60s/it]

tensor([[9.5543e-01, 4.4573e-02],
        [4.3090e-01, 5.6910e-01],
        [3.3166e-01, 6.6834e-01],
        [9.5974e-01, 4.0257e-02],
        [9.4525e-02, 9.0547e-01],
        [1.2420e-01, 8.7580e-01],
        [9.1566e-01, 8.4340e-02],
        [1.9095e-01, 8.0905e-01],
        [9.9565e-01, 4.3478e-03],
        [8.1291e-01, 1.8709e-01],
        [1.4360e-04, 9.9986e-01],
        [9.9372e-01, 6.2782e-03],
        [8.6772e-01, 1.3228e-01],
        [9.1025e-01, 8.9747e-02],
        [4.4554e-01, 5.5446e-01],
        [4.4761e-01, 5.5239e-01]])


Predicting:  34%|███▍      | 73/215 [10:44<20:13,  8.55s/it]

tensor([[8.7262e-01, 1.2738e-01],
        [2.9546e-01, 7.0454e-01],
        [9.9965e-01, 3.5083e-04],
        [5.5678e-01, 4.4322e-01],
        [7.7298e-01, 2.2702e-01],
        [8.7859e-01, 1.2141e-01],
        [9.9307e-01, 6.9271e-03],
        [6.8389e-01, 3.1611e-01],
        [9.9599e-01, 4.0126e-03],
        [1.2462e-04, 9.9988e-01],
        [9.4578e-01, 5.4218e-02],
        [5.0492e-02, 9.4951e-01],
        [4.0880e-01, 5.9120e-01],
        [9.8654e-01, 1.3462e-02],
        [2.0610e-04, 9.9979e-01],
        [9.6866e-01, 3.1343e-02]])


Predicting:  34%|███▍      | 74/215 [10:53<19:57,  8.50s/it]

tensor([[3.7361e-01, 6.2639e-01],
        [4.4390e-01, 5.5610e-01],
        [9.9394e-01, 6.0586e-03],
        [7.9321e-01, 2.0679e-01],
        [8.6701e-01, 1.3299e-01],
        [5.6648e-01, 4.3352e-01],
        [9.7349e-01, 2.6508e-02],
        [9.9415e-01, 5.8520e-03],
        [1.5074e-01, 8.4926e-01],
        [1.4926e-01, 8.5074e-01],
        [9.9377e-01, 6.2336e-03],
        [9.9968e-01, 3.1801e-04],
        [4.9764e-01, 5.0236e-01],
        [9.6555e-05, 9.9990e-01],
        [9.1301e-01, 8.6990e-02],
        [9.9913e-01, 8.6566e-04]])


Predicting:  35%|███▍      | 75/215 [11:02<20:03,  8.60s/it]

tensor([[9.9929e-01, 7.0807e-04],
        [1.0049e-01, 8.9951e-01],
        [1.2076e-04, 9.9988e-01],
        [9.9473e-01, 5.2685e-03],
        [8.8758e-01, 1.1242e-01],
        [9.8668e-01, 1.3321e-02],
        [1.5952e-01, 8.4048e-01],
        [4.2748e-01, 5.7252e-01],
        [2.2906e-01, 7.7094e-01],
        [9.9881e-01, 1.1861e-03],
        [4.0399e-01, 5.9601e-01],
        [6.0881e-01, 3.9119e-01],
        [1.0411e-01, 8.9589e-01],
        [2.6635e-01, 7.3365e-01],
        [8.8027e-01, 1.1973e-01],
        [1.2126e-02, 9.8787e-01]])


Predicting:  35%|███▌      | 76/215 [11:10<19:35,  8.46s/it]

tensor([[1.1553e-02, 9.8845e-01],
        [1.9280e-03, 9.9807e-01],
        [9.5272e-01, 4.7281e-02],
        [9.9153e-01, 8.4652e-03],
        [6.5476e-04, 9.9935e-01],
        [3.8188e-01, 6.1812e-01],
        [1.3576e-01, 8.6424e-01],
        [9.4719e-01, 5.2813e-02],
        [6.2973e-04, 9.9937e-01],
        [8.7271e-01, 1.2729e-01],
        [3.0044e-01, 6.9956e-01],
        [5.1178e-02, 9.4882e-01],
        [9.6917e-01, 3.0835e-02],
        [1.7951e-01, 8.2049e-01],
        [3.8407e-01, 6.1593e-01],
        [1.1734e-01, 8.8266e-01]])


Predicting:  36%|███▌      | 77/215 [11:18<19:16,  8.38s/it]

tensor([[9.9701e-01, 2.9945e-03],
        [6.2148e-01, 3.7852e-01],
        [8.8531e-01, 1.1469e-01],
        [9.9434e-01, 5.6579e-03],
        [8.2494e-01, 1.7506e-01],
        [9.9703e-01, 2.9702e-03],
        [6.2553e-01, 3.7447e-01],
        [9.9907e-01, 9.3309e-04],
        [1.9460e-02, 9.8054e-01],
        [9.9739e-01, 2.6052e-03],
        [9.9983e-01, 1.7496e-04],
        [6.2517e-01, 3.7483e-01],
        [3.5084e-04, 9.9965e-01],
        [5.3976e-01, 4.6024e-01],
        [9.9903e-01, 9.7195e-04],
        [5.4115e-01, 4.5885e-01]])


Predicting:  36%|███▋      | 78/215 [11:26<18:53,  8.28s/it]

tensor([[6.9445e-01, 3.0555e-01],
        [4.4266e-01, 5.5734e-01],
        [6.7893e-01, 3.2107e-01],
        [4.0596e-01, 5.9404e-01],
        [3.2445e-01, 6.7555e-01],
        [7.1782e-01, 2.8218e-01],
        [8.6029e-01, 1.3971e-01],
        [9.9864e-01, 1.3585e-03],
        [8.5281e-01, 1.4719e-01],
        [9.1897e-01, 8.1030e-02],
        [9.9745e-01, 2.5511e-03],
        [9.9710e-01, 2.8954e-03],
        [9.1816e-01, 8.1839e-02],
        [8.1720e-01, 1.8280e-01],
        [9.9922e-01, 7.8124e-04],
        [9.9871e-01, 1.2883e-03]])


Predicting:  37%|███▋      | 79/215 [11:34<18:54,  8.34s/it]

tensor([[9.4371e-01, 5.6288e-02],
        [8.1230e-01, 1.8770e-01],
        [9.9885e-01, 1.1517e-03],
        [9.5612e-01, 4.3883e-02],
        [7.1740e-01, 2.8260e-01],
        [7.6758e-01, 2.3242e-01],
        [9.9974e-01, 2.6045e-04],
        [3.8153e-01, 6.1847e-01],
        [9.2287e-01, 7.7128e-02],
        [2.1467e-01, 7.8533e-01],
        [9.9372e-01, 6.2789e-03],
        [6.2527e-01, 3.7473e-01],
        [7.0371e-01, 2.9629e-01],
        [9.8103e-01, 1.8970e-02],
        [6.2128e-01, 3.7872e-01],
        [9.9621e-01, 3.7876e-03]])


Predicting:  37%|███▋      | 80/215 [11:43<18:33,  8.25s/it]

tensor([[8.4597e-01, 1.5403e-01],
        [9.0121e-01, 9.8788e-02],
        [6.9372e-01, 3.0628e-01],
        [9.9341e-01, 6.5854e-03],
        [8.6139e-01, 1.3861e-01],
        [7.9267e-01, 2.0733e-01],
        [9.8898e-01, 1.1024e-02],
        [2.5098e-01, 7.4902e-01],
        [5.3256e-01, 4.6744e-01],
        [8.8763e-01, 1.1237e-01],
        [5.8956e-01, 4.1044e-01],
        [4.6952e-01, 5.3048e-01],
        [9.9148e-01, 8.5249e-03],
        [7.9241e-01, 2.0759e-01],
        [8.7779e-01, 1.2221e-01],
        [9.9959e-01, 4.0670e-04]])


Predicting:  38%|███▊      | 81/215 [11:51<18:18,  8.20s/it]

tensor([[0.5949, 0.4051],
        [0.9967, 0.0033],
        [0.7388, 0.2612],
        [0.8840, 0.1160],
        [0.9989, 0.0011],
        [0.9971, 0.0029],
        [0.0839, 0.9161],
        [0.9092, 0.0908],
        [0.0653, 0.9347],
        [0.9972, 0.0028],
        [0.3062, 0.6938],
        [0.9228, 0.0772],
        [0.2682, 0.7318],
        [0.9971, 0.0029],
        [0.8462, 0.1538],
        [0.7951, 0.2049]])


Predicting:  38%|███▊      | 82/215 [11:59<18:32,  8.36s/it]

tensor([[9.8436e-01, 1.5640e-02],
        [6.1565e-01, 3.8435e-01],
        [9.6293e-01, 3.7070e-02],
        [4.5782e-02, 9.5422e-01],
        [8.7041e-01, 1.2959e-01],
        [7.2744e-01, 2.7256e-01],
        [9.9922e-01, 7.7586e-04],
        [7.4962e-01, 2.5038e-01],
        [8.2508e-01, 1.7492e-01],
        [5.7679e-01, 4.2321e-01],
        [9.4005e-01, 5.9950e-02],
        [6.3829e-02, 9.3617e-01],
        [3.8846e-01, 6.1154e-01],
        [2.5105e-01, 7.4895e-01],
        [2.6000e-02, 9.7400e-01],
        [9.8416e-01, 1.5838e-02]])


Predicting:  39%|███▊      | 83/215 [12:08<18:35,  8.45s/it]

tensor([[9.8494e-01, 1.5057e-02],
        [1.8093e-02, 9.8191e-01],
        [1.9159e-01, 8.0841e-01],
        [9.2030e-01, 7.9702e-02],
        [1.5465e-01, 8.4535e-01],
        [3.2962e-01, 6.7038e-01],
        [3.2379e-02, 9.6762e-01],
        [9.9800e-01, 2.0047e-03],
        [9.9991e-01, 8.5880e-05],
        [9.3006e-01, 6.9940e-02],
        [9.0240e-04, 9.9910e-01],
        [7.1250e-01, 2.8750e-01],
        [9.4296e-01, 5.7038e-02],
        [4.2581e-01, 5.7419e-01],
        [1.9208e-01, 8.0792e-01],
        [6.8800e-02, 9.3120e-01]])


Predicting:  39%|███▉      | 84/215 [12:17<18:53,  8.65s/it]

tensor([[9.5872e-01, 4.1284e-02],
        [9.8929e-01, 1.0713e-02],
        [8.1522e-03, 9.9185e-01],
        [1.6433e-01, 8.3567e-01],
        [9.6079e-01, 3.9211e-02],
        [9.4952e-01, 5.0482e-02],
        [7.1608e-01, 2.8392e-01],
        [8.6417e-01, 1.3583e-01],
        [4.3346e-01, 5.6654e-01],
        [9.9856e-01, 1.4371e-03],
        [9.9989e-01, 1.1308e-04],
        [7.6850e-01, 2.3150e-01],
        [6.2345e-04, 9.9938e-01],
        [7.7174e-01, 2.2826e-01],
        [9.9209e-01, 7.9115e-03],
        [9.9477e-01, 5.2282e-03]])


Predicting:  40%|███▉      | 85/215 [12:26<18:45,  8.66s/it]

tensor([[1.2356e-04, 9.9988e-01],
        [1.0649e-02, 9.8935e-01],
        [3.6765e-01, 6.3235e-01],
        [5.8677e-01, 4.1323e-01],
        [8.1115e-01, 1.8885e-01],
        [9.6986e-01, 3.0142e-02],
        [1.9981e-01, 8.0019e-01],
        [8.7067e-01, 1.2933e-01],
        [8.6944e-01, 1.3056e-01],
        [6.3937e-01, 3.6063e-01],
        [6.3524e-01, 3.6476e-01],
        [6.9014e-03, 9.9310e-01],
        [9.8452e-01, 1.5485e-02],
        [9.9918e-01, 8.1864e-04],
        [9.1108e-01, 8.8916e-02],
        [1.1620e-01, 8.8380e-01]])


Predicting:  40%|████      | 86/215 [12:34<18:37,  8.66s/it]

tensor([[5.7090e-01, 4.2910e-01],
        [5.8304e-01, 4.1696e-01],
        [9.9434e-01, 5.6586e-03],
        [7.7784e-01, 2.2216e-01],
        [8.8572e-01, 1.1428e-01],
        [2.9455e-01, 7.0545e-01],
        [9.5348e-01, 4.6522e-02],
        [2.4862e-02, 9.7514e-01],
        [1.4903e-02, 9.8510e-01],
        [9.5209e-01, 4.7913e-02],
        [9.3018e-01, 6.9819e-02],
        [7.0917e-01, 2.9083e-01],
        [1.9658e-01, 8.0342e-01],
        [9.9728e-01, 2.7178e-03],
        [9.9919e-01, 8.1433e-04],
        [4.8349e-01, 5.1651e-01]])


Predicting:  40%|████      | 87/215 [12:43<18:20,  8.60s/it]

tensor([[1.3059e-04, 9.9987e-01],
        [4.9849e-01, 5.0151e-01],
        [9.7678e-01, 2.3217e-02],
        [3.0134e-03, 9.9699e-01],
        [9.2345e-01, 7.6550e-02],
        [4.4038e-02, 9.5596e-01],
        [1.0330e-01, 8.9670e-01],
        [9.9925e-01, 7.4765e-04],
        [7.2523e-01, 2.7477e-01],
        [9.3497e-01, 6.5029e-02],
        [7.0654e-01, 2.9346e-01],
        [9.9893e-01, 1.0655e-03],
        [4.7323e-01, 5.2677e-01],
        [9.4671e-01, 5.3287e-02],
        [4.6978e-01, 5.3022e-01],
        [9.8917e-01, 1.0830e-02]])


Predicting:  41%|████      | 88/215 [12:51<18:08,  8.57s/it]

tensor([[4.5923e-01, 5.4077e-01],
        [8.6349e-01, 1.3651e-01],
        [9.3475e-01, 6.5250e-02],
        [1.2744e-01, 8.7256e-01],
        [9.9703e-01, 2.9731e-03],
        [9.9862e-01, 1.3833e-03],
        [7.9918e-01, 2.0082e-01],
        [4.3130e-04, 9.9957e-01],
        [9.4714e-01, 5.2856e-02],
        [9.9608e-01, 3.9170e-03],
        [9.9988e-01, 1.2329e-04],
        [8.1125e-01, 1.8875e-01],
        [1.3801e-04, 9.9986e-01],
        [9.3123e-01, 6.8768e-02],
        [9.9385e-01, 6.1450e-03],
        [3.4926e-01, 6.5074e-01]])


Predicting:  41%|████▏     | 89/215 [13:00<17:52,  8.51s/it]

tensor([[0.5554, 0.4446],
        [0.0780, 0.9220],
        [0.9950, 0.0050],
        [0.5350, 0.4650],
        [0.0300, 0.9700],
        [0.8433, 0.1567],
        [0.9787, 0.0213],
        [0.4472, 0.5528],
        [0.6473, 0.3527],
        [0.9664, 0.0336],
        [0.7636, 0.2364],
        [0.9982, 0.0018],
        [0.7869, 0.2131],
        [0.9929, 0.0071],
        [0.9029, 0.0971],
        [0.9934, 0.0066]])


Predicting:  42%|████▏     | 90/215 [13:08<17:26,  8.37s/it]

tensor([[0.8732, 0.1268],
        [0.8210, 0.1790],
        [0.9895, 0.0105],
        [0.9000, 0.1000],
        [0.9786, 0.0214],
        [0.8978, 0.1022],
        [0.8944, 0.1056],
        [0.9948, 0.0052],
        [0.9044, 0.0956],
        [0.9946, 0.0054],
        [0.6601, 0.3399],
        [0.8655, 0.1345],
        [0.9952, 0.0048],
        [0.8326, 0.1674],
        [0.9118, 0.0882],
        [0.3556, 0.6444]])


Predicting:  42%|████▏     | 91/215 [13:16<17:08,  8.30s/it]

tensor([[0.5516, 0.4484],
        [0.9912, 0.0088],
        [0.3980, 0.6020],
        [0.6461, 0.3539],
        [0.5982, 0.4018],
        [0.7527, 0.2473],
        [0.3438, 0.6562],
        [0.7786, 0.2214],
        [0.5381, 0.4619],
        [0.9525, 0.0475],
        [0.0124, 0.9876],
        [0.9699, 0.0301],
        [0.7064, 0.2936],
        [0.4124, 0.5876],
        [0.2941, 0.7059],
        [0.9879, 0.0121]])


Predicting:  43%|████▎     | 92/215 [13:24<17:03,  8.32s/it]

tensor([[9.9973e-01, 2.7499e-04],
        [9.8406e-05, 9.9990e-01],
        [1.3843e-04, 9.9986e-01],
        [3.3382e-01, 6.6618e-01],
        [8.3009e-01, 1.6991e-01],
        [5.3658e-01, 4.6342e-01],
        [9.4991e-02, 9.0501e-01],
        [8.6457e-04, 9.9914e-01],
        [9.7038e-01, 2.9620e-02],
        [9.9891e-01, 1.0891e-03],
        [9.5742e-01, 4.2578e-02],
        [9.6941e-01, 3.0588e-02],
        [3.4443e-01, 6.5557e-01],
        [9.9206e-01, 7.9384e-03],
        [4.3489e-01, 5.6511e-01],
        [4.8398e-01, 5.1602e-01]])


Predicting:  43%|████▎     | 93/215 [13:33<16:54,  8.32s/it]

tensor([[7.7644e-01, 2.2356e-01],
        [4.8014e-01, 5.1986e-01],
        [3.4169e-01, 6.5831e-01],
        [7.3813e-01, 2.6187e-01],
        [9.1586e-01, 8.4143e-02],
        [7.4247e-01, 2.5753e-01],
        [2.0242e-02, 9.7976e-01],
        [3.4616e-03, 9.9654e-01],
        [5.5514e-01, 4.4486e-01],
        [2.3769e-01, 7.6231e-01],
        [2.6307e-01, 7.3693e-01],
        [2.3483e-01, 7.6517e-01],
        [2.1702e-01, 7.8298e-01],
        [9.8422e-01, 1.5776e-02],
        [1.4365e-03, 9.9856e-01],
        [9.8570e-05, 9.9990e-01]])


Predicting:  44%|████▎     | 94/215 [13:42<17:23,  8.63s/it]

tensor([[7.9868e-01, 2.0132e-01],
        [6.7873e-01, 3.2127e-01],
        [9.9586e-01, 4.1379e-03],
        [9.9858e-01, 1.4220e-03],
        [9.4423e-01, 5.5770e-02],
        [9.9105e-01, 8.9494e-03],
        [9.9982e-01, 1.7657e-04],
        [9.8285e-01, 1.7146e-02],
        [3.2743e-04, 9.9967e-01],
        [7.4026e-01, 2.5974e-01],
        [7.3461e-01, 2.6539e-01],
        [2.5831e-01, 7.4169e-01],
        [7.5381e-01, 2.4619e-01],
        [1.2435e-01, 8.7565e-01],
        [7.5240e-01, 2.4760e-01],
        [8.6563e-01, 1.3437e-01]])


Predicting:  44%|████▍     | 95/215 [13:50<17:04,  8.54s/it]

tensor([[0.3050, 0.6950],
        [0.9022, 0.0978],
        [0.7791, 0.2209],
        [0.3385, 0.6615],
        [0.6541, 0.3459],
        [0.8379, 0.1621],
        [0.9658, 0.0342],
        [0.7543, 0.2457],
        [0.3134, 0.6866],
        [0.1336, 0.8664],
        [0.3069, 0.6931],
        [0.0257, 0.9743],
        [0.3287, 0.6713],
        [0.7789, 0.2211],
        [0.1109, 0.8891],
        [0.9369, 0.0631]])


Predicting:  45%|████▍     | 96/215 [13:58<16:39,  8.40s/it]

tensor([[0.8858, 0.1142],
        [0.7989, 0.2011],
        [0.6081, 0.3919],
        [0.3008, 0.6992],
        [0.7881, 0.2119],
        [0.4088, 0.5912],
        [0.8518, 0.1482],
        [0.9952, 0.0048],
        [0.9632, 0.0368],
        [0.9751, 0.0249],
        [0.9125, 0.0875],
        [0.6950, 0.3050],
        [0.9288, 0.0712],
        [0.9835, 0.0165],
        [0.8825, 0.1175],
        [0.2111, 0.7889]])


Predicting:  45%|████▌     | 97/215 [14:07<16:34,  8.43s/it]

tensor([[3.6986e-03, 9.9630e-01],
        [8.8452e-01, 1.1548e-01],
        [9.8851e-01, 1.1490e-02],
        [9.8620e-01, 1.3798e-02],
        [9.9033e-01, 9.6732e-03],
        [9.7807e-01, 2.1931e-02],
        [9.9203e-01, 7.9700e-03],
        [9.7358e-01, 2.6419e-02],
        [6.2439e-01, 3.7561e-01],
        [9.3194e-01, 6.8056e-02],
        [7.1120e-01, 2.8880e-01],
        [9.9941e-01, 5.9417e-04],
        [1.1234e-01, 8.8766e-01],
        [8.7828e-01, 1.2172e-01],
        [9.9970e-01, 3.0100e-04],
        [9.9976e-01, 2.4187e-04]])


Predicting:  46%|████▌     | 98/215 [14:15<16:31,  8.48s/it]

tensor([[7.4812e-01, 2.5188e-01],
        [9.8349e-01, 1.6508e-02],
        [9.4581e-01, 5.4192e-02],
        [6.2438e-01, 3.7562e-01],
        [7.7174e-01, 2.2826e-01],
        [9.2410e-01, 7.5897e-02],
        [8.4194e-01, 1.5806e-01],
        [9.9186e-01, 8.1445e-03],
        [4.9957e-01, 5.0043e-01],
        [2.0643e-01, 7.9357e-01],
        [6.9798e-01, 3.0202e-01],
        [9.3866e-01, 6.1341e-02],
        [9.9192e-01, 8.0831e-03],
        [1.1241e-04, 9.9989e-01],
        [9.9834e-01, 1.6582e-03],
        [9.9169e-01, 8.3089e-03]])


Predicting:  46%|████▌     | 99/215 [14:24<16:16,  8.42s/it]

tensor([[7.7026e-01, 2.2974e-01],
        [9.9913e-01, 8.7348e-04],
        [9.9961e-01, 3.8504e-04],
        [7.7709e-04, 9.9922e-01],
        [2.2143e-01, 7.7857e-01],
        [9.6663e-01, 3.3375e-02],
        [6.6426e-01, 3.3574e-01],
        [2.5383e-01, 7.4617e-01],
        [7.3472e-01, 2.6528e-01],
        [9.6857e-01, 3.1429e-02],
        [9.7003e-01, 2.9965e-02],
        [9.0364e-01, 9.6359e-02],
        [9.4026e-01, 5.9743e-02],
        [1.0091e-01, 8.9909e-01],
        [9.9559e-01, 4.4135e-03],
        [4.0982e-01, 5.9018e-01]])


Predicting:  47%|████▋     | 100/215 [14:32<16:02,  8.37s/it]

tensor([[9.9092e-02, 9.0091e-01],
        [7.2892e-01, 2.7108e-01],
        [3.7469e-01, 6.2531e-01],
        [1.3770e-01, 8.6230e-01],
        [9.4190e-01, 5.8100e-02],
        [1.4252e-01, 8.5748e-01],
        [8.9827e-01, 1.0173e-01],
        [8.9779e-01, 1.0221e-01],
        [5.1620e-01, 4.8380e-01],
        [8.6056e-01, 1.3944e-01],
        [9.9541e-01, 4.5947e-03],
        [9.9836e-01, 1.6400e-03],
        [8.3883e-01, 1.6117e-01],
        [1.2920e-04, 9.9987e-01],
        [9.7472e-01, 2.5276e-02],
        [9.4811e-01, 5.1890e-02]])


Predicting:  47%|████▋     | 101/215 [14:40<15:52,  8.35s/it]

tensor([[5.3660e-01, 4.6340e-01],
        [9.9121e-01, 8.7886e-03],
        [9.8550e-01, 1.4500e-02],
        [2.2700e-01, 7.7300e-01],
        [5.4672e-01, 4.5328e-01],
        [9.9745e-01, 2.5497e-03],
        [9.8882e-01, 1.1176e-02],
        [5.2078e-01, 4.7922e-01],
        [9.9034e-01, 9.6573e-03],
        [9.9822e-01, 1.7764e-03],
        [9.8847e-01, 1.1533e-02],
        [1.0912e-04, 9.9989e-01],
        [9.9832e-01, 1.6805e-03],
        [9.9251e-01, 7.4901e-03],
        [8.7174e-01, 1.2826e-01],
        [9.9977e-01, 2.3296e-04]])


Predicting:  47%|████▋     | 102/215 [14:49<15:43,  8.35s/it]

tensor([[9.9980e-01, 1.9897e-04],
        [9.6653e-01, 3.3467e-02],
        [2.2343e-01, 7.7657e-01],
        [7.0553e-01, 2.9447e-01],
        [4.8411e-01, 5.1589e-01],
        [1.1442e-01, 8.8558e-01],
        [2.3205e-02, 9.7680e-01],
        [3.4223e-01, 6.5777e-01],
        [9.8142e-01, 1.8578e-02],
        [3.8754e-04, 9.9961e-01],
        [9.9947e-01, 5.2745e-04],
        [8.9091e-01, 1.0909e-01],
        [1.5919e-01, 8.4081e-01],
        [9.9846e-01, 1.5388e-03],
        [9.9976e-01, 2.4337e-04],
        [6.5573e-02, 9.3443e-01]])


Predicting:  48%|████▊     | 103/215 [14:57<15:36,  8.36s/it]

tensor([[8.8043e-01, 1.1957e-01],
        [9.9875e-01, 1.2528e-03],
        [9.7160e-01, 2.8399e-02],
        [6.4059e-01, 3.5941e-01],
        [7.4480e-01, 2.5520e-01],
        [9.8968e-01, 1.0316e-02],
        [9.6612e-01, 3.3884e-02],
        [7.3714e-01, 2.6286e-01],
        [6.2376e-01, 3.7624e-01],
        [9.8838e-01, 1.1618e-02],
        [8.0288e-01, 1.9712e-01],
        [9.6876e-01, 3.1241e-02],
        [2.2575e-01, 7.7425e-01],
        [9.9798e-01, 2.0238e-03],
        [1.0133e-04, 9.9990e-01],
        [9.9932e-01, 6.7748e-04]])


Predicting:  48%|████▊     | 104/215 [15:05<15:28,  8.36s/it]

tensor([[9.9759e-01, 2.4147e-03],
        [7.8041e-01, 2.1959e-01],
        [9.9984e-01, 1.6502e-04],
        [9.9990e-01, 1.0014e-04],
        [6.5445e-04, 9.9935e-01],
        [8.1297e-01, 1.8703e-01],
        [9.9780e-01, 2.2020e-03],
        [9.6280e-01, 3.7195e-02],
        [7.3919e-01, 2.6081e-01],
        [9.8779e-01, 1.2214e-02],
        [9.9311e-01, 6.8882e-03],
        [8.9113e-01, 1.0887e-01],
        [1.4307e-03, 9.9857e-01],
        [9.9211e-01, 7.8871e-03],
        [9.5428e-01, 4.5716e-02],
        [5.5234e-01, 4.4766e-01]])


Predicting:  49%|████▉     | 105/215 [15:14<15:18,  8.35s/it]

tensor([[9.9664e-01, 3.3568e-03],
        [9.9937e-01, 6.3447e-04],
        [4.7371e-01, 5.2629e-01],
        [9.9549e-01, 4.5076e-03],
        [8.4187e-01, 1.5813e-01],
        [3.4455e-01, 6.5545e-01],
        [9.8259e-01, 1.7412e-02],
        [9.9621e-01, 3.7862e-03],
        [6.8335e-02, 9.3167e-01],
        [8.0533e-01, 1.9467e-01],
        [5.1007e-01, 4.8993e-01],
        [9.3887e-02, 9.0611e-01],
        [1.2764e-01, 8.7236e-01],
        [7.9353e-01, 2.0647e-01],
        [9.0886e-01, 9.1141e-02],
        [5.0471e-02, 9.4953e-01]])


Predicting:  49%|████▉     | 106/215 [15:23<15:23,  8.47s/it]

tensor([[0.5038, 0.4962],
        [0.8241, 0.1759],
        [0.6557, 0.3443],
        [0.9335, 0.0665],
        [0.8766, 0.1234],
        [0.3019, 0.6981],
        [0.8311, 0.1689],
        [0.5615, 0.4385],
        [0.8711, 0.1289],
        [0.8701, 0.1299],
        [0.8490, 0.1510],
        [0.9659, 0.0341],
        [0.9689, 0.0311],
        [0.8987, 0.1013],
        [0.9910, 0.0090],
        [0.7153, 0.2847]])


Predicting:  50%|████▉     | 107/215 [15:31<15:21,  8.54s/it]

tensor([[9.8267e-01, 1.7332e-02],
        [7.4294e-01, 2.5706e-01],
        [9.7639e-01, 2.3608e-02],
        [5.5942e-01, 4.4058e-01],
        [8.8484e-01, 1.1516e-01],
        [5.2065e-01, 4.7935e-01],
        [9.8749e-01, 1.2512e-02],
        [9.5044e-01, 4.9558e-02],
        [9.9929e-01, 7.1069e-04],
        [9.2816e-01, 7.1844e-02],
        [9.9845e-01, 1.5507e-03],
        [4.3420e-04, 9.9957e-01],
        [9.9855e-01, 1.4542e-03],
        [5.4028e-01, 4.5972e-01],
        [9.2768e-01, 7.2320e-02],
        [9.9238e-01, 7.6221e-03]])


Predicting:  50%|█████     | 108/215 [15:40<15:38,  8.77s/it]

tensor([[9.9953e-01, 4.7155e-04],
        [2.2983e-02, 9.7702e-01],
        [9.8011e-01, 1.9895e-02],
        [8.9561e-01, 1.0439e-01],
        [6.7898e-01, 3.2102e-01],
        [9.8519e-01, 1.4809e-02],
        [9.9041e-01, 9.5920e-03],
        [9.9513e-01, 4.8746e-03],
        [8.9363e-01, 1.0637e-01],
        [9.9440e-01, 5.5998e-03],
        [6.5706e-01, 3.4294e-01],
        [9.4575e-01, 5.4249e-02],
        [5.8238e-01, 4.1762e-01],
        [6.2344e-01, 3.7656e-01],
        [6.0137e-01, 3.9863e-01],
        [9.4044e-01, 5.9558e-02]])


Predicting:  51%|█████     | 109/215 [15:50<15:38,  8.85s/it]

tensor([[9.9818e-01, 1.8218e-03],
        [9.8518e-01, 1.4819e-02],
        [7.7580e-01, 2.2420e-01],
        [8.2353e-01, 1.7647e-01],
        [9.9358e-01, 6.4248e-03],
        [9.2158e-01, 7.8417e-02],
        [9.9090e-01, 9.1047e-03],
        [9.9709e-01, 2.9125e-03],
        [9.9519e-01, 4.8063e-03],
        [3.0312e-01, 6.9688e-01],
        [8.9738e-01, 1.0262e-01],
        [9.7380e-01, 2.6204e-02],
        [8.5671e-01, 1.4329e-01],
        [9.9937e-01, 6.2777e-04],
        [6.6742e-02, 9.3326e-01],
        [9.9565e-01, 4.3509e-03]])


Predicting:  51%|█████     | 110/215 [15:58<15:31,  8.87s/it]

tensor([[0.9693, 0.0307],
        [0.9605, 0.0395],
        [0.9952, 0.0048],
        [0.9156, 0.0844],
        [0.9811, 0.0189],
        [0.9535, 0.0465],
        [0.9311, 0.0689],
        [0.5746, 0.4254],
        [0.9700, 0.0300],
        [0.0500, 0.9500],
        [0.9772, 0.0228],
        [0.8617, 0.1383],
        [0.0034, 0.9966],
        [0.4920, 0.5080],
        [0.7949, 0.2051],
        [0.5348, 0.4652]])


Predicting:  52%|█████▏    | 111/215 [16:07<15:23,  8.88s/it]

tensor([[0.1967, 0.8033],
        [0.6658, 0.3342],
        [0.6943, 0.3057],
        [0.0053, 0.9947],
        [0.4421, 0.5579],
        [0.6232, 0.3768],
        [0.9359, 0.0641],
        [0.8476, 0.1524],
        [0.9534, 0.0466],
        [0.8482, 0.1518],
        [0.8986, 0.1014],
        [0.7627, 0.2373],
        [0.4435, 0.5565],
        [0.6326, 0.3674],
        [0.9533, 0.0467],
        [0.1959, 0.8041]])


Predicting:  52%|█████▏    | 112/215 [16:16<15:19,  8.93s/it]

tensor([[9.4832e-01, 5.1675e-02],
        [9.1959e-01, 8.0406e-02],
        [8.9497e-01, 1.0503e-01],
        [9.8169e-01, 1.8314e-02],
        [7.6615e-01, 2.3385e-01],
        [9.7939e-01, 2.0613e-02],
        [9.7392e-01, 2.6084e-02],
        [4.5045e-01, 5.4955e-01],
        [9.7271e-01, 2.7289e-02],
        [9.9873e-01, 1.2696e-03],
        [9.9819e-01, 1.8071e-03],
        [9.9983e-01, 1.7280e-04],
        [9.9870e-01, 1.2964e-03],
        [9.9991e-01, 9.3159e-05],
        [6.2225e-01, 3.7775e-01],
        [9.9961e-01, 3.8966e-04]])


Predicting:  53%|█████▎    | 113/215 [16:25<15:07,  8.90s/it]

tensor([[0.9968, 0.0032],
        [0.9198, 0.0802],
        [0.9746, 0.0254],
        [0.6352, 0.3648],
        [0.9887, 0.0113],
        [0.9910, 0.0090],
        [0.9758, 0.0242],
        [0.5772, 0.4228],
        [0.9452, 0.0548],
        [0.0019, 0.9981],
        [0.9547, 0.0453],
        [0.7355, 0.2645],
        [0.8633, 0.1367],
        [0.9857, 0.0143],
        [0.2198, 0.7802],
        [0.1197, 0.8803]])


Predicting:  53%|█████▎    | 114/215 [16:34<14:49,  8.81s/it]

tensor([[0.8530, 0.1470],
        [0.2615, 0.7385],
        [0.7803, 0.2197],
        [0.5849, 0.4151],
        [0.6001, 0.3999],
        [0.7914, 0.2086],
        [0.7580, 0.2420],
        [0.8650, 0.1350],
        [0.6767, 0.3233],
        [0.0638, 0.9362],
        [0.9373, 0.0627],
        [0.0048, 0.9952],
        [0.4409, 0.5591],
        [0.9859, 0.0141],
        [0.8905, 0.1095],
        [0.4870, 0.5130]])


Predicting:  53%|█████▎    | 115/215 [16:43<14:38,  8.79s/it]

tensor([[9.6119e-01, 3.8813e-02],
        [9.6802e-01, 3.1980e-02],
        [1.5293e-03, 9.9847e-01],
        [6.2706e-01, 3.7294e-01],
        [9.9571e-01, 4.2941e-03],
        [9.5551e-01, 4.4488e-02],
        [7.8481e-01, 2.1519e-01],
        [9.7971e-01, 2.0291e-02],
        [9.9227e-01, 7.7344e-03],
        [2.7532e-01, 7.2468e-01],
        [9.0598e-01, 9.4022e-02],
        [8.9747e-01, 1.0253e-01],
        [9.1396e-01, 8.6043e-02],
        [8.8679e-01, 1.1321e-01],
        [9.8660e-01, 1.3402e-02],
        [2.8040e-04, 9.9972e-01]])


Predicting:  54%|█████▍    | 116/215 [16:51<14:21,  8.71s/it]

tensor([[4.4005e-01, 5.5995e-01],
        [2.0863e-01, 7.9137e-01],
        [1.5463e-01, 8.4537e-01],
        [9.9851e-01, 1.4916e-03],
        [9.9954e-01, 4.5835e-04],
        [3.5897e-01, 6.4103e-01],
        [6.6324e-01, 3.3676e-01],
        [6.9691e-01, 3.0309e-01],
        [5.1007e-01, 4.8993e-01],
        [9.9358e-02, 9.0064e-01],
        [9.7739e-01, 2.2614e-02],
        [9.7861e-01, 2.1389e-02],
        [9.4958e-01, 5.0420e-02],
        [9.7506e-01, 2.4938e-02],
        [1.5280e-01, 8.4720e-01],
        [9.8877e-01, 1.1230e-02]])


Predicting:  54%|█████▍    | 117/215 [17:00<14:06,  8.64s/it]

tensor([[0.0246, 0.9754],
        [0.9861, 0.0139],
        [0.5461, 0.4539],
        [0.9915, 0.0085],
        [0.2333, 0.7667],
        [0.9747, 0.0254],
        [0.9322, 0.0678],
        [0.9866, 0.0134],
        [0.8997, 0.1003],
        [0.8459, 0.1541],
        [0.6557, 0.3443],
        [0.5085, 0.4915],
        [0.2990, 0.7010],
        [0.0299, 0.9701],
        [0.5528, 0.4472],
        [0.9567, 0.0433]])


Predicting:  55%|█████▍    | 118/215 [17:08<13:47,  8.53s/it]

tensor([[4.1889e-04, 9.9958e-01],
        [9.3028e-02, 9.0697e-01],
        [9.5240e-01, 4.7599e-02],
        [5.4075e-01, 4.5925e-01],
        [2.2754e-01, 7.7246e-01],
        [7.6711e-01, 2.3289e-01],
        [9.6751e-01, 3.2493e-02],
        [2.0019e-03, 9.9800e-01],
        [3.2941e-01, 6.7059e-01],
        [9.6224e-01, 3.7759e-02],
        [8.2878e-01, 1.7122e-01],
        [1.0526e-01, 8.9474e-01],
        [3.6388e-01, 6.3612e-01],
        [9.7614e-01, 2.3862e-02],
        [4.0014e-01, 5.9987e-01],
        [7.8604e-01, 2.1396e-01]])


Predicting:  55%|█████▌    | 119/215 [17:17<13:43,  8.58s/it]

tensor([[0.6593, 0.3407],
        [0.8157, 0.1843],
        [0.2837, 0.7163],
        [0.0348, 0.9652],
        [0.4991, 0.5009],
        [0.0692, 0.9308],
        [0.8021, 0.1979],
        [0.8880, 0.1120],
        [0.9107, 0.0893],
        [0.4070, 0.5930],
        [0.3396, 0.6604],
        [0.8787, 0.1213],
        [0.3473, 0.6527],
        [0.7158, 0.2842],
        [0.8200, 0.1800],
        [0.7736, 0.2264]])


Predicting:  56%|█████▌    | 120/215 [17:25<13:35,  8.59s/it]

tensor([[0.9746, 0.0254],
        [0.9683, 0.0317],
        [0.8499, 0.1501],
        [0.9615, 0.0385],
        [0.8857, 0.1143],
        [0.8632, 0.1368],
        [0.2188, 0.7812],
        [0.9564, 0.0436],
        [0.1292, 0.8708],
        [0.5231, 0.4769],
        [0.7574, 0.2426],
        [0.5085, 0.4915],
        [0.5467, 0.4533],
        [0.5425, 0.4575],
        [0.6665, 0.3335],
        [0.8982, 0.1018]])


Predicting:  56%|█████▋    | 121/215 [17:34<13:36,  8.68s/it]

tensor([[9.0509e-01, 9.4911e-02],
        [9.8339e-01, 1.6614e-02],
        [9.1778e-01, 8.2224e-02],
        [7.4330e-01, 2.5670e-01],
        [3.1462e-01, 6.8538e-01],
        [9.9831e-01, 1.6904e-03],
        [3.3150e-03, 9.9669e-01],
        [8.9355e-01, 1.0645e-01],
        [8.8130e-01, 1.1870e-01],
        [8.0939e-01, 1.9061e-01],
        [9.8203e-01, 1.7967e-02],
        [9.8782e-01, 1.2183e-02],
        [9.9934e-01, 6.6390e-04],
        [8.8554e-01, 1.1446e-01],
        [9.9854e-01, 1.4577e-03],
        [3.6072e-01, 6.3928e-01]])


Predicting:  57%|█████▋    | 122/215 [17:43<13:26,  8.67s/it]

tensor([[1.1192e-01, 8.8808e-01],
        [7.6059e-03, 9.9239e-01],
        [6.4925e-01, 3.5075e-01],
        [5.5735e-01, 4.4265e-01],
        [9.9958e-01, 4.2198e-04],
        [9.9992e-01, 8.2312e-05],
        [3.3852e-01, 6.6148e-01],
        [9.9547e-01, 4.5289e-03],
        [6.8630e-04, 9.9931e-01],
        [1.8902e-01, 8.1098e-01],
        [9.9077e-01, 9.2288e-03],
        [9.9683e-01, 3.1682e-03],
        [5.0077e-01, 4.9923e-01],
        [1.4121e-04, 9.9986e-01],
        [3.3374e-01, 6.6626e-01],
        [9.6879e-01, 3.1207e-02]])


Predicting:  57%|█████▋    | 123/215 [17:52<13:26,  8.76s/it]

tensor([[0.5644, 0.4356],
        [0.2145, 0.7855],
        [0.9892, 0.0108],
        [0.6186, 0.3814],
        [0.6392, 0.3608],
        [0.9244, 0.0756],
        [0.7338, 0.2662],
        [0.9866, 0.0134],
        [0.9222, 0.0778],
        [0.9802, 0.0198],
        [0.9152, 0.0848],
        [0.2234, 0.7766],
        [0.9261, 0.0739],
        [0.9609, 0.0391],
        [0.8623, 0.1377],
        [0.5562, 0.4438]])


Predicting:  58%|█████▊    | 124/215 [18:01<13:31,  8.91s/it]

tensor([[1.3804e-04, 9.9986e-01],
        [1.2405e-02, 9.8759e-01],
        [9.8579e-01, 1.4209e-02],
        [8.6015e-01, 1.3985e-01],
        [9.2284e-01, 7.7165e-02],
        [8.4523e-01, 1.5477e-01],
        [9.7746e-01, 2.2537e-02],
        [8.0940e-01, 1.9060e-01],
        [4.9044e-01, 5.0956e-01],
        [6.3191e-01, 3.6809e-01],
        [8.1760e-01, 1.8240e-01],
        [9.9319e-01, 6.8101e-03],
        [8.4319e-01, 1.5681e-01],
        [8.9389e-01, 1.0611e-01],
        [9.7344e-01, 2.6560e-02],
        [9.9942e-01, 5.8104e-04]])


Predicting:  58%|█████▊    | 125/215 [18:10<13:20,  8.89s/it]

tensor([[9.9765e-01, 2.3536e-03],
        [9.9933e-01, 6.6574e-04],
        [9.9919e-01, 8.1406e-04],
        [4.4469e-01, 5.5531e-01],
        [1.0065e-04, 9.9990e-01],
        [9.6495e-01, 3.5049e-02],
        [9.6697e-01, 3.3031e-02],
        [4.8692e-01, 5.1308e-01],
        [8.7454e-01, 1.2546e-01],
        [9.9725e-01, 2.7487e-03],
        [8.8853e-01, 1.1147e-01],
        [1.7331e-01, 8.2669e-01],
        [3.1955e-01, 6.8045e-01],
        [5.7412e-02, 9.4259e-01],
        [6.9425e-03, 9.9306e-01],
        [9.9764e-01, 2.3555e-03]])


Predicting:  59%|█████▊    | 126/215 [18:18<13:06,  8.84s/it]

tensor([[9.9529e-01, 4.7126e-03],
        [1.2380e-04, 9.9988e-01],
        [3.4825e-01, 6.5175e-01],
        [9.9195e-01, 8.0487e-03],
        [9.9975e-01, 2.4952e-04],
        [6.3600e-01, 3.6400e-01],
        [1.8923e-04, 9.9981e-01],
        [6.2632e-01, 3.7368e-01],
        [9.2618e-01, 7.3815e-02],
        [2.4377e-01, 7.5623e-01],
        [8.7596e-01, 1.2404e-01],
        [8.4578e-01, 1.5422e-01],
        [7.9880e-03, 9.9201e-01],
        [8.0514e-01, 1.9486e-01],
        [5.6762e-01, 4.3238e-01],
        [4.3691e-01, 5.6309e-01]])


Predicting:  59%|█████▉    | 127/215 [18:27<12:58,  8.84s/it]

tensor([[9.4116e-01, 5.8842e-02],
        [1.1923e-04, 9.9988e-01],
        [7.8924e-01, 2.1076e-01],
        [9.9630e-01, 3.7021e-03],
        [5.2362e-01, 4.7638e-01],
        [1.1070e-04, 9.9989e-01],
        [9.9399e-01, 6.0086e-03],
        [9.9945e-01, 5.4815e-04],
        [6.3976e-01, 3.6024e-01],
        [1.5599e-01, 8.4401e-01],
        [5.7375e-01, 4.2625e-01],
        [1.3199e-04, 9.9987e-01],
        [9.9854e-01, 1.4589e-03],
        [9.8886e-01, 1.1142e-02],
        [7.8404e-01, 2.1596e-01],
        [9.9918e-01, 8.1617e-04]])


Predicting:  60%|█████▉    | 128/215 [18:36<12:48,  8.84s/it]

tensor([[9.9967e-01, 3.3473e-04],
        [9.6405e-01, 3.5948e-02],
        [9.3638e-03, 9.9064e-01],
        [9.6744e-01, 3.2564e-02],
        [9.9960e-01, 4.0312e-04],
        [9.4843e-01, 5.1574e-02],
        [5.5034e-01, 4.4966e-01],
        [7.1781e-01, 2.8219e-01],
        [8.0021e-02, 9.1998e-01],
        [9.7840e-01, 2.1599e-02],
        [4.2581e-01, 5.7419e-01],
        [7.8223e-01, 2.1777e-01],
        [9.9317e-01, 6.8256e-03],
        [2.7972e-01, 7.2028e-01],
        [9.9327e-01, 6.7293e-03],
        [9.8536e-01, 1.4641e-02]])


Predicting:  60%|██████    | 129/215 [18:45<12:46,  8.91s/it]

tensor([[0.9965, 0.0035],
        [0.9979, 0.0021],
        [0.9954, 0.0046],
        [0.9944, 0.0056],
        [0.8791, 0.1209],
        [0.9911, 0.0089],
        [0.5501, 0.4499],
        [0.3265, 0.6735],
        [0.8811, 0.1189],
        [0.9420, 0.0580],
        [0.9979, 0.0021],
        [0.6996, 0.3004],
        [0.9988, 0.0012],
        [0.0708, 0.9292],
        [0.9976, 0.0024],
        [0.9387, 0.0613]])


Predicting:  60%|██████    | 130/215 [18:54<12:27,  8.80s/it]

tensor([[3.1361e-01, 6.8639e-01],
        [6.8092e-01, 3.1908e-01],
        [9.6408e-04, 9.9904e-01],
        [8.0598e-01, 1.9402e-01],
        [1.8924e-03, 9.9811e-01],
        [9.9942e-01, 5.8293e-04],
        [9.9990e-01, 9.6979e-05],
        [9.1885e-01, 8.1153e-02],
        [8.0135e-01, 1.9865e-01],
        [9.9682e-01, 3.1769e-03],
        [9.9777e-01, 2.2330e-03],
        [9.9855e-01, 1.4542e-03],
        [9.9882e-01, 1.1797e-03],
        [9.9082e-01, 9.1811e-03],
        [9.9438e-01, 5.6207e-03],
        [9.9650e-01, 3.5042e-03]])


Predicting:  61%|██████    | 131/215 [19:02<12:10,  8.69s/it]

tensor([[1.4394e-01, 8.5606e-01],
        [8.7408e-01, 1.2592e-01],
        [2.2136e-04, 9.9978e-01],
        [9.9992e-01, 7.6490e-05],
        [9.9995e-01, 4.9393e-05],
        [4.4896e-02, 9.5510e-01],
        [9.9249e-01, 7.5076e-03],
        [9.9084e-01, 9.1624e-03],
        [6.0492e-04, 9.9940e-01],
        [8.1134e-01, 1.8866e-01],
        [2.9490e-04, 9.9971e-01],
        [9.9877e-01, 1.2317e-03],
        [9.7631e-01, 2.3687e-02],
        [9.9965e-01, 3.5002e-04],
        [9.9904e-01, 9.6385e-04],
        [9.9186e-01, 8.1401e-03]])


Predicting:  61%|██████▏   | 132/215 [19:11<12:09,  8.79s/it]

tensor([[9.9321e-01, 6.7900e-03],
        [2.2740e-01, 7.7260e-01],
        [3.8161e-01, 6.1839e-01],
        [6.7208e-01, 3.2792e-01],
        [2.7612e-01, 7.2388e-01],
        [2.2074e-01, 7.7926e-01],
        [9.9775e-01, 2.2537e-03],
        [4.4619e-02, 9.5538e-01],
        [8.1764e-01, 1.8236e-01],
        [7.9735e-04, 9.9920e-01],
        [9.8443e-01, 1.5566e-02],
        [1.2053e-02, 9.8795e-01],
        [8.2858e-01, 1.7142e-01],
        [8.1788e-01, 1.8212e-01],
        [9.7311e-01, 2.6885e-02],
        [3.4922e-02, 9.6508e-01]])


Predicting:  62%|██████▏   | 133/215 [19:20<12:07,  8.87s/it]

tensor([[9.9834e-01, 1.6592e-03],
        [7.3257e-01, 2.6743e-01],
        [8.7962e-01, 1.2038e-01],
        [9.9130e-01, 8.6998e-03],
        [7.5296e-01, 2.4704e-01],
        [9.8437e-01, 1.5626e-02],
        [9.1530e-01, 8.4703e-02],
        [9.3972e-01, 6.0275e-02],
        [9.7818e-01, 2.1821e-02],
        [1.1907e-04, 9.9988e-01],
        [9.2315e-01, 7.6848e-02],
        [9.9825e-01, 1.7502e-03],
        [9.9990e-01, 1.0328e-04],
        [9.9993e-01, 6.5461e-05],
        [7.9218e-01, 2.0782e-01],
        [9.9922e-01, 7.7765e-04]])


Predicting:  62%|██████▏   | 134/215 [19:30<12:11,  9.03s/it]

tensor([[9.9975e-01, 2.4875e-04],
        [9.7620e-02, 9.0238e-01],
        [2.1609e-04, 9.9978e-01],
        [2.0315e-01, 7.9685e-01],
        [9.9327e-01, 6.7266e-03],
        [9.7233e-01, 2.7670e-02],
        [3.4305e-01, 6.5695e-01],
        [1.5560e-04, 9.9984e-01],
        [1.1781e-01, 8.8219e-01],
        [2.6187e-03, 9.9738e-01],
        [9.7814e-01, 2.1856e-02],
        [9.9689e-01, 3.1145e-03],
        [8.0531e-01, 1.9469e-01],
        [9.9774e-01, 2.2640e-03],
        [9.9126e-01, 8.7446e-03],
        [9.9819e-01, 1.8103e-03]])


Predicting:  63%|██████▎   | 135/215 [19:39<12:06,  9.08s/it]

tensor([[9.8503e-01, 1.4974e-02],
        [9.6844e-01, 3.1563e-02],
        [9.9676e-01, 3.2442e-03],
        [9.9729e-01, 2.7078e-03],
        [5.1963e-01, 4.8037e-01],
        [9.9510e-01, 4.8985e-03],
        [8.2890e-01, 1.7110e-01],
        [9.9969e-01, 3.1099e-04],
        [1.7655e-02, 9.8234e-01],
        [9.8248e-01, 1.7525e-02],
        [9.8772e-01, 1.2275e-02],
        [9.9791e-01, 2.0888e-03],
        [9.8099e-01, 1.9012e-02],
        [9.5827e-01, 4.1730e-02],
        [8.1603e-01, 1.8397e-01],
        [9.2349e-01, 7.6510e-02]])


Predicting:  63%|██████▎   | 136/215 [19:48<11:49,  8.98s/it]

tensor([[4.7531e-01, 5.2469e-01],
        [8.3582e-01, 1.6418e-01],
        [2.5605e-03, 9.9744e-01],
        [9.7943e-01, 2.0566e-02],
        [9.6463e-01, 3.5374e-02],
        [6.6761e-01, 3.3240e-01],
        [1.0843e-01, 8.9157e-01],
        [4.8330e-01, 5.1670e-01],
        [8.6301e-01, 1.3699e-01],
        [9.9381e-05, 9.9990e-01],
        [8.9804e-01, 1.0196e-01],
        [1.0883e-04, 9.9989e-01],
        [3.0900e-01, 6.9100e-01],
        [8.8056e-01, 1.1944e-01],
        [6.5233e-02, 9.3477e-01],
        [1.1675e-04, 9.9988e-01]])


Predicting:  64%|██████▎   | 137/215 [19:56<11:34,  8.91s/it]

tensor([[0.9744, 0.0256],
        [0.0275, 0.9725],
        [0.7346, 0.2654],
        [0.8813, 0.1187],
        [0.9800, 0.0200],
        [0.9068, 0.0932],
        [0.8759, 0.1241],
        [0.8373, 0.1627],
        [0.2231, 0.7769],
        [0.9906, 0.0094],
        [0.5575, 0.4425],
        [0.9873, 0.0127],
        [0.9977, 0.0023],
        [0.7797, 0.2203],
        [0.9841, 0.0159],
        [0.6884, 0.3116]])


Predicting:  64%|██████▍   | 138/215 [20:05<11:12,  8.73s/it]

tensor([[0.5023, 0.4977],
        [0.9661, 0.0339],
        [0.9859, 0.0141],
        [0.9820, 0.0180],
        [0.9896, 0.0104],
        [0.9984, 0.0016],
        [0.9976, 0.0024],
        [0.4475, 0.5525],
        [0.0165, 0.9835],
        [0.9971, 0.0029],
        [0.1196, 0.8804],
        [0.4893, 0.5107],
        [0.3812, 0.6188],
        [0.4540, 0.5460],
        [0.2909, 0.7091],
        [0.0068, 0.9932]])


Predicting:  65%|██████▍   | 139/215 [20:13<11:02,  8.71s/it]

tensor([[0.3232, 0.6768],
        [0.7630, 0.2370],
        [0.9748, 0.0252],
        [0.9988, 0.0012],
        [0.9933, 0.0067],
        [0.9909, 0.0091],
        [0.8909, 0.1091],
        [0.7465, 0.2535],
        [0.7690, 0.2310],
        [0.9129, 0.0871],
        [0.8068, 0.1932],
        [0.5067, 0.4933],
        [0.8050, 0.1950],
        [0.8879, 0.1121],
        [0.4486, 0.5514],
        [0.8229, 0.1771]])


Predicting:  65%|██████▌   | 140/215 [20:22<10:48,  8.65s/it]

tensor([[9.7280e-01, 2.7205e-02],
        [9.0908e-01, 9.0925e-02],
        [8.4003e-01, 1.5997e-01],
        [9.7133e-01, 2.8669e-02],
        [5.9745e-01, 4.0255e-01],
        [1.8647e-01, 8.1353e-01],
        [3.6412e-01, 6.3588e-01],
        [1.7535e-01, 8.2465e-01],
        [6.4175e-02, 9.3582e-01],
        [2.7840e-01, 7.2160e-01],
        [9.9334e-01, 6.6571e-03],
        [9.8543e-01, 1.4574e-02],
        [9.9976e-01, 2.3731e-04],
        [2.4338e-03, 9.9757e-01],
        [7.5784e-01, 2.4216e-01],
        [8.2429e-01, 1.7571e-01]])


Predicting:  66%|██████▌   | 141/215 [20:30<10:36,  8.60s/it]

tensor([[0.9638, 0.0362],
        [0.7805, 0.2195],
        [0.9694, 0.0306],
        [0.7333, 0.2667],
        [0.9060, 0.0940],
        [0.9863, 0.0137],
        [0.9980, 0.0020],
        [0.8686, 0.1314],
        [0.7544, 0.2456],
        [0.9633, 0.0367],
        [0.6462, 0.3538],
        [0.8456, 0.1544],
        [0.9808, 0.0192],
        [0.9512, 0.0488],
        [0.0324, 0.9676],
        [0.9388, 0.0612]])


Predicting:  66%|██████▌   | 142/215 [20:39<10:26,  8.58s/it]

tensor([[8.8843e-01, 1.1157e-01],
        [9.0986e-01, 9.0138e-02],
        [9.9030e-01, 9.6981e-03],
        [9.6801e-01, 3.1994e-02],
        [7.5673e-01, 2.4327e-01],
        [9.9940e-01, 6.0166e-04],
        [9.7599e-01, 2.4011e-02],
        [5.8116e-01, 4.1884e-01],
        [9.9875e-01, 1.2452e-03],
        [8.0104e-01, 1.9896e-01],
        [9.9580e-01, 4.2017e-03],
        [6.5425e-01, 3.4575e-01],
        [9.9518e-01, 4.8187e-03],
        [1.9983e-02, 9.8002e-01],
        [9.8603e-01, 1.3971e-02],
        [9.6722e-01, 3.2782e-02]])


Predicting:  67%|██████▋   | 143/215 [20:48<10:25,  8.69s/it]

tensor([[9.5918e-01, 4.0822e-02],
        [2.8234e-02, 9.7177e-01],
        [4.2925e-01, 5.7075e-01],
        [8.2886e-01, 1.7114e-01],
        [1.2329e-04, 9.9988e-01],
        [7.5730e-01, 2.4270e-01],
        [7.0615e-01, 2.9385e-01],
        [9.8597e-01, 1.4027e-02],
        [7.6896e-01, 2.3104e-01],
        [2.8945e-01, 7.1055e-01],
        [9.4373e-01, 5.6272e-02],
        [9.8533e-01, 1.4667e-02],
        [9.9296e-01, 7.0389e-03],
        [4.3477e-01, 5.6523e-01],
        [9.9867e-01, 1.3331e-03],
        [9.8275e-01, 1.7248e-02]])


Predicting:  67%|██████▋   | 144/215 [20:57<10:18,  8.71s/it]

tensor([[5.0682e-01, 4.9318e-01],
        [9.9911e-01, 8.8689e-04],
        [9.4829e-01, 5.1715e-02],
        [9.8101e-01, 1.8991e-02],
        [9.9629e-01, 3.7119e-03],
        [7.5510e-01, 2.4490e-01],
        [9.9704e-01, 2.9632e-03],
        [9.9830e-01, 1.6970e-03],
        [6.3422e-01, 3.6578e-01],
        [9.9665e-01, 3.3486e-03],
        [6.7684e-01, 3.2316e-01],
        [9.7174e-01, 2.8263e-02],
        [5.5765e-01, 4.4235e-01],
        [1.8108e-03, 9.9819e-01],
        [3.8348e-03, 9.9617e-01],
        [2.1709e-02, 9.7829e-01]])


Predicting:  67%|██████▋   | 145/215 [21:06<10:17,  8.82s/it]

tensor([[9.9944e-01, 5.5711e-04],
        [9.9976e-01, 2.4031e-04],
        [3.0359e-01, 6.9641e-01],
        [9.6239e-01, 3.7613e-02],
        [2.3136e-01, 7.6864e-01],
        [8.3664e-01, 1.6336e-01],
        [9.9438e-01, 5.6232e-03],
        [2.1982e-01, 7.8018e-01],
        [8.6630e-01, 1.3370e-01],
        [2.3994e-03, 9.9760e-01],
        [8.9476e-01, 1.0524e-01],
        [9.9434e-01, 5.6591e-03],
        [6.0955e-01, 3.9045e-01],
        [1.0985e-04, 9.9989e-01],
        [8.2786e-01, 1.7214e-01],
        [3.5265e-02, 9.6473e-01]])


Predicting:  68%|██████▊   | 146/215 [21:14<10:03,  8.75s/it]

tensor([[1.2736e-01, 8.7264e-01],
        [7.7384e-01, 2.2616e-01],
        [1.4503e-03, 9.9855e-01],
        [3.0731e-01, 6.9269e-01],
        [6.1057e-01, 3.8943e-01],
        [9.8250e-01, 1.7502e-02],
        [7.6035e-01, 2.3965e-01],
        [2.0081e-04, 9.9980e-01],
        [8.5440e-01, 1.4560e-01],
        [9.9662e-01, 3.3799e-03],
        [9.3751e-01, 6.2487e-02],
        [2.2361e-04, 9.9978e-01],
        [9.8381e-01, 1.6195e-02],
        [1.9216e-04, 9.9981e-01],
        [8.6411e-01, 1.3589e-01],
        [2.6315e-01, 7.3685e-01]])


Predicting:  68%|██████▊   | 147/215 [21:23<09:49,  8.67s/it]

tensor([[9.4077e-05, 9.9991e-01],
        [6.8558e-01, 3.1442e-01],
        [9.9431e-05, 9.9990e-01],
        [2.2810e-03, 9.9772e-01],
        [1.1699e-01, 8.8301e-01],
        [7.7144e-01, 2.2856e-01],
        [6.9789e-01, 3.0211e-01],
        [6.4330e-01, 3.5670e-01],
        [1.7133e-04, 9.9983e-01],
        [9.9492e-01, 5.0824e-03],
        [9.9970e-01, 3.0233e-04],
        [9.7218e-01, 2.7816e-02],
        [4.7216e-04, 9.9953e-01],
        [9.9451e-01, 5.4902e-03],
        [9.9969e-01, 3.0648e-04],
        [7.5806e-01, 2.4194e-01]])


Predicting:  69%|██████▉   | 148/215 [21:31<09:33,  8.55s/it]

tensor([[1.7750e-01, 8.2250e-01],
        [7.5845e-02, 9.2415e-01],
        [5.2418e-01, 4.7582e-01],
        [1.6857e-04, 9.9983e-01],
        [2.9804e-01, 7.0196e-01],
        [1.1881e-02, 9.8812e-01],
        [4.6567e-02, 9.5343e-01],
        [7.0662e-02, 9.2934e-01],
        [9.9853e-01, 1.4732e-03],
        [9.9963e-01, 3.6767e-04],
        [1.8541e-04, 9.9981e-01],
        [9.0560e-01, 9.4405e-02],
        [5.4233e-04, 9.9946e-01],
        [3.0134e-04, 9.9970e-01],
        [9.6541e-05, 9.9990e-01],
        [7.8435e-01, 2.1565e-01]])


Predicting:  69%|██████▉   | 149/215 [21:40<09:35,  8.71s/it]

tensor([[1.7892e-01, 8.2108e-01],
        [5.7666e-01, 4.2334e-01],
        [1.7319e-04, 9.9983e-01],
        [1.3317e-04, 9.9987e-01],
        [2.9501e-03, 9.9705e-01],
        [1.5647e-04, 9.9984e-01],
        [1.0292e-03, 9.9897e-01],
        [8.5745e-01, 1.4255e-01],
        [2.0352e-04, 9.9980e-01],
        [9.5667e-01, 4.3332e-02],
        [5.4510e-01, 4.5490e-01],
        [1.3623e-04, 9.9986e-01],
        [9.8151e-01, 1.8490e-02],
        [5.3185e-01, 4.6815e-01],
        [9.9974e-01, 2.6008e-04],
        [9.8068e-01, 1.9317e-02]])


Predicting:  70%|██████▉   | 150/215 [21:49<09:25,  8.70s/it]

tensor([[9.9928e-01, 7.1539e-04],
        [7.0441e-01, 2.9559e-01],
        [9.9898e-01, 1.0159e-03],
        [9.9896e-01, 1.0440e-03],
        [8.6833e-01, 1.3167e-01],
        [9.5011e-01, 4.9886e-02],
        [8.8046e-01, 1.1954e-01],
        [9.9825e-01, 1.7550e-03],
        [9.9618e-01, 3.8210e-03],
        [9.9666e-01, 3.3434e-03],
        [9.7399e-01, 2.6013e-02],
        [9.9625e-01, 3.7501e-03],
        [9.9594e-01, 4.0568e-03],
        [9.7776e-01, 2.2243e-02],
        [9.9980e-01, 1.9917e-04],
        [9.9889e-01, 1.1130e-03]])


Predicting:  70%|███████   | 151/215 [21:58<09:29,  8.89s/it]

tensor([[0.3704, 0.6296],
        [0.9640, 0.0360],
        [0.1705, 0.8295],
        [0.0255, 0.9745],
        [0.1401, 0.8599],
        [0.9886, 0.0114],
        [0.9947, 0.0053],
        [0.9951, 0.0049],
        [0.8358, 0.1642],
        [0.5412, 0.4588],
        [0.0595, 0.9405],
        [0.6352, 0.3648],
        [0.4753, 0.5247],
        [0.8678, 0.1322],
        [0.6811, 0.3189],
        [0.1795, 0.8205]])


Predicting:  71%|███████   | 152/215 [22:07<09:18,  8.87s/it]

tensor([[0.0011, 0.9989],
        [0.1443, 0.8557],
        [0.9920, 0.0080],
        [0.7929, 0.2071],
        [0.4504, 0.5496],
        [0.9551, 0.0449],
        [0.9961, 0.0039],
        [0.9257, 0.0743],
        [0.7576, 0.2424],
        [0.9980, 0.0020],
        [0.9819, 0.0181],
        [0.9391, 0.0609],
        [0.9975, 0.0025],
        [0.0477, 0.9523],
        [0.8437, 0.1563],
        [0.9965, 0.0035]])


Predicting:  71%|███████   | 153/215 [22:16<09:13,  8.92s/it]

tensor([[8.2593e-01, 1.7407e-01],
        [5.5542e-01, 4.4458e-01],
        [5.9600e-01, 4.0400e-01],
        [9.6170e-01, 3.8298e-02],
        [7.2984e-01, 2.7016e-01],
        [9.6579e-01, 3.4211e-02],
        [9.9335e-01, 6.6490e-03],
        [9.7978e-01, 2.0222e-02],
        [9.6278e-01, 3.7216e-02],
        [8.9991e-01, 1.0009e-01],
        [9.6697e-01, 3.3032e-02],
        [5.4283e-01, 4.5717e-01],
        [3.4203e-04, 9.9966e-01],
        [5.6172e-01, 4.3828e-01],
        [4.4350e-01, 5.5650e-01],
        [9.8662e-01, 1.3382e-02]])


Predicting:  72%|███████▏  | 154/215 [22:24<08:54,  8.76s/it]

tensor([[9.7989e-01, 2.0115e-02],
        [9.8900e-01, 1.1005e-02],
        [9.1193e-01, 8.8066e-02],
        [9.5363e-01, 4.6370e-02],
        [8.8265e-01, 1.1735e-01],
        [3.2893e-01, 6.7107e-01],
        [7.7648e-01, 2.2352e-01],
        [9.9824e-01, 1.7590e-03],
        [9.7358e-01, 2.6420e-02],
        [1.2915e-04, 9.9987e-01],
        [1.5214e-01, 8.4786e-01],
        [9.9922e-01, 7.7692e-04],
        [9.4652e-01, 5.3485e-02],
        [9.3228e-01, 6.7718e-02],
        [8.0273e-01, 1.9727e-01],
        [9.9939e-01, 6.1153e-04]])


Predicting:  72%|███████▏  | 155/215 [22:33<08:48,  8.81s/it]

tensor([[7.3690e-01, 2.6310e-01],
        [8.0746e-01, 1.9254e-01],
        [5.5078e-01, 4.4922e-01],
        [9.2269e-01, 7.7307e-02],
        [8.3105e-01, 1.6895e-01],
        [2.9522e-01, 7.0478e-01],
        [1.0471e-01, 8.9529e-01],
        [9.9391e-01, 6.0884e-03],
        [9.9937e-01, 6.3087e-04],
        [9.9581e-01, 4.1928e-03],
        [9.4592e-01, 5.4079e-02],
        [9.9740e-01, 2.5981e-03],
        [9.9952e-01, 4.8411e-04],
        [9.7016e-01, 2.9839e-02],
        [1.0841e-01, 8.9159e-01],
        [8.8596e-01, 1.1404e-01]])


Predicting:  73%|███████▎  | 156/215 [22:42<08:44,  8.88s/it]

tensor([[0.1239, 0.8761],
        [0.9906, 0.0094],
        [0.6580, 0.3420],
        [0.8899, 0.1101],
        [0.9923, 0.0077],
        [0.8003, 0.1997],
        [0.4718, 0.5282],
        [0.9630, 0.0370],
        [0.9906, 0.0094],
        [0.9285, 0.0715],
        [0.7431, 0.2569],
        [0.2181, 0.7819],
        [0.8215, 0.1785],
        [0.9456, 0.0544],
        [0.5418, 0.4582],
        [0.7066, 0.2934]])


Predicting:  73%|███████▎  | 157/215 [22:51<08:30,  8.81s/it]

tensor([[9.5769e-01, 4.2306e-02],
        [6.6078e-01, 3.3922e-01],
        [6.5845e-02, 9.3416e-01],
        [1.1717e-04, 9.9988e-01],
        [9.9830e-01, 1.6966e-03],
        [9.9976e-01, 2.3503e-04],
        [7.0853e-02, 9.2915e-01],
        [3.5353e-01, 6.4647e-01],
        [8.8860e-01, 1.1140e-01],
        [9.7321e-01, 2.6794e-02],
        [8.7447e-01, 1.2553e-01],
        [9.2600e-01, 7.3998e-02],
        [8.3336e-01, 1.6664e-01],
        [9.9326e-01, 6.7422e-03],
        [7.9567e-01, 2.0433e-01],
        [9.9543e-01, 4.5744e-03]])


Predicting:  73%|███████▎  | 158/215 [23:00<08:18,  8.74s/it]

tensor([[9.2305e-01, 7.6952e-02],
        [7.5177e-01, 2.4823e-01],
        [9.9915e-01, 8.5208e-04],
        [9.9496e-01, 5.0378e-03],
        [9.9846e-01, 1.5392e-03],
        [3.3923e-01, 6.6077e-01],
        [6.7161e-01, 3.2839e-01],
        [9.1288e-05, 9.9991e-01],
        [9.9627e-01, 3.7348e-03],
        [9.9927e-01, 7.2666e-04],
        [9.9944e-01, 5.6315e-04],
        [9.9834e-01, 1.6604e-03],
        [1.4105e-02, 9.8589e-01],
        [2.5721e-03, 9.9743e-01],
        [8.7983e-01, 1.2017e-01],
        [9.9059e-01, 9.4082e-03]])


Predicting:  74%|███████▍  | 159/215 [23:08<08:07,  8.71s/it]

tensor([[8.9352e-01, 1.0648e-01],
        [8.9418e-01, 1.0582e-01],
        [9.9436e-01, 5.6357e-03],
        [7.0854e-01, 2.9146e-01],
        [9.9985e-01, 1.4991e-04],
        [9.9995e-01, 5.2778e-05],
        [4.0798e-01, 5.9202e-01],
        [1.1675e-03, 9.9883e-01],
        [9.9947e-01, 5.3114e-04],
        [9.9996e-01, 4.2511e-05],
        [9.9972e-01, 2.7799e-04],
        [9.8682e-01, 1.3180e-02],
        [9.9890e-01, 1.0952e-03],
        [9.9192e-01, 8.0838e-03],
        [9.9848e-01, 1.5186e-03],
        [9.2778e-01, 7.2222e-02]])


Predicting:  74%|███████▍  | 160/215 [23:18<08:14,  8.99s/it]

tensor([[9.9497e-01, 5.0339e-03],
        [6.7805e-04, 9.9932e-01],
        [9.9432e-01, 5.6790e-03],
        [5.0745e-01, 4.9255e-01],
        [9.9412e-01, 5.8843e-03],
        [9.1436e-02, 9.0856e-01],
        [9.5012e-01, 4.9876e-02],
        [9.7919e-01, 2.0812e-02],
        [1.1378e-04, 9.9989e-01],
        [1.0179e-04, 9.9990e-01],
        [9.4853e-01, 5.1471e-02],
        [2.4192e-01, 7.5808e-01],
        [9.9954e-01, 4.5606e-04],
        [9.9876e-01, 1.2431e-03],
        [9.9713e-01, 2.8717e-03],
        [9.9059e-01, 9.4060e-03]])


Predicting:  75%|███████▍  | 161/215 [23:27<08:04,  8.98s/it]

tensor([[9.9875e-01, 1.2516e-03],
        [9.7465e-02, 9.0253e-01],
        [1.2986e-03, 9.9870e-01],
        [9.9772e-01, 2.2839e-03],
        [9.9975e-01, 2.4664e-04],
        [8.4833e-03, 9.9152e-01],
        [6.8876e-04, 9.9931e-01],
        [9.9177e-01, 8.2313e-03],
        [9.9949e-01, 5.1023e-04],
        [5.6908e-01, 4.3092e-01],
        [1.0697e-04, 9.9989e-01],
        [9.9960e-01, 4.0039e-04],
        [9.9992e-01, 8.3355e-05],
        [9.9978e-01, 2.2265e-04],
        [9.9971e-01, 2.9424e-04],
        [8.8118e-02, 9.1188e-01]])


Predicting:  75%|███████▌  | 162/215 [23:35<07:40,  8.68s/it]

tensor([[4.0622e-03, 9.9594e-01],
        [9.9971e-01, 2.8991e-04],
        [9.7708e-01, 2.2921e-02],
        [9.9861e-01, 1.3929e-03],
        [9.8437e-01, 1.5627e-02],
        [2.3465e-02, 9.7654e-01],
        [9.9582e-01, 4.1807e-03],
        [8.0795e-01, 1.9205e-01],
        [4.1177e-04, 9.9959e-01],
        [9.9826e-01, 1.7419e-03],
        [9.9956e-01, 4.3510e-04],
        [9.2201e-01, 7.7986e-02],
        [1.3857e-02, 9.8614e-01],
        [9.9235e-01, 7.6479e-03],
        [4.4414e-03, 9.9556e-01],
        [3.5662e-01, 6.4338e-01]])


Predicting:  76%|███████▌  | 163/215 [23:43<07:28,  8.63s/it]

tensor([[9.5370e-01, 4.6296e-02],
        [9.7728e-01, 2.2718e-02],
        [9.9759e-01, 2.4124e-03],
        [1.3055e-01, 8.6945e-01],
        [1.3086e-02, 9.8691e-01],
        [9.9973e-01, 2.6561e-04],
        [9.9987e-01, 1.3221e-04],
        [9.9922e-01, 7.8364e-04],
        [9.9953e-01, 4.6693e-04],
        [9.9573e-01, 4.2700e-03],
        [4.7048e-01, 5.2952e-01],
        [9.4595e-01, 5.4051e-02],
        [3.9851e-01, 6.0149e-01],
        [7.0832e-04, 9.9929e-01],
        [4.1738e-01, 5.8262e-01],
        [1.1671e-04, 9.9988e-01]])


Predicting:  76%|███████▋  | 164/215 [23:52<07:18,  8.60s/it]

tensor([[9.9976e-01, 2.3718e-04],
        [9.9982e-01, 1.8288e-04],
        [8.0163e-01, 1.9837e-01],
        [5.8282e-03, 9.9417e-01],
        [9.9986e-01, 1.4269e-04],
        [9.9993e-01, 7.1900e-05],
        [2.8752e-04, 9.9971e-01],
        [1.8525e-01, 8.1475e-01],
        [9.5520e-01, 4.4801e-02],
        [9.8765e-01, 1.2350e-02],
        [3.3191e-02, 9.6681e-01],
        [9.9982e-01, 1.8267e-04],
        [9.9990e-01, 1.0264e-04],
        [9.6956e-01, 3.0437e-02],
        [5.2491e-01, 4.7509e-01],
        [9.0230e-01, 9.7702e-02]])


Predicting:  77%|███████▋  | 165/215 [24:01<07:15,  8.71s/it]

tensor([[9.9919e-01, 8.0752e-04],
        [9.9893e-01, 1.0705e-03],
        [9.9817e-01, 1.8289e-03],
        [7.5839e-01, 2.4161e-01],
        [1.7813e-02, 9.8219e-01],
        [9.9993e-01, 6.7792e-05],
        [9.9995e-01, 4.5139e-05],
        [7.4775e-01, 2.5225e-01],
        [1.4364e-01, 8.5636e-01],
        [1.9045e-03, 9.9810e-01],
        [9.8924e-01, 1.0763e-02],
        [9.9850e-01, 1.5044e-03],
        [7.5110e-01, 2.4890e-01],
        [7.6870e-01, 2.3130e-01],
        [7.5938e-01, 2.4062e-01],
        [9.6092e-01, 3.9079e-02]])


Predicting:  77%|███████▋  | 166/215 [24:09<06:58,  8.54s/it]

tensor([[8.4535e-01, 1.5465e-01],
        [7.6899e-01, 2.3101e-01],
        [8.4344e-02, 9.1566e-01],
        [8.6156e-01, 1.3844e-01],
        [7.0774e-02, 9.2923e-01],
        [7.4517e-01, 2.5483e-01],
        [9.4114e-01, 5.8863e-02],
        [1.9972e-01, 8.0028e-01],
        [1.3005e-02, 9.8699e-01],
        [5.0027e-01, 4.9973e-01],
        [9.9337e-01, 6.6339e-03],
        [9.7574e-01, 2.4260e-02],
        [9.9930e-01, 6.9699e-04],
        [7.3677e-01, 2.6323e-01],
        [1.0934e-01, 8.9066e-01],
        [9.5286e-01, 4.7142e-02]])


Predicting:  78%|███████▊  | 167/215 [24:17<06:42,  8.39s/it]

tensor([[0.9837, 0.0163],
        [0.9627, 0.0373],
        [0.5383, 0.4617],
        [0.8583, 0.1417],
        [0.9605, 0.0395],
        [0.9170, 0.0830],
        [0.9649, 0.0351],
        [0.9001, 0.0999],
        [0.6195, 0.3805],
        [0.9931, 0.0069],
        [0.9921, 0.0079],
        [0.9622, 0.0378],
        [0.8659, 0.1341],
        [0.9976, 0.0024],
        [0.9990, 0.0010],
        [0.6712, 0.3288]])


Predicting:  78%|███████▊  | 168/215 [24:25<06:33,  8.36s/it]

tensor([[0.8862, 0.1138],
        [0.6117, 0.3883],
        [0.1165, 0.8835],
        [0.7160, 0.2840],
        [0.8471, 0.1529],
        [0.7698, 0.2302],
        [0.6334, 0.3666],
        [0.3361, 0.6639],
        [0.4454, 0.5546],
        [0.8423, 0.1577],
        [0.9799, 0.0201],
        [0.9483, 0.0517],
        [0.9914, 0.0086],
        [0.9910, 0.0090],
        [0.8283, 0.1717],
        [0.7659, 0.2341]])


Predicting:  79%|███████▊  | 169/215 [24:34<06:28,  8.45s/it]

tensor([[0.0485, 0.9515],
        [0.9747, 0.0253],
        [0.5482, 0.4518],
        [0.9988, 0.0012],
        [0.2883, 0.7117],
        [0.9154, 0.0846],
        [0.9944, 0.0056],
        [0.5783, 0.4217],
        [0.9765, 0.0235],
        [0.9941, 0.0059],
        [0.6376, 0.3624],
        [0.8657, 0.1343],
        [0.9216, 0.0784],
        [0.6987, 0.3013],
        [0.8373, 0.1627],
        [0.9470, 0.0530]])


Predicting:  79%|███████▉  | 170/215 [24:42<06:19,  8.42s/it]

tensor([[0.7133, 0.2867],
        [0.8877, 0.1123],
        [0.9857, 0.0143],
        [0.0608, 0.9392],
        [0.1387, 0.8613],
        [0.3955, 0.6045],
        [0.8567, 0.1433],
        [0.9722, 0.0278],
        [0.9027, 0.0973],
        [0.8009, 0.1991],
        [0.7979, 0.2021],
        [0.8065, 0.1935],
        [0.9912, 0.0088],
        [0.9809, 0.0191],
        [0.3187, 0.6813],
        [0.4125, 0.5875]])


Predicting:  80%|███████▉  | 171/215 [24:51<06:19,  8.63s/it]

tensor([[0.9099, 0.0901],
        [0.9743, 0.0257],
        [0.7439, 0.2561],
        [0.2310, 0.7690],
        [0.9945, 0.0055],
        [0.7933, 0.2067],
        [0.6974, 0.3026],
        [0.9251, 0.0749],
        [0.9887, 0.0113],
        [0.9795, 0.0205],
        [0.8645, 0.1355],
        [0.9984, 0.0016],
        [0.5860, 0.4140],
        [0.9936, 0.0064],
        [0.0336, 0.9664],
        [0.0104, 0.9896]])


Predicting:  80%|████████  | 172/215 [25:00<06:06,  8.52s/it]

tensor([[0.6832, 0.3168],
        [0.6218, 0.3782],
        [0.0579, 0.9421],
        [0.9433, 0.0567],
        [0.7819, 0.2181],
        [0.0029, 0.9971],
        [0.7480, 0.2520],
        [0.9980, 0.0020],
        [0.9836, 0.0164],
        [0.9989, 0.0011],
        [0.9872, 0.0128],
        [0.9967, 0.0033],
        [0.9758, 0.0242],
        [0.9683, 0.0317],
        [0.9906, 0.0094],
        [0.9029, 0.0971]])


Predicting:  80%|████████  | 173/215 [25:08<05:54,  8.44s/it]

tensor([[9.5812e-01, 4.1877e-02],
        [7.2479e-01, 2.7521e-01],
        [9.2183e-01, 7.8166e-02],
        [4.5003e-01, 5.4997e-01],
        [3.2928e-01, 6.7072e-01],
        [9.8339e-01, 1.6610e-02],
        [9.2033e-01, 7.9675e-02],
        [9.9920e-01, 7.9687e-04],
        [9.9971e-01, 2.8658e-04],
        [8.0148e-01, 1.9852e-01],
        [6.5591e-01, 3.4409e-01],
        [9.7018e-01, 2.9815e-02],
        [7.3235e-01, 2.6765e-01],
        [9.8878e-01, 1.1225e-02],
        [1.7395e-01, 8.2605e-01],
        [3.1664e-01, 6.8336e-01]])


Predicting:  81%|████████  | 174/215 [25:16<05:46,  8.46s/it]

tensor([[2.5550e-01, 7.4450e-01],
        [4.7804e-01, 5.2196e-01],
        [8.5596e-01, 1.4404e-01],
        [6.0766e-01, 3.9234e-01],
        [9.9927e-01, 7.2780e-04],
        [9.9977e-01, 2.3439e-04],
        [9.9967e-01, 3.3411e-04],
        [2.7658e-01, 7.2342e-01],
        [8.6421e-01, 1.3579e-01],
        [9.4974e-01, 5.0261e-02],
        [8.5904e-01, 1.4096e-01],
        [9.1932e-01, 8.0677e-02],
        [9.5498e-01, 4.5025e-02],
        [6.8949e-01, 3.1051e-01],
        [9.7827e-01, 2.1730e-02],
        [9.5602e-01, 4.3980e-02]])


Predicting:  81%|████████▏ | 175/215 [25:25<05:39,  8.48s/it]

tensor([[0.9901, 0.0099],
        [0.9923, 0.0077],
        [0.5137, 0.4863],
        [0.9945, 0.0055],
        [0.9872, 0.0128],
        [0.9971, 0.0029],
        [0.9980, 0.0020],
        [0.2856, 0.7144],
        [0.9979, 0.0021],
        [0.8624, 0.1376],
        [0.9940, 0.0060],
        [0.8900, 0.1100],
        [0.9963, 0.0037],
        [0.9698, 0.0302],
        [0.5702, 0.4298],
        [0.7964, 0.2036]])


Predicting:  82%|████████▏ | 176/215 [25:33<05:27,  8.38s/it]

tensor([[0.9673, 0.0327],
        [0.0398, 0.9602],
        [0.9913, 0.0087],
        [0.9529, 0.0471],
        [0.4452, 0.5548],
        [0.8270, 0.1730],
        [0.9551, 0.0449],
        [0.9844, 0.0156],
        [0.9014, 0.0986],
        [0.3435, 0.6565],
        [0.0368, 0.9632],
        [0.9313, 0.0687],
        [0.7626, 0.2374],
        [0.9860, 0.0140],
        [0.4426, 0.5574],
        [0.9984, 0.0016]])


Predicting:  82%|████████▏ | 177/215 [25:41<05:17,  8.36s/it]

tensor([[9.9135e-01, 8.6529e-03],
        [9.8656e-01, 1.3442e-02],
        [9.9682e-01, 3.1794e-03],
        [2.3954e-04, 9.9976e-01],
        [9.9853e-01, 1.4749e-03],
        [9.9495e-01, 5.0486e-03],
        [8.7157e-01, 1.2843e-01],
        [9.9954e-01, 4.6428e-04],
        [9.9988e-01, 1.1639e-04],
        [8.6534e-01, 1.3466e-01],
        [3.6098e-01, 6.3902e-01],
        [8.4044e-02, 9.1596e-01],
        [4.6700e-01, 5.3300e-01],
        [9.9618e-01, 3.8226e-03],
        [9.9851e-01, 1.4931e-03],
        [3.9708e-01, 6.0292e-01]])


Predicting:  83%|████████▎ | 178/215 [25:50<05:13,  8.48s/it]

tensor([[1.2587e-03, 9.9874e-01],
        [9.8939e-01, 1.0607e-02],
        [9.9032e-01, 9.6778e-03],
        [9.9494e-01, 5.0598e-03],
        [9.6573e-01, 3.4267e-02],
        [9.8721e-01, 1.2790e-02],
        [9.4072e-01, 5.9282e-02],
        [4.7825e-01, 5.2175e-01],
        [4.0399e-02, 9.5960e-01],
        [1.0836e-04, 9.9989e-01],
        [5.3906e-01, 4.6094e-01],
        [9.8427e-01, 1.5728e-02],
        [1.2205e-02, 9.8779e-01],
        [9.9953e-01, 4.6759e-04],
        [9.9930e-01, 6.9732e-04],
        [1.2363e-03, 9.9876e-01]])


Predicting:  83%|████████▎ | 179/215 [26:00<05:13,  8.72s/it]

tensor([[9.4690e-01, 5.3103e-02],
        [8.2355e-01, 1.7645e-01],
        [9.9890e-01, 1.1007e-03],
        [6.5253e-01, 3.4747e-01],
        [2.8997e-01, 7.1003e-01],
        [1.7333e-01, 8.2667e-01],
        [9.9806e-01, 1.9404e-03],
        [9.9431e-01, 5.6927e-03],
        [9.9964e-01, 3.5891e-04],
        [9.9980e-01, 2.0046e-04],
        [8.1995e-01, 1.8005e-01],
        [9.4967e-05, 9.9990e-01],
        [9.0757e-01, 9.2426e-02],
        [2.3141e-01, 7.6859e-01],
        [5.7390e-01, 4.2610e-01],
        [8.8934e-01, 1.1066e-01]])


Predicting:  84%|████████▎ | 180/215 [26:09<05:10,  8.87s/it]

tensor([[9.9898e-01, 1.0236e-03],
        [9.9982e-01, 1.7684e-04],
        [1.7421e-02, 9.8258e-01],
        [9.9403e-01, 5.9685e-03],
        [9.9505e-01, 4.9526e-03],
        [6.6342e-01, 3.3658e-01],
        [9.0566e-01, 9.4336e-02],
        [9.4713e-01, 5.2866e-02],
        [1.8767e-01, 8.1233e-01],
        [3.9486e-02, 9.6051e-01],
        [6.3692e-01, 3.6308e-01],
        [9.8566e-05, 9.9990e-01],
        [9.9967e-01, 3.2871e-04],
        [9.9988e-01, 1.1641e-04],
        [9.6999e-01, 3.0005e-02],
        [9.9255e-01, 7.4476e-03]])


Predicting:  84%|████████▍ | 181/215 [26:18<05:06,  9.03s/it]

tensor([[2.5160e-01, 7.4840e-01],
        [9.9268e-01, 7.3208e-03],
        [9.9550e-01, 4.4975e-03],
        [9.7774e-01, 2.2265e-02],
        [7.5173e-01, 2.4827e-01],
        [2.1997e-04, 9.9978e-01],
        [9.6164e-01, 3.8355e-02],
        [6.6760e-03, 9.9332e-01],
        [9.9293e-01, 7.0740e-03],
        [8.7499e-01, 1.2501e-01],
        [7.4731e-01, 2.5269e-01],
        [7.1834e-01, 2.8166e-01],
        [4.3354e-01, 5.6646e-01],
        [8.9878e-01, 1.0122e-01],
        [9.9818e-01, 1.8246e-03],
        [1.2118e-04, 9.9988e-01]])


Predicting:  85%|████████▍ | 182/215 [26:26<04:50,  8.80s/it]

tensor([[9.9970e-01, 2.9517e-04],
        [9.8398e-01, 1.6020e-02],
        [9.9934e-01, 6.5645e-04],
        [9.9632e-01, 3.6767e-03],
        [9.4907e-01, 5.0931e-02],
        [5.0878e-01, 4.9122e-01],
        [8.5636e-01, 1.4364e-01],
        [9.1890e-01, 8.1099e-02],
        [1.0055e-01, 8.9945e-01],
        [9.3941e-01, 6.0595e-02],
        [9.9010e-01, 9.8951e-03],
        [9.0967e-01, 9.0331e-02],
        [8.3230e-01, 1.6770e-01],
        [9.9899e-01, 1.0067e-03],
        [9.6881e-01, 3.1187e-02],
        [5.2213e-01, 4.7787e-01]])


Predicting:  85%|████████▌ | 183/215 [26:35<04:39,  8.72s/it]

tensor([[8.2188e-01, 1.7812e-01],
        [9.2857e-01, 7.1430e-02],
        [9.3318e-01, 6.6825e-02],
        [9.9938e-01, 6.2244e-04],
        [9.8927e-01, 1.0733e-02],
        [5.2317e-01, 4.7683e-01],
        [9.9511e-01, 4.8901e-03],
        [9.9981e-01, 1.9240e-04],
        [9.9703e-01, 2.9697e-03],
        [9.9964e-01, 3.6190e-04],
        [9.7523e-01, 2.4773e-02],
        [9.9939e-01, 6.1130e-04],
        [9.9690e-01, 3.1025e-03],
        [9.7842e-01, 2.1582e-02],
        [9.9701e-01, 2.9878e-03],
        [8.8843e-01, 1.1157e-01]])


Predicting:  86%|████████▌ | 184/215 [26:44<04:32,  8.77s/it]

tensor([[9.9566e-01, 4.3431e-03],
        [9.8436e-01, 1.5638e-02],
        [9.8188e-01, 1.8116e-02],
        [9.9512e-01, 4.8817e-03],
        [1.8629e-01, 8.1371e-01],
        [9.9613e-01, 3.8749e-03],
        [8.3441e-01, 1.6559e-01],
        [9.3824e-01, 6.1757e-02],
        [7.4139e-01, 2.5861e-01],
        [1.1042e-04, 9.9989e-01],
        [9.9950e-01, 5.0384e-04],
        [9.9983e-01, 1.6719e-04],
        [6.7848e-01, 3.2152e-01],
        [1.0186e-04, 9.9990e-01],
        [9.9965e-01, 3.4987e-04],
        [9.9987e-01, 1.2542e-04]])


Predicting:  86%|████████▌ | 185/215 [26:53<04:28,  8.94s/it]

tensor([[9.9087e-01, 9.1341e-03],
        [9.8833e-01, 1.1675e-02],
        [1.1389e-03, 9.9886e-01],
        [8.0915e-03, 9.9191e-01],
        [5.2861e-01, 4.7139e-01],
        [7.6720e-01, 2.3280e-01],
        [7.7411e-01, 2.2589e-01],
        [9.9938e-01, 6.1504e-04],
        [2.3624e-02, 9.7638e-01],
        [9.9681e-01, 3.1870e-03],
        [9.9990e-01, 1.0223e-04],
        [9.9718e-01, 2.8235e-03],
        [9.9991e-01, 8.5639e-05],
        [9.9974e-01, 2.6262e-04],
        [8.9267e-01, 1.0733e-01],
        [1.6811e-03, 9.9832e-01]])


Predicting:  87%|████████▋ | 186/215 [27:02<04:17,  8.89s/it]

tensor([[3.7952e-03, 9.9620e-01],
        [9.9991e-01, 9.4499e-05],
        [9.9996e-01, 4.3419e-05],
        [4.4631e-01, 5.5369e-01],
        [3.9352e-03, 9.9606e-01],
        [9.9388e-01, 6.1196e-03],
        [9.9843e-01, 1.5655e-03],
        [8.5879e-03, 9.9141e-01],
        [6.0843e-01, 3.9157e-01],
        [9.7976e-01, 2.0241e-02],
        [9.9934e-01, 6.6247e-04],
        [3.0408e-01, 6.9592e-01],
        [4.3924e-01, 5.6076e-01],
        [4.8011e-01, 5.1989e-01],
        [9.9818e-01, 1.8215e-03],
        [9.6819e-01, 3.1814e-02]])


Predicting:  87%|████████▋ | 187/215 [27:10<04:04,  8.74s/it]

tensor([[9.9927e-01, 7.3306e-04],
        [9.4062e-01, 5.9383e-02],
        [9.9913e-01, 8.7258e-04],
        [1.9380e-04, 9.9981e-01],
        [9.8608e-01, 1.3920e-02],
        [9.7326e-01, 2.6744e-02],
        [4.2211e-02, 9.5779e-01],
        [8.6494e-01, 1.3506e-01],
        [9.8660e-01, 1.3405e-02],
        [9.7935e-01, 2.0651e-02],
        [9.6346e-01, 3.6541e-02],
        [8.9225e-01, 1.0775e-01],
        [9.9521e-01, 4.7908e-03],
        [3.4148e-02, 9.6585e-01],
        [1.2928e-04, 9.9987e-01],
        [8.3685e-01, 1.6315e-01]])


Predicting:  87%|████████▋ | 188/215 [27:19<03:51,  8.57s/it]

tensor([[9.9216e-01, 7.8445e-03],
        [9.7270e-01, 2.7303e-02],
        [9.9816e-01, 1.8356e-03],
        [9.9047e-01, 9.5334e-03],
        [9.9060e-01, 9.4010e-03],
        [3.5624e-01, 6.4376e-01],
        [9.9241e-01, 7.5858e-03],
        [8.1143e-02, 9.1886e-01],
        [9.9812e-01, 1.8756e-03],
        [9.9951e-01, 4.8758e-04],
        [9.4888e-01, 5.1122e-02],
        [8.8352e-02, 9.1165e-01],
        [8.9128e-01, 1.0872e-01],
        [6.1603e-01, 3.8397e-01],
        [9.9680e-01, 3.2043e-03],
        [5.0091e-01, 4.9909e-01]])


Predicting:  88%|████████▊ | 189/215 [27:27<03:40,  8.48s/it]

tensor([[9.5375e-05, 9.9990e-01],
        [9.9942e-01, 5.8452e-04],
        [9.9972e-01, 2.8026e-04],
        [9.6827e-01, 3.1731e-02],
        [1.9186e-04, 9.9981e-01],
        [8.6595e-01, 1.3405e-01],
        [2.7363e-01, 7.2637e-01],
        [9.9713e-01, 2.8651e-03],
        [8.2914e-01, 1.7086e-01],
        [7.3821e-01, 2.6179e-01],
        [9.5471e-01, 4.5294e-02],
        [9.9670e-01, 3.2956e-03],
        [4.8336e-01, 5.1664e-01],
        [9.9059e-05, 9.9990e-01],
        [9.9985e-01, 1.4775e-04],
        [9.9993e-01, 7.0653e-05]])


Predicting:  88%|████████▊ | 190/215 [27:35<03:30,  8.43s/it]

tensor([[9.8425e-01, 1.5745e-02],
        [9.9813e-01, 1.8700e-03],
        [2.8263e-02, 9.7174e-01],
        [1.2223e-04, 9.9988e-01],
        [9.1274e-01, 8.7256e-02],
        [9.1437e-02, 9.0856e-01],
        [6.1418e-01, 3.8582e-01],
        [9.7382e-01, 2.6180e-02],
        [8.0417e-01, 1.9583e-01],
        [9.6076e-01, 3.9236e-02],
        [3.7271e-01, 6.2729e-01],
        [5.6186e-01, 4.3814e-01],
        [8.6268e-01, 1.3732e-01],
        [4.6004e-02, 9.5400e-01],
        [9.7597e-01, 2.4030e-02],
        [9.8566e-01, 1.4336e-02]])


Predicting:  89%|████████▉ | 191/215 [27:44<03:24,  8.52s/it]

tensor([[1.1073e-04, 9.9989e-01],
        [7.4300e-04, 9.9926e-01],
        [6.0545e-02, 9.3945e-01],
        [9.9807e-01, 1.9266e-03],
        [9.9883e-01, 1.1681e-03],
        [9.1099e-01, 8.9014e-02],
        [2.7020e-04, 9.9973e-01],
        [9.6826e-01, 3.1735e-02],
        [9.9914e-01, 8.5910e-04],
        [9.9948e-01, 5.2063e-04],
        [7.4787e-01, 2.5213e-01],
        [1.2736e-04, 9.9987e-01],
        [9.9396e-01, 6.0359e-03],
        [9.7497e-01, 2.5028e-02],
        [8.2663e-01, 1.7337e-01],
        [3.8444e-03, 9.9616e-01]])


Predicting:  89%|████████▉ | 192/215 [27:52<03:15,  8.52s/it]

tensor([[3.3173e-01, 6.6827e-01],
        [9.8539e-01, 1.4605e-02],
        [7.9379e-01, 2.0621e-01],
        [3.8220e-02, 9.6178e-01],
        [7.9630e-01, 2.0370e-01],
        [9.9240e-01, 7.5998e-03],
        [9.6671e-01, 3.3290e-02],
        [6.7579e-03, 9.9324e-01],
        [2.5559e-01, 7.4441e-01],
        [7.3805e-02, 9.2619e-01],
        [9.9328e-01, 6.7231e-03],
        [9.9933e-01, 6.6909e-04],
        [3.0685e-01, 6.9315e-01],
        [1.2785e-04, 9.9987e-01],
        [6.3285e-01, 3.6715e-01],
        [9.7334e-01, 2.6664e-02]])


Predicting:  90%|████████▉ | 193/215 [28:01<03:05,  8.42s/it]

tensor([[0.3284, 0.6716],
        [0.8306, 0.1694],
        [0.0074, 0.9926],
        [0.9804, 0.0196],
        [0.7657, 0.2343],
        [0.5773, 0.4227],
        [0.8982, 0.1018],
        [0.9575, 0.0425],
        [0.6241, 0.3759],
        [0.4650, 0.5350],
        [0.8937, 0.1063],
        [0.6340, 0.3660],
        [0.8016, 0.1984],
        [0.5513, 0.4487],
        [0.8945, 0.1055],
        [0.5892, 0.4108]])


Predicting:  90%|█████████ | 194/215 [28:09<02:55,  8.34s/it]

tensor([[5.1424e-01, 4.8576e-01],
        [9.8760e-01, 1.2401e-02],
        [9.5740e-01, 4.2599e-02],
        [7.9993e-01, 2.0007e-01],
        [9.9267e-01, 7.3325e-03],
        [9.2215e-01, 7.7850e-02],
        [8.4733e-01, 1.5267e-01],
        [9.7262e-01, 2.7383e-02],
        [8.7844e-01, 1.2156e-01],
        [9.5906e-01, 4.0938e-02],
        [5.2074e-04, 9.9948e-01],
        [7.0732e-01, 2.9268e-01],
        [1.0212e-01, 8.9788e-01],
        [8.8945e-01, 1.1055e-01],
        [8.4955e-01, 1.5045e-01],
        [9.7388e-01, 2.6117e-02]])


Predicting:  91%|█████████ | 195/215 [28:17<02:49,  8.48s/it]

tensor([[8.7338e-01, 1.2662e-01],
        [9.9340e-01, 6.5989e-03],
        [1.1508e-01, 8.8492e-01],
        [9.6384e-01, 3.6156e-02],
        [5.5858e-01, 4.4142e-01],
        [9.2292e-05, 9.9991e-01],
        [9.9923e-01, 7.7106e-04],
        [9.9971e-01, 2.8851e-04],
        [9.9817e-01, 1.8338e-03],
        [9.8897e-01, 1.1033e-02],
        [9.9176e-01, 8.2399e-03],
        [9.8619e-01, 1.3811e-02],
        [9.9922e-01, 7.7754e-04],
        [9.9809e-01, 1.9093e-03],
        [8.7956e-01, 1.2044e-01],
        [8.9401e-01, 1.0599e-01]])


Predicting:  91%|█████████ | 196/215 [28:26<02:41,  8.49s/it]

tensor([[9.6336e-01, 3.6641e-02],
        [2.9695e-01, 7.0305e-01],
        [1.6136e-04, 9.9984e-01],
        [4.7562e-01, 5.2438e-01],
        [8.9352e-01, 1.0648e-01],
        [3.3579e-01, 6.6421e-01],
        [9.4298e-01, 5.7019e-02],
        [6.3365e-01, 3.6635e-01],
        [9.7777e-01, 2.2228e-02],
        [8.1682e-01, 1.8318e-01],
        [7.3432e-01, 2.6568e-01],
        [7.2456e-01, 2.7544e-01],
        [9.9824e-01, 1.7558e-03],
        [9.3245e-01, 6.7554e-02],
        [1.0379e-01, 8.9621e-01],
        [9.8638e-01, 1.3623e-02]])


Predicting:  92%|█████████▏| 197/215 [28:34<02:32,  8.45s/it]

tensor([[7.9922e-01, 2.0078e-01],
        [6.4845e-01, 3.5155e-01],
        [5.1160e-01, 4.8840e-01],
        [6.5886e-01, 3.4114e-01],
        [2.1154e-01, 7.8846e-01],
        [2.5217e-01, 7.4783e-01],
        [9.9626e-01, 3.7397e-03],
        [9.6031e-01, 3.9692e-02],
        [9.8383e-01, 1.6169e-02],
        [9.9642e-01, 3.5807e-03],
        [9.9830e-01, 1.7039e-03],
        [9.9966e-01, 3.4418e-04],
        [4.9648e-01, 5.0352e-01],
        [1.1381e-04, 9.9989e-01],
        [9.8742e-01, 1.2576e-02],
        [9.9105e-01, 8.9511e-03]])


Predicting:  92%|█████████▏| 198/215 [28:43<02:22,  8.40s/it]

tensor([[9.7852e-01, 2.1481e-02],
        [4.6640e-02, 9.5336e-01],
        [9.8822e-01, 1.1776e-02],
        [8.5299e-01, 1.4701e-01],
        [9.9133e-01, 8.6710e-03],
        [9.0150e-01, 9.8502e-02],
        [5.9253e-01, 4.0747e-01],
        [4.9214e-01, 5.0786e-01],
        [5.3087e-01, 4.6913e-01],
        [8.5749e-01, 1.4251e-01],
        [3.1772e-04, 9.9968e-01],
        [7.0221e-01, 2.9779e-01],
        [2.3828e-01, 7.6172e-01],
        [4.9069e-02, 9.5093e-01],
        [2.7689e-01, 7.2311e-01],
        [9.9726e-01, 2.7360e-03]])


Predicting:  93%|█████████▎| 199/215 [28:51<02:15,  8.44s/it]

tensor([[9.8192e-01, 1.8078e-02],
        [9.6393e-01, 3.6073e-02],
        [7.8800e-01, 2.1200e-01],
        [9.9944e-01, 5.5696e-04],
        [2.6403e-01, 7.3597e-01],
        [9.6304e-01, 3.6961e-02],
        [9.9855e-01, 1.4515e-03],
        [6.9438e-02, 9.3056e-01],
        [9.9992e-01, 8.2946e-05],
        [9.9996e-01, 4.2494e-05],
        [6.2324e-03, 9.9377e-01],
        [9.9962e-01, 3.8059e-04],
        [9.9982e-01, 1.8315e-04],
        [9.9913e-01, 8.6588e-04],
        [9.9884e-01, 1.1590e-03],
        [9.9952e-01, 4.7900e-04]])


Predicting:  93%|█████████▎| 200/215 [28:59<02:05,  8.40s/it]

tensor([[0.9958, 0.0042],
        [0.9873, 0.0127],
        [0.9900, 0.0100],
        [0.9979, 0.0021],
        [0.9974, 0.0026],
        [0.9506, 0.0494],
        [0.9643, 0.0357],
        [0.9772, 0.0228],
        [0.9753, 0.0247],
        [0.6162, 0.3838],
        [0.7948, 0.2052],
        [0.9229, 0.0771],
        [0.6340, 0.3660],
        [0.9806, 0.0194],
        [0.9584, 0.0416],
        [0.9462, 0.0538]])


Predicting:  93%|█████████▎| 201/215 [29:08<01:57,  8.38s/it]

tensor([[1.7800e-01, 8.2200e-01],
        [9.8741e-01, 1.2595e-02],
        [4.0672e-04, 9.9959e-01],
        [9.9729e-01, 2.7103e-03],
        [9.9120e-01, 8.8048e-03],
        [8.1128e-01, 1.8872e-01],
        [9.9447e-01, 5.5315e-03],
        [9.4461e-01, 5.5386e-02],
        [9.1692e-01, 8.3075e-02],
        [9.9759e-01, 2.4119e-03],
        [9.9824e-01, 1.7585e-03],
        [9.9977e-01, 2.3386e-04],
        [9.9928e-01, 7.2499e-04],
        [9.9944e-01, 5.5590e-04],
        [9.9972e-01, 2.7578e-04],
        [9.9932e-01, 6.8136e-04]])


Predicting:  94%|█████████▍| 202/215 [29:16<01:49,  8.43s/it]

tensor([[7.8167e-01, 2.1833e-01],
        [9.8563e-01, 1.4370e-02],
        [9.9910e-01, 8.9896e-04],
        [8.3843e-01, 1.6157e-01],
        [2.4817e-02, 9.7518e-01],
        [9.9244e-01, 7.5625e-03],
        [2.7284e-04, 9.9973e-01],
        [1.5121e-04, 9.9985e-01],
        [3.5084e-01, 6.4916e-01],
        [9.8459e-01, 1.5407e-02],
        [1.5709e-01, 8.4291e-01],
        [8.0211e-01, 1.9789e-01],
        [3.6759e-02, 9.6324e-01],
        [9.9266e-01, 7.3397e-03],
        [4.9615e-02, 9.5038e-01],
        [6.1698e-03, 9.9383e-01]])


Predicting:  94%|█████████▍| 203/215 [29:25<01:42,  8.55s/it]

tensor([[7.7183e-01, 2.2817e-01],
        [8.8052e-01, 1.1948e-01],
        [8.1350e-01, 1.8650e-01],
        [7.5963e-01, 2.4037e-01],
        [8.6539e-01, 1.3461e-01],
        [7.7221e-01, 2.2779e-01],
        [9.9960e-01, 4.0439e-04],
        [9.9848e-01, 1.5215e-03],
        [9.9988e-01, 1.1617e-04],
        [8.2244e-02, 9.1776e-01],
        [9.8824e-01, 1.1761e-02],
        [9.9732e-01, 2.6812e-03],
        [9.3108e-01, 6.8922e-02],
        [8.8650e-01, 1.1350e-01],
        [7.3827e-01, 2.6173e-01],
        [7.2876e-01, 2.7124e-01]])


Predicting:  95%|█████████▍| 204/215 [29:33<01:32,  8.45s/it]

tensor([[0.9256, 0.0744],
        [0.9580, 0.0420],
        [0.0091, 0.9909],
        [0.9974, 0.0026],
        [0.8994, 0.1006],
        [0.9938, 0.0062],
        [0.8774, 0.1226],
        [0.9976, 0.0024],
        [0.8664, 0.1336],
        [0.9986, 0.0014],
        [0.0020, 0.9980],
        [0.9977, 0.0023],
        [0.3264, 0.6736],
        [0.6646, 0.3354],
        [0.5548, 0.4452],
        [0.9952, 0.0048]])


Predicting:  95%|█████████▌| 205/215 [29:42<01:24,  8.43s/it]

tensor([[9.2896e-01, 7.1037e-02],
        [2.1511e-01, 7.8489e-01],
        [2.4411e-01, 7.5589e-01],
        [8.3940e-01, 1.6060e-01],
        [9.8240e-01, 1.7603e-02],
        [9.9835e-01, 1.6464e-03],
        [1.6517e-04, 9.9983e-01],
        [8.2821e-02, 9.1718e-01],
        [9.5541e-03, 9.9045e-01],
        [9.9528e-01, 4.7182e-03],
        [9.9938e-01, 6.2219e-04],
        [2.9399e-01, 7.0601e-01],
        [1.0103e-04, 9.9990e-01],
        [1.5098e-01, 8.4902e-01],
        [9.9682e-01, 3.1778e-03],
        [9.6646e-01, 3.3536e-02]])


Predicting:  96%|█████████▌| 206/215 [29:50<01:15,  8.40s/it]

tensor([[1.5385e-01, 8.4615e-01],
        [8.2926e-01, 1.7074e-01],
        [5.6719e-04, 9.9943e-01],
        [3.4869e-01, 6.5131e-01],
        [9.0049e-01, 9.9511e-02],
        [8.1215e-01, 1.8785e-01],
        [7.6705e-01, 2.3295e-01],
        [9.5683e-01, 4.3174e-02],
        [9.9706e-01, 2.9392e-03],
        [9.9862e-01, 1.3753e-03],
        [9.9618e-01, 3.8230e-03],
        [3.9738e-01, 6.0262e-01],
        [3.3786e-01, 6.6214e-01],
        [9.6821e-05, 9.9990e-01],
        [9.9859e-01, 1.4104e-03],
        [9.9990e-01, 1.0285e-04]])


Predicting:  96%|█████████▋| 207/215 [29:58<01:06,  8.35s/it]

tensor([[3.9255e-01, 6.0745e-01],
        [1.0424e-04, 9.9990e-01],
        [9.9463e-01, 5.3730e-03],
        [9.9986e-01, 1.4198e-04],
        [9.9352e-01, 6.4822e-03],
        [9.9944e-01, 5.6262e-04],
        [7.5684e-03, 9.9243e-01],
        [9.7790e-01, 2.2097e-02],
        [2.6888e-02, 9.7311e-01],
        [1.0250e-03, 9.9897e-01],
        [9.9582e-01, 4.1848e-03],
        [6.3568e-02, 9.3643e-01],
        [7.3904e-02, 9.2610e-01],
        [9.5804e-03, 9.9042e-01],
        [2.6189e-02, 9.7381e-01],
        [1.1013e-01, 8.8987e-01]])


Predicting:  97%|█████████▋| 208/215 [30:07<00:58,  8.38s/it]

tensor([[9.2236e-01, 7.7638e-02],
        [4.2912e-01, 5.7088e-01],
        [6.8690e-01, 3.1310e-01],
        [2.3057e-02, 9.7694e-01],
        [1.3055e-04, 9.9987e-01],
        [9.9779e-01, 2.2083e-03],
        [1.2670e-01, 8.7330e-01],
        [9.4834e-01, 5.1661e-02],
        [9.8407e-01, 1.5929e-02],
        [9.9940e-01, 6.0465e-04],
        [9.9293e-01, 7.0746e-03],
        [9.9652e-01, 3.4842e-03],
        [6.5220e-01, 3.4780e-01],
        [9.9892e-01, 1.0771e-03],
        [5.1748e-01, 4.8252e-01],
        [7.4030e-01, 2.5970e-01]])


Predicting:  97%|█████████▋| 209/215 [30:15<00:50,  8.36s/it]

tensor([[1.4328e-01, 8.5672e-01],
        [9.9840e-01, 1.5973e-03],
        [9.8281e-01, 1.7188e-02],
        [9.9382e-01, 6.1782e-03],
        [9.9436e-01, 5.6447e-03],
        [9.8864e-01, 1.1358e-02],
        [9.9848e-01, 1.5221e-03],
        [2.5233e-01, 7.4767e-01],
        [5.1164e-01, 4.8836e-01],
        [9.9955e-01, 4.4762e-04],
        [9.9979e-01, 2.1197e-04],
        [6.0532e-01, 3.9468e-01],
        [4.7801e-03, 9.9522e-01],
        [6.9883e-01, 3.0117e-01],
        [2.3207e-01, 7.6793e-01],
        [2.0709e-04, 9.9979e-01]])


Predicting:  98%|█████████▊| 210/215 [30:23<00:41,  8.35s/it]

tensor([[9.9973e-01, 2.7441e-04],
        [9.9992e-01, 8.3934e-05],
        [7.0239e-01, 2.9761e-01],
        [9.1276e-01, 8.7239e-02],
        [5.4141e-02, 9.4586e-01],
        [1.1969e-03, 9.9880e-01],
        [2.1940e-01, 7.8060e-01],
        [1.1430e-04, 9.9989e-01],
        [9.9670e-01, 3.3039e-03],
        [9.9652e-01, 3.4764e-03],
        [9.9971e-01, 2.8681e-04],
        [9.9784e-01, 2.1628e-03],
        [9.9975e-01, 2.5353e-04],
        [9.6449e-01, 3.5515e-02],
        [9.9950e-01, 5.0372e-04],
        [5.7895e-01, 4.2105e-01]])


Predicting:  98%|█████████▊| 211/215 [30:32<00:33,  8.38s/it]

tensor([[9.1957e-05, 9.9991e-01],
        [3.8387e-02, 9.6161e-01],
        [8.0853e-04, 9.9919e-01],
        [4.7805e-01, 5.2195e-01],
        [2.8306e-01, 7.1694e-01],
        [8.6975e-01, 1.3025e-01],
        [2.4793e-01, 7.5207e-01],
        [4.4047e-01, 5.5953e-01],
        [9.6239e-01, 3.7610e-02],
        [9.8055e-01, 1.9450e-02],
        [8.3275e-01, 1.6725e-01],
        [1.1137e-04, 9.9989e-01],
        [9.9971e-01, 2.9328e-04],
        [9.9991e-01, 8.8009e-05],
        [9.6779e-01, 3.2208e-02],
        [9.7403e-01, 2.5974e-02]])


Predicting:  99%|█████████▊| 212/215 [30:40<00:25,  8.37s/it]

tensor([[1.4524e-01, 8.5476e-01],
        [1.0619e-02, 9.8938e-01],
        [1.3408e-01, 8.6592e-01],
        [9.3522e-01, 6.4777e-02],
        [4.7307e-01, 5.2693e-01],
        [6.0429e-02, 9.3957e-01],
        [3.8962e-01, 6.1038e-01],
        [9.1974e-05, 9.9991e-01],
        [7.8627e-04, 9.9921e-01],
        [9.7830e-01, 2.1699e-02],
        [5.8624e-01, 4.1376e-01],
        [8.8021e-01, 1.1979e-01],
        [9.9616e-01, 3.8391e-03],
        [1.1911e-01, 8.8089e-01],
        [4.6582e-01, 5.3418e-01],
        [8.9899e-03, 9.9101e-01]])


Predicting:  99%|█████████▉| 213/215 [30:49<00:16,  8.43s/it]

tensor([[2.6431e-04, 9.9974e-01],
        [1.9602e-04, 9.9980e-01],
        [6.1879e-01, 3.8121e-01],
        [9.9773e-01, 2.2734e-03],
        [9.4752e-01, 5.2483e-02],
        [1.9436e-01, 8.0564e-01],
        [1.0774e-01, 8.9226e-01],
        [9.8363e-01, 1.6374e-02],
        [8.4544e-01, 1.5456e-01],
        [7.7038e-01, 2.2962e-01],
        [9.9126e-01, 8.7363e-03],
        [8.3547e-01, 1.6453e-01],
        [8.1612e-01, 1.8388e-01],
        [9.8382e-01, 1.6185e-02],
        [9.9318e-01, 6.8216e-03],
        [9.8075e-01, 1.9249e-02]])


Predicting: 100%|█████████▉| 214/215 [30:58<00:08,  8.51s/it]

tensor([[6.1331e-04, 9.9939e-01],
        [9.2636e-01, 7.3641e-02],
        [8.1997e-01, 1.8003e-01],
        [9.4873e-01, 5.1266e-02],
        [8.4854e-01, 1.5146e-01],
        [9.9468e-01, 5.3219e-03],
        [9.5679e-01, 4.3208e-02],
        [6.1372e-01, 3.8628e-01],
        [9.5266e-01, 4.7338e-02],
        [5.7470e-01, 4.2530e-01],
        [8.0666e-01, 1.9334e-01],
        [3.5936e-01, 6.4064e-01],
        [7.4850e-01, 2.5150e-01],
        [5.5690e-01, 4.4310e-01],
        [9.8293e-01, 1.7073e-02],
        [1.6950e-02, 9.8305e-01]])


Predicting: 100%|██████████| 215/215 [30:59<00:00,  8.65s/it]

tensor([[9.9872e-01, 1.2806e-03],
        [9.9987e-01, 1.3329e-04]])
[0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 




In [7]:
print(len(all_predictions))
print(len(all_probabilities))

3426
3426


In [14]:
print(all_probabilities)
all_probabilities = [max(sublist) for sublist in all_probabilities]
print(all_probabilities)

[0.8513832688331604, 0.9530447125434875, 0.9813437461853027, 0.9999032020568848, 0.9502105712890625, 0.9856218695640564, 0.9996399879455566, 0.7003836631774902, 0.9784537553787231, 0.9999059438705444, 0.5496004223823547, 0.9834389090538025, 0.9875774383544922, 0.9435656070709229, 0.9999083280563354, 0.8970844745635986, 0.9932764768600464, 0.5377925038337708, 0.9752899408340454, 0.9779468178749084, 0.9998916387557983, 0.6552639603614807, 0.9617480635643005, 0.9989585876464844, 0.9997628331184387, 0.9690905809402466, 0.7799464464187622, 0.9968823194503784, 0.9996205568313599, 0.7533658742904663, 0.9707180857658386, 0.9991835951805115, 0.9999151229858398, 0.6591960191726685, 0.9929141998291016, 0.6223980188369751, 0.5727285146713257, 0.9057391285896301, 0.9987348914146423, 0.6909785270690918, 0.9980675578117371, 0.9997209906578064, 0.99972003698349, 0.9550588130950928, 0.7782137989997864, 0.5976061224937439, 0.9899438619613647, 0.6962425112724304, 0.7719510197639465, 0.811423122882843, 0.

### Extract reslut to csv file

In [17]:
print(test_df.columns)

Index(['topic', 'premise', 'hypothesis', 'label'], dtype='object')


In [19]:
add_df = pd.DataFrame({
    'prediction': all_predictions,
    'probability':all_probabilities
})
result_df = pd.concat([test_df,add_df],axis=1)
print(result_df)

                                                  topic  \
0        Routine child vaccinations should be mandatory   
1        Routine child vaccinations should be mandatory   
2        Routine child vaccinations should be mandatory   
3        Routine child vaccinations should be mandatory   
4        Routine child vaccinations should be mandatory   
...                                                 ...   
3421  Social media platforms should be regulated by ...   
3422     Routine child vaccinations should be mandatory   
3423     Routine child vaccinations should be mandatory   
3424     Routine child vaccinations should be mandatory   
3425     Routine child vaccinations should be mandatory   

                                                premise  \
0      Routine childhood vaccinations  should be man...   
1      Routine childhood vaccinations  should be man...   
2      Routine childhood vaccinations  should be man...   
3      Routine childhood vaccinations  should be man...

In [20]:
result_df.to_csv('./result/KPM_official_test.csv', index=False)
print(f"DataFrame has been saved to {'./result/KPM_official_test.csv'}")

DataFrame has been saved to ./result/KPM_official_test.csv


In [13]:
# print(all_probabilities)
new_test_df = pd.read_csv('./result/KPM_official_test.csv')
print(new_test_df)

                                                  topic  \
0        Routine child vaccinations should be mandatory   
1        Routine child vaccinations should be mandatory   
2        Routine child vaccinations should be mandatory   
3        Routine child vaccinations should be mandatory   
4        Routine child vaccinations should be mandatory   
...                                                 ...   
3421  Social media platforms should be regulated by ...   
3422     Routine child vaccinations should be mandatory   
3423     Routine child vaccinations should be mandatory   
3424     Routine child vaccinations should be mandatory   
3425     Routine child vaccinations should be mandatory   

                                                premise  \
0      Routine childhood vaccinations  should be man...   
1      Routine childhood vaccinations  should be man...   
2      Routine childhood vaccinations  should be man...   
3      Routine childhood vaccinations  should be man...

In [16]:
new_test_df['logit'] = all_probabilities
print(new_test_df)
new_test_df.to_csv('./result/KPM_official_test.csv', index=False)
print(f"DataFrame has been saved to {'./result/KPM_official_test.csv'}")

                                                  topic  \
0        Routine child vaccinations should be mandatory   
1        Routine child vaccinations should be mandatory   
2        Routine child vaccinations should be mandatory   
3        Routine child vaccinations should be mandatory   
4        Routine child vaccinations should be mandatory   
...                                                 ...   
3421  Social media platforms should be regulated by ...   
3422     Routine child vaccinations should be mandatory   
3423     Routine child vaccinations should be mandatory   
3424     Routine child vaccinations should be mandatory   
3425     Routine child vaccinations should be mandatory   

                                                premise  \
0      Routine childhood vaccinations  should be man...   
1      Routine childhood vaccinations  should be man...   
2      Routine childhood vaccinations  should be man...   
3      Routine childhood vaccinations  should be man...

### Calculate mAP

In [17]:
df_test = pd.read_csv('./result/KPM_official_test.csv')
print(df_test.columns)
print(df_test)

Index(['topic', 'premise', 'hypothesis', 'label', 'prediction', 'probability',
       'logit'],
      dtype='object')
                                                  topic  \
0        Routine child vaccinations should be mandatory   
1        Routine child vaccinations should be mandatory   
2        Routine child vaccinations should be mandatory   
3        Routine child vaccinations should be mandatory   
4        Routine child vaccinations should be mandatory   
...                                                 ...   
3421  Social media platforms should be regulated by ...   
3422     Routine child vaccinations should be mandatory   
3423     Routine child vaccinations should be mandatory   
3424     Routine child vaccinations should be mandatory   
3425     Routine child vaccinations should be mandatory   

                                                premise  \
0      Routine childhood vaccinations  should be man...   
1      Routine childhood vaccinations  should be man...

In [18]:
df_sorted = df_test.sort_values(by='probability', ascending= False)
print(df_sorted)

                                                  topic  \
3177     Routine child vaccinations should be mandatory   
2537     Routine child vaccinations should be mandatory   
2962     Routine child vaccinations should be mandatory   
2630     Routine child vaccinations should be mandatory   
730      Routine child vaccinations should be mandatory   
...                                                 ...   
3007     Routine child vaccinations should be mandatory   
1892               The USA is a good country to live in   
1948  Social media platforms should be regulated by ...   
1560               The USA is a good country to live in   
2649               The USA is a good country to live in   

                                                premise  \
3177  the government should not regulate what parent...   
2537  child vaccinations is not mandatory because it...   
2962  parents should have the freedom to decide what...   
2630  everyone should be free to choose what they wa...

In [19]:
top_50_df = df_sorted.head(int(0.5*len(df_sorted)))
print(top_50_df)

                                                  topic  \
3177     Routine child vaccinations should be mandatory   
2537     Routine child vaccinations should be mandatory   
2962     Routine child vaccinations should be mandatory   
2630     Routine child vaccinations should be mandatory   
730      Routine child vaccinations should be mandatory   
...                                                 ...   
2451  Social media platforms should be regulated by ...   
654                The USA is a good country to live in   
1749               The USA is a good country to live in   
1366  Social media platforms should be regulated by ...   
2358     Routine child vaccinations should be mandatory   

                                                premise  \
3177  the government should not regulate what parent...   
2537  child vaccinations is not mandatory because it...   
2962  parents should have the freedom to decide what...   
2630  everyone should be free to choose what they wa...

In [12]:
from sklearn.metrics import average_precision_score
import numpy as np

# Assuming A and B are the predicted and true labels for each instance, respectively
A = [0, 1, 0]
B = [1, 0, 1]

# Calculate Average Precision
ap = average_precision_score(label, prediction)

map_score = np.mean(ap)

print(f'Mean Average Precision (mAP): {map_score:.2f}')

Mean Average Precision (mAP): 0.65


### Make Test Dataset

In [30]:
argument_df = pd.read_csv('./data/arguments_test.csv')
print(argument_df.columns)

Index(['arg_id', 'argument', 'topic', 'stance'], dtype='object')


In [31]:
kp_df = pd.read_csv('./data/key_points_test.csv')
print(kp_df.columns)

Index(['key_point_id', 'key_point', 'topic', 'stance'], dtype='object')


In [37]:
test_data = []
for index,row in argument_df.iterrows():
    filtered_kp = kp_df[(kp_df['topic'] == row['topic']) & (kp_df['stance'] == row['stance'])]
    filtered_kp_list = filtered_kp.values.tolist()
    for line in filtered_kp_list:
        test_data.append([row['topic'],row['stance'],row['arg_id'],row['argument'],line[0],line[1]])
print(len(test_data))

header = ['topic','stance','arg_id','argument','key_point_id','key_point']
df = pd.DataFrame(test_data, columns=header)
print(df)
print(df.columns)



3923
                                               topic  stance     arg_id  \
0     Routine child vaccinations should be mandatory      -1    arg_0_0   
1     Routine child vaccinations should be mandatory      -1    arg_0_0   
2     Routine child vaccinations should be mandatory      -1    arg_0_0   
3     Routine child vaccinations should be mandatory      -1    arg_0_0   
4     Routine child vaccinations should be mandatory      -1    arg_0_1   
...                                              ...     ...        ...   
3918            The USA is a good country to live in       1  arg_2_209   
3919            The USA is a good country to live in       1  arg_2_209   
3920            The USA is a good country to live in       1  arg_2_209   
3921            The USA is a good country to live in       1  arg_2_209   
3922            The USA is a good country to live in       1  arg_2_209   

                                               argument key_point_id  \
0     Routine child va

In [36]:
df.to_csv('./data/KPM_test_data.csv', index=False)
print(f"DataFrame has been saved to {'./data/KPM_test_data.csv'}")

DataFrame has been saved to ./data/KPM_test_data.csv


### Prediction

In [8]:
#Add new header confidence_score
df = pd.read_csv('./data/KPM_test_data.csv')
print(df.columns)

Index(['Unnamed: 0.1', 'Unnamed: 0', 'topic', 'stance', 'arg_id', 'argument',
       'key_point_id', 'key_point', 'confidence_score'],
      dtype='object')


In [9]:
# df['confidence_score'] = ''
# df.to_csv('./data/KPM_test_data.csv')
# print(f"DataFrame has been saved to {'./data/KPM_test_data.csv'}")

In [11]:
# Assuming you have already trained the model and saved a checkpoint
checkpoint_path = './checkpoint/nli_model-epoch=01-val_loss=0.37.ckpt'
# Load the model from the checkpoint
model_name = "cross-encoder/nli-distilroberta-base"
num_classes = 2
max_length = 512
batch_size = 16
learning_rate = 5e-05
weight_decay = 0.001

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = num_classes, ignore_mismatched_sizes = True)
loaded_model = KPM.load_from_checkpoint(model=model, checkpoint_path=checkpoint_path)
print("hello world")

# Set the model to evaluation mode
loaded_model.eval()


# Make predictions on the test set with probabilities

all_probabilities = []
df = pd.read_csv('./data/KPM_test_data.csv')
for index, row in df.iloc[:2].iterrows():
    premise = row['argument']+tokenizer.sep_token+row['topic']
    hypothesis = row['key_point']+tokenizer.sep_token+row['topic']
    encoding = tokenizer(
            premise,
            hypothesis,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
    with torch.no_grad():
        logits = loaded_model(**encoding).logits
        print(F.softmax(logits, dim=1).tolist()[0])
        probabilities = F.softmax(logits, dim=1).tolist()[0][0]
    print(str(index)+": "+str(probabilities))
    df.at[index,'confidence_score'] = probabilities
    # all_probabilities.append(probabilities)

df.to_csv('./data/KPM_test_data.csv', index=False)
print("Add new data successfully!!!!!")
# print("Finished")
# print(all_probabilities)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at cross-encoder/nli-distilroberta-base and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


hello world
[0.04978947713971138, 0.9502105712890625]
0: 0.04978947713971138
[0.9856218695640564, 0.01437812577933073]
1: 0.9856218695640564


### Make output json file for automatically evaluating

In [2]:
argument_df = pd.read_csv('./data/arguments_test.csv')
print(len(argument_df))
print(argument_df.columns)

723
Index(['arg_id', 'argument', 'topic', 'stance'], dtype='object')


In [3]:
test_df = pd.read_csv('./data/KPM_test_data.csv')
# test_df['confidence_score'] = 1- test_df['confidence_score']
print(len(test_df))
print(test_df)

3923
                                               topic  stance     arg_id  \
0     Routine child vaccinations should be mandatory      -1    arg_0_0   
1     Routine child vaccinations should be mandatory      -1    arg_0_0   
2     Routine child vaccinations should be mandatory      -1    arg_0_0   
3     Routine child vaccinations should be mandatory      -1    arg_0_0   
4     Routine child vaccinations should be mandatory      -1    arg_0_1   
...                                              ...     ...        ...   
3918            The USA is a good country to live in       1  arg_2_209   
3919            The USA is a good country to live in       1  arg_2_209   
3920            The USA is a good country to live in       1  arg_2_209   
3921            The USA is a good country to live in       1  arg_2_209   
3922            The USA is a good country to live in       1  arg_2_209   

                                               argument key_point_id  \
0     Routine child va

In [20]:
# test_df = test_df.drop(test_df.columns[0], axis=1)
# print(test_df)

                                               topic  stance     arg_id  \
0     Routine child vaccinations should be mandatory      -1    arg_0_0   
1     Routine child vaccinations should be mandatory      -1    arg_0_0   
2     Routine child vaccinations should be mandatory      -1    arg_0_0   
3     Routine child vaccinations should be mandatory      -1    arg_0_0   
4     Routine child vaccinations should be mandatory      -1    arg_0_1   
...                                              ...     ...        ...   
3918            The USA is a good country to live in       1  arg_2_209   
3919            The USA is a good country to live in       1  arg_2_209   
3920            The USA is a good country to live in       1  arg_2_209   
3921            The USA is a good country to live in       1  arg_2_209   
3922            The USA is a good country to live in       1  arg_2_209   

                                               argument key_point_id  \
0     Routine child vaccina

In [23]:
# test_df.to_csv('./data/KPM_test_data.csv', index=False)
# print("Add new data successfully!!!!!")

Add new data successfully!!!!!


In [24]:
output_data = []
arg_id = argument_df['arg_id'].values.tolist()
for arg in arg_id:
    filtered_df = test_df[test_df['arg_id'] == arg]
    kp_id = filtered_df['key_point_id'].values.tolist()
    score = filtered_df['confidence_score'].values.tolist()
    result_dict = dict(zip(kp_id, score))
    output_data.append(result_dict)
final_result = dict(zip(arg_id,output_data))
print(final_result)

{'arg_0_0': {'kp_0_0': 0.950210523, 'kp_0_1': 0.014378129999999989, 'kp_0_2': 0.00036001200000002065, 'kp_0_3': 0.824956879}, 'arg_0_1': {'kp_0_0': 0.961748116, 'kp_0_1': 0.0010414119999999638, 'kp_0_2': 0.00023716699999998259, 'kp_0_3': 0.344736099}, 'arg_0_2': {'kp_0_0': 0.779946506, 'kp_0_1': 0.0031176809999999833, 'kp_0_2': 0.00037944300000003484, 'kp_0_3': 0.969090639}, 'arg_0_3': {'kp_0_0': 0.970718089, 'kp_0_1': 0.00081640500000002, 'kp_0_2': 8.487700000003873e-05, 'kp_0_3': 0.753365889}, 'arg_0_4': {'kp_0_0': 0.997252888, 'kp_0_1': 0.025554359000000026, 'kp_0_2': 0.006825387000000016, 'kp_0_3': 0.828299209}, 'arg_0_5': {'kp_0_0': 0.03237760099999998, 'kp_0_1': 0.10669434099999997, 'kp_0_2': 0.023948430999999992, 'kp_0_3': 0.013217926000000046}, 'arg_0_6': {'kp_0_0': 0.99965146, 'kp_0_1': 0.0014120939999999749, 'kp_0_2': 0.00013506399999996255, 'kp_0_3': 0.99885866}, 'arg_0_7': {'kp_0_0': 0.999903, 'kp_0_1': 0.0002350209999999464, 'kp_0_2': 7.629399999997677e-05, 'kp_0_3': 0.470

In [25]:
import json
file_path = './result/test.json'
with open(file_path, 'w') as json_file:
    json.dump(final_result, json_file)
print(f"JSON file '{file_path}' created successfully.")

JSON file './result/test.json' created successfully.
