<a href="https://colab.research.google.com/github/lacykaltgr/continual-learning-ait/blob/experiment/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
'''Download the files '''
'''Only for colab'''

!wget https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/experiment.zip
!unzip experiment.zip
!find continual-learning-ait-experiment -type f ! -name "main.ipynb" -exec cp {} . \;

!rm -r stable_diffusion
!mkdir stable_diffusion
!mv diffusion_model.py stable_diffusion/
!mv autoencoder_kl.py stable_diffusion/
!mv layers.py stable_diffusion/
!mv stable_diffusion.py stable_diffusion/
!mv constants.py stable_diffusion/

--2023-05-08 15:34:43--  https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/experiment.zip
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/experiment [following]
--2023-05-08 15:34:43--  https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/experiment
Resolving codeload.github.com (codeload.github.com)... 140.82.112.9
Connecting to codeload.github.com (codeload.github.com)|140.82.112.9|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘experiment.zip’

experiment.zip          [ <=>                ] 344.05K  --.-KB/s    in 0.1s    

2023-05-08 15:34:44 (3.42 MB/s) - ‘experiment.zip’ saved [352303]

Archive:  experiment.zip
57ff15ff788de6d1fa6664c650342a7f5ee9beb4
   creating

In [2]:
import numpy as np
import tensorflow as tf
import keras

from sklearn.metrics import accuracy_score
#from sklearn.metrics import classification_report
#from keras.metrics import Accuracy

import classifier
from stable_diffusion import stable_diffusion
import utils
from data_preparation import load_dataset, CLDataLoader

import gc
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import importlib
from tqdm import tqdm

# Load the dataset

In [3]:
dpt_train, dpt_test = load_dataset('cifar-10', n_classes_first_task=4, n_classes_other_task=2)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
batch_size = 256
train_loader = CLDataLoader(dpt_train, batch_size , train=True)
test_loader = CLDataLoader(dpt_test, batch_size, train=False)

# Define parameters and agent

In [5]:
params = {
    #general
    "n_runs": 1,
    "n_tasks": 10,
    "n_classes": 10,
    "input_shape": (32,32, 3),
    "embedding_shape": (4, 4, 4),
    "samples_per_task": 10000,
    "batch_size": batch_size,
    "eval_batch_size": 1,
    "print_every": 1,

    #classifier
    "cls_epochs": 5,
    "cls_iters": 5,
    "cls_hiddens": 32,
    "cls_lr": 0.01,

    #generator
    "gen_epochs": 5,
    "num_steps": 2,
    "gen_iters": 1,
    "input_latent_strength":0.5,
    "gen_lr": 0.01,
    "temperature": 1,

    #mir
    "reuse_samples": True,
    "cls_mir_gen": 1,
    "gen_mir_gen": 1,
    "mem_coeff": 0.12,
    "n_mem": 2,
    "z_size": 10,
    "mir_iters": 3,
    "gen_kl_coeff": 0.5,
    "gen_rec_coeff": 0.5,
    "gen_ent_coeff": 0.5,
    "gen_div_coeff": 0.5,
    "gen_shell_coeff": 0.5,
    "cls_xent_coeff": 0.5,
    "cls_ent_coeff": 0.5,
    "cls_div_coeff": 0.5,
    "cls_shell_coeff": 0.5,
}

In [6]:
from typing_extensions import ParamSpec
'''Agent to handle models, parameters and states'''

class Agent:
  def __init__(self, hparams):
    self.cls = None
    self.opt = None
    self.opt_gen = None
    self.gen = None
    self.params = hparams
    self.state = dict()
    self.encoder = None
    self.classifier_model = None
    self.eval = accuracy_score
    #self.decoder = None

  def set_models(self, generator=None, classifier=None):
    self.cls = classifier #classifier
    self.gen = generator  #generator
    self.encoder = generator.encoder #encoder
    #self.decoder = gen.decoder
    self.opt = tf.keras.optimizers.legacy.Adam(learning_rate=params["cls_lr"])
    self.opt_gen = tf.keras.optimizers.legacy.Adam(learning_rate=params["gen_lr"])

    # encoder - classifier pipeline
    data_input = keras.Input(shape=self.params["input_shape"], name="image")
    encoder_output = agent.encoder(data_input)
    cls_encoder_output = agent.cls(encoder_output)
    self.classifier_model = keras.Model(inputs=data_input, outputs=cls_encoder_output)
    self.classifier_model.compile(optimizer=self.opt, loss="categorical_crossentropy", metrics=["accuracy"])

    # classifier pipeline
    latent_input = keras.Input(shape=self.params["embedding_shape"], name="latent")
    cls_latent_output = agent.cls(latent_input)
    self.cls_model = keras.Model(inputs=latent_input, outputs=cls_latent_output)
    self.cls_model.compile(optimizer=self.opt, loss="categorical_crossentropy", metrics=["accuracy"])

  def set_params(self, params):
    self.params = params

# Functions for training

In [7]:
'''Generate samples and train the diffusion model at the same time'''

def generate(agent, cls=None, input_latent=None, train=True, coeff=1.0):

    if cls is None:
        cls = agent.cls

    latent, alphas, alphas_prev, timesteps = agent.gen.initialize(agent.params, input_latent)

    batch_size = agent.params['batch_size'] if train else 64
    for index, timestep in reversed(list(enumerate(timesteps))):
        if train:
            with tf.GradientTape() as tape:
                e_t = agent.gen.get_model_output(
                    latent,
                    timestep,
                    batch_size,
                )
                a_t, a_prev = alphas[index], alphas_prev[index]
                latent = agent.gen.get_x_prev(latent, e_t,  a_t, a_prev, agent.params["temperature"])

                pred = cls(latent)
                #loss based on confidence
                #ENT = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_pre, logits=y_pre))
                loss = coeff*tf.keras.losses.categorical_crossentropy(pred, pred)
            grads = tape.gradient(loss, agent.gen.diffusion_model.trainable_variables)
            agent.opt_gen.apply_gradients(zip(grads, agent.gen.diffusion_model.trainable_variables))
        else:
            e_t = agent.gen.get_model_output(
                latent,
                timestep,
                agent.params['batch_size'],
            )
            a_t, a_prev = alphas[index], alphas_prev[index]
            latent = agent.gen.get_x_prev(latent, e_t,  a_t, a_prev, agent.params["temperature"])



    return latent

In [8]:
'''Retrive maximally interferred latent vector for classifier'''

def retrieve_gen_for_cls(agent):

    print("Retrieving latent vector for classifier...")

    latent = agent.gen.encoder(agent.state["data"])
    virtual_cls = classifier.classifier(agent.params)
    virtual_cls = utils.get_next_step_cls(
        agent.cls,
        virtual_cls,
        latent,
        agent.state["target"]
    )

    #mean_latent = tf.cast(tf.reduce_mean(latent, axis=0), tf.float64)
    z_new_max = None

    for i in range(agent.params["n_mem"]):

        z_new = generate(agent, input_latent=latent, train=False, coeff=0.1)

        for j in range(params["mir_iters"]):
            with tf.GradientTape(persistent=True) as tape:

                tape.watch(z_new)

                #z_new = tf.cast(z_new, tf.float64)
                y_pre = agent.cls(z_new)
                y_virtual = virtual_cls(z_new)

                # maximise the interference:
                XENT = tf.constant(0.)
                if params["cls_xent_coeff"] > 0.:
                    XENT = tf.keras.losses.categorical_crossentropy(y_virtual, y_pre)

                # the predictions from the two models should be confident
                ENT = tf.constant(0.)
                if params["cls_ent_coeff"] > 0.:
                    ENT = tf.keras.losses.categorical_crossentropy(y_pre, y_pre)

                # the new-found samples should be different from each others
                DIV = tf.constant(0.)
                if params["cls_div_coeff"] > 0.:
                    for found_z_i in range(i):
                        DIV += tf.keras.losses.MSE(
                            z_new,
                            z_new_max[found_z_i * z_new.shape[0]:found_z_i * z_new.shape[0] + z_new.shape[0]]
                        ) / i

                # (NEW) stay on gaussian shell loss:
                SHELL = tf.constant(0.)
                if params["cls_shell_coeff"] > 0.:
                    SHELL = tf.keras.losses.MSE(
                        tf.norm(z_new, axis=1),
                        tf.ones_like(tf.norm(z_new, axis=1))*np.sqrt(params["z_size"])
                    )

                XENT, ENT, DIV, SHELL = \
                    tf.reduce_mean(XENT), \
                        tf.reduce_mean(ENT), \
                        tf.reduce_mean(DIV), \
                        tf.reduce_mean(SHELL)

                gain = params["cls_xent_coeff"] * XENT + \
                       -params["cls_ent_coeff"] * ENT + \
                       params["cls_div_coeff"] * DIV + \
                       -params["cls_shell_coeff"] * SHELL

            z_g = tape.gradient(gain, z_new)
            if z_g is not None:
                z_new = (z_new + 1 * z_g)

        if z_new_max is None:
            z_new_max = z_new.numpy().copy()
        else:
            z_new_max = np.concatenate([z_new_max, z_new.numpy().copy()])

    tf.stop_gradient(z_new_max)

    if np.isnan(z_new_max).any():
        mir_worked = 0
        mem_x = generate(agent, train=False)
    else:
        mem_x = z_new_max
        mir_worked = 1

    mem_y = agent.cls(mem_x).numpy()

    return mem_x, mem_y, mir_worked

In [9]:
'''Retrive maximally interferred latent vector for generator'''
#TODO: vmi más loss is (maximise interference)

def retrieve_gen_for_gen(agent):

    print("Retrieving latent vector for generator...")

    latent = agent.gen.encoder(agent.state["data"])
    #mean_latent = tf.cast(tf.reduce_mean(latent, axis=0), tf.float64)
    z_new_max = None

    for i in range(params["n_mem"]):

        z_new = generate(agent, input_latent=latent, train=False, coeff=0.1)

        for j in range(params["mir_iters"]):

            with tf.GradientTape(persistent=True) as tape:
                tape.watch(z_new)
                # the predictions from the two models should be confident
                ENT = tf.constant(0.)
                if params["gen_ent_coeff"]>0.:
                    y_pre = agent.cls(z_new)
                    ENT = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pre, y_pre))

                # the new-found samples should be different from each others
                DIV = tf.constant(0.)
                if params["gen_div_coeff"]>0.:
                    for found_z_i in range(i):
                        DIV += tf.reduce_mean(tf.math.squared_difference(
                            z_new,
                            z_new_max[found_z_i * z_new.shape[0]:found_z_i * z_new.shape[0] + z_new.shape[0]])
                        ) / i

                # (NEW) stay on gaussian shell loss:
                SHELL = tf.constant(0.)
                if params["gen_shell_coeff"]>0.:
                    SHELL = tf.reduce_mean(tf.math.squared_difference(
                        tf.norm(z_new, ord=2, axis=1),
                        tf.ones_like(tf.norm(z_new, ord=2, axis=1))*np.sqrt(params["z_size"])))


                gain =params["gen_div_coeff"] * DIV + \
                      -params["gen_ent_coeff"] * ENT + \
                       -params["gen_shell_coeff"] * SHELL

            z_g = tape.gradient(gain, z_new)
            z_new = (z_new + z_g)

        if z_new_max is None:
            z_new_max = tf.identity(z_new)
        else:
            z_new_max = tf.concat([z_new_max, z_new], axis=0)


    tf.stop_gradient(z_new_max)

    if np.isnan(z_new_max).any():
        mir_worked = 0
        mem_x = generate(agent, train=False)
    else:
        mem_x = z_new_max
        mir_worked = 1

    return mem_x, mir_worked

In [10]:
'''Train the generator unit'''

def train_generator(agent):

    data = agent.state["data"]
    latent = agent.gen.encoder(data)
    mem_x = None

    for it in range(agent.params["gen_iters"]):
        generate(agent, input_latent=latent)

        #if agent.state["task"] > 0:
        #    if it == 0 or not agent.params["reuse_samples"]:
        #        mem_x, mir_worked = retrieve_gen_for_gen(agent)
        #
        #        agent.state["mir_tries"] += 1
        #        if mir_worked:
        #            agent.state["mir_success"] += 1

        # TODO
        #if mem_x is not None:
        #  if len(mem_x.shape) == 3:
        #    mem_x = tf.expand__dims(mem_x, axis=-1)
        #  generate(agent, input_latent=mem_x, coeff=agent.params["mem_coeff"])


In [11]:
'''Train the encoder and the classifier unit'''

def train_classifier(agent):

    data = agent.state["data"]
    target = agent.state["target"]
    mem_x, mem_y = None, None

    for it in range(agent.params["cls_iters"]):
        agent.classifier_model.fit(data, target, batch_size=agent.params["batch_size"], epochs=1, verbose=1)
        #if agent.state["task"] > 0:
        #    if it == 0 or not agent.params["reuse_samples"]:
        #        mem_x, mem_y, mir_worked = retrieve_gen_for_cls(agent)
        #        agent.state["mir_tries"] += 1
        #        if mir_worked:
        #            agent.state["mir_success"] += 1

        #    if mem_x is not None:
        #        agent.cls_model.fit(mem_x, mem_y, batch_size=agent.params["batch_size"], epochs=1, verbose=1)



