In [30]:
import wandb
wandb.init(project="pytorch-ignite-example")

W&B Run: https://app.wandb.ai/evs/pytorch-ignite-example/runs/csabzg3f

In [37]:
import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F

from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST


from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
        
from ignite.metrics import Accuracy, Loss

from tqdm import tqdm


In [44]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_features, num_samples):
        super(VariationalAutoencoder, self).__init__()
        
        self.latent_features = latent_features
        self.num_samples = num_samples

        # We encode the data onto the latent space using two linear layers
        self.encoder = nn.Sequential(
            nn.Linear(in_features=num_features, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            # A Gaussian is fully characterised by its mean and variance
            nn.Linear(in_features=64, out_features=2*self.latent_features) # <- note the 2*latent_features
        )
        
        # The latent code must be decoded into the original image
        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.latent_features, out_features=64),
            nn.LeakyReLU(),
            nn.Linear(in_features=64, out_features=128),
            nn.LeakyReLU(),
            nn.Linear(in_features=128, out_features=256),
            nn.LeakyReLU(),
            nn.Linear(in_features=256, out_features=num_features)
        )
        

    def forward(self, x): 
        outputs = {}
        
        # Split encoder outputs into a mean and variance vector
        mu, log_var = torch.chunk(self.encoder(x), 2, dim=-1)
        
        # :- Reparametrisation trick
        # a sample from N(mu, sigma) is mu + sigma * epsilon
        # where epsilon ~ N(0, 1)
                
        # Don't propagate gradients through randomness
        with torch.no_grad():
            batch_size = mu.size(0)
            epsilon = torch.randn(batch_size, self.num_samples, self.latent_features)
            
            if cuda:
                epsilon = epsilon.cuda()
        
        sigma = torch.exp(log_var/2)
        
        # We will need to unsqueeze to turn
        # (batch_size, latent_dim) -> (batch_size, 1, latent_dim)
        z = mu.unsqueeze(1) + epsilon * sigma.unsqueeze(1)        
        
        # Run through decoder
        x = self.decoder(z)
        
        # The original digits are on the scale [0, 1]
        x = torch.sigmoid(x)
        
        # Mean over samples
        x_hat = torch.mean(x, dim=1)
        
        outputs["x_hat"] = x_hat
        outputs["z"] = z
        outputs["mu"] = mu
        outputs["log_var"] = log_var
        
        return outputs


In [42]:
def get_data_loaders(train_batch_size, val_batch_size):
    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    wandb.watch(model)
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': Accuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            device=device)

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),
        desc=desc.format(0)
    )

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(log_interval)
        wandb.log({"train loss": engine.state.output})

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll)
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0
        wandb.log({"validation loss": avg_nll})
        wandb.log({"validation accuracy": avg_accuracy})

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()

In [43]:
# Train Model

hyperparameter_defaults = dict(
    batch_size = 256,
    val_batch_size = 100,
    epochs = 10,
    lr = 0.001,
    momentum = 0.3,
    log_interval = 10,
)


# Get metrics in Weights & Biases
wandb.init(config=hyperparameter_defaults, project="pytorch-ignite-example")
config = wandb.config
run(config.batch_size, config.val_batch_size, config.epochs, config.lr, config.momentum, config.log_interval)



