# Vision Transformer [ViT] on Cifar 10 dataset

In [36]:
# define
IMAGE_SIZE = 32
BATCH_SIZE = 128

PATCH_SIZE = 4
HEADS_NUM = 8
BLOCK_SIZE = 8
EMBED_DIM = 256

EPOCHS = 80

In [37]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as T

transform_train = T.Compose([
    T.RandomCrop(IMAGE_SIZE, padding=4),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
    T.RandomAffine(degrees=25, translate=(0.1, 0.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    T.RandomErasing(p=0.5, scale=(0.05, 0.25), ratio=(0.3, 3.3), value=0)
])

transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

train_dataset = CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=transform_train)

test_dataset = CIFAR10(
    root='data',
    train=False,
    download=True,
    transform=transform_test)

print(f"Train dataset size: {len(train_dataset)}, Test dataset size: {len(test_dataset)}")
# Train dataset size: 50000, Test dataset size: 10000

Train dataset size: 50000, Test dataset size: 10000


In [62]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True)

print(f"Number of training batches: {len(train_dataloader)}")
print(f"Number of test batches: {len(test_dataloader)}")
#Number of training batches: 391
#Number of test batches: 79

Number of training batches: 391
Number of test batches: 79


In [63]:
import torch
import torch.nn as nn

# Embedding Layer for Vision Transformer
class ViTEmbedding(nn.Module):
    def __init__(self, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, in_channels=3, embed_dim=EMBED_DIM):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)                     # (B, D, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)            # (B, N, D)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_tokens, x), dim=1)       # (B, N+1, D) Concatinate CLS token at the beginning
        x = x + self.pos_embed                      # Positional encoding
        return x # (B, N+1, D)


In [64]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, num_heads=HEADS_NUM):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, return_attention=False): # takes in embeddings of shpape (B, N + 1, D)
        B, N, D = x.shape # (B, N+1, D)
        qkv = self.qkv_proj(x)                                     # (B, N+1, 3*D)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)  # (B, N+1, 3, H, D/H)
        qkv = qkv.permute(2, 0, 3, 1, 4)                           # (3, B, H, N+1, D/H)
        q, k, v = qkv[0], qkv[1], qkv[2]                           # Split Q, K, V

        attn_scores = torch.einsum("bhnd,bhmd->bhnm", q / (self.head_dim ** 0.5), k) # (B, H, N+1, N+1)
        attn_weights = torch.softmax(attn_scores, dim=-1) # (B, H, N+1, N+1)
        attn_output = torch.einsum("bhnm,bhmd->bhnd", attn_weights, v) # (B, H, N+1, D/H)
        attn_output = attn_output.transpose(1, 2).reshape(B, N, D)  # (B, H, N+1, D/H) --> (B, N+1, H, D/H) --> (B, N+1, D)

        if return_attention:
            return self.out_proj(attn_output), attn_weights

        return self.out_proj(attn_output)

In [65]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim) # Normalization layer for MSA

        self.norm2 = nn.LayerNorm(embed_dim) # Normalization layer for MLP

        self.MSA = MultiHeadAttention(embed_dim=embed_dim)

        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x):
        x = x + self.MSA(self.norm1(x))   # (B, N+1, D)
        x = x + self.MLP(self.norm2(x))   # (B, N+1, D)
        return x

In [66]:
class ViTTransformers(nn.Module):
    def __init__(self, num_classes=10,
                 block_num=BLOCK_SIZE,
                 img_size=IMAGE_SIZE,
                 patch_size=PATCH_SIZE,
                 in_channels=3,
                 embed_dim=EMBED_DIM):
        super().__init__()

        self.embedding = ViTEmbedding(img_size,
                                      patch_size,
                                      in_channels,
                                      embed_dim)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoder(embed_dim) for _ in range(block_num)
        ])
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)  # (B, N+1, D)
        for encoder in self.encoder_layers:
            x = encoder(x)
        cls_token = x[:, 0]  # Extract CLS token
        return self.classifier(cls_token)  # (B, num_classes)

In [67]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_0 = ViTTransformers().to(device)
test_input = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)  # Batch size of 32
output = model_0(test_input)
print(device)
print(f"Output shape: {output.shape}")  # Should be (32, 10) for CIFAR-10

cuda
Output shape: torch.Size([128, 10])


