# Vision Transformer from scratch

In [2]:
! pip install --quiet "setuptools==59.5.0" "pytorch-lightning>=1.4" "matplotlib" "torch>=1.8" "ipython[notebook]" "torchmetrics>=0.7" "torchvision" "seaborn"

[K     |████████████████████████████████| 952 kB 5.2 MB/s 
[K     |████████████████████████████████| 798 kB 38.4 MB/s 
[K     |████████████████████████████████| 529 kB 33.9 MB/s 
[K     |████████████████████████████████| 87 kB 5.8 MB/s 
[K     |████████████████████████████████| 1.6 MB 41.8 MB/s 
[?25h  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [3]:
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import numpy as np
import matplotlib.pyplot as plt
import os
import math

In [4]:
#--------------------------------
# Device configuration
#--------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device: %s'%device)

# Init DataLoader from MNIST Dataset
# Init DataLoader from MNIST Dataset
batch_size = 512

transform=transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,))])
    
mnist_train = torchvision.datasets.MNIST('.', train=True, download=True, transform=transform)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [50000, 10000])

mnist_test = torchvision.datasets.MNIST(os.getcwd(), train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, num_workers=8)
val_loader = torch.utils.data.DataLoader(mnist_val, batch_size=64, num_workers=8) 
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, num_workers=8)

Using device: cpu
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  cpuset_checked))


In [5]:
class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """
    def __init__(self, model_dim, max_len, device):
        """
        model_dim: dimension of model
        max_len: max sequence length
        device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        self.encoding = torch.zeros(max_len, model_dim, device=device)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, model_dim, step=2, device=device).float()
        # 'i' means index of model_dim (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / model_dim)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / model_dim)))

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # batch_size, seq_len, model_dim
        return self.encoding[:seq_len, :] 
  

In [6]:
class LinearProjection(nn.Module):
    """
    return linear projection of each patch
    """
    def __init__(self, patch_size, hidden_dim):
        """
        patch_size: embeded patch size
        hidden_dim: dimension of linear projection
        """
        super().__init__()
        self.patch_size = patch_size
        self.linear_emb = nn.Linear(patch_size * patch_size, hidden_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))

    def split_up_patches(self, x, patch_size):
        h = x.shape[-2]
        w = x.shape[-1]
        patches = nn.Unfold(kernel_size = patch_size, stride = patch_size+1)(x)
        patches = torch.permute(patches, (0, 2, 1)) # note: index convention is (n_batches, n_tokens, hidden_dim)!
        return patches

    def forward(self, x):
        #  split image to fixed size patch and flatten
        x = self.split_up_patches(x, self.patch_size)

        # linear embedding
        x = self.linear_emb(x) # batch_size, seq, hidden_dim

        # concate class token
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 1) # batch_size, seq+1, hidden_dim
        return x

In [7]:
class Attention(nn.Module):
    """
    compute attention
    """
    def __init__(self, hidden_dim):
        """
        hidden_dim: dimension of linear projection
        """
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.W_Q = nn.Linear(hidden_dim, hidden_dim)
        self.W_K = nn.Linear(hidden_dim, hidden_dim)
        self.W_V = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        
        score = self.softmax(Q @ K.T.permute(2,0,1)/np.sqrt(self.hidden_dim))
        attention = score @ V
        
        return attention

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nr_heads):
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.nr_heads = nr_heads
        self.head_dim = int(hidden_dim/nr_heads)
        self.attention = torch.Tensor([])
        self.W_concat = nn.Linear(hidden_dim, hidden_dim)
        self.attn = Attention(hidden_dim=self.head_dim)

    def forward(self, x):

        self.attention = torch.Tensor([])
        batch_size,_,_ = x.shape

        # devide x into 'nr_heads' splits
        x = x.view(batch_size, -1, self.nr_heads, self.head_dim).permute(0,2,1,3)
        
        # compute attention by each head of x
        for i in range(self.nr_heads):
            self.attention = torch.cat((self.attention, self.attn(x[:,i])),2)

        return self.W_concat(self.attention)

In [8]:
class LayerNorm(nn.Module):
    """
    normalize layer
    """
    def __init__(self, model_dim, eps=1e-12):
        """
        model_dim: dimension of the input for normalization
        eps: epsilon for handling zero denominator (std+eps)
        """
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(model_dim)).to(device)
        self.beta = nn.Parameter(torch.zeros(model_dim)).to(device)
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)

        out = (x - mean) / (std + self.eps)
        out = self.gamma * out + self.beta
        return out

class FFN(nn.Module):
    """
    mlp layer
    """
    def __init__(self, model_dim, hidden_dim, drop_prob=0.1):
        """
        model_dim: dimension of the input
        hidden_dim: dimension of hidden layer
        drop_prob: probability of dropout
        """
        super(FFN, self).__init__()

        self.linear1 =nn.Linear(model_dim, hidden_dim)
        self.gelu = nn.GELU()
        self.linear2 =nn.Linear(hidden_dim, model_dim)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        self.ffn = nn.Sequential(
            self.linear1,
            self.gelu,
            self.dropout,
            self.linear2,
            self.dropout
        )
        return self.ffn(x)

class Transformer_module(nn.Module):
    """
    transformer encoder layer
    """
    def __init__(self, hidden_dim, ffn_hidden, nr_heads):
        """
        hidden_dim: dimension of the input
        ffn_hidden: dimension of hidden layer
        nr_heads: number of heads for multihead attention
        """
        super(Transformer_module, self).__init__()
        self.hidden_dim = hidden_dim
        self.ffn_hidden = ffn_hidden
        self.nr_heads = nr_heads

        self.ln1 = LayerNorm(hidden_dim)
        self.msa = MultiHeadAttention(hidden_dim=hidden_dim, nr_heads=nr_heads)
        self.ln2 = LayerNorm(hidden_dim)
        self.ffn = FFN(model_dim=hidden_dim, hidden_dim=ffn_hidden)
        self.ln3 = LayerNorm(hidden_dim)

    def forward(self, x):
        # layer normalization
        norm = self.ln1(x)

        # compute self attention
        attn = self.msa(norm)
        
        # residual connection and normalization
        x = attn + x
        norm = self.ln2(x)

        # feed forward network (mlp)
        ffn = self.ffn(norm)

        # residual connection and normalization
        out = self.ln3(x + ffn)
        return out

