In [6]:
import os
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image
torch.backends.cudnn.deterministic = True

In [12]:
BATCH_SIZE = 128
NUM_EPOCHS = 50

NUM_CLASSES = 9
ZDIM = 5

In [13]:
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def delete_class(dataset, to_delete):
    targets = np.array([np.where(np.array(dataset.targets) != to_delete)])[0][0]
    dataset.targets = np.array(dataset.targets)[targets]
    dataset.data = dataset.data[targets]
    
    return dataset

def mv_class(dataset, from_, to):
    targets = np.array([np.where(np.array(dataset.targets) == from_)])[0][0]
    dataset.targets = np.array(dataset.targets)
    dataset.targets[targets] = to
    
    return dataset


def prune(dataset, keep):
    all_len = dataset.data.shape[0]
    keep = int(all_len * keep)
    dataset.data = dataset.data[:keep, ...]
    dataset.targets = dataset.targets[:keep]
    
    return dataset

train_dataset = datasets.CIFAR10(root='data',
                                 train=True, 
                                 transform=transform_train,
                                 download=True)
delete_class(train_dataset, 0)
mv_class(train_dataset, 9, 0)
#prune(train_dataset, 0.05)

test_dataset = datasets.CIFAR10(root='data',
                                train=False, 
                                transform=transform_test)
delete_class(test_dataset, 0)
mv_class(test_dataset, 9, 0)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE,
                          num_workers=20,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=BATCH_SIZE,
                         num_workers=20,
                         shuffle=False)

Files already downloaded and verified


In [14]:
import pytorch_lightning as pl
from bottleneckresnet import resnet18
from sklearn.metrics import classification_report
from tqdm.autonotebook import tqdm

def get_prediction(x, model: pl.LightningModule):
    model.freeze()
    pred, z = model(x)
    probabilities = torch.softmax(pred, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1)
    return predicted_class, probabilities, z

class ResNetFuzzy(pl.LightningModule):
    def __init__(self, zdim):
        super().__init__()
        self.model = resnet18(num_classes=NUM_CLASSES, z_dim=zdim, fuzzy=True)
        self.loss = nn.BCELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_no):
        x, y = batch
        logits, _ = self(x)
        Y = torch.zeros(logits.shape, dtype=torch.float32, device='cuda')
        for i in range(len(y)):
            ind = y[i]
            Y[i, ind] = 1
        y = Y

        loss = self.loss(logits, y)
        return loss
    
    def on_train_epoch_end(self) -> None:
        simple_true_y, simple_pred_y, simple_z = [], [], []
        for batch in tqdm(iter(test_loader), total=len(test_loader)):
            x, y = batch
            simple_true_y.extend(y)
            preds, probs, z = get_prediction(x.to('cuda'), self)
            simple_z.extend(z.cpu())
            simple_pred_y.extend(preds.cpu())
        acc = classification_report(simple_true_y, simple_pred_y, digits=3, output_dict=True)['accuracy']
        self.unfreeze()
        print(f'acc={acc}')
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        return [optimizer], [scheduler]


class ResNetSimple(pl.LightningModule):
    def __init__(self, zdim):
        super().__init__()
        self.model = resnet18(num_classes=NUM_CLASSES, z_dim=zdim, fuzzy=False)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_no):
        x, y = batch
        logits, _ = self(x)
        loss = self.loss(logits, y)
        return loss
    
    def on_train_epoch_end(self) -> None:
        simple_true_y, simple_pred_y, simple_z = [], [], []
        for batch in tqdm(iter(test_loader), total=len(test_loader)):
            x, y = batch
            simple_true_y.extend(y)
            preds, probs, z = get_prediction(x.to('cuda'), self)
            simple_z.extend(z.cpu())
            simple_pred_y.extend(preds.cpu())
        acc = classification_report(simple_true_y, simple_pred_y, digits=3, output_dict=True)['accuracy']
        self.unfreeze()
        print(f'acc={acc}')
    
    def configure_optimizers(self):
        # choose your optimizer
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
        return [optimizer], [scheduler]