In [68]:
# from torch.optim.lr_scheduler import MultiStepLR
# scheduler = MultiStepLR(optimizer, milestones=[10*i for i in range(1,3)], gamma=0.5)

In [10]:
!pip install torchinfo -q


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [69]:
from torchinfo import summary

summary(model=model_0,
        input_size=(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ViTTransformers (ViTTransformers)        [128, 3, 32, 32]     [128, 10]            --                   True
├─ViTEmbedding (embedding)               [128, 3, 32, 32]     [128, 65, 256]       16,896               True
│    └─Conv2d (patch_embed)              [128, 3, 32, 32]     [128, 256, 8, 8]     12,544               True
├─ModuleList (encoder_layers)            --                   --                   --                   True
│    └─TransformerEncoder (0)            [128, 65, 256]       [128, 65, 256]       --                   True
│    │    └─LayerNorm (norm1)            [128, 65, 256]       [128, 65, 256]       512                  True
│    │    └─MultiHeadAttention (MSA)     [128, 65, 256]       [128, 65, 256]       263,168              True
│    │    └─LayerNorm (norm2)            [128, 65, 256]       [128, 65, 256]       512                  True
│    │    └─Se

In [74]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model_0.parameters(), lr=5e-2, weight_decay=0.05)

from transformers import get_cosine_schedule_with_warmup

total_steps = len(train_dataloader) * EPOCHS
warmup_steps = int(0.25 * total_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
print(warmup_steps) #2346

7820


In [75]:
import torch
from tqdm import tqdm

def train(model, train_dataloader, test_dataloader, loss_fn, optimizer, epochs, device, scheduler=None, use_wandb=True):
    results = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    run = wandb.init(
        entity="hellcasterz-nit-warangal",
        # Set the wandb project where this run will be logged.
        project="Vision_Transformer",
        # Track hyperparameters and run metadata.
        config={
            "learning_rate": optimizer.param_groups[0]["lr"],
            "architecture": "ViT",
            "dataset": "CIFAR-100",
            "epochs": epochs,
        },
    )
    for epoch in range(epochs):
        model.train()
        train_loss, train_correct, total = 0, 0, 0

        loop = tqdm(train_dataloader, leave=False, desc=f"Epoch {epoch+1}/{epochs}")
        for X, y in loop:
            X, y = X.to(device), y.to(device)

            preds = model(X)
            loss = loss_fn(preds, y)

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * X.size(0)
            _, predicted = preds.max(1)
            train_correct += predicted.eq(y).sum().item()
            total += y.size(0)

            loop.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / total
        avg_train_acc = train_correct / total

        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)

                preds = model(X)
                loss = loss_fn(preds, y)

                test_loss += loss.item() * X.size(0)
                _, predicted = preds.max(1)
                test_correct += predicted.eq(y).sum().item()
                test_total += y.size(0)

        avg_test_loss = test_loss / test_total
        avg_test_acc = test_correct / test_total

        results["train_loss"].append(avg_train_loss)
        results["train_acc"].append(avg_train_acc)
        results["test_loss"].append(avg_test_loss)
        results["test_acc"].append(avg_test_acc)

        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f} "
            f"Test Loss: {avg_test_loss:.4f}, Test Acc: {avg_test_acc:.4f}"
        )

        if use_wandb:
            run.log({
                "Train Loss": avg_train_loss,
                "Train Acc": avg_train_acc,
                "Test Loss": avg_test_loss,
                "Test Acc": avg_test_acc,
                "Learning Rate": optimizer.param_groups[0]["lr"],
            })

        if scheduler is not None:
            scheduler.step()
    if use_wandb:
        print(run.summary())
        run.finish()
    return results


In [76]:
!pip install wandb
import wandb
wandb.login()