In [12]:
'''Run an epoch'''

def run_cls_epoch(agent):

    gc.collect()

    agent.state["sample_amt"] = 0

    print("Running epoch on classifier ", agent.state["epoch"])

    for i, (data, target) in tqdm(enumerate(agent.state["tr_loader"])):
        #if agent.state["sample_amt"] > agent.params["samples_per_task"] > 0: break
        agent.state["sample_amt"] += data.shape[0]

        agent.state["data"] = data
        agent.state["target"] = target
        agent.state["i_example"] = i

        train_classifier(agent)


    '''Evaluate the models in epoch'''
    if agent.state["epoch"] % agent.params["print_every"] == 0:

        print("\nEvaluate classifier on Task: ", agent.state["task"], " Epoch: ", agent.state["epoch"])

        accuracy = []
        losses = []

        for i, (data, target) in tqdm(enumerate(agent.state["ts_loader"])):

          '''Evaluate the classifier'''
          logits = agent.classifier_model(data)
          pred = np.argmax(logits, axis=1)
          report = agent.eval(np.argmax(target, axis=1), pred)
          loss = tf.keras.losses.categorical_crossentropy(target, logits)
          accuracy.append(report)
          losses.append(np.mean(loss))
          print(logits)
          print(target)

        print("Mean accuracy: ", np.mean(accuracy))
        print("Mean loss: ", np.mean(losses))


In [13]:
'''Run an epoch'''

def run_gen_epoch(agent):

    gc.collect()

    agent.state["sample_amt"] = 0

    print("Running generator epoch ", agent.state["epoch"])

    for i, (data, target) in tqdm(enumerate(agent.state["tr_loader"])):
        #if agent.state["sample_amt"] > agent.params["samples_per_task"] > 0: break
        if data.shape[0] != batch_size: break
        agent.state["sample_amt"] += data.shape[0]

        agent.state["data"] = data
        agent.state["target"] = target
        agent.state["i_example"] = i

        train_generator(agent)

    '''Evaluate the models in epoch'''
    if agent.state["epoch"] % agent.params["print_every"] == 0:

        print("\nEvaluate generator on Task: ", agent.state["task"], " Epoch: ", agent.state["epoch"])
        loss = []
        for i, (data, target) in tqdm(enumerate(agent.state["ts_loader"])):

          '''Evaluate the generator'''
          mem_x = generate(agent, input_latent=agent.encoder(data), train=False)
          mem_pred = agent.cls(mem_x)
          mem_loss = tf.keras.losses.categorical_crossentropy(mem_pred, mem_pred)
          loss.append(np.mean(mem_loss))

        print("Loss on generate: ",  np.mean(mem_loss))

          #mem_x_cls, mem_y, mir_worked_cls = retrieve_gen_for_cls(agent)
          #mem_loss_cls = tf.keras.losses.categorical_crossentropy(mem_y, mem_y)
          #print("Loss on retrieve for cls: ",  np.mean(mem_loss_cls))
          #print("MIR worked on retrieve for cls: ", mir_worked_cls)

          #mem_x_gen, mir_worked_gen = retrieve_gen_for_gen(agent)
          #em_pred_gen = agent.cls(mem_x_gen)
          #mem_loss_gen = tf.keras.losses.categorical_crossentropy(mem_pred_gen, mem_pred_gen)
          #print("Loss on retrieve for gen: ",  np.mean(mem_loss_gen))
          #print("MIR worked on retrieve for gen: ", mir_worked_gen)

In [14]:
'''Run a task'''

def run_task(agent):
  
    print("Running task ", agent.state["task"])

    agent.state["mir_tries"], agent.state["mir_success"] = 0, 0

    for epoch in range(agent.params["cls_epochs"]):
        agent.state["epoch"] = epoch
        run_cls_epoch(agent)
      
    for epoch in range(agent.params["gen_epochs"]):
        agent.state["epoch"] = epoch
        run_gen_epoch(agent)

    '''Evaluate forgetting'''
    print("Task: ", agent.state["task"])
    accuracy = []
    losses = []
    for i in range(agent.state["task"]-1):
        print("Task forgetting on task ", i+1)
        for data, target in agent.state["ts_loader"][i]:
          logits = agent.classifier_model(data)
          pred = np.argmax(logits, axis=1)
          report = agent.eval(np.argmax(target, axis=1), pred)
          loss = tf.keras.losses.categorical_crossentropy(target, logits)
          accuracy.append(report)
          losses.append(np.mean(loss))

        print("Mean accuracy: ", np.mean(accuracy))
        print("Mean loss: ", np.mean(losses))

    #print("MIR success rate: ", agent.state["mir_success"] / agent.state["mir_tries"])

In [15]:
'''Run the experiment'''

def run(agent):

  agent.set_models(
      classifier=classifier.classifier(agent.params),
      generator=stable_diffusion.StableDiffusion(img_height=32, img_width=32, download_weights=True))

  for task, tr_loader in enumerate(train_loader):
    agent.state["task"] = task
    agent.state["tr_loader"] = tr_loader
    run_task(agent)

# Testing for development

In [16]:
tasks_to_test = 3

agent = Agent(params)
agent.set_models(
    classifier=classifier.classifier(agent.params),
    generator=stable_diffusion.StableDiffusion(img_height=32, img_width=32, download_weights=False))

Classifier init
StableDiffusion init
UNetModel init
Encoder init


In [17]:
agent.set_params(params)
for task, (tr_loader, ts_loader) in enumerate(zip(train_loader[:tasks_to_test],test_loader[:tasks_to_test])):
    agent.state["task"] = task
    agent.state["tr_loader"] = tr_loader
    agent.state["ts_loader"] = ts_loader
    run_task(agent)

Running task  0
Running epoch on classifier  0


0it [00:00, ?it/s]



1it [00:32, 32.56s/it]



2it [00:38, 16.77s/it]



3it [00:42, 11.11s/it]



4it [00:48,  9.15s/it]



5it [00:54,  7.76s/it]



6it [01:00,  7.21s/it]



7it [01:05,  6.45s/it]



8it [01:10,  5.95s/it]



9it [01:15,  5.89s/it]



10it [01:20,  5.46s/it]



11it [01:26,  5.56s/it]



12it [01:30,  5.37s/it]



13it [01:36,  5.37s/it]



14it [01:41,  5.38s/it]



15it [01:47,  5.39s/it]



16it [01:53,  5.55s/it]



17it [01:58,  5.62s/it]



18it [02:04,  5.56s/it]



19it [02:09,  5.38s/it]



20it [02:14,  5.27s/it]



21it [02:18,  5.05s/it]



22it [02:23,  5.02s/it]



23it [02:28,  5.01s/it]



24it [02:33,  4.99s/it]



25it [02:39,  5.22s/it]



26it [02:44,  5.14s/it]



27it [02:49,  5.22s/it]



28it [02:54,  5.14s/it]



29it [03:00,  5.21s/it]



30it [03:05,  5.38s/it]



31it [03:11,  5.51s/it]



32it [03:17,  5.47s/it]



33it [03:22,  5.32s/it]



34it [03:27,  5.47s/it]



35it [03:34,  5.68s/it]



36it [03:39,  5.48s/it]



37it [03:43,  5.20s/it]



38it [03:48,  5.14s/it]



39it [03:54,  5.45s/it]



40it [04:01,  5.79s/it]



41it [04:06,  5.66s/it]



42it [04:12,  5.70s/it]



43it [04:17,  5.48s/it]



44it [04:23,  5.69s/it]



45it [04:29,  5.85s/it]



46it [04:34,  5.60s/it]



47it [04:40,  5.55s/it]



48it [04:45,  5.50s/it]



49it [04:51,  5.59s/it]



50it [04:56,  5.40s/it]



51it [05:02,  5.53s/it]



52it [05:08,  5.72s/it]



53it [05:13,  5.60s/it]



54it [05:19,  5.54s/it]



55it [05:24,  5.50s/it]



56it [05:29,  5.35s/it]



57it [05:34,  5.23s/it]



58it [05:40,  5.53s/it]



59it [05:55,  6.03s/it]



Evaluate classifier on Task:  0  Epoch:  0


2it [00:01,  1.43it/s]