In [15]:
for i in range(ZDIM, ZDIM-1, -1):
    ZDIM = i
    fuzzy_model = ResNetFuzzy(ZDIM)
    trainer = pl.Trainer(
        devices=1, 
        max_epochs=NUM_EPOCHS, 
    )
    trainer.fit(fuzzy_model, train_loader)

    fuzzy_true_y, fuzzy_pred_y, fuzzy_z = [], [], []
    for batch in tqdm(iter(test_loader), total=len(test_loader)):
        x, y = batch
        fuzzy_true_y.extend(y)
        preds, probs, z = get_prediction(x, fuzzy_model)
        fuzzy_z.extend(z.cpu())
        fuzzy_pred_y.extend(preds.cpu())

    print(classification_report(fuzzy_true_y, fuzzy_pred_y, digits=3))
    print(f"ZDIM is {ZDIM}")

    trainer.save_checkpoint(f"resnet18-{ZDIM}-fuzzy_2.pt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type    | Params
----------------------------------
0 | model | ResNet  | 11.2 M
1 | loss  | BCELoss | 0     
----------------------------------
11.2 M    Trainable params
55        Non-trainable params
11.2 M    Total params
44.687    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.2018888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.21344444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.28844444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.45744444444444443


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.4628888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.47888888888888886


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6194444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6253333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5961111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6472222222222223


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7388888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7476666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7556666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7244444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7602222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8078888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7032222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8258888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.794


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8133333333333334


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8278888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.7717777777777778


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8376666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8404444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8596666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8298888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8445555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8592222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8404444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8571111111111112


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8615555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8837777777777778


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8977777777777778


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8913333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.8975555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9036666666666666


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9044444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9146666666666666


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9221111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9225555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9213333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9286666666666666


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9314444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9325555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9372222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9355555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.938


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9391111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9397777777777778


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.9406666666666667


`Trainer.fit` stopped: `max_epochs=50` reached.


  0%|          | 0/71 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0      0.961     0.961     0.961      1000
           1      0.972     0.971     0.971      1000
           2      0.926     0.931     0.929      1000
           3      0.874     0.877     0.875      1000
           4      0.934     0.946     0.940      1000
           5      0.897     0.906     0.901      1000
           6      0.963     0.961     0.962      1000
           7      0.966     0.944     0.955      1000
           8      0.974     0.968     0.971      1000

    accuracy                          0.941      9000
   macro avg      0.941     0.941     0.941      9000
weighted avg      0.941     0.941     0.941      9000

ZDIM is 5


In [16]:
for i in range(ZDIM, ZDIM-1, -1):
    ZDIM = i
    simple_model = ResNetSimple(ZDIM)
    trainer = pl.Trainer(
        devices=1, 
        max_epochs=NUM_EPOCHS, 
    )
    
    trainer.fit(simple_model, train_loader)

    simple_true_y, simple_pred_y, simple_z = [], [], []
    for batch in tqdm(iter(test_loader), total=len(test_loader)):
        x, y = batch
        simple_true_y.extend(y)
        preds, probs, z = get_prediction(x, simple_model)
        simple_z.extend(z.cpu())
        simple_pred_y.extend(preds.cpu())

    print(classification_report(simple_true_y, simple_pred_y, digits=3, output_dict=True)['accuracy'])
    print(f"ZDIM is {ZDIM}")

    trainer.save_checkpoint(f"resnet18-{ZDIM}-simple_2.pt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.686    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.279


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.3851111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.43133333333333335


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.4643333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.4742222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5055555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5612222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5002222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5573333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5825555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6061111111111112


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.565


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6114444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.47633333333333333


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6155555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6194444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6221111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.6178888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

acc=0.5945555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.5842222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.5988888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6091111111111112


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6456666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.662


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.602


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6491111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.623


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6375555555555555


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6407777777777778


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6661111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.683


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6825555555555556


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6318888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.6888888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.698


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.703


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.705


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7071111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7162222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7222222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7234444444444444


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7241111111111111


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7258888888888889


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7292222222222222


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7294444444444445


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7316666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7331111111111112


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7346666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7346666666666667


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


acc=0.7348888888888889


`Trainer.fit` stopped: `max_epochs=50` reached.


  0%|          | 0/71 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.7347777777777778
ZDIM is 5