ITERATION - loss: 0.00:   0%|          | 0/235 [00:00<?, ?it/s][A[A

ITERATION - loss: 2.32:   4%|▍         | 10/235 [00:00<00:16, 13.24it/s][A[A

ITERATION - loss: 2.31:   9%|▊         | 20/235 [00:01<00:16, 13.18it/s][A[A

ITERATION - loss: 2.31:  13%|█▎        | 30/235 [00:02<00:16, 12.44it/s][A[A

ITERATION - loss: 2.29:  17%|█▋        | 40/235 [00:03<00:15, 12.43it/s][A[A

ITERATION - loss: 2.32:  21%|██▏       | 50/235 [00:03<00:14, 13.11it/s][A[A

ITERATION - loss: 2.32:  26%|██▌       | 60/235 [00:04<00:12, 13.60it/s][A[A

ITERATION - loss: 2.33:  30%|██▉       | 70/235 [00:05<00:12, 13.10it/s][A[A

ITERATION - loss: 2.32:  34%|███▍      | 80/235 [00:06<00:13, 11.78it/s][A[A

ITERATION - loss: 2.29:  38%|███▊      | 90/235 [00:07<00:12, 11.59it/s][A[A

ITERATION - loss: 2.29:  43%|████▎     | 100/235 [00:08<00:12, 11.03it/s][A[A

ITERATION - loss: 2.32:  47%|████▋     | 110/235 [00:09<00:11, 10.92it/s][A[A

ITERATION - loss: 2.30:  51%|█████     | 120/

Training Results - Epoch: 1  Avg accuracy: 0.17 Avg loss: 2.28


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [05:25<00:21, 10.77it/s]  
ITERATION - loss: 2.30:   0%|          | 0/235 [02:34<00:15, 15.30it/s][A

ITERATION - loss: 2.30:  98%|█████████▊| 230/235 [00:31<00:00, 14.57it/s][A[A

Validation Results - Epoch: 1  Avg accuracy: 0.16 Avg loss: 2.28




ITERATION - loss: 2.29:   4%|▍         | 10/235 [00:32<01:45,  2.13it/s] [A[A

ITERATION - loss: 2.29:   9%|▊         | 20/235 [00:32<01:15,  2.85it/s][A[A

ITERATION - loss: 2.28:  13%|█▎        | 30/235 [00:33<00:54,  3.75it/s][A[A

ITERATION - loss: 2.26:  17%|█▋        | 40/235 [00:34<00:40,  4.82it/s][A[A

ITERATION - loss: 2.28:  21%|██▏       | 50/235 [00:34<00:30,  5.98it/s][A[A

ITERATION - loss: 2.28:  26%|██▌       | 60/235 [00:35<00:24,  7.23it/s][A[A

ITERATION - loss: 2.29:  30%|██▉       | 70/235 [00:36<00:19,  8.38it/s][A[A

ITERATION - loss: 2.27:  34%|███▍      | 80/235 [00:37<00:16,  9.63it/s][A[A

ITERATION - loss: 2.30:  38%|███▊      | 90/235 [00:37<00:13, 10.80it/s][A[A

ITERATION - loss: 2.27:  43%|████▎     | 100/235 [00:38<00:11, 11.37it/s][A[A

ITERATION - loss: 2.27:  47%|████▋     | 110/235 [00:39<00:10, 12.15it/s][A[A

ITERATION - loss: 2.29:  51%|█████     | 120/235 [00:39<00:09, 12.66it/s][A[A

ITERATION - loss: 2.28:  55%|█████

Training Results - Epoch: 2  Avg accuracy: 0.28 Avg loss: 2.24


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [05:54<00:21, 10.77it/s]
ITERATION - loss: 2.30:   0%|          | 0/235 [03:02<00:15, 15.30it/s][A

ITERATION - loss: 2.25: 240it [00:59, 14.86it/s][A[A

Validation Results - Epoch: 2  Avg accuracy: 0.27 Avg loss: 2.24




ITERATION - loss: 2.26:   4%|▍         | 10/235 [01:00<01:33,  2.40it/s][A[A

ITERATION - loss: 2.26:   9%|▊         | 20/235 [01:01<01:07,  3.19it/s][A[A

ITERATION - loss: 2.25:  13%|█▎        | 30/235 [01:02<00:49,  4.17it/s][A[A

ITERATION - loss: 2.26:  17%|█▋        | 40/235 [01:02<00:36,  5.29it/s][A[A

ITERATION - loss: 2.25:  21%|██▏       | 50/235 [01:03<00:28,  6.48it/s][A[A

ITERATION - loss: 2.23:  26%|██▌       | 60/235 [01:04<00:22,  7.80it/s][A[A

ITERATION - loss: 2.24:  30%|██▉       | 70/235 [01:04<00:18,  8.97it/s][A[A

ITERATION - loss: 2.22:  34%|███▍      | 80/235 [01:05<00:15, 10.28it/s][A[A

ITERATION - loss: 2.25:  38%|███▊      | 90/235 [01:06<00:12, 11.45it/s][A[A

ITERATION - loss: 2.21:  43%|████▎     | 100/235 [01:06<00:10, 12.37it/s][A[A

ITERATION - loss: 2.22:  47%|████▋     | 110/235 [01:07<00:09, 13.02it/s][A[A

ITERATION - loss: 2.23:  51%|█████     | 120/235 [01:08<00:08, 13.35it/s][A[A

ITERATION - loss: 2.23:  55%|█████▌

Training Results - Epoch: 3  Avg accuracy: 0.40 Avg loss: 2.18




                                                                       s][A[A
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [06:20<00:21, 10.77it/s]  
ITERATION - loss: 2.30:   0%|          | 0/235 [03:29<00:15, 15.30it/s][A

ITERATION - loss: 2.22:  98%|█████████▊| 230/235 [01:26<00:00, 15.52it/s][A[A

Validation Results - Epoch: 3  Avg accuracy: 0.40 Avg loss: 2.18




ITERATION - loss: 2.21:   4%|▍         | 10/235 [01:27<01:28,  2.54it/s] [A[A

ITERATION - loss: 2.22:   9%|▊         | 20/235 [01:27<01:03,  3.39it/s][A[A

ITERATION - loss: 2.21:  13%|█▎        | 30/235 [01:28<00:46,  4.43it/s][A[A

ITERATION - loss: 2.20:  17%|█▋        | 40/235 [01:28<00:34,  5.63it/s][A[A

ITERATION - loss: 2.20:  21%|██▏       | 50/235 [01:29<00:26,  6.97it/s][A[A

ITERATION - loss: 2.16:  26%|██▌       | 60/235 [01:30<00:21,  8.33it/s][A[A

ITERATION - loss: 2.22:  30%|██▉       | 70/235 [01:30<00:16,  9.72it/s][A[A

ITERATION - loss: 2.21:  34%|███▍      | 80/235 [01:31<00:14, 10.67it/s][A[A

ITERATION - loss: 2.19:  38%|███▊      | 90/235 [01:32<00:12, 11.75it/s][A[A

ITERATION - loss: 2.18:  43%|████▎     | 100/235 [01:32<00:10, 12.65it/s][A[A

ITERATION - loss: 2.20:  47%|████▋     | 110/235 [01:33<00:09, 13.20it/s][A[A

ITERATION - loss: 2.20:  51%|█████     | 120/235 [01:34<00:08, 13.76it/s][A[A

ITERATION - loss: 2.18:  55%|█████

Training Results - Epoch: 4  Avg accuracy: 0.52 Avg loss: 2.08


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [06:47<00:21, 10.77it/s]
ITERATION - loss: 2.30:   0%|          | 0/235 [03:55<00:15, 15.30it/s][A

ITERATION - loss: 2.13: 240it [01:52, 15.67it/s][A[A

Validation Results - Epoch: 4  Avg accuracy: 0.52 Avg loss: 2.08




ITERATION - loss: 2.11:   4%|▍         | 10/235 [01:53<01:28,  2.55it/s][A[A

ITERATION - loss: 2.13:   9%|▊         | 20/235 [01:54<01:03,  3.40it/s][A[A

ITERATION - loss: 2.15:  13%|█▎        | 30/235 [01:54<00:46,  4.45it/s][A[A

ITERATION - loss: 2.11:  17%|█▋        | 40/235 [01:55<00:34,  5.66it/s][A[A

ITERATION - loss: 2.12:  21%|██▏       | 50/235 [01:56<00:26,  6.99it/s][A[A

ITERATION - loss: 2.10:  26%|██▌       | 60/235 [01:56<00:20,  8.36it/s][A[A

ITERATION - loss: 2.11:  30%|██▉       | 70/235 [01:57<00:16,  9.71it/s][A[A

ITERATION - loss: 2.10:  34%|███▍      | 80/235 [01:58<00:14, 10.88it/s][A[A

ITERATION - loss: 2.10:  38%|███▊      | 90/235 [01:58<00:12, 11.86it/s][A[A

ITERATION - loss: 2.06:  43%|████▎     | 100/235 [01:59<00:10, 12.83it/s][A[A

ITERATION - loss: 2.10:  47%|████▋     | 110/235 [01:59<00:09, 13.56it/s][A[A

ITERATION - loss: 2.08:  51%|█████     | 120/235 [02:00<00:08, 14.17it/s][A[A

ITERATION - loss: 2.07:  55%|█████▌

Training Results - Epoch: 5  Avg accuracy: 0.60 Avg loss: 1.91


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [07:13<00:21, 10.77it/s]  
ITERATION - loss: 2.30:   0%|          | 0/235 [04:21<00:15, 15.30it/s][A

ITERATION - loss: 2.03:  98%|█████████▊| 230/235 [02:18<00:00, 15.56it/s][A[A

Validation Results - Epoch: 5  Avg accuracy: 0.61 Avg loss: 1.90




ITERATION - loss: 1.99:   4%|▍         | 10/235 [02:19<01:28,  2.55it/s] [A[A

ITERATION - loss: 2.03:   9%|▊         | 20/235 [02:19<01:03,  3.41it/s][A[A

ITERATION - loss: 1.99:  13%|█▎        | 30/235 [02:20<00:46,  4.45it/s][A[A

ITERATION - loss: 1.99:  17%|█▋        | 40/235 [02:21<00:34,  5.64it/s][A[A

ITERATION - loss: 1.94:  21%|██▏       | 50/235 [02:21<00:26,  6.97it/s][A[A

ITERATION - loss: 1.98:  26%|██▌       | 60/235 [02:22<00:20,  8.38it/s][A[A

ITERATION - loss: 1.91:  30%|██▉       | 70/235 [02:23<00:16,  9.76it/s][A[A

ITERATION - loss: 1.96:  34%|███▍      | 80/235 [02:23<00:14, 10.94it/s][A[A

ITERATION - loss: 1.96:  38%|███▊      | 90/235 [02:24<00:12, 12.04it/s][A[A

ITERATION - loss: 1.91:  43%|████▎     | 100/235 [02:25<00:10, 12.84it/s][A[A

ITERATION - loss: 1.94:  47%|████▋     | 110/235 [02:25<00:09, 13.50it/s][A[A

ITERATION - loss: 1.94:  51%|█████     | 120/235 [02:26<00:08, 14.02it/s][A[A

ITERATION - loss: 1.93:  55%|█████

Training Results - Epoch: 6  Avg accuracy: 0.67 Avg loss: 1.65


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [07:39<00:21, 10.77it/s]
ITERATION - loss: 2.30:   0%|          | 0/235 [04:47<00:15, 15.30it/s][A

ITERATION - loss: 1.77: 240it [02:45, 15.75it/s][A[A

Validation Results - Epoch: 6  Avg accuracy: 0.68 Avg loss: 1.64




ITERATION - loss: 1.77:   0%|          | 0/235 [02:45<00:14, 15.75it/s][A[A

ITERATION - loss: 1.75:   4%|▍         | 10/235 [02:45<01:28,  2.54it/s][A[A

ITERATION - loss: 1.84:   9%|▊         | 20/235 [02:46<01:03,  3.39it/s][A[A

ITERATION - loss: 1.77:  13%|█▎        | 30/235 [02:47<00:46,  4.42it/s][A[A

ITERATION - loss: 1.84:  17%|█▋        | 40/235 [02:47<00:34,  5.64it/s][A[A

ITERATION - loss: 1.76:  21%|██▏       | 50/235 [02:48<00:26,  6.98it/s][A[A

ITERATION - loss: 1.74:  26%|██▌       | 60/235 [02:48<00:20,  8.36it/s][A[A

ITERATION - loss: 1.81:  30%|██▉       | 70/235 [02:49<00:16,  9.74it/s][A[A

ITERATION - loss: 1.81:  34%|███▍      | 80/235 [02:50<00:14, 10.86it/s][A[A

ITERATION - loss: 1.74:  38%|███▊      | 90/235 [02:50<00:12, 11.95it/s][A[A

ITERATION - loss: 1.85:  43%|████▎     | 100/235 [02:51<00:10, 12.70it/s][A[A

ITERATION - loss: 1.73:  47%|████▋     | 110/235 [02:52<00:09, 13.49it/s][A[A

ITERATION - loss: 1.73:  51%|█████   

Training Results - Epoch: 7  Avg accuracy: 0.72 Avg loss: 1.38


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [08:05<00:21, 10.77it/s]  
ITERATION - loss: 2.30:   0%|          | 0/235 [05:13<00:15, 15.30it/s][A

ITERATION - loss: 1.75:  98%|█████████▊| 230/235 [03:11<00:00, 15.46it/s][A[A

Validation Results - Epoch: 7  Avg accuracy: 0.73 Avg loss: 1.37




ITERATION - loss: 1.62:   4%|▍         | 10/235 [03:11<01:29,  2.52it/s] [A[A

ITERATION - loss: 1.69:   9%|▊         | 20/235 [03:12<01:03,  3.37it/s][A[A

ITERATION - loss: 1.60:  13%|█▎        | 30/235 [03:13<00:46,  4.40it/s][A[A

ITERATION - loss: 1.64:  17%|█▋        | 40/235 [03:13<00:34,  5.61it/s][A[A

ITERATION - loss: 1.60:  21%|██▏       | 50/235 [03:14<00:26,  6.95it/s][A[A

ITERATION - loss: 1.55:  26%|██▌       | 60/235 [03:14<00:21,  8.31it/s][A[A

ITERATION - loss: 1.51:  30%|██▉       | 70/235 [03:15<00:17,  9.66it/s][A[A

ITERATION - loss: 1.60:  34%|███▍      | 80/235 [03:16<00:14, 10.82it/s][A[A

ITERATION - loss: 1.60:  38%|███▊      | 90/235 [03:16<00:12, 11.91it/s][A[A

ITERATION - loss: 1.65:  43%|████▎     | 100/235 [03:17<00:10, 12.76it/s][A[A

ITERATION - loss: 1.59:  47%|████▋     | 110/235 [03:18<00:09, 13.45it/s][A[A

ITERATION - loss: 1.52:  51%|█████     | 120/235 [03:18<00:08, 14.01it/s][A[A

ITERATION - loss: 1.53:  55%|█████

Training Results - Epoch: 8  Avg accuracy: 0.76 Avg loss: 1.14


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [08:32<00:21, 10.77it/s]
ITERATION - loss: 2.30:   0%|          | 0/235 [05:40<00:15, 15.30it/s][A

ITERATION - loss: 1.64: 240it [03:37, 15.39it/s][A[A

Validation Results - Epoch: 8  Avg accuracy: 0.77 Avg loss: 1.12




ITERATION - loss: 1.43:   4%|▍         | 10/235 [03:38<01:29,  2.51it/s][A[A

ITERATION - loss: 1.42:   9%|▊         | 20/235 [03:39<01:04,  3.35it/s][A[A

ITERATION - loss: 1.41:  13%|█▎        | 30/235 [03:39<00:46,  4.36it/s][A[A

ITERATION - loss: 1.48:  17%|█▋        | 40/235 [03:40<00:35,  5.55it/s][A[A

ITERATION - loss: 1.44:  21%|██▏       | 50/235 [03:41<00:27,  6.79it/s][A[A

ITERATION - loss: 1.38:  26%|██▌       | 60/235 [03:41<00:21,  8.14it/s][A[A

ITERATION - loss: 1.46:  30%|██▉       | 70/235 [03:42<00:17,  9.42it/s][A[A

ITERATION - loss: 1.38:  34%|███▍      | 80/235 [03:43<00:14, 10.67it/s][A[A

ITERATION - loss: 1.36:  38%|███▊      | 90/235 [03:43<00:12, 11.79it/s][A[A

ITERATION - loss: 1.40:  43%|████▎     | 100/235 [03:44<00:10, 12.51it/s][A[A

ITERATION - loss: 1.37:  47%|████▋     | 110/235 [03:45<00:09, 13.30it/s][A[A

ITERATION - loss: 1.25:  51%|█████     | 120/235 [03:45<00:08, 13.94it/s][A[A

ITERATION - loss: 1.35:  55%|█████▌

Training Results - Epoch: 9  Avg accuracy: 0.79 Avg loss: 0.95


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [08:58<00:21, 10.77it/s]  
ITERATION - loss: 2.30:   0%|          | 0/235 [06:06<00:15, 15.30it/s][A

ITERATION - loss: 1.23:  98%|█████████▊| 230/235 [04:04<00:00, 15.29it/s][A[A

Validation Results - Epoch: 9  Avg accuracy: 0.80 Avg loss: 0.93




ITERATION - loss: 1.34:   4%|▍         | 10/235 [04:04<01:29,  2.52it/s] [A[A

ITERATION - loss: 1.32:   9%|▊         | 20/235 [04:05<01:03,  3.37it/s][A[A

ITERATION - loss: 1.31:  13%|█▎        | 30/235 [04:05<00:46,  4.41it/s][A[A

ITERATION - loss: 1.35:  17%|█▋        | 40/235 [04:06<00:34,  5.60it/s][A[A

ITERATION - loss: 1.31:  21%|██▏       | 50/235 [04:07<00:26,  6.94it/s][A[A

ITERATION - loss: 1.17:  26%|██▌       | 60/235 [04:07<00:20,  8.34it/s][A[A

ITERATION - loss: 1.25:  30%|██▉       | 70/235 [04:08<00:16,  9.71it/s][A[A

ITERATION - loss: 1.30:  34%|███▍      | 80/235 [04:09<00:14, 10.97it/s][A[A

ITERATION - loss: 1.20:  38%|███▊      | 90/235 [04:09<00:11, 12.08it/s][A[A

ITERATION - loss: 1.20:  43%|████▎     | 100/235 [04:10<00:10, 12.98it/s][A[A

ITERATION - loss: 1.23:  47%|████▋     | 110/235 [04:10<00:09, 13.57it/s][A[A

ITERATION - loss: 1.18:  51%|█████     | 120/235 [04:11<00:08, 14.06it/s][A[A

ITERATION - loss: 1.27:  55%|█████

Training Results - Epoch: 10  Avg accuracy: 0.81 Avg loss: 0.81


                                                                       
[A                                                                    

ITERATION - loss: 2.29:   0%|          | 0/235 [09:24<00:21, 10.77it/s]
ITERATION - loss: 2.30:   0%|          | 0/235 [06:32<00:15, 15.30it/s][A

ITERATION - loss: 1.11: 240it [04:30, 15.74it/s][A[A

                                                [A[A

Validation Results - Epoch: 10  Avg accuracy: 0.82 Avg loss: 0.79