[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip










True

In [77]:
results =train(model=model_0,
      train_dataloader=train_dataloader,
      test_dataloader=test_dataloader,
      loss_fn=loss_fn,
      optimizer=optimizer,
      epochs=EPOCHS,
      device=device,
      scheduler=scheduler
      )  # Reduced epochs for quicker testing

0,1
Epoch,▁▂▄▅▇█
Learning Rate,▁▂▄▅▇█
Test Acc,▁▄▅▆██
Test Loss,█▄▃▃▁▁
Train Acc,▁▁▄▆▇█
Train Loss,█▅▄▃▂▁

0,1
Epoch,6.0
Learning Rate,0.0
Test Acc,0.2185
Test Loss,2.08379
Train Acc,0.2206
Train Loss,2.10803


                                                                        

Epoch [1/80] Train Loss: 2.0927, Train Acc: 0.2230 Test Loss: 2.0793, Test Acc: 0.2416


                                                                        

Epoch [2/80] Train Loss: 2.0767, Train Acc: 0.2333 Test Loss: 2.0315, Test Acc: 0.2434


                                                                        

Epoch [3/80] Train Loss: 2.0142, Train Acc: 0.2582 Test Loss: 1.9751, Test Acc: 0.2832


                                                                        

Epoch [4/80] Train Loss: 1.9667, Train Acc: 0.2779 Test Loss: 1.9013, Test Acc: 0.3086


                                                                        

Epoch [5/80] Train Loss: 1.9314, Train Acc: 0.2939 Test Loss: 1.8594, Test Acc: 0.3309


                                                                        

Epoch [6/80] Train Loss: 1.8816, Train Acc: 0.3129 Test Loss: 1.7451, Test Acc: 0.3699


                                                                        

Epoch [7/80] Train Loss: 1.8158, Train Acc: 0.3378 Test Loss: 1.6878, Test Acc: 0.3893


                                                                        

Epoch [8/80] Train Loss: 1.7570, Train Acc: 0.3562 Test Loss: 1.6124, Test Acc: 0.4152


                                                                        

Epoch [9/80] Train Loss: 1.7008, Train Acc: 0.3803 Test Loss: 1.5320, Test Acc: 0.4508


                                                                         

Epoch [10/80] Train Loss: 1.6599, Train Acc: 0.3960 Test Loss: 1.5304, Test Acc: 0.4461


                                                                         

Epoch [11/80] Train Loss: 1.6267, Train Acc: 0.4086 Test Loss: 1.4520, Test Acc: 0.4764


                                                                         

Epoch [12/80] Train Loss: 1.5871, Train Acc: 0.4237 Test Loss: 1.4116, Test Acc: 0.4943


                                                                         

Epoch [13/80] Train Loss: 1.5609, Train Acc: 0.4363 Test Loss: 1.3890, Test Acc: 0.5001


                                                                         

Epoch [14/80] Train Loss: 1.5332, Train Acc: 0.4433 Test Loss: 1.3700, Test Acc: 0.5085


                                                                         

Epoch [15/80] Train Loss: 1.5002, Train Acc: 0.4562 Test Loss: 1.3478, Test Acc: 0.5236


                                                                         

Epoch [16/80] Train Loss: 1.4790, Train Acc: 0.4655 Test Loss: 1.2957, Test Acc: 0.5321


                                                                         

Epoch [17/80] Train Loss: 1.4522, Train Acc: 0.4769 Test Loss: 1.3082, Test Acc: 0.5248


                                                                         

Epoch [18/80] Train Loss: 1.4225, Train Acc: 0.4853 Test Loss: 1.2580, Test Acc: 0.5448


                                                                         

Epoch [19/80] Train Loss: 1.4005, Train Acc: 0.4960 Test Loss: 1.2399, Test Acc: 0.5572


                                                                         

Epoch [20/80] Train Loss: 1.3751, Train Acc: 0.5034 Test Loss: 1.2255, Test Acc: 0.5559


                                                                         

Epoch [21/80] Train Loss: 1.3524, Train Acc: 0.5138 Test Loss: 1.1992, Test Acc: 0.5690


                                                                         

Epoch [22/80] Train Loss: 1.3315, Train Acc: 0.5210 Test Loss: 1.2313, Test Acc: 0.5550


                                                                         

Epoch [23/80] Train Loss: 1.3168, Train Acc: 0.5258 Test Loss: 1.2099, Test Acc: 0.5593


                                                                         

Epoch [24/80] Train Loss: 1.2967, Train Acc: 0.5336 Test Loss: 1.1687, Test Acc: 0.5783


                                                                         

Epoch [25/80] Train Loss: 1.2781, Train Acc: 0.5393 Test Loss: 1.1502, Test Acc: 0.5852


                                                                          

Epoch [26/80] Train Loss: 1.2588, Train Acc: 0.5502 Test Loss: 1.1257, Test Acc: 0.5947


                                                                          

Epoch [27/80] Train Loss: 1.2461, Train Acc: 0.5530 Test Loss: 1.1244, Test Acc: 0.5932


                                                                          

Epoch [28/80] Train Loss: 1.2220, Train Acc: 0.5617 Test Loss: 1.0931, Test Acc: 0.6039


                                                                          

Epoch [29/80] Train Loss: 1.2027, Train Acc: 0.5696 Test Loss: 1.0811, Test Acc: 0.6105


                                                                          

Epoch [30/80] Train Loss: 1.1918, Train Acc: 0.5714 Test Loss: 1.0608, Test Acc: 0.6206


                                                                          

Epoch [31/80] Train Loss: 1.1745, Train Acc: 0.5769 Test Loss: 1.0401, Test Acc: 0.6246


                                                                          

Epoch [32/80] Train Loss: 1.1591, Train Acc: 0.5860 Test Loss: 1.0551, Test Acc: 0.6194


                                                                          

Epoch [33/80] Train Loss: 1.1399, Train Acc: 0.5941 Test Loss: 1.0272, Test Acc: 0.6328


                                                                          

Epoch [34/80] Train Loss: 1.1205, Train Acc: 0.5983 Test Loss: 1.0321, Test Acc: 0.6271


                                                                          

Epoch [35/80] Train Loss: 1.1085, Train Acc: 0.6037 Test Loss: 0.9857, Test Acc: 0.6441


                                                                          

Epoch [36/80] Train Loss: 1.0939, Train Acc: 0.6105 Test Loss: 0.9915, Test Acc: 0.6462


                                                                          

Epoch [37/80] Train Loss: 1.0760, Train Acc: 0.6155 Test Loss: 0.9619, Test Acc: 0.6588


                                                                          

Epoch [38/80] Train Loss: 1.0696, Train Acc: 0.6164 Test Loss: 0.9720, Test Acc: 0.6565


                                                                          

Epoch [39/80] Train Loss: 1.0528, Train Acc: 0.6242 Test Loss: 0.9753, Test Acc: 0.6545


                                                                          

Epoch [40/80] Train Loss: 1.0388, Train Acc: 0.6294 Test Loss: 0.9354, Test Acc: 0.6680


                                                                          

Epoch [41/80] Train Loss: 1.0197, Train Acc: 0.6356 Test Loss: 0.9475, Test Acc: 0.6577


                                                                          

Epoch [42/80] Train Loss: 1.0114, Train Acc: 0.6385 Test Loss: 0.9129, Test Acc: 0.6718


                                                                          

Epoch [43/80] Train Loss: 0.9957, Train Acc: 0.6432 Test Loss: 0.9394, Test Acc: 0.6714


                                                                          

Epoch [44/80] Train Loss: 0.9871, Train Acc: 0.6464 Test Loss: 0.9410, Test Acc: 0.6639


                                                                          

Epoch [45/80] Train Loss: 0.9704, Train Acc: 0.6537 Test Loss: 0.9289, Test Acc: 0.6686


                                                                          

Epoch [46/80] Train Loss: 0.9680, Train Acc: 0.6568 Test Loss: 0.9188, Test Acc: 0.6720


                                                                          

Epoch [47/80] Train Loss: 0.9492, Train Acc: 0.6604 Test Loss: 0.8890, Test Acc: 0.6860


                                                                          

Epoch [48/80] Train Loss: 0.9335, Train Acc: 0.6672 Test Loss: 0.9080, Test Acc: 0.6828


                                                                          

Epoch [49/80] Train Loss: 0.9218, Train Acc: 0.6721 Test Loss: 0.8801, Test Acc: 0.6873


                                                                          

Epoch [50/80] Train Loss: 0.9164, Train Acc: 0.6745 Test Loss: 0.8548, Test Acc: 0.6933


                                                                          

Epoch [51/80] Train Loss: 0.8970, Train Acc: 0.6775 Test Loss: 0.8605, Test Acc: 0.6967


                                                                          

Epoch [52/80] Train Loss: 0.8917, Train Acc: 0.6832 Test Loss: 0.8750, Test Acc: 0.6967


                                                                          

Epoch [53/80] Train Loss: 0.8798, Train Acc: 0.6863 Test Loss: 0.8323, Test Acc: 0.6999


                                                                          

Epoch [54/80] Train Loss: 0.8733, Train Acc: 0.6873 Test Loss: 0.8162, Test Acc: 0.7123


                                                                          

Epoch [55/80] Train Loss: 0.8629, Train Acc: 0.6942 Test Loss: 0.8079, Test Acc: 0.7138


                                                                          

Epoch [56/80] Train Loss: 0.8493, Train Acc: 0.6993 Test Loss: 0.8201, Test Acc: 0.7129


                                                                          

Epoch [57/80] Train Loss: 0.8418, Train Acc: 0.6958 Test Loss: 0.8434, Test Acc: 0.7071


                                                                          

Epoch [58/80] Train Loss: 0.8302, Train Acc: 0.7047 Test Loss: 0.8362, Test Acc: 0.7099


                                                                          

Epoch [59/80] Train Loss: 0.8179, Train Acc: 0.7063 Test Loss: 0.8481, Test Acc: 0.7022


                                                                          

Epoch [60/80] Train Loss: 0.8124, Train Acc: 0.7094 Test Loss: 0.8358, Test Acc: 0.7103


                                                                          

Epoch [61/80] Train Loss: 0.8044, Train Acc: 0.7138 Test Loss: 0.7860, Test Acc: 0.7271


                                                                          

Epoch [62/80] Train Loss: 0.7887, Train Acc: 0.7206 Test Loss: 0.8302, Test Acc: 0.7088


                                                                          

Epoch [63/80] Train Loss: 0.7828, Train Acc: 0.7193 Test Loss: 0.7627, Test Acc: 0.7342


                                                                          

Epoch [64/80] Train Loss: 0.7703, Train Acc: 0.7263 Test Loss: 0.7889, Test Acc: 0.7216


                                                                          

Epoch [65/80] Train Loss: 0.7590, Train Acc: 0.7301 Test Loss: 0.7641, Test Acc: 0.7302


                                                                          

Epoch [66/80] Train Loss: 0.7529, Train Acc: 0.7321 Test Loss: 0.7568, Test Acc: 0.7335


                                                                          

Epoch [67/80] Train Loss: 0.7432, Train Acc: 0.7351 Test Loss: 0.7956, Test Acc: 0.7240


                                                                          

Epoch [68/80] Train Loss: 0.7327, Train Acc: 0.7380 Test Loss: 0.7640, Test Acc: 0.7359


                                                                          

Epoch [69/80] Train Loss: 0.7249, Train Acc: 0.7430 Test Loss: 0.7576, Test Acc: 0.7373


                                                                          

Epoch [70/80] Train Loss: 0.7234, Train Acc: 0.7399 Test Loss: 0.7564, Test Acc: 0.7394


                                                                          

Epoch [71/80] Train Loss: 0.7158, Train Acc: 0.7438 Test Loss: 0.7773, Test Acc: 0.7272


                                                                          

Epoch [72/80] Train Loss: 0.7029, Train Acc: 0.7478 Test Loss: 0.7298, Test Acc: 0.7456


                                                                          

Epoch [73/80] Train Loss: 0.6967, Train Acc: 0.7515 Test Loss: 0.7704, Test Acc: 0.7347


                                                                          

Epoch [74/80] Train Loss: 0.6840, Train Acc: 0.7547 Test Loss: 0.7605, Test Acc: 0.7338


                                                                          

Epoch [75/80] Train Loss: 0.6746, Train Acc: 0.7604 Test Loss: 0.7661, Test Acc: 0.7344


                                                                          

Epoch [76/80] Train Loss: 0.6709, Train Acc: 0.7607 Test Loss: 0.7423, Test Acc: 0.7442


                                                                          

Epoch [77/80] Train Loss: 0.6659, Train Acc: 0.7627 Test Loss: 0.7303, Test Acc: 0.7475


                                                                          

Epoch [78/80] Train Loss: 0.6543, Train Acc: 0.7657 Test Loss: 0.7442, Test Acc: 0.7427


                                                                          

Epoch [79/80] Train Loss: 0.6423, Train Acc: 0.7699 Test Loss: 0.7553, Test Acc: 0.7419


                                                                          

Epoch [80/80] Train Loss: 0.6421, Train Acc: 0.7729 Test Loss: 0.7330, Test Acc: 0.7500


TypeError: 'Summary' object is not callable