In [9]:
class Transformer(pl.LightningModule):
    """
    vision transformer model
    """
    def __init__(self, hidden_dim = 16, num_class = 10, nr_layers = 3, nr_heads = 3, patch_size = 4):
        """
        hidden_dim: dimension of linear projection
        num_class: number of classes
        nr_layers: number of transformer layers
        nr_heads: number of heads for multihead attention
        patch_size: size of the patch for patch embedding
        """
        super().__init__()
        self.num_class = num_class
        self.nr_layers = nr_layers
        self.nr_heads = nr_heads
        self.learning_rate = 0.01 #1e-2
        self.patch_size = patch_size

        self.train_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')
        self.val_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')
        self.test_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')

        self.hidden_dim = hidden_dim

        # for processing input (patch emb, linear prj, pos_enc)
        self.linear_prj = LinearProjection(self.patch_size, self.hidden_dim)
        self.pos_enc = PositionalEncoding(self.hidden_dim, self.hidden_dim, device)
        
        # Transformer blocks
        self.transformer_list = nn.ModuleList([])
        for i in range(nr_layers):
            self.transformer_list.append(Transformer_module(hidden_dim, hidden_dim, nr_heads))
        
        # classification head
        self.mlp = nn.Sequential(LayerNorm(hidden_dim), nn.Linear(hidden_dim,num_class))


    def forward(self, x):
        # patch embedding
        x = self.linear_prj(x)

        # add positional encoding
        x = x + self.pos_enc(x)

        # perform transformer encoder
        for transformer in self.transformer_list:
            x = transformer(x)

        # for classification
        return self.mlp(x[:,0])

    def training_step(self, train_batch, batch_idx):
        images, labels = train_batch
        
        outputs = self(images)
        criterion = torch.nn.CrossEntropyLoss()

        loss = criterion(outputs, labels)
        self.log('train_loss', loss)

        pred_labels = torch.argmax(outputs, 1)
        
        self.train_acc(pred_labels, labels)

        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        images, labels = val_batch
        outputs = self(images)
        pred_labels = torch.argmax(outputs, 1)

        self.val_acc(pred_labels, labels)

        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        pred_labels = torch.argmax(outputs, 1)

        self.test_acc(pred_labels, labels)

        self.log("test_acc", self.test_acc)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
        return opt

In [10]:
Transformer(256, 10, 3, 4, 16)

Transformer(
  (train_acc): MulticlassAccuracy()
  (val_acc): MulticlassAccuracy()
  (test_acc): MulticlassAccuracy()
  (linear_prj): LinearProjection(
    (linear_emb): Linear(in_features=256, out_features=256, bias=True)
  )
  (pos_enc): PositionalEncoding()
  (transformer_list): ModuleList(
    (0): Transformer_module(
      (ln1): LayerNorm()
      (msa): MultiHeadAttention(
        (W_concat): Linear(in_features=256, out_features=256, bias=True)
        (attn): Attention(
          (W_Q): Linear(in_features=64, out_features=64, bias=True)
          (W_K): Linear(in_features=64, out_features=64, bias=True)
          (W_V): Linear(in_features=64, out_features=64, bias=True)
          (softmax): Softmax(dim=-1)
        )
      )
      (ln2): LayerNorm()
      (ffn): FFN(
        (linear1): Linear(in_features=256, out_features=256, bias=True)
        (gelu): GELU(approximate=none)
        (linear2): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1,

In [11]:
trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=100)
model = Transformer(256, 10, 3, 4, 16)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.callbacks.model_summary:
  | Name             | Type               | Params
--------------------------------------------------------
0 | train_acc        | MulticlassAccuracy | 0     
1 | val_acc          | MulticlassAccuracy | 0     
2 | test_acc         | MulticlassAccuracy | 0     
3 | linear_prj       | LinearProjection   | 66.0 K
4 | pos_enc          | PositionalEncoding | 0     
5 | transformer_list | ModuleList         | 634 K 
6 | mlp              | Sequential         | 3.1 K 
--------------------------------------------------------
703 K     Trainable params
0         Non-trainable params
703 K     Total params
2.813     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  cpuset_checked))


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [13]:
trainer.test(model, test_loader)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             0.879800021648407
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.879800021648407}]

In [15]:
trainer2 = pl.Trainer(accelerator="auto", devices=1, max_epochs=100)
model2 = Transformer(16, 10, 3, 4, 16)
trainer2.fit(model2, train_loader, val_loader)
trainer2.test(model2, test_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name             | Type               | Params
--------------------------------------------------------
0 | train_acc        | MulticlassAccuracy | 0     
1 | val_acc          | MulticlassAccuracy | 0     
2 | test_acc         | MulticlassAccuracy | 0     
3 | linear_prj       | LinearProjection   | 4.1 K 
4 | pos_enc          | PositionalEncoding | 0     
5 | transformer_list | ModuleList         | 2.9 K 
6 | mlp              | Sequential         | 202   
--------------------------------------------------------
7.2 K     Trainable params
0         Non-trainable params
7.2 K     Total params
0.029     Total estimated mode

Sanity Checking: 0it [00:00, ?it/s]

  cpuset_checked))


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8906000852584839
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.8906000852584839}]