tf.Tensor(
[[9.69610969e-09 9.49869305e-03 4.68124092e-01 5.22377074e-01
  1.92256167e-08 2.16156426e-09 3.80126153e-09 2.83856689e-08
  4.43751169e-08 3.30351462e-08]
 [5.51134917e-06 3.65399122e-02 6.38157308e-01 3.25218499e-01
  1.92639618e-05 4.12260124e-06 8.88559589e-06 1.96856381e-05
  1.99167171e-05 6.92175263e-06]
 [3.26059002e-04 1.92015484e-01 3.86300385e-01 4.18093920e-01
  5.59519685e-04 2.81710672e-04 4.31890861e-04 7.38356845e-04
  6.87527587e-04 5.65123744e-04]
 [1.07551114e-06 7.72229850e-01 8.12539533e-02 1.46491006e-01
  8.98057863e-07 1.58029150e-06 2.52565883e-07 1.37112102e-05
  2.02238448e-06 5.62732203e-06]
 [3.21931802e-05 5.52207045e-02 5.45311093e-01 3.98857027e-01
  1.52256936e-04 2.53064118e-05 8.44755414e-05 1.09707784e-04
  1.56955473e-04 5.02302282e-05]
 [5.98838767e-07 9.24479604e-01 3.54930647e-02 4.00181152e-02
  2.98442529e-07 6.17034743e-07 2.20407443e-07 5.31193064e-06
  5.15861188e-07 1.54865324e-06]
 [1.05182451e-04 4.26321626e-01 2.32314467e-01 

4it [00:01,  2.82it/s]

tf.Tensor(
[[3.39779732e-10 1.50722452e-03 8.59062672e-01 1.39430046e-01
  4.81292428e-09 1.10296439e-10 2.40489045e-10 2.73308931e-09
  6.04347505e-09 4.80430917e-10]
 [6.70234076e-05 6.57206714e-01 1.30553901e-01 2.11168721e-01
  1.00615391e-04 7.11660978e-05 9.14108896e-05 3.40107363e-04
  1.59146657e-04 2.41216971e-04]
 [5.40571976e-10 5.87969366e-03 4.43996340e-01 5.50123990e-01
  1.12123577e-09 1.12624632e-10 1.71595529e-10 1.68357872e-09
  3.25772276e-09 3.57990193e-09]
 [4.38119374e-10 6.49605878e-03 7.36556053e-01 2.56947905e-01
  2.80150636e-09 1.19091764e-10 2.45206910e-10 1.41021828e-09
  3.22631655e-09 3.14924176e-09]
 [1.92977473e-06 4.72697951e-02 6.42721415e-01 3.09977621e-01
  7.39079815e-06 9.59236786e-07 1.95026473e-06 5.14946760e-06
  9.83770769e-06 3.95012466e-06]
 [9.21372848e-05 1.27445012e-01 4.29811060e-01 4.41806018e-01
  1.42000281e-04 5.57067360e-05 7.04459744e-05 1.74697838e-04
  2.32187129e-04 1.70866086e-04]
 [2.88555134e-06 7.35303283e-01 8.47264007e-02 

6it [00:02,  4.00it/s]

tf.Tensor(
[[1.32973355e-05 7.52710581e-01 1.01719484e-01 1.45303011e-01
  1.73949393e-05 2.39166511e-05 8.05688524e-06 1.30592889e-04
  2.04985736e-05 5.31992337e-05]
 [8.78134379e-05 6.88515902e-01 1.54825106e-01 1.55217171e-01
  1.78088085e-04 9.22067265e-05 1.78898699e-04 3.30287992e-04
  2.70008924e-04 3.04448564e-04]
 [7.44456372e-08 7.75863528e-01 5.37231266e-02 1.70408562e-01
  2.03643509e-07 1.89204314e-07 1.12138132e-07 2.84634143e-06
  4.94466235e-07 8.65913819e-07]
 [1.66864193e-05 3.19150612e-02 5.02728522e-01 4.65182066e-01
  3.15830184e-05 7.72145268e-06 1.28583224e-05 2.98430550e-05
  5.39050197e-05 2.17554116e-05]
 [9.45610338e-07 2.03471370e-02 7.16940939e-01 2.62690902e-01
  6.39907375e-06 5.98815689e-07 1.89883735e-06 3.26746681e-06
  6.10088728e-06 1.83020154e-06]
 [3.89420336e-08 7.82611251e-01 5.87433726e-02 1.58642024e-01
  1.40343488e-07 1.44912576e-07 7.99546314e-08 2.17071693e-06
  2.61545807e-07 5.43678425e-07]
 [5.14523606e-08 1.79650988e-02 1.01650402e-01 

8it [00:02,  4.76it/s]

tf.Tensor(
[[2.17910292e-05 2.31717695e-02 2.53064036e-01 7.23622143e-01
  1.08367731e-05 7.93458639e-06 5.40458404e-06 2.29693178e-05
  3.12523334e-05 4.18280615e-05]
 [4.25116468e-06 7.42620468e-01 1.05187625e-01 1.52098373e-01
  6.25080429e-06 8.34306957e-06 4.39392352e-06 4.52552486e-05
  8.23969094e-06 1.68305596e-05]
 [2.61811074e-05 3.68066207e-02 5.58284283e-01 4.04582590e-01
  5.99265804e-05 1.76192716e-05 3.16748810e-05 4.97931032e-05
  8.70842996e-05 5.41853333e-05]
 [1.88764044e-08 9.90683794e-01 2.52226880e-03 6.79373089e-03
  2.61331023e-09 9.02557673e-09 1.68259975e-10 2.32098145e-07
  3.48660611e-09 1.66363314e-08]
 [7.17673387e-08 8.68312549e-03 7.71965504e-01 2.19349086e-01
  8.14808345e-07 5.06605566e-08 1.46948665e-07 3.54276153e-07
  6.81362962e-07 1.61501461e-07]
 [2.30978060e-07 9.39008057e-01 2.98042297e-02 3.11839283e-02
  1.27311878e-07 2.26349442e-07 1.33478324e-07 2.37425502e-06
  1.73680277e-07 4.80310064e-07]
 [3.22321580e-06 2.86356527e-02 5.23069680e-01 

10it [00:03,  5.11it/s]

tf.Tensor(
[[2.76591416e-09 9.59186196e-01 1.08410576e-02 2.99719460e-02
  1.97703116e-08 1.45001495e-08 2.14196980e-08 4.88395699e-07
  4.41376216e-08 1.37045220e-07]
 [6.58978877e-07 1.42723052e-02 3.72721225e-01 6.13000631e-01
  2.45634027e-07 2.27701847e-07 9.14468430e-08 4.61138683e-07
  6.70911845e-07 3.50832033e-06]
 [3.01648839e-09 7.54011562e-03 4.31947768e-01 5.60512066e-01
  5.29383959e-09 5.64362890e-10 9.88586990e-10 7.58045804e-09
  1.57994222e-08 1.93976994e-08]
 [4.30947438e-07 8.77961636e-01 4.10944261e-02 8.09169039e-02
  1.50297569e-06 1.19890774e-06 1.50456333e-06 1.31697798e-05
  2.44865373e-06 6.83211783e-06]
 [2.98605833e-08 9.35192872e-03 7.54599512e-01 2.36047819e-01
  2.47038457e-07 1.30391040e-08 3.60800030e-08 9.53372279e-08
  2.33026569e-07 9.46537853e-08]
 [4.08737177e-09 8.03005636e-01 6.46625906e-02 1.32331148e-01
  1.86185165e-08 2.73280900e-08 8.81854945e-09 4.17618025e-07
  2.41906690e-08 7.31482643e-08]
 [3.23078013e-04 1.38204932e-01 2.86406398e-01 

12it [00:03,  5.36it/s]

tf.Tensor(
[[3.65578207e-10 9.87165451e-01 3.87060922e-03 8.96379165e-03
  1.42050882e-09 9.63675473e-10 6.50284382e-10 7.55915224e-08
  3.82211196e-09 2.02110098e-08]
 [3.24118758e-08 6.48538722e-03 7.92089105e-01 2.01424479e-01
  3.17532965e-07 1.69861316e-08 3.93961308e-08 1.53879924e-07
  3.10707890e-07 5.58564430e-08]
 [1.94850745e-05 7.07307830e-02 5.14957964e-01 4.13975477e-01
  6.35457181e-05 2.26098091e-05 3.99503224e-05 5.99400046e-05
  7.45297002e-05 5.56638188e-05]
 [2.36998517e-06 2.74190065e-02 6.46057725e-01 3.26474786e-01
  1.23963237e-05 1.45590775e-06 4.89343529e-06 8.90345746e-06
  1.45948534e-05 3.86338070e-06]
 [1.68082159e-10 2.07155640e-03 8.63173306e-01 1.34755120e-01
  2.13952900e-09 5.64165693e-11 1.49468576e-10 1.06977827e-09
  1.75791892e-09 4.72599015e-10]
 [1.24208546e-08 9.66689084e-03 3.65253329e-01 6.25079572e-01
  1.54546740e-08 1.49786028e-09 2.14505791e-09 2.25996715e-08
  6.89353925e-08 7.66421948e-08]
 [1.75179023e-12 9.54277813e-01 1.21876784e-02 

14it [00:03,  5.47it/s]

tf.Tensor(
[[1.18829721e-05 2.52131000e-02 4.50402290e-01 5.24282396e-01
  1.23811305e-05 4.91685796e-06 4.84662496e-06 1.41800729e-05
  2.68928343e-05 2.70846376e-05]
 [3.23829248e-08 8.65766685e-03 6.39579892e-01 3.51761550e-01
  2.26445252e-07 1.44049608e-08 8.30443767e-08 1.41316065e-07
  2.81917266e-07 1.09569676e-07]
 [1.62622537e-11 9.59612787e-01 1.45541672e-02 2.58329734e-02
  1.09348051e-10 1.04303219e-10 5.64318106e-11 1.05841247e-08
  1.11698879e-10 3.49332230e-10]
 [5.03766641e-05 6.11124523e-02 3.48867565e-01 5.89624941e-01
  3.88287772e-05 2.07680332e-05 2.36320611e-05 5.42191083e-05
  1.12706250e-04 9.45357315e-05]
 [4.50873671e-07 1.12120034e-02 1.76351815e-01 8.12432289e-01
  2.37334191e-07 8.26231243e-08 1.04968471e-07 7.75020453e-07
  1.13889246e-06 1.07951007e-06]
 [2.85424235e-08 5.85881295e-03 7.75786340e-01 2.18353942e-01
  3.04742088e-07 1.14335368e-08 5.90468794e-08 1.12876918e-07
  3.28691897e-07 6.04351413e-08]
 [2.42488586e-05 8.38939399e-02 1.99277818e-01 

16it [00:04,  5.36it/s]

tf.Tensor(
[[1.59909232e-05 8.04965198e-02 1.57856762e-01 7.61503339e-01
  9.15305372e-06 7.24235269e-06 2.83763143e-06 2.61616515e-05
  3.46432898e-05 4.73638684e-05]
 [7.22572076e-05 3.65279280e-02 2.12299153e-01 7.50611305e-01
  4.05069986e-05 2.85118404e-05 1.75126461e-05 7.87006720e-05
  1.28766944e-04 1.95363697e-04]
 [2.45043891e-04 6.40546307e-02 4.02748644e-01 5.31375289e-01
  2.78316380e-04 1.17866461e-04 1.24374550e-04 3.43018008e-04
  4.79648646e-04 2.33162937e-04]
 [4.57346003e-04 2.59398460e-01 2.83992738e-01 4.51167226e-01
  7.05780520e-04 5.27328521e-04 5.89181203e-04 1.19760050e-03
  9.04829009e-04 1.05949794e-03]
 [9.57457558e-10 9.88895655e-01 3.67142376e-03 7.43265962e-03
  6.24188390e-09 2.44377407e-09 6.16419982e-09 1.07832648e-07
  1.65133791e-08 8.26540330e-08]
 [2.02221017e-05 1.61370635e-01 2.00443268e-01 6.38000190e-01
  1.67235921e-05 1.48045165e-05 5.89139654e-06 5.69492913e-05
  3.32300770e-05 3.80150304e-05]
 [2.07728643e-08 5.14524616e-03 1.74379095e-01 

18it [00:04,  5.57it/s]

tf.Tensor(
[[1.76614281e-04 3.04879785e-01 2.51286507e-01 4.41765964e-01
  2.37355838e-04 1.90216539e-04 1.73935259e-04 4.60263720e-04
  3.68279958e-04 4.61187476e-04]
 [9.41669714e-05 5.14267385e-02 5.40987551e-01 4.06248391e-01
  2.38000459e-04 1.26316008e-04 1.50523585e-04 2.70849967e-04
  2.72726582e-04 1.84782213e-04]
 [2.29974440e-03 1.27750367e-01 3.09287071e-01 5.44866025e-01
  2.32026167e-03 2.01412593e-03 1.55901257e-03 3.16944881e-03
  3.43059492e-03 3.30335810e-03]
 [1.69172415e-06 2.13878583e-02 6.90532327e-01 2.88037956e-01
  1.38744053e-05 1.00452303e-06 4.77120147e-06 5.60246417e-06
  1.20729774e-05 2.85143938e-06]
 [3.03432723e-09 2.16102786e-03 1.05215788e-01 8.92623127e-01
  4.49203619e-10 1.32042557e-10 7.24443630e-11 2.76691026e-09
  7.18750437e-09 9.87195747e-09]
 [4.82122857e-07 9.45359945e-01 2.04875674e-02 3.41467448e-02
  1.77624969e-07 3.28001931e-07 8.01394506e-08 3.74041497e-06
  2.98901000e-07 7.05175694e-07]
 [1.93861371e-04 8.18500817e-02 3.48294318e-01 

20it [00:04,  5.58it/s]

tf.Tensor(
[[1.10561661e-04 3.85587096e-01 2.90526599e-01 3.22493553e-01
  2.12186089e-04 9.22241161e-05 1.81994765e-04 3.89757392e-04
  2.23163341e-04 1.82898148e-04]
 [8.40672692e-06 2.20365096e-02 5.97298801e-01 3.80570352e-01
  1.83884749e-05 2.77602180e-06 5.07140112e-06 1.24794951e-05
  3.71833776e-05 9.94444690e-06]
 [4.33285585e-09 7.24728871e-03 2.64034510e-01 7.28718162e-01
  2.93957325e-09 1.78361381e-10 1.83180623e-10 8.26596569e-09
  2.75378529e-08 8.32973246e-09]
 [1.74617817e-05 3.79862934e-01 1.67907074e-01 4.52065468e-01
  9.88134980e-06 1.03828679e-05 2.50349808e-06 6.47853958e-05
  2.55766754e-05 3.39500257e-05]
 [1.47674393e-06 6.86116097e-03 1.03196748e-01 8.89935732e-01
  7.52328546e-08 1.33928410e-07 1.52640887e-08 4.84787222e-07
  1.05715117e-06 3.08074732e-06]
 [4.23255528e-07 1.93487145e-02 3.71862978e-01 6.08783364e-01
  5.88101841e-07 7.59046443e-08 2.61524718e-07 9.91921979e-07
  1.98395628e-06 6.00007752e-07]
 [6.73773757e-04 1.32583395e-01 2.84730375e-01 

22it [00:05,  5.58it/s]

tf.Tensor(
[[1.03472546e-03 1.86350554e-01 2.97482103e-01 5.08032918e-01
  1.06057676e-03 8.30070523e-04 6.07198162e-04 1.72281882e-03
  1.42702262e-03 1.45197415e-03]
 [1.30655637e-04 4.33238566e-01 2.23415002e-01 3.41146618e-01
  2.70891614e-04 1.97302754e-04 2.54069309e-04 5.81093132e-04
  3.08615330e-04 4.57223621e-04]
 [1.78236155e-06 9.66763198e-01 2.58251838e-02 7.40270969e-03
  3.62484457e-06 1.83071108e-07 6.22087782e-07 1.02170520e-06
  1.43472687e-06 2.41276695e-07]
 [1.57763243e-05 7.37734914e-01 9.81928036e-02 1.63596153e-01
  4.11165347e-05 2.57370157e-05 3.45929984e-05 1.60967305e-04
  7.01833269e-05 1.27804247e-04]
 [2.08615209e-03 1.42203197e-01 2.88674951e-01 5.57393491e-01
  1.36165356e-03 1.26263138e-03 6.66132604e-04 2.10710755e-03
  2.25488842e-03 1.98977184e-03]
 [9.44751974e-08 9.68440592e-01 1.15194842e-02 2.00322606e-02
  3.44505452e-07 1.62892320e-07 3.56689497e-07 3.20066192e-06
  7.95810138e-07 2.74392869e-06]
 [2.14227275e-05 5.72407186e-01 1.61715463e-01 

24it [00:05,  5.71it/s]

tf.Tensor(
[[1.07928674e-04 1.08770676e-01 2.69135982e-01 6.21341109e-01
  6.94589980e-05 4.39964460e-05 3.80703787e-05 1.50579581e-04
  1.71949723e-04 1.70141881e-04]
 [2.12277178e-06 4.15918715e-02 2.19297037e-01 7.39080071e-01
  2.67006840e-06 8.15305555e-07 9.60311127e-07 6.95596009e-06
  8.72084365e-06 8.80144580e-06]
 [1.58246181e-07 4.46122233e-03 7.55700245e-02 9.19967413e-01
  3.16865645e-08 9.20876619e-09 3.46338602e-09 1.46434630e-07
  4.11049285e-07 4.93652578e-07]
 [4.01664920e-07 1.65156368e-02 1.51137933e-01 8.32342327e-01
  2.77134120e-07 8.11041616e-08 1.03116349e-07 7.28056875e-07
  1.69359998e-06 7.55423684e-07]
 [1.39622480e-05 7.39452243e-01 9.24140513e-02 1.67831182e-01
  2.31918184e-05 1.60553136e-05 1.31254274e-05 1.08022701e-04
  4.29028187e-05 8.52581143e-05]
 [7.57126109e-05 5.77402897e-02 1.73979431e-01 7.67655551e-01
  4.85170131e-05 2.79884589e-05 2.01394632e-05 1.11985450e-04
  1.47155355e-04 1.93272761e-04]
 [6.52444214e-08 9.27785933e-01 2.53604762e-02 

26it [00:05,  5.69it/s]

tf.Tensor(
[[1.14606291e-06 3.43545489e-02 1.68985233e-01 7.96647966e-01
  1.04519927e-06 2.46455272e-07 4.27468194e-07 2.91109723e-06
  3.74969045e-06 2.84092403e-06]
 [1.50040067e-11 1.35116535e-03 8.23889315e-01 1.74759522e-01
  4.76599260e-10 5.15711605e-12 3.03702653e-11 1.29650929e-10
  4.45438214e-10 9.73395531e-11]
 [2.21909363e-06 1.18319727e-02 2.02470914e-01 7.85681784e-01
  8.23939502e-07 4.85936596e-07 4.78926893e-07 2.41087628e-06
  3.87426326e-06 5.00157694e-06]
 [3.83385901e-09 6.90844096e-03 7.38749206e-01 2.54342318e-01
  2.58488893e-08 7.75793929e-10 3.00382830e-09 1.08164588e-08
  3.97079241e-08 1.40937129e-08]
 [1.31422855e-08 4.34345333e-03 7.98606932e-01 1.97049186e-01
  1.13727403e-07 4.44989512e-09 1.08252616e-08 5.11227540e-08
  1.44310519e-07 2.13945821e-08]
 [2.65093347e-09 1.07097151e-02 7.66907871e-01 2.22382367e-01
  1.86617246e-08 4.74056683e-10 1.96320959e-09 1.66157577e-08
  2.62167514e-08 3.05598147e-09]
 [4.62088683e-05 5.00240624e-01 1.83198199e-01 

28it [00:06,  5.63it/s]

tf.Tensor(
[[7.59482061e-11 1.32463942e-03 8.52926493e-01 1.45748839e-01
  1.62793146e-09 2.48761463e-11 1.13362104e-10 6.07119632e-10
  1.65815661e-09 2.48496002e-10]
 [3.57142517e-06 1.47337243e-02 1.57902792e-01 8.27344716e-01
  9.86917257e-07 4.73194945e-07 1.56363839e-07 2.86010891e-06
  5.17926946e-06 5.61491242e-06]
 [1.94193533e-04 9.59822908e-02 4.14277107e-01 4.88088250e-01
  2.15589098e-04 1.18371136e-04 1.48118095e-04 2.77324289e-04
  3.63235798e-04 3.35504359e-04]
 [9.13434706e-05 5.93841374e-02 2.10857511e-01 7.29039609e-01
  5.66065901e-05 3.59757796e-05 2.77907002e-05 1.21592995e-04
  1.74546556e-04 2.10879341e-04]
 [1.85744182e-04 5.71729600e-01 1.67216465e-01 2.56982267e-01
  4.55490808e-04 2.96497834e-04 5.72506862e-04 1.01412844e-03
  5.70475007e-04 9.76876006e-04]
 [8.86446742e-06 4.04259041e-02 1.50409430e-01 8.09067667e-01
  6.08131359e-06 2.47103389e-06 2.32855837e-06 1.30637300e-05
  2.63274942e-05 3.79049816e-05]
 [9.17819563e-08 6.26503304e-03 7.57548034e-01 

30it [00:06,  5.65it/s]

tf.Tensor(
[[7.96728017e-08 6.84041483e-03 2.10167393e-01 7.82991529e-01
  2.85222832e-08 1.22281287e-08 7.77467246e-09 9.35694899e-08
  1.47569239e-07 2.67721987e-07]
 [6.97287874e-07 5.73693998e-02 6.57360017e-01 2.85252541e-01
  6.03809031e-06 3.75344086e-07 2.51965616e-06 3.08098242e-06
  4.33811647e-06 1.01985904e-06]
 [1.25610945e-03 9.31376219e-02 3.07574511e-01 5.87768257e-01
  1.31750794e-03 1.10061129e-03 9.89587512e-04 1.96481543e-03
  2.50069750e-03 2.39035604e-03]
 [1.51597357e-09 4.15373454e-03 8.37285340e-01 1.58560812e-01
  2.44187390e-08 7.11104176e-10 3.29087912e-09 1.10307461e-08
  1.47803778e-08 2.98706127e-09]
 [1.31559609e-05 2.35719588e-02 1.83452636e-01 7.92863667e-01
  1.07112855e-05 4.29974352e-06 5.25457335e-06 2.25345029e-05
  3.05075337e-05 2.52741302e-05]
 [2.65042627e-10 3.07693519e-03 8.73378992e-01 1.23544097e-01
  7.57468577e-09 1.55353536e-10 5.76400871e-10 2.49837595e-09
  3.79021703e-09 8.57854066e-10]
 [5.18839229e-08 1.16981510e-02 7.56233633e-01 

32it [00:06,  5.73it/s]

tf.Tensor(
[[2.51412421e-05 8.80604014e-02 1.89976394e-01 7.21704245e-01
  2.43699578e-05 9.33209685e-06 1.26438235e-05 5.40435576e-05
  6.27954360e-05 7.05954662e-05]
 [4.08877153e-04 4.09494698e-01 2.73765236e-01 3.10264766e-01
  8.67975235e-04 5.04965661e-04 8.56389233e-04 1.43546355e-03
  1.18882617e-03 1.21282402e-03]
 [3.02219028e-09 9.88770902e-01 4.16069385e-03 7.06786569e-03
  1.76054407e-08 5.11875742e-09 1.79085582e-08 1.55034371e-07
  4.13150829e-08 1.72604842e-07]
 [2.08161000e-04 3.73487711e-01 2.22783223e-01 4.00680453e-01
  3.52774514e-04 2.67169584e-04 2.87156989e-04 7.28694606e-04
  4.92611551e-04 7.12155947e-04]
 [8.27460156e-07 8.81446302e-01 5.80826253e-02 6.04573898e-02
  5.87060526e-07 1.07409403e-06 5.42097894e-07 7.41773420e-06
  8.46966543e-07 2.33403989e-06]
 [2.14326801e-09 3.72446724e-03 8.38673174e-01 1.57602385e-01
  2.39783091e-08 7.70854935e-10 3.68653996e-09 1.35931320e-08
  1.95664942e-08 3.42877593e-09]
 [1.56512982e-07 1.10896870e-01 8.91410932e-02 

34it [00:07,  5.67it/s]

tf.Tensor(
[[1.98147200e-06 1.55390706e-02 2.98536152e-01 6.85907960e-01
  1.24715689e-06 6.02711339e-07 5.39596044e-07 2.51993197e-06
  3.81202221e-06 6.10587313e-06]
 [1.71613778e-04 6.88681662e-01 1.54245272e-01 1.56283438e-01
  5.17478766e-05 7.45418874e-05 2.31096037e-05 1.80103554e-04
  1.00322861e-04 1.88174978e-04]
 [3.82331962e-08 9.85689044e-01 4.22711857e-03 1.00827515e-02
  2.18022844e-08 4.56547049e-08 4.12357348e-09 8.05047648e-07
  4.78818691e-08 2.14346613e-07]
 [4.75157958e-06 1.60574224e-02 2.12482631e-01 7.71410942e-01
  4.89048443e-06 1.35346659e-06 2.06547156e-06 8.86441194e-06
  1.62251672e-05 1.08460617e-05]
 [1.35757611e-04 1.50579348e-01 4.31462556e-01 4.15858209e-01
  3.96983931e-04 1.46772873e-04 3.47378780e-04 4.64292563e-04
  3.97879194e-04 2.10906146e-04]
 [4.84379325e-06 2.05830019e-02 1.75734192e-01 8.03632677e-01
  2.09972927e-06 5.86605893e-07 8.81148424e-07 3.71410601e-06
  2.20963138e-05 1.59377960e-05]
 [9.95641574e-04 1.24219790e-01 2.93999434e-01 

36it [00:07,  5.58it/s]

tf.Tensor(
[[1.72594431e-04 4.77761269e-01 1.99498639e-01 3.19935024e-01
  3.49364942e-04 2.52368016e-04 3.04253306e-04 7.89285928e-04
  3.89731198e-04 5.47467906e-04]
 [1.89886240e-08 9.54143703e-01 1.93087179e-02 2.65467670e-02
  1.87308871e-08 3.04056655e-08 1.16609469e-08 6.42358771e-07
  3.10142383e-08 1.16184054e-07]
 [4.81812185e-06 7.84491003e-01 8.00146833e-02 1.35326400e-01
  1.33467456e-05 9.41287180e-06 1.08877093e-05 7.96580425e-05
  1.69542400e-05 3.27201669e-05]
 [7.21706783e-07 8.66852343e-01 4.40640077e-02 8.90417248e-02
  2.50027620e-06 1.79591609e-06 2.47666071e-06 1.86682610e-05
  4.58919158e-06 1.12024791e-05]
 [3.55875216e-07 1.33728189e-02 2.88878202e-01 6.97742820e-01
  5.93654192e-07 5.94942762e-08 3.41090981e-07 6.72142335e-07
  3.00928332e-06 1.12891303e-06]
 [1.51096398e-08 3.26114893e-03 6.55639619e-02 9.31174755e-01
  1.50828261e-09 3.20532212e-10 1.98479358e-10 1.10302070e-08
  5.33738387e-08 2.43298217e-08]
 [1.67630124e-03 9.82882828e-02 4.77885991e-01 

38it [00:08,  5.52it/s]

tf.Tensor(
[[5.19380983e-09 6.73656678e-03 8.01689863e-01 1.91573396e-01
  6.99209153e-08 2.36247888e-09 7.31292982e-09 2.30855015e-08
  5.26752935e-08 1.64104641e-08]
 [9.83043446e-06 3.10247064e-01 1.69834614e-01 5.19815326e-01
  6.47422621e-06 4.96158555e-06 2.01371358e-06 3.63467261e-05
  1.92120297e-05 2.41005473e-05]
 [2.63571621e-07 2.70346124e-02 2.06329063e-01 7.66633987e-01
  9.93893821e-08 3.73023106e-08 4.05626217e-08 3.02551143e-07
  6.66188157e-07 9.44025771e-07]
 [2.07940989e-08 1.40643725e-02 6.12967908e-01 3.72967452e-01
  7.55911813e-08 5.39531531e-09 9.52377466e-09 6.79090633e-08
  1.25800469e-07 4.78476636e-08]
 [2.48990386e-06 2.20793337e-02 6.07379496e-01 3.70508492e-01
  6.62024377e-06 1.43126397e-06 1.98271778e-06 4.99647831e-06
  9.25112363e-06 5.86226906e-06]
 [1.92014719e-04 7.38508627e-02 3.32550079e-01 5.92087805e-01
  1.69088205e-04 1.16367104e-04 1.19385295e-04 2.63586029e-04
  3.12587770e-04 3.38189129e-04]
 [4.33931997e-13 9.56867993e-01 1.58213973e-02 

40it [00:08,  5.65it/s]

tf.Tensor(
[[2.97215195e-08 1.01355510e-02 8.09710085e-01 1.80153310e-01
  4.17511529e-07 1.94087466e-08 8.80000783e-08 1.48494834e-07
  2.52856580e-07 8.08435630e-08]
 [1.62321214e-08 7.55788805e-03 3.78239453e-01 6.14202380e-01
  1.79294588e-08 3.22929883e-09 5.05995601e-09 2.58250683e-08
  7.26501170e-08 9.75591163e-08]
 [1.42805138e-05 8.45752597e-01 8.06360543e-02 7.34935552e-02
  4.90381763e-06 1.41359506e-05 1.48817583e-06 5.34401806e-05
  7.83643554e-06 2.17040724e-05]
 [3.06607399e-04 1.59032151e-01 4.34424937e-01 4.02081817e-01
  8.24264251e-04 2.84075417e-04 9.63942846e-04 7.70040206e-04
  7.00633624e-04 6.11600233e-04]
 [3.66340780e-09 7.36875925e-03 4.83757228e-01 5.08873999e-01
  8.75060913e-09 1.43141332e-09 3.48318885e-09 1.16611982e-08
  1.83065261e-08 2.05746904e-08]
 [4.91107894e-05 7.97901332e-01 8.43601301e-02 1.17029712e-01
  4.36729133e-05 4.91697901e-05 2.83234313e-05 2.25253883e-04
  8.79825020e-05 2.25377444e-04]
 [4.25163686e-04 2.60399669e-01 2.97002405e-01 

42it [00:08,  5.63it/s]

tf.Tensor(
[[3.75055213e-04 5.93554825e-02 4.46021140e-01 4.90049362e-01
  5.78498701e-04 5.55306324e-04 3.63697705e-04 9.39751742e-04
  1.05782482e-03 7.03911763e-04]
 [3.20120222e-07 8.51618312e-03 1.71207339e-01 8.20274234e-01
  1.02538323e-07 4.83022653e-08 4.30399929e-08 3.60749056e-07
  6.02031321e-07 7.72954991e-07]
 [1.64579367e-04 7.22451508e-02 2.86088228e-01 6.40391707e-01
  1.18992873e-04 9.84971120e-05 7.53311833e-05 2.36890133e-04
  2.38929395e-04 3.41687643e-04]
 [5.01202710e-04 2.40126923e-01 3.42066020e-01 4.10925329e-01
  1.08450139e-03 5.69137570e-04 1.09639391e-03 1.26950582e-03
  1.22557010e-03 1.13544473e-03]
 [1.00210795e-09 9.96626616e-01 8.68063129e-04 2.50529801e-03
  5.84965687e-11 3.13473081e-10 2.99759154e-12 1.17028360e-08
  1.01850972e-10 5.30559374e-10]
 [4.58353679e-05 5.12426972e-01 1.60754442e-01 3.26100677e-01
  6.18269914e-05 5.47063464e-05 3.83131774e-05 2.36056178e-04
  1.08147229e-04 1.73077075e-04]
 [7.99175552e-07 1.19647700e-02 2.23645091e-01 

44it [00:09,  5.55it/s]

tf.Tensor(
[[1.39225953e-09 3.90603719e-03 8.37223768e-01 1.58870116e-01
  2.04961701e-08 6.99985236e-10 1.83551963e-09 8.21255774e-09
  1.47985855e-08 4.36240510e-09]
 [4.05984792e-05 4.92048144e-01 1.85562178e-01 3.21591139e-01
  9.23511907e-05 7.28063897e-05 7.09052110e-05 2.34533785e-04
  1.15298601e-04 1.72123051e-04]
 [5.28535398e-04 1.48614198e-01 4.39337671e-01 4.05762613e-01
  8.51735065e-04 7.37748574e-04 8.23243347e-04 1.23511709e-03
  1.07857829e-03 1.03049679e-03]
 [1.05691607e-08 3.73746804e-03 1.55865580e-01 8.40396881e-01
  2.80602963e-09 1.02182962e-09 8.37436009e-10 1.22869181e-08
  2.59644626e-08 6.00183299e-08]
 [6.92391353e-08 6.84649590e-03 9.05176774e-02 9.02635038e-01
  2.66887685e-08 5.11005727e-09 3.99126332e-09 8.12277605e-08
  2.98757101e-07 3.30421614e-07]
 [9.50741560e-06 2.91693211e-01 4.11739469e-01 2.96373367e-01
  4.14655442e-05 6.59605939e-06 3.54773729e-05 4.30825712e-05
  3.82962535e-05 1.95171797e-05]
 [2.87776925e-06 2.34643649e-02 6.76163793e-01 

46it [00:09,  5.59it/s]

tf.Tensor(
[[3.37190900e-11 9.72303212e-01 1.22197745e-02 1.54770436e-02
  1.54564417e-10 2.96447838e-10 1.84446791e-10 1.15941843e-08
  1.54471574e-10 7.27277016e-10]
 [4.44154903e-05 5.45925140e-01 2.04764158e-01 2.48903200e-01
  3.27344096e-05 3.98934535e-05 1.66887694e-05 1.37721625e-04
  4.82186224e-05 8.78863284e-05]
 [1.62033047e-07 1.39003918e-02 7.03616142e-01 2.82479048e-01
  1.32447019e-06 1.18590464e-07 6.32500303e-07 7.43918918e-07
  1.00259251e-06 4.17025433e-07]
 [2.87427842e-08 9.60552633e-01 1.25833228e-02 2.68597510e-02
  1.83279383e-07 9.15219118e-08 2.08023678e-07 2.23810821e-06
  3.58353162e-07 1.07758819e-06]
 [7.57617329e-08 8.11665878e-03 1.96777508e-01 7.95104980e-01
  4.71579575e-08 1.55333151e-08 2.50023984e-08 1.54700345e-07
  2.19528687e-07 2.56327979e-07]
 [7.52792752e-04 2.03347296e-01 2.52693832e-01 5.37487745e-01
  6.12676609e-04 5.83887973e-04 8.40128632e-04 9.83716222e-04
  1.08860526e-03 1.60930003e-03]
 [1.00235236e-04 6.55148923e-01 1.78153932e-01 

47it [00:10,  4.42it/s]

tf.Tensor(
[[3.03948644e-09 2.53802235e-03 1.25613436e-01 8.71848583e-01
  6.53843868e-10 1.59831162e-10 9.79009443e-11 3.47438189e-09
  8.33664426e-09 1.08007931e-08]
 [1.38306777e-05 6.77867532e-01 1.10775754e-01 2.10885882e-01
  4.21693003e-05 3.21654152e-05 4.47522034e-05 1.38128889e-04
  6.49115755e-05 1.34901100e-04]
 [2.70238638e-06 1.03250034e-01 1.80497989e-01 7.16226578e-01
  1.66303209e-06 5.04798834e-07 5.93621394e-07 6.55303938e-06
  6.94949267e-06 6.33929631e-06]
 [5.73183112e-10 9.98393118e-01 2.20087953e-04 1.38672197e-03
  1.22888880e-10 7.40046968e-10 1.54109115e-12 5.80151500e-08
  9.97474464e-11 3.11803111e-10]
 [1.61936498e-06 8.97600889e-01 3.86975110e-02 6.36444539e-02
  3.23857125e-06 2.27324927e-06 2.32474008e-06 2.53008311e-05
  6.15299950e-06 1.62206616e-05]
 [2.01389981e-11 9.83998656e-01 5.77664049e-03 1.02246990e-02
  7.00217037e-11 1.00665233e-10 4.37768606e-11 1.04189963e-08
  8.56256455e-11 4.86128582e-10]
 [1.72933505e-05 2.40146983e-02 2.03264475e-01 




Running epoch on classifier  1


0it [00:00, ?it/s]



1it [00:05,  5.85s/it]



2it [00:11,  5.83s/it]



3it [00:17,  5.66s/it]



4it [00:23,  5.88s/it]



5it [00:28,  5.75s/it]



6it [00:34,  5.77s/it]



7it [00:40,  5.80s/it]



8it [00:45,  5.67s/it]



9it [00:51,  5.71s/it]



10it [00:57,  5.61s/it]



11it [01:02,  5.41s/it]



12it [01:07,  5.40s/it]



13it [01:14,  5.76s/it]



14it [01:19,  5.65s/it]



15it [01:24,  5.57s/it]



16it [01:30,  5.65s/it]



17it [01:35,  5.46s/it]



18it [01:41,  5.70s/it]



19it [01:46,  5.37s/it]



20it [01:51,  5.39s/it]



21it [01:56,  5.28s/it]



22it [02:02,  5.44s/it]



23it [02:08,  5.42s/it]



24it [02:13,  5.41s/it]



25it [02:18,  5.42s/it]



26it [02:23,  5.17s/it]



27it [02:29,  5.36s/it]



28it [02:33,  5.12s/it]



29it [02:39,  5.22s/it]



30it [02:45,  5.38s/it]



31it [02:49,  5.15s/it]



32it [02:55,  5.24s/it]



33it [03:01,  5.40s/it]



34it [03:06,  5.30s/it]



35it [03:11,  5.33s/it]



36it [03:17,  5.60s/it]



37it [03:23,  5.65s/it]



38it [03:29,  5.71s/it]



39it [03:34,  5.61s/it]



40it [03:40,  5.56s/it]



41it [03:44,  5.28s/it]



42it [03:50,  5.31s/it]



43it [03:55,  5.35s/it]



44it [04:00,  5.36s/it]



45it [04:06,  5.39s/it]



46it [04:12,  5.50s/it]



47it [04:17,  5.48s/it]



48it [04:24,  5.82s/it]



49it [04:29,  5.70s/it]



50it [04:35,  5.72s/it]



51it [04:41,  5.76s/it]



52it [04:46,  5.65s/it]



53it [04:52,  5.59s/it]



54it [04:57,  5.41s/it]



55it [05:01,  5.17s/it]



56it [05:07,  5.36s/it]



57it [05:12,  5.26s/it]



58it [05:17,  5.08s/it]



59it [05:20,  5.43s/it]



Evaluate classifier on Task:  0  Epoch:  1


2it [00:00,  5.21it/s]

tf.Tensor(
[[9.82731119e-16 9.79930997e-01 5.04198484e-03 1.50270257e-02
  5.86804415e-16 1.51150097e-15 5.74855255e-17 8.38149081e-12
  2.05781759e-14 4.04287504e-14]
 [9.64183528e-07 8.51576682e-03 2.04379186e-01 7.87099183e-01
  2.46748726e-07 1.18734157e-07 9.45600860e-08 8.28206225e-07
  1.92367679e-06 1.59624028e-06]
 [7.52442375e-08 8.45462739e-01 7.19091818e-02 8.26227367e-02
  1.57610202e-07 1.40121529e-07 9.82832873e-08 3.82836197e-06
  3.19380064e-07 6.57532610e-07]
 [1.09329019e-04 3.94745953e-02 3.22281420e-01 6.36630833e-01
  1.79747891e-04 1.41550554e-04 8.91502641e-05 2.68890522e-04
  3.27639922e-04 4.96930210e-04]
 [4.98569261e-06 3.65411520e-01 2.13291958e-01 4.21206832e-01
  5.37224059e-06 5.35624349e-06 3.23124050e-06 2.95516093e-05
  1.67688668e-05 2.43727900e-05]
 [4.94334012e-13 9.90645289e-01 3.51569778e-03 5.83905168e-03
  3.85254383e-13 4.11616758e-13 1.61695201e-13 6.74268419e-10
  2.81734098e-12 4.98269394e-12]
 [8.34415914e-05 2.76611634e-02 2.68575191e-01 

4it [00:00,  5.53it/s]

tf.Tensor(
[[3.45666109e-07 3.83171509e-03 6.22675776e-01 3.73490065e-01
  2.67699676e-07 6.56745485e-08 4.06861105e-08 5.85310147e-07
  8.89097578e-07 2.49023117e-07]
 [2.49936071e-04 9.02827010e-02 3.95327657e-01 5.10526001e-01
  4.64782875e-04 3.82884231e-04 5.16109576e-04 7.32454646e-04
  7.07602361e-04 8.09868216e-04]
 [2.81586399e-08 1.14880851e-03 8.43826234e-01 1.55024767e-01
  3.51642448e-08 3.41192874e-09 1.10772058e-09 8.78678321e-08
  5.66554981e-08 9.00290953e-09]
 [1.75075991e-06 9.58404783e-03 1.47952899e-01 8.42454016e-01
  2.87065745e-07 9.41291347e-08 8.05843570e-08 9.74864861e-07
  3.16764203e-06 2.78629705e-06]
 [7.63781749e-09 9.00246918e-01 3.79406884e-02 6.18111938e-02
  1.23289938e-08 1.39293919e-08 5.00962427e-09 9.30793078e-07
  5.31866107e-08 1.17301759e-07]
 [5.82373445e-07 1.96876843e-02 1.27343208e-01 8.52965534e-01
  1.01890656e-07 4.95126002e-08 4.57831071e-08 3.41248438e-07
  1.07635265e-06 1.45894217e-06]
 [1.70726125e-05 4.30418588e-02 2.03011066e-01 

6it [00:01,  5.50it/s]

tf.Tensor(
[[3.53549840e-04 4.58969213e-02 4.74539608e-01 4.74563569e-01
  6.79943943e-04 5.67753217e-04 6.78077107e-04 1.06648426e-03
  9.96699673e-04 6.57422934e-04]
 [1.71727754e-11 9.82584059e-01 8.99718981e-03 8.41880497e-03
  1.30634698e-10 7.90793472e-11 6.20665394e-11 2.52803307e-08
  4.74916106e-10 8.94539665e-10]
 [3.30659944e-19 9.98512805e-01 6.36183482e-04 8.51053861e-04
  2.37982096e-18 2.24701463e-18 9.82389784e-19 9.00955810e-14
  9.58140840e-18 5.03483663e-17]
 [3.34106608e-06 3.13374221e-01 2.06975251e-01 4.79606211e-01
  2.01617513e-06 1.80735128e-06 1.30359513e-06 1.20117511e-05
  9.68461882e-06 1.41402143e-05]
 [3.01118871e-06 1.69861130e-02 1.22004472e-01 8.60993326e-01
  5.79206699e-07 2.61056442e-07 2.33585865e-07 1.82422161e-06
  4.71001977e-06 5.54530516e-06]
 [1.67107984e-07 1.87320902e-03 7.80316710e-01 2.17807695e-01
  4.15070446e-07 5.34402176e-08 4.69341543e-08 8.58986141e-07
  7.87452734e-07 6.69761420e-08]
 [2.23138952e-09 3.57548287e-03 4.41091955e-02 

8it [00:01,  5.51it/s]

tf.Tensor(
[[5.93609295e-09 8.73087943e-01 5.12373522e-02 7.56740794e-02
  9.36790112e-09 1.27727890e-08 3.29689565e-09 5.39259531e-07
  2.95670635e-08 3.32098118e-08]
 [1.03015327e-06 3.99231026e-03 6.66213393e-01 3.29781145e-01
  2.14894112e-06 6.94667449e-07 6.33292586e-07 3.80873144e-06
  3.75744298e-06 1.07422613e-06]
 [9.47079202e-19 9.99706209e-01 1.19952383e-04 1.73849592e-04
  2.04100468e-18 1.88083542e-18 1.53101112e-18 3.58997039e-13
  2.47671193e-17 6.25578044e-18]
 [1.69771488e-06 6.88637141e-03 4.06747580e-01 5.86351693e-01
  1.12613156e-06 6.81523829e-07 5.04093691e-07 2.44206240e-06
  4.60214005e-06 3.26742611e-06]
 [6.38830534e-05 1.60501301e-01 4.32742774e-01 4.05911446e-01
  6.64745094e-05 9.27589936e-05 2.97423412e-05 3.47849767e-04
  1.21298588e-04 1.22490354e-04]
 [1.99498504e-06 5.39474981e-03 5.47281146e-01 4.47308302e-01
  2.03851187e-06 5.98900272e-07 5.04970501e-07 3.27359635e-06
  5.53680229e-06 1.91983963e-06]
 [3.01205382e-05 1.80333294e-02 4.24345016e-01 

10it [00:01,  5.42it/s]

tf.Tensor(
[[2.95700005e-07 6.99944329e-03 1.29114851e-01 8.63884151e-01
  2.72813949e-08 1.08573168e-08 6.61838673e-09 1.22371716e-07
  5.16006651e-07 5.42824694e-07]
 [2.52120352e-12 9.88427699e-01 5.97663829e-03 5.59569709e-03
  8.36986730e-12 2.38148867e-12 4.71296006e-12 7.24579008e-10
  7.52333890e-11 3.38611139e-10]
 [1.54974643e-07 5.47990389e-03 1.82710603e-01 8.11808646e-01
  3.12015800e-08 5.13800247e-09 7.84338461e-09 8.10234084e-08
  3.92451113e-07 1.81363163e-07]
 [2.28939824e-16 9.97950375e-01 7.66894780e-04 1.28274271e-03
  1.50433740e-15 1.84581227e-15 1.65571105e-16 1.97021600e-11
  9.82471483e-15 2.17850742e-14]
 [1.75110719e-04 6.81850389e-02 5.50195992e-01 3.78956109e-01
  4.05565661e-04 3.32743861e-04 3.25794477e-04 5.87479852e-04
  5.18378511e-04 3.17722530e-04]
 [1.38927136e-09 9.49737310e-01 2.32910551e-02 2.69712750e-02
  4.02973210e-09 3.41094730e-09 1.78586501e-09 3.00625743e-07
  1.70197065e-08 3.46773135e-08]
 [4.95386848e-05 2.69371808e-01 3.09329122e-01 

12it [00:02,  5.52it/s]

tf.Tensor(
[[9.30075888e-15 9.98739302e-01 8.60675995e-04 4.00082616e-04
  2.99597214e-14 1.39294297e-14 2.22530107e-14 4.21865711e-12
  1.93515438e-13 5.80386867e-12]
 [1.04341552e-05 1.81186795e-02 2.54722625e-01 7.27067173e-01
  6.58461977e-06 6.31498597e-06 4.17396268e-06 1.50502519e-05
  2.45866886e-05 2.43591548e-05]
 [9.93870373e-08 2.68815714e-03 7.58801043e-01 2.38509387e-01
  2.81812760e-07 2.41438212e-08 4.58199807e-08 4.44561181e-07
  4.30256478e-07 6.14228810e-08]
 [8.97820840e-08 1.41103403e-03 7.98139274e-01 2.00448588e-01
  1.48872942e-07 3.59063463e-08 1.59734785e-08 5.25531732e-07
  2.57124441e-07 4.24894040e-08]
 [3.14440484e-07 5.33572724e-03 5.63537002e-01 4.31125283e-01
  1.90979165e-07 3.43037563e-08 3.96541004e-08 3.94437336e-07
  7.96062352e-07 1.86570801e-07]
 [1.36377037e-06 1.04058757e-02 7.59972110e-02 9.13587034e-01
  1.49951163e-07 5.37633937e-08 3.67375108e-08 7.17652370e-07
  2.75919865e-06 4.75685829e-06]
 [6.11398791e-05 4.84418385e-02 5.80471218e-01 

14it [00:02,  5.60it/s]

tf.Tensor(
[[1.85029023e-06 1.35799935e-02 1.31218910e-01 8.55190814e-01
  3.25687409e-07 1.61927773e-07 1.32710625e-07 1.07391963e-06
  2.99618864e-06 3.58656280e-06]
 [4.82515425e-05 1.47043914e-02 3.85292530e-01 5.99735022e-01
  2.45891188e-05 1.61388780e-05 1.17830896e-05 3.92341026e-05
  6.96169955e-05 5.84216323e-05]
 [7.59457444e-28 9.99990106e-01 4.14966053e-06 5.78774916e-06
  6.92373125e-27 4.08658761e-27 6.87999617e-28 1.69949982e-19
  1.27868313e-25 4.34309097e-25]
 [3.25611813e-06 1.17661711e-02 1.40436590e-01 8.47776175e-01
  7.14101191e-07 3.40642742e-07 2.26036605e-07 2.17006027e-06
  5.98943734e-06 8.35308128e-06]
 [1.32815767e-05 5.26030481e-01 1.46283165e-01 3.27565968e-01
  7.81608287e-06 5.10455584e-06 3.93885966e-06 4.08538363e-05
  2.66981733e-05 2.26313659e-05]
 [4.35487779e-09 7.80132832e-04 8.49439621e-01 1.49780229e-01
  1.16263488e-08 2.37058317e-10 3.47371187e-10 1.65291016e-08
  2.75024217e-08 7.94935895e-10]
 [6.22891503e-06 1.59012023e-02 2.02853188e-01 

16it [00:02,  5.69it/s]

tf.Tensor(
[[4.41657484e-08 9.35404241e-01 3.51910181e-02 2.94026360e-02
  8.32105655e-08 5.43418963e-08 4.85834413e-08 1.24810435e-06
  2.40199569e-07 4.49175332e-07]
 [2.50584017e-05 5.81378527e-02 2.05534860e-01 7.36096025e-01
  1.75449786e-05 1.03961065e-05 1.42239478e-05 3.73529765e-05
  5.35311883e-05 7.31685286e-05]
 [1.14349241e-03 1.03836022e-01 3.26158881e-01 5.59881628e-01
  1.05863065e-03 1.05382642e-03 1.04682427e-03 1.67082727e-03
  1.89464376e-03 2.25515105e-03]
 [1.08986205e-05 1.69431102e-02 4.38900441e-01 5.44057190e-01
  9.23382504e-06 7.43269584e-06 6.02681439e-06 2.38922039e-05
  2.69012580e-05 1.49399857e-05]
 [5.03049691e-09 8.69199540e-03 8.00270885e-02 9.11280930e-01
  2.35671122e-10 7.21704363e-11 5.91596078e-11 1.86977567e-09
  1.13995400e-08 1.44181218e-08]
 [1.11700629e-05 2.25841943e-02 1.47087783e-01 8.30225348e-01
  5.46137835e-06 2.46035461e-06 2.77715230e-06 1.25262095e-05
  2.73778205e-05 4.08493179e-05]
 [7.36808444e-12 9.80430126e-01 8.75675958e-03 

17it [00:03,  5.63it/s]

tf.Tensor(
[[9.34249713e-18 9.93761361e-01 2.82393536e-03 3.41469492e-03
  7.75869077e-17 1.31301846e-16 3.96455128e-17 8.80401600e-13
  2.18874915e-16 1.13188896e-15]
 [5.95715346e-06 1.05964923e-02 6.76828980e-01 3.12473983e-01
  9.77263244e-06 1.38862697e-05 4.13063617e-06 4.14208589e-05
  1.49035905e-05 1.04060109e-05]
 [4.82175274e-05 2.98030376e-02 2.53464282e-01 7.16338754e-01
  3.02222525e-05 2.37057993e-05 2.66523784e-05 6.44370739e-05
  9.46101063e-05 1.06148633e-04]
 [4.19708849e-05 2.40872018e-02 3.14070731e-01 6.61520123e-01
  3.31408410e-05 1.94121294e-05 2.87394905e-05 5.85216403e-05
  8.79905710e-05 5.21893999e-05]
 [9.60303481e-08 4.27416433e-03 7.79844373e-02 9.17741001e-01
  3.38043260e-09 1.50921831e-09 5.89804483e-10 2.37479707e-08
  1.33114668e-07 1.76761475e-07]
 [1.89186197e-19 9.99353468e-01 2.62332585e-04 3.84250452e-04
  8.90107654e-19 5.59713829e-19 2.78141678e-19 9.27557727e-14
  6.82440921e-18 2.03906407e-17]
 [7.51635469e-07 1.19422600e-01 2.16999382e-01 

19it [00:03,  5.52it/s]

tf.Tensor(
[[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

21it [00:03,  5.50it/s]

tf.Tensor(
[[1.16533147e-06 7.19291763e-03 9.59711820e-02 8.96831334e-01
  1.14852526e-07 5.32189901e-08 3.31927410e-08 4.46930954e-07
  1.41333089e-06 1.25634028e-06]
 [2.40359128e-08 6.34418707e-03 7.60621056e-02 9.17593539e-01
  1.81544313e-09 4.51062493e-10 4.77350159e-10 8.61167582e-09
  6.05683255e-08 6.61508182e-08]
 [1.72116434e-05 1.53072774e-02 3.84174883e-01 6.00370288e-01
  1.43713378e-05 1.18981334e-05 6.13122620e-06 2.79906180e-05
  3.92908660e-05 3.05775866e-05]
 [1.28605998e-05 3.72963786e-01 2.69279420e-01 3.57513875e-01
  2.15898071e-05 1.18545668e-05 2.56418316e-05 6.35980396e-05
  4.08448723e-05 6.65092157e-05]
 [1.35165203e-08 9.36047971e-01 2.52859835e-02 3.86651307e-02
  1.00811306e-08 1.01891526e-08 4.63653027e-09 7.72840281e-07
  4.62588012e-08 5.84813939e-08]
 [3.30899985e-15 9.98187959e-01 9.91051551e-04 8.20984365e-04
  3.49561519e-14 5.69304142e-15 2.72685740e-14 1.71725984e-11
  5.08157373e-13 1.46123760e-12]
 [1.96054374e-04 2.01012701e-01 3.76833707e-01 

23it [00:04,  5.49it/s]

tf.Tensor(
[[1.49893680e-07 3.68861221e-02 8.28034505e-02 8.80309105e-01
  1.27892310e-08 4.43742643e-09 4.00797084e-09 8.21653572e-08
  4.41978557e-07 5.96480106e-07]
 [5.86957850e-14 9.91486907e-01 3.73829226e-03 4.77475720e-03
  2.05016166e-13 1.63474323e-13 8.03534959e-14 3.59871410e-10
  6.69322995e-13 1.03513053e-12]
 [1.00872938e-07 8.91610742e-01 6.95456043e-02 3.88251208e-02
  4.52852055e-06 7.32308308e-07 2.82085739e-06 3.58251850e-06
  1.59633862e-06 5.16463524e-06]
 [1.22597314e-07 7.63068557e-01 9.24672857e-02 1.44458368e-01
  1.73648729e-07 1.75762494e-07 2.21734808e-07 4.24913105e-06
  5.63959418e-07 3.14048805e-07]
 [7.81240033e-06 2.06134524e-02 2.18409017e-01 7.60930121e-01
  3.25253359e-06 1.60558363e-06 2.15919749e-06 6.18008744e-06
  1.58113253e-05 1.05437484e-05]
 [2.89484306e-05 1.50290295e-01 2.89915472e-01 5.59376657e-01
  3.25322944e-05 2.74206341e-05 3.87783803e-05 8.11309219e-05
  7.99674453e-05 1.28719039e-04]
 [9.60145782e-11 9.36827183e-01 2.22861748e-02 

25it [00:04,  5.55it/s]

tf.Tensor(
[[1.45633367e-05 1.85963623e-02 3.31757098e-01 6.49544060e-01
  9.16054796e-06 5.03651700e-06 5.98734687e-06 1.54439804e-05
  3.20156541e-05 2.02427691e-05]
 [1.69652117e-06 9.83365998e-03 1.90134972e-01 8.00022304e-01
  3.57209700e-07 1.83888432e-07 1.40191702e-07 1.18508910e-06
  2.79234632e-06 2.77025947e-06]
 [7.48095772e-08 5.07406890e-03 6.65618479e-02 9.28363681e-01
  2.95041591e-09 1.01305320e-09 4.77000661e-10 2.31241923e-08
  1.28299220e-07 1.62589600e-07]
 [2.51634114e-09 6.06539019e-04 8.67312491e-01 1.32081091e-01
  8.26382163e-09 2.21757654e-10 2.60493710e-10 1.50752282e-08
  1.21825190e-08 5.33720068e-10]
 [1.40067036e-19 9.99569714e-01 1.62034616e-04 2.68294272e-04
  5.08616712e-19 3.97659919e-19 1.10308696e-19 1.00320536e-13
  6.17495993e-18 1.59184391e-17]
 [7.95690608e-13 9.89084899e-01 5.19564422e-03 5.71947778e-03
  4.95016051e-12 2.89739933e-12 1.83139354e-12 2.80995849e-09
  1.99842972e-11 1.59432675e-11]
 [1.47618877e-07 8.43083918e-01 7.61877224e-02 

27it [00:04,  5.53it/s]

tf.Tensor(
[[9.12035532e-08 4.37589316e-03 6.43576607e-02 9.31266129e-01
  2.61662536e-09 1.17005772e-09 3.85490945e-10 2.31643025e-08
  1.19902651e-07 1.55885502e-07]
 [1.98533922e-16 9.97248828e-01 1.08393759e-03 1.66718266e-03
  1.19937451e-15 1.09689846e-15 4.68191638e-16 1.23486030e-11
  7.57840124e-15 3.60115820e-14]
 [1.88118180e-18 9.99455512e-01 2.24544026e-04 3.19970335e-04
  6.50451696e-18 3.26679887e-18 1.73942393e-18 4.75465424e-13
  5.93971862e-17 7.86668813e-17]
 [6.62845559e-04 1.30121976e-01 4.74447966e-01 3.87147933e-01
  1.22795778e-03 9.65330866e-04 1.22002955e-03 1.80510839e-03
  1.38822896e-03 1.01266091e-03]
 [2.18781682e-09 5.37288084e-04 8.81571054e-01 1.17891602e-01
  6.46389697e-09 1.59110974e-10 1.17528265e-10 1.01889412e-08
  1.05154614e-08 4.76526762e-10]
 [3.49548820e-04 5.53184450e-02 2.98427284e-01 6.43394053e-01
  2.89165706e-04 2.49580160e-04 2.65175506e-04 4.53153305e-04
  6.02653250e-04 6.50889648e-04]
 [1.43000536e-04 1.44660503e-01 3.44585091e-01 

29it [00:05,  5.33it/s]

tf.Tensor(
[[1.16109091e-14 9.97720182e-01 1.00796961e-03 1.27187790e-03
  4.92684710e-14 2.42508684e-14 1.34368157e-14 1.95728031e-10
  2.94523466e-13 3.12461970e-13]
 [1.52292676e-04 3.78528014e-02 4.92921948e-01 4.67554599e-01
  2.16002707e-04 2.07003264e-04 1.28959509e-04 3.33726028e-04
  3.58594814e-04 2.74173974e-04]
 [1.05972940e-05 4.60839644e-02 1.92750901e-01 7.61093199e-01
  4.48779747e-06 2.81820326e-06 3.43335751e-06 1.04588999e-05
  1.83508819e-05 2.17665820e-05]
 [1.05779058e-07 4.81838267e-03 3.98305953e-01 5.96875191e-01
  2.82909252e-08 3.69931574e-09 4.64830796e-09 5.47750787e-08
  2.77744533e-07 1.12766507e-07]
 [2.67249834e-06 1.43193081e-02 1.44119278e-01 8.41543138e-01
  6.43839314e-07 3.16490315e-07 2.73509613e-07 1.97462123e-06
  5.07780169e-06 7.35946696e-06]
 [4.42051241e-04 3.97898592e-02 3.61716002e-01 5.93167603e-01
  5.66196628e-04 4.92118183e-04 5.86931885e-04 9.72642563e-04
  1.08652434e-03 1.18006137e-03]
 [5.16572584e-07 3.01497569e-03 7.36019909e-01 

31it [00:05,  5.36it/s]

tf.Tensor(
[[1.92936483e-07 1.25908107e-02 1.20529167e-01 8.66878569e-01
  2.83464257e-08 1.40503618e-08 1.05151283e-08 1.30175124e-07
  4.50073458e-07 6.29845772e-07]
 [3.46066045e-05 1.33175785e-02 5.47278643e-01 4.39009279e-01
  5.58335123e-05 2.81215853e-05 3.68870205e-05 8.33490049e-05
  9.97733878e-05 5.59306136e-05]
 [2.92702430e-06 8.89657438e-03 4.58199978e-01 5.32886565e-01
  1.26814314e-06 2.53032340e-07 2.59350514e-07 2.27703936e-06
  6.57888631e-06 3.33249477e-06]
 [2.11310813e-09 8.70486557e-01 5.39026931e-02 7.56104439e-02
  5.58092861e-09 4.77879603e-09 2.11929585e-09 3.13056063e-07
  1.73482473e-08 2.33863293e-08]
 [9.36290975e-14 9.95550752e-01 1.99759845e-03 2.45168689e-03
  4.09111899e-13 3.05499116e-13 1.24288478e-13 7.95496391e-10
  2.27793804e-12 4.88370889e-12]
 [4.39976575e-04 5.21111824e-02 2.70028532e-01 6.74683869e-01
  3.56153207e-04 2.85903254e-04 2.79451750e-04 5.28135803e-04
  6.66240521e-04 6.20644889e-04]
 [7.26769400e-08 1.49710791e-03 8.02386880e-01 

32it [00:05,  4.85it/s]

tf.Tensor(
[[2.52025438e-06 1.11383805e-02 2.71644264e-01 7.17198670e-01
  9.18805426e-07 2.29790970e-07 4.58659144e-07 1.69009013e-06
  6.42322402e-06 6.42872101e-06]
 [1.55766047e-05 4.44576621e-01 2.22518668e-01 3.32598418e-01
  2.24246214e-05 1.88315953e-05 1.67335365e-05 9.81985941e-05
  5.55995903e-05 7.88198740e-05]
 [1.04452920e-06 1.35097848e-02 1.09781519e-01 8.76703620e-01
  1.52040670e-07 6.07965660e-08 5.83840283e-08 5.80198275e-07
  1.66264238e-06 1.54673364e-06]
 [9.08553993e-05 3.47203612e-02 2.21408322e-01 7.43157148e-01
  6.59076031e-05 4.34416615e-05 5.17538247e-05 1.22710408e-04
  1.69673222e-04 1.69818901e-04]
 [1.84321252e-05 2.30829105e-01 3.85049105e-01 3.83899093e-01
  2.08887868e-05 1.12007365e-05 1.61840999e-05 7.34973219e-05
  4.16053117e-05 4.09436834e-05]
 [3.95595816e-06 6.81963749e-03 6.55505836e-01 3.37630540e-01
  5.95653728e-06 2.56856811e-06 2.66796587e-06 1.44220076e-05
  1.05388353e-05 3.75346099e-06]
 [3.13598393e-05 1.87885035e-02 5.15069842e-01 

33it [00:06,  4.59it/s]

tf.Tensor(
[[2.03408581e-05 4.44819450e-01 3.30252379e-01 2.24342942e-01
  9.18496662e-05 5.39394860e-05 6.34663593e-05 1.84361517e-04
  7.05614657e-05 1.00625970e-04]
 [8.37912921e-06 7.38348246e-01 1.18960768e-01 1.42567173e-01
  6.69607653e-06 6.92135472e-06 5.66671588e-06 6.74165058e-05
  1.46481334e-05 1.41770543e-05]
 [2.64454338e-05 2.88928956e-01 3.73046070e-01 3.37655216e-01
  4.35686852e-05 2.89713989e-05 2.94326856e-05 1.34021044e-04
  6.91824171e-05 3.81685386e-05]
 [4.77497460e-15 9.96082127e-01 1.50394079e-03 2.41393014e-03
  1.89385945e-14 1.66618784e-14 3.51071906e-15 1.02290856e-10
  1.47500190e-13 1.88198117e-13]
 [3.86865831e-05 1.77848432e-02 4.67297196e-01 5.14577329e-01
  4.07066473e-05 1.39891217e-05 2.43031081e-05 5.39989778e-05
  1.08081673e-04 6.08765804e-05]
 [4.01060191e-07 6.04147464e-03 8.61846581e-02 9.07770872e-01
  3.53926772e-08 1.29926043e-08 7.25704874e-09 1.58454185e-07
  8.71225552e-07 1.47551305e-06]
 [1.05464834e-12 9.71801043e-01 1.07337935e-02 

34it [00:06,  4.41it/s]

tf.Tensor(
[[1.24371418e-05 1.67933386e-02 6.43101573e-01 3.39954287e-01
  2.93918802e-05 5.63523827e-06 1.59192441e-05 3.26776608e-05
  4.19618664e-05 1.27576777e-05]
 [8.12747821e-05 2.50163406e-01 2.28344426e-01 5.20613730e-01
  6.71810922e-05 9.08819566e-05 4.66996462e-05 1.87522266e-04
  1.85995901e-04 2.18914953e-04]
 [3.22886490e-06 6.98591163e-03 4.78558451e-01 5.14419794e-01
  3.76140611e-06 2.61749142e-06 1.97652321e-06 8.35260835e-06
  1.02825879e-05 5.67202460e-06]
 [1.15290723e-05 1.05926096e-02 1.86765641e-01 8.02553117e-01
  4.68507142e-06 2.24409632e-06 1.56788872e-06 9.85799943e-06
  2.69789562e-05 3.16616351e-05]
 [7.49779999e-17 9.98788059e-01 4.46009624e-04 7.65953155e-04
  4.62005598e-16 4.95557526e-16 9.63772815e-17 8.80249693e-12
  5.52975973e-15 2.00142603e-14]
 [1.05546633e-06 1.01960236e-02 1.59575403e-01 8.30222428e-01
  1.92548256e-07 7.92548605e-08 6.06299650e-08 6.21835568e-07
  1.98248290e-06 2.05425545e-06]
 [1.00153661e-03 6.79209009e-02 4.30285662e-01 

35it [00:06,  4.38it/s]

tf.Tensor(
[[2.23649366e-09 9.32665706e-01 2.86529157e-02 3.86809185e-02
  4.79019402e-09 4.12559498e-09 2.01016759e-09 4.37181939e-07
  1.58937361e-08 1.90950864e-08]
 [1.18938722e-11 9.82572675e-01 7.63086090e-03 9.79651976e-03
  4.87235460e-11 3.93574860e-11 1.72632828e-11 1.87890361e-08
  2.72364215e-10 4.12033796e-10]
 [3.35466038e-20 9.99621630e-01 1.41924276e-04 2.36450593e-04
  1.71183890e-19 1.48268855e-19 3.40509772e-20 3.95631268e-14
  1.91366229e-18 6.63669463e-18]
 [2.06210698e-05 2.39586812e-02 2.26989999e-01 7.48775959e-01
  2.27376258e-05 1.26208524e-05 1.97118843e-05 5.08352678e-05
  6.95164417e-05 7.94346488e-05]
 [4.87394125e-10 9.40296710e-01 2.45625395e-02 3.51405591e-02
  1.27522803e-09 1.16989152e-09 4.56102822e-10 1.78303637e-07
  5.03942310e-09 7.95786370e-09]
 [5.86246606e-04 1.14970274e-01 3.37583214e-01 5.41498840e-01
  5.89583069e-04 6.60867605e-04 7.43651530e-04 1.09533942e-03
  1.08246016e-03 1.18948729e-03]
 [2.13294861e-05 1.16186300e-02 4.28617209e-01 

36it [00:06,  4.44it/s]

tf.Tensor(
[[2.06088450e-16 9.98729408e-01 5.07839373e-04 7.62739626e-04
  1.06381536e-15 8.58885426e-16 2.14986690e-16 1.93270625e-11
  9.61194862e-15 1.83814417e-14]
 [5.01802377e-09 6.27184331e-01 9.74092856e-02 2.75406241e-01
  1.49092383e-09 2.21349739e-09 4.02567674e-10 9.78972636e-08
  2.48931098e-08 2.31987283e-08]
 [1.97632100e-24 9.99982476e-01 6.37019275e-06 1.11337931e-05
  2.35211217e-23 1.69796693e-23 2.31124097e-24 1.14828379e-16
  5.86538202e-22 2.43998772e-21]
 [7.96588608e-08 1.32143805e-02 9.73361954e-02 8.89448702e-01
  9.12604925e-09 3.76677978e-09 2.29744934e-09 4.28790656e-08
  1.85282005e-07 3.15870977e-07]
 [3.21978433e-07 2.24209856e-03 7.60497808e-01 2.37256423e-01
  5.77477863e-07 1.30040320e-07 8.15549868e-08 1.39396457e-06
  1.00385887e-06 2.00892799e-07]
 [2.82876368e-04 4.52368520e-02 2.85604775e-01 6.66532934e-01
  2.79000087e-04 2.78819643e-04 3.01384192e-04 4.82994597e-04
  5.45913819e-04 4.54474590e-04]
 [3.09997995e-04 5.48398383e-02 3.30240577e-01 

37it [00:07,  4.36it/s]

tf.Tensor(
[[1.64103938e-07 2.64688372e-03 7.08853781e-01 2.88497895e-01
  2.02598216e-07 3.27792478e-08 2.58014072e-08 5.11812686e-07
  5.45858313e-07 8.76601689e-08]
 [7.48501567e-04 6.68224841e-02 2.73566842e-01 6.53435290e-01
  6.79913734e-04 5.80904132e-04 5.64156799e-04 1.01147313e-03
  1.21192401e-03 1.37859117e-03]
 [2.54642782e-06 2.08415717e-01 1.74197882e-01 6.17362142e-01
  7.93909180e-07 6.60058276e-07 3.15992907e-07 4.68422286e-06
  6.08084429e-06 9.26082339e-06]
 [4.50607757e-07 2.93132686e-03 6.13695741e-01 3.83366942e-01
  7.73831061e-07 3.07650225e-07 2.06288533e-07 2.01574358e-06
  1.87234002e-06 3.69827632e-07]
 [7.19988373e-14 9.64877009e-01 1.31942928e-02 2.19287183e-02
  1.19566100e-13 1.96174335e-13 5.31311648e-14 2.02753495e-10
  6.47888861e-13 1.88003888e-12]
 [3.34476181e-06 1.29301734e-02 1.54170126e-01 8.32885265e-01
  4.37980134e-07 2.50537909e-07 1.90150587e-07 1.49109280e-06
  4.17442607e-06 4.53754956e-06]
 [7.06297271e-23 9.99948382e-01 1.79207673e-05 

38it [00:07,  4.19it/s]

tf.Tensor(
[[3.89565414e-13 9.88152802e-01 4.65677865e-03 7.19044823e-03
  1.02216391e-12 1.00778619e-12 3.12695779e-13 1.68418302e-09
  5.08805021e-12 9.81321795e-12]
 [1.26689708e-08 9.12853241e-01 3.88504528e-02 4.82947193e-02
  2.84284134e-08 2.22473791e-08 1.39090739e-08 1.38950884e-06
  8.75631727e-08 1.03482080e-07]
 [1.04426108e-04 2.07380101e-01 2.98571706e-01 4.92870569e-01
  1.16628682e-04 1.00089230e-04 1.39198892e-04 2.70540331e-04
  1.81338633e-04 2.65395152e-04]
 [2.65314092e-16 9.98268127e-01 7.43216195e-04 9.88640706e-04
  1.07821238e-15 5.95154783e-16 4.25750258e-16 9.07019859e-12
  7.24247983e-15 2.77742815e-14]
 [4.40530739e-20 9.99642968e-01 1.21588964e-04 2.35470216e-04
  2.33017399e-19 3.30230741e-19 1.07005473e-20 5.35242356e-14
  3.20724683e-18 4.69870517e-18]
 [2.93521690e-13 9.95394886e-01 2.16497830e-03 2.44007050e-03
  6.65750929e-13 4.39816288e-13 3.31439059e-13 2.54802540e-10
  5.85031979e-12 5.49708543e-11]
 [3.46371098e-06 5.00954362e-03 7.18952358e-01 

39it [00:07,  4.10it/s]

tf.Tensor(
[[1.05009254e-04 1.93387363e-02 5.86565256e-01 3.92724246e-01
  2.14899730e-04 1.04262239e-04 1.81919662e-04 2.75405793e-04
  3.11414013e-04 1.78817310e-04]
 [1.50435113e-07 3.64761823e-03 7.60363817e-01 2.35985026e-01
  6.79601442e-07 9.47749612e-08 1.54912314e-07 1.48917127e-06
  7.55982683e-07 1.13577990e-07]
 [9.99974986e-07 1.37568722e-02 1.38267994e-01 8.47968042e-01
  2.23716910e-07 1.18048249e-07 8.92650220e-08 7.22696143e-07
  2.08900997e-06 2.90346338e-06]
 [5.93666527e-07 5.83623303e-03 2.72967279e-01 7.21192896e-01
  2.05945497e-07 5.74494159e-08 7.07486834e-08 4.21631569e-07
  1.37491384e-06 8.22875847e-07]
 [8.97909467e-06 1.69909537e-01 2.08358005e-01 6.21645689e-01
  4.14528222e-06 3.70327280e-06 2.26719635e-06 1.51337026e-05
  2.01370640e-05 3.23099957e-05]
 [2.50524891e-07 2.73400545e-03 7.57830441e-01 2.39432931e-01
  5.28067346e-07 5.37206155e-08 6.27612096e-08 7.16346619e-07
  8.79458298e-07 1.40263282e-07]
 [1.69243211e-18 9.98889863e-01 4.09892295e-04 

40it [00:07,  3.95it/s]

tf.Tensor(
[[2.04093521e-05 2.48678438e-02 1.02120019e-01 8.72888863e-01
  4.08219694e-06 2.02987371e-06 1.52915175e-06 1.21088724e-05
  3.62757091e-05 4.68423386e-05]
 [1.09915090e-08 3.67584713e-02 1.06231652e-01 8.57009768e-01
  1.27728095e-09 5.01755582e-10 2.55497595e-10 1.06845111e-08
  3.78460818e-08 4.81223488e-08]
 [4.64293372e-13 9.92322385e-01 3.30499583e-03 4.37264191e-03
  2.44533972e-12 1.91979397e-12 6.85202871e-13 3.10563553e-09
  1.04181724e-11 1.84667819e-11]
 [3.09761487e-11 9.80643153e-01 8.08236282e-03 1.12744123e-02
  4.63671497e-11 4.87613318e-11 1.67011387e-11 1.52916861e-08
  2.56350691e-10 9.92277704e-10]
 [1.02403546e-05 9.79472604e-03 6.82346344e-01 3.07719678e-01
  1.69807954e-05 1.51846307e-05 3.64498555e-06 3.69266963e-05
  2.93088633e-05 2.69909597e-05]
 [9.51672155e-06 1.34722367e-02 2.36973077e-01 7.49491930e-01
  5.87203021e-06 2.87016269e-06 3.53269729e-06 9.84001963e-06
  2.12696104e-05 9.81899302e-06]
 [5.46578046e-19 9.99778569e-01 9.84473591e-05 

41it [00:08,  3.91it/s]

tf.Tensor(
[[2.33795930e-07 1.83925196e-03 8.16418111e-01 1.81740120e-01
  3.44284956e-07 1.21653983e-07 3.58111905e-08 1.03538002e-06
  5.29385090e-07 1.19803317e-07]
 [2.97384046e-04 7.04296157e-02 3.78748477e-01 5.47742248e-01
  3.63043131e-04 3.82432539e-04 2.20630856e-04 6.25763438e-04
  5.86149108e-04 6.04327652e-04]
 [7.24626880e-05 2.40464434e-02 4.32619572e-01 5.42824566e-01
  5.47362179e-05 2.75363564e-05 3.71866881e-05 9.18888472e-05
  1.40641365e-04 8.49085118e-05]
 [2.31788275e-04 4.54296954e-02 4.72074360e-01 4.79114771e-01
  4.87117301e-04 3.40319006e-04 4.53910761e-04 6.96401228e-04
  7.24703248e-04 4.46947262e-04]
 [5.44414193e-19 9.99111831e-01 3.45332344e-04 5.42840979e-04
  1.61819379e-18 1.23554521e-18 4.87527845e-19 1.57688492e-13
  1.61781804e-17 4.22969674e-17]
 [1.49251433e-09 9.14038956e-01 3.32748927e-02 5.26859052e-02
  2.48805176e-09 1.92629890e-09 1.04665754e-09 1.97252547e-07
  1.23979538e-08 2.54733727e-08]
 [1.16283798e-06 1.07264407e-02 9.63244662e-02 

44it [00:08,  4.83it/s]

tf.Tensor(
[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

46it [00:09,  5.24it/s]

tf.Tensor(
[[3.39166329e-23 9.99969125e-01 1.11570398e-05 1.97484987e-05
  2.83370172e-22 2.27288194e-22 1.37395715e-23 6.51671045e-16
  5.02017449e-21 8.54335159e-21]
 [1.87922359e-16 9.95850444e-01 1.53529586e-03 2.61431257e-03
  8.70470349e-16 1.16587150e-15 2.29500162e-16 1.23660995e-11
  5.92702115e-15 2.29343759e-14]
 [2.30464748e-16 9.99296188e-01 3.66856111e-04 3.36934201e-04
  3.42998898e-15 1.85717839e-15 1.35649320e-15 2.18519230e-11
  4.29170622e-14 1.00881129e-13]
 [1.93890433e-08 9.21015263e-01 3.37223895e-02 4.52604331e-02
  3.40427242e-08 2.85055641e-08 4.21194173e-08 1.67685096e-06
  8.76599202e-08 2.79288468e-08]
 [1.13474443e-05 2.15358455e-02 2.06484005e-01 7.71913290e-01
  3.04214063e-06 1.87840828e-06 1.91295180e-06 7.65601908e-06
  1.67356848e-05 2.43155155e-05]
 [4.52125846e-07 2.29440909e-03 7.62258172e-01 2.35440612e-01
  1.04637911e-06 3.16408062e-07 1.92579279e-07 2.53495477e-06
  1.93298069e-06 2.98818151e-07]
 [5.89817489e-07 9.87378042e-03 1.67304978e-01 

47it [00:09,  5.07it/s]


tf.Tensor(
[[1.90522940e-08 3.11942562e-03 5.55564798e-02 9.41324115e-01
  4.86936547e-10 1.29166677e-10 8.25006521e-11 4.24489466e-09
  3.22132543e-08 2.37565949e-08]
 [4.20032734e-07 2.66728364e-03 7.47880936e-01 2.49446794e-01
  8.68927259e-07 1.21994262e-07 1.42152359e-07 1.50521964e-06
  1.69312113e-06 2.35517703e-07]
 [7.91575960e-09 8.69308352e-01 5.94601296e-02 7.12304562e-02
  3.25555121e-08 1.37861349e-08 2.81659069e-08 6.73942509e-07
  1.09483771e-07 1.56263184e-07]
 [3.55363653e-07 1.11084960e-01 1.19985305e-01 7.68926382e-01
  6.31291641e-08 2.63228301e-08 1.55100306e-08 4.71846533e-07
  1.02673255e-06 1.44762384e-06]
 [9.30006960e-10 9.70217943e-01 1.54373562e-02 1.43445674e-02
  1.81940374e-09 1.28105815e-09 1.01249809e-09 1.16298352e-07
  6.07261352e-09 1.85124183e-08]
 [6.52766289e-07 1.08070578e-02 9.87908691e-02 8.90398324e-01
  7.97089186e-08 3.00101242e-08 2.12680291e-08 3.29204880e-07
  1.14581201e-06 1.58468663e-06]
 [3.77564993e-12 9.81559992e-01 7.74985319e-03 

58it [01:00,  1.05s/it]



Evaluate generator on Task:  0  Epoch:  0


0it [00:00, ?it/s]


InvalidArgumentError: ignored

In [None]:
''' Evaluation'''
for task, loader in enumerate(test_loader[:tasks_to_test]):
    print("Task: ", task)
    LOSS = []
    ACC = []
    for data, target in loader:
      logits = agent.classifier_model(data)
      pred = np.argmax(logits, axis=1)
      report = agent.eval(np.argmax(target, axis=1), pred)
      loss = tf.keras.losses.categorical_crossentropy(target, logits)
      #print(report)
      ACC.append(report)
      LOSS.append(loss)
    print("Mean loss: ", np.mean(LOSS))
    print("Mean accuracy: ", np.mean(ACC))
    print("\n")

# Training

In [None]:
agent = Agent(params)
for r in range(agent.params["n_runs"]):
  agent.state["run"] = r
  run(agent)

# Evaluation, testing

In [None]:
def evaluate(loader, first_n_tasks=None):
    for task, tr_loader in enumerate(loader):
        print("Task: ", task)
        data, target = tr_loader.batch(124)
        logits = agent.classifier_model(data)
        pred = np.argmax(logits, axis=1)
        report = agent.eval(np.argmax(target, axis=1), pred)
        loss = tf.keras.losses.categorical_crossentropy(target, logits)
        print(report)
        print("Mean loss: ", np.mean(loss))

In [None]:
print("Evaluation on training set:")
evaluate(train_loader)
print("Evaluation on test set:")
evaluate(test_loader)

# Utils for development

In [None]:
# Reload modules
importlib.reload(stable_diffusion)

<module 'stable_diffusion.stable_diffusion' from '/Users/laszlofreund/PycharmProjects/continual-learning-ait/stable_diffusion/stable_diffusion.py'>

In [None]:
# Garbage collection
gc.collect()

21547