In [1]:
%matplotlib inline 

In [2]:
from time import time
from glob import glob

from torchvision.io import read_image
from torchvision import datasets, transforms
from matplotlib import pyplot as plt

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold

from tqdm import tqdm
import torchmetrics
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, ConcatDataset
plt.gray()

<Figure size 640x480 with 0 Axes>

In [3]:
batch_size = 64
input_size = 28
num_classes = 47

In [5]:
# transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
datasets.ImageFolder
# data_loader = DataLoader(
#         datasets.EMNIST('data/emnist', train=True, download=True, transform=transform, split="bymerge"),
#         batch_size=batch_size, shuffle=True)
# data, label = data_loader.__iter__().__next__()
# label

tensor([39, 15, 46,  1, 23,  0,  1, 36, 25,  1, 18,  5, 24,  3,  9, 21,  7, 38,
         9,  8,  5,  2,  0,  8, 30,  2, 32, 35, 45, 22, 30,  8, 46, 23,  8, 38,
         9,  8, 27, 44,  4, 45, 29, 45, 45,  8, 18,  2,  4,  9, 32,  6, 22,  8,
        20,  4,  2, 39, 33,  9, 22,  8,  0, 34])

In [6]:
# plt.imshow(data[2, 0, :] * 255, interpolation='nearest')
counter = [0] * num_classes
for _, l in data_loader:
    for v in l.numpy():
        counter[v] += 1

In [9]:
counter = np.array(counter)
hist = counter[:47] / (counter.sum()) * 100
hist

array([4.96008207, 5.48821375, 4.91552186, 5.05565012, 4.82224629,
       4.48181198, 4.89302683, 5.16096124, 4.86064545, 4.85462767,
       0.91857086, 0.5550684 , 1.85734427, 0.65994968, 0.70565614,
       1.30356539, 0.36307262, 0.4437395 , 2.11095064, 0.81512239,
       0.71611561, 2.92019853, 1.6637724 , 1.18020094, 3.96370993,
       1.53997811, 0.37295897, 0.72313635, 3.36837973, 1.39927672,
       2.20479932, 1.08721193, 1.06070505, 0.80208387, 1.01614484,
       0.77600683, 1.43409387, 0.72786461, 1.45458297, 3.53286567,
       0.3632159 , 0.52913464, 1.24396073, 1.63970129, 0.42496977,
       2.0145229 , 2.61458136])

In [4]:
class LeNet(nn.Module):

  def __init__(self, num_classes=10, device="cuda"):
    super(LeNet, self).__init__()
    self.num_classes = num_classes
    self.device = device

    self.core = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),

      nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),

      nn.Flatten(start_dim=1),

      nn.Linear(16*4*4, 120),
      nn.ReLU(),
      nn.Linear(120, 84),
      nn.ReLU(),
      nn.Linear(84, self.num_classes),
      nn.Softmax(dim=1),
    )
    self.core = self.core.to(self.device)
    self.to(self.device)

  def forward(self, x):
    res = self.core(x)
    return res


class ANN(nn.Module):
  def __init__(self, num_classes=10, device="cuda"):
    super(ANN, self).__init__()
    self.num_classes = num_classes
    self.device = device
    self.core = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, self.num_classes),
        nn.Softmax(dim=1),
    )

  def forward(self, x):
    return self.core(x)


class IDK(nn.Module):
  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size = 3, padding = 1),
      nn.ReLU(),
      nn.Conv2d(32,64, kernel_size = 3, stride = 1, padding = 1),
      nn.ReLU(),
      nn.MaxPool2d(2,2),
  
      nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
      nn.ReLU(),
      nn.Conv2d(128 ,128, kernel_size = 3, stride = 1, padding = 1),
      nn.ReLU(),
      nn.MaxPool2d(2,2),
      
      nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1),
      nn.ReLU(),
      nn.Conv2d(256,256, kernel_size = 3, stride = 1, padding = 1),
      nn.ReLU(),
      nn.MaxPool2d(2,2),
      
      nn.Flatten(),
      nn.Linear(82944,1024),
      nn.ReLU(),
      nn.Linear(1024, 512),
      nn.ReLU(),
      nn.Linear(512,6)
    )
  
  def forward(self, xb):
    return self.network(xb)

In [32]:
IGNORED_PARAM_RESET = {"activation", "loss_function"}
def reset_modules(module, parents_modules_names=[]):
  for name, module in module.named_children():
    if name in IGNORED_PARAM_RESET:
      continue
    if isinstance(module, nn.ModuleList):
        reset_modules(module, parents_modules_names=[*parents_modules_names, name])
    elif isinstance(module, nn.Dropout):
      continue
    else:
      module.reset_parameters()

def score(prediction, target, num_classes=10, device="cuda"):
  accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes, average="macro").to(device)
  recall = torchmetrics.Recall("multiclass", num_classes=num_classes, average="macro").to(device)
  precision = torchmetrics.Precision("multiclass", num_classes=num_classes, average="macro").to(device)

  return accuracy(prediction, target), recall(prediction, target), precision(prediction, target)

def forward(model, x):
  return model(x)


def test(model, test_dataset, num_classes, device):
  test_dataloader = DataLoader(test_dataset, batch_size=64)
  all_preds = []
  all_targets = []

  model.eval()
  with torch.no_grad():
    for X, y in tqdm(test_dataloader, leave=False):
      X = X.to(device)
      y = y.to(device)
      pred = forward(model, X)
      all_preds.extend(torch.sigmoid(pred))
      all_targets.extend(y)
    all_preds = torch.stack(all_preds)
    all_targets = torch.stack(all_targets)
  return score(all_preds, all_targets, num_classes, device)

def learn(train_dataset, test_dataset, get_model, num_classes=10, lr=0.0005, n_splits=3, n_epoch=5, batch_size=64, device="cuda"):
  kfolder = KFold(n_splits=n_splits, shuffle=True)
  splits = [None] * n_splits
  for i, (train_ids, label_ids) in enumerate(kfolder.split(train_dataset.data)):
    splits[i] = (train_ids, label_ids)

  scheduler_args = {"verbose": True, "min_lr":1e-9, "threshold": 20, "cooldown": 5, "patience": 20, "factor":0.25, "mode": "min"}
  metrics = [None] * n_splits
  loss_function = nn.CrossEntropyLoss()

  for fold, (train_ids, validation_ids) in enumerate(splits):

    train_subsampler = SubsetRandomSampler(train_ids)
    validation_subsampler = SubsetRandomSampler(validation_ids)
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                    sampler=train_subsampler)
    validation_loader = DataLoader(train_dataset, batch_size=128,
                                    sampler=validation_subsampler)
    model = get_model()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, **scheduler_args)
    model.to(device)

    total_loss = []
    total_validation_loss = []
    for e in range(n_epoch):
      model.train()
      loss = 0
      epoch_loss = 0
      epoch_start_time = time()
      for (X, y) in tqdm(train_loader, leave=False):
        optimizer.zero_grad()
        X = X.to(device)
        y = y.to(device)

        y_hat = forward(model, X)
        loss = loss_function(y_hat, y)
        loss.backward()

        optimizer.step()
        epoch_loss += loss.item()
      
      epoch_loss /= len(train_ids)
      total_loss.append(epoch_loss)
      scheduler.step(loss)

      all_preds = []
      all_targets = []
      validation_loss = 0
      model.eval()
      with torch.no_grad():
        for X, y in tqdm(validation_loader, leave=False):
          X = X.to(device)
          y = y.to(device)

          pred = model.forward(X)
          loss = loss_function(pred, y)
          validation_loss += loss.item()

          all_preds.extend(torch.sigmoid(pred) if isinstance(loss_function, nn.BCEWithLogitsLoss) else pred)
          all_targets.extend(y)
        validation_loss /= len(validation_ids)
        total_validation_loss.append(validation_loss)
      
      all_preds = torch.stack(all_preds)
      all_targets = torch.stack(all_targets)
      accuracy_value, recall_value, precision_value, *_ = score(all_preds, all_targets, num_classes, device)

      print(f"fold: {fold} | epoch: {e} @ {(time() - epoch_start_time):>0.1f}s | train -> loss: {(epoch_loss):>0.5f} | validation -> loss: {(validation_loss):>0.5f} | accuracy: {(100 * accuracy_value):>0.6f} | precision: {(100 * precision_value):>0.6f} | recall: {(100 * recall_value):>0.6f}")
    accuracy_value, recall_value, precision_value, *_ = test(model, test_dataset, num_classes, device)
    print(f"fold: {fold} on Test Set | accuracy: {(100 * accuracy_value):>0.6f} | precision: {(100 * precision_value):>0.6f} | recall: {(100 * recall_value):>0.6f}")
    metrics[fold] = (accuracy_value, recall_value, precision_value)
  return metrics

In [29]:
transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
num_classes = 47
train_set = datasets.EMNIST('data/emnist', train=True, download=True, transform=transform, split="bymerge")
test_set = datasets.EMNIST('data/emnist', train=False, download=True, transform=transform, split="bymerge")
get_model = lambda: LeNet(num_classes)
torch.cuda.empty_cache()
metrics1 = learn(train_set, test_set, get_model, lr=0.0001, num_classes=num_classes, n_splits=5, n_epoch=15, batch_size=128)

                                                   

fold: 0 | epoch: 0 | train -> loss: 0.02735 | validation -> loss: 0.02702 | accuracy: 18.341988 | precision: 9.122836 | recall: 18.341988


                                                   

fold: 0 | epoch: 1 | train -> loss: 0.02697 | validation -> loss: 0.02696 | accuracy: 18.587402 | precision: 8.992884 | recall: 18.587402


                                                   

fold: 0 | epoch: 2 | train -> loss: 0.02668 | validation -> loss: 0.02654 | accuracy: 21.313921 | precision: 12.275155 | recall: 21.313921


                                                   

fold: 0 | epoch: 3 | train -> loss: 0.02650 | validation -> loss: 0.02651 | accuracy: 21.514383 | precision: 12.128469 | recall: 21.514383


                                                   

fold: 0 | epoch: 4 | train -> loss: 0.02648 | validation -> loss: 0.02650 | accuracy: 21.599560 | precision: 12.265759 | recall: 21.599560


                                                   

fold: 0 | epoch: 5 | train -> loss: 0.02641 | validation -> loss: 0.02639 | accuracy: 23.526196 | precision: 13.488587 | recall: 23.526196


                                                   

fold: 0 | epoch: 6 | train -> loss: 0.02626 | validation -> loss: 0.02616 | accuracy: 27.327808 | precision: 16.050409 | recall: 27.327808


                                                   

fold: 0 | epoch: 7 | train -> loss: 0.02610 | validation -> loss: 0.02610 | accuracy: 27.893269 | precision: 17.084431 | recall: 27.893269


                                                   

fold: 0 | epoch: 8 | train -> loss: 0.02603 | validation -> loss: 0.02604 | accuracy: 29.828989 | precision: 18.946533 | recall: 29.828989


                                                   

fold: 0 | epoch: 9 | train -> loss: 0.02600 | validation -> loss: 0.02601 | accuracy: 30.076653 | precision: 18.850233 | recall: 30.076653


                                                   

fold: 0 | epoch: 10 | train -> loss: 0.02599 | validation -> loss: 0.02602 | accuracy: 29.994738 | precision: 19.312304 | recall: 29.994738


                                                   

fold: 0 | epoch: 11 | train -> loss: 0.02590 | validation -> loss: 0.02583 | accuracy: 31.938194 | precision: 20.932442 | recall: 31.938194


                                                   

fold: 0 | epoch: 12 | train -> loss: 0.02580 | validation -> loss: 0.02582 | accuracy: 32.088028 | precision: 21.430876 | recall: 32.088028


                                                   

fold: 0 | epoch: 13 | train -> loss: 0.02579 | validation -> loss: 0.02581 | accuracy: 32.238354 | precision: 21.234137 | recall: 32.238354


                                                   

fold: 0 | epoch: 14 | train -> loss: 0.02578 | validation -> loss: 0.02580 | accuracy: 32.283741 | precision: 21.209789 | recall: 32.283741


                                                   

fold: 0 on Test Set | accuracy: 32.234604 | precision: 21.131992 | recall: 32.234604


                                                   

fold: 1 | epoch: 0 | train -> loss: 0.02802 | validation -> loss: 0.02763 | accuracy: 15.607301 | precision: 6.312460 | recall: 15.607301


                                                   

fold: 1 | epoch: 1 | train -> loss: 0.02714 | validation -> loss: 0.02699 | accuracy: 19.419592 | precision: 9.387092 | recall: 19.419592


                                                   

fold: 1 | epoch: 2 | train -> loss: 0.02674 | validation -> loss: 0.02649 | accuracy: 23.441086 | precision: 12.998218 | recall: 23.441086


                                                   

fold: 1 | epoch: 3 | train -> loss: 0.02645 | validation -> loss: 0.02645 | accuracy: 23.633272 | precision: 13.080868 | recall: 23.633272


                                                   

fold: 1 | epoch: 4 | train -> loss: 0.02638 | validation -> loss: 0.02632 | accuracy: 27.056700 | precision: 15.471964 | recall: 27.056700


                                                   

fold: 1 | epoch: 5 | train -> loss: 0.02626 | validation -> loss: 0.02626 | accuracy: 27.751087 | precision: 15.606232 | recall: 27.751087


                                                   

fold: 1 | epoch: 6 | train -> loss: 0.02625 | validation -> loss: 0.02625 | accuracy: 27.793594 | precision: 15.374073 | recall: 27.793594


                                                   

fold: 1 | epoch: 7 | train -> loss: 0.02616 | validation -> loss: 0.02613 | accuracy: 29.665110 | precision: 17.408503 | recall: 29.665110


                                                   

fold: 1 | epoch: 8 | train -> loss: 0.02603 | validation -> loss: 0.02603 | accuracy: 31.656343 | precision: 19.234150 | recall: 31.656343


                                                   

fold: 1 | epoch: 9 | train -> loss: 0.02601 | validation -> loss: 0.02601 | accuracy: 31.874895 | precision: 19.004425 | recall: 31.874895


                                                   

fold: 1 | epoch: 10 | train -> loss: 0.02601 | validation -> loss: 0.02598 | accuracy: 32.611675 | precision: 19.989004 | recall: 32.611675


                                                   

fold: 1 | epoch: 11 | train -> loss: 0.02596 | validation -> loss: 0.02596 | accuracy: 32.852890 | precision: 20.081728 | recall: 32.852890


                                                   

fold: 1 | epoch: 12 | train -> loss: 0.02594 | validation -> loss: 0.02594 | accuracy: 32.904476 | precision: 20.253962 | recall: 32.904476


                                                   

fold: 1 | epoch: 13 | train -> loss: 0.02585 | validation -> loss: 0.02576 | accuracy: 34.748085 | precision: 22.827801 | recall: 34.748085


                                                   

fold: 1 | epoch: 14 | train -> loss: 0.02574 | validation -> loss: 0.02575 | accuracy: 34.821667 | precision: 22.574673 | recall: 34.821667


                                                   

fold: 1 on Test Set | accuracy: 34.830250 | precision: 22.532236 | recall: 34.830250


                                                   

fold: 2 | epoch: 0 | train -> loss: 0.02750 | validation -> loss: 0.02716 | accuracy: 18.406815 | precision: 8.128314 | recall: 18.406815


                                                   

fold: 2 | epoch: 1 | train -> loss: 0.02691 | validation -> loss: 0.02675 | accuracy: 20.710331 | precision: 10.057410 | recall: 20.710331


                                                   

fold: 2 | epoch: 2 | train -> loss: 0.02671 | validation -> loss: 0.02673 | accuracy: 20.794176 | precision: 10.499530 | recall: 20.794176


                                                   

fold: 2 | epoch: 3 | train -> loss: 0.02638 | validation -> loss: 0.02631 | accuracy: 22.975605 | precision: 12.948995 | recall: 22.975605


                                                   

fold: 2 | epoch: 4 | train -> loss: 0.02620 | validation -> loss: 0.02612 | accuracy: 27.187634 | precision: 16.483290 | recall: 27.187634


                                                   

fold: 2 | epoch: 5 | train -> loss: 0.02603 | validation -> loss: 0.02600 | accuracy: 28.182903 | precision: 17.868036 | recall: 28.182903


                                                   

fold: 2 | epoch: 6 | train -> loss: 0.02599 | validation -> loss: 0.02599 | accuracy: 28.294840 | precision: 18.154154 | recall: 28.294840


                                                   

fold: 2 | epoch: 7 | train -> loss: 0.02592 | validation -> loss: 0.02586 | accuracy: 30.218098 | precision: 19.427280 | recall: 30.218098


                                                   

fold: 2 | epoch: 8 | train -> loss: 0.02580 | validation -> loss: 0.02574 | accuracy: 32.277996 | precision: 21.781616 | recall: 32.277996


                                                   

fold: 2 | epoch: 9 | train -> loss: 0.02572 | validation -> loss: 0.02572 | accuracy: 32.439133 | precision: 21.428812 | recall: 32.439133


                                                   

fold: 2 | epoch: 10 | train -> loss: 0.02571 | validation -> loss: 0.02572 | accuracy: 32.492004 | precision: 21.874613 | recall: 32.492004


                                                   

fold: 2 | epoch: 11 | train -> loss: 0.02570 | validation -> loss: 0.02571 | accuracy: 32.423065 | precision: 22.017782 | recall: 32.423065


                                                   

fold: 2 | epoch: 12 | train -> loss: 0.02569 | validation -> loss: 0.02571 | accuracy: 32.659462 | precision: 21.653976 | recall: 32.659462


                                                   

fold: 2 | epoch: 13 | train -> loss: 0.02569 | validation -> loss: 0.02570 | accuracy: 32.572643 | precision: 21.953648 | recall: 32.572643


                                                   

fold: 2 | epoch: 14 | train -> loss: 0.02568 | validation -> loss: 0.02569 | accuracy: 32.658787 | precision: 21.712940 | recall: 32.658787


                                                   

fold: 2 on Test Set | accuracy: 32.624840 | precision: 21.582094 | recall: 32.624840


                                                   

fold: 3 | epoch: 0 | train -> loss: 0.02808 | validation -> loss: 0.02786 | accuracy: 14.412414 | precision: 4.972749 | recall: 14.412414


                                                   

fold: 3 | epoch: 1 | train -> loss: 0.02755 | validation -> loss: 0.02745 | accuracy: 16.622438 | precision: 6.532790 | recall: 16.622438


                                                   

fold: 3 | epoch: 2 | train -> loss: 0.02706 | validation -> loss: 0.02651 | accuracy: 22.682133 | precision: 12.973679 | recall: 22.682133


                                                   

fold: 3 | epoch: 3 | train -> loss: 0.02646 | validation -> loss: 0.02644 | accuracy: 22.999903 | precision: 13.246221 | recall: 22.999903


                                                   

fold: 3 | epoch: 4 | train -> loss: 0.02632 | validation -> loss: 0.02622 | accuracy: 27.045954 | precision: 15.865781 | recall: 27.045954


                                                   

fold: 3 | epoch: 5 | train -> loss: 0.02619 | validation -> loss: 0.02619 | accuracy: 27.332306 | precision: 16.009150 | recall: 27.332306


                                                   

fold: 3 | epoch: 6 | train -> loss: 0.02617 | validation -> loss: 0.02619 | accuracy: 27.414089 | precision: 16.053265 | recall: 27.414089


                                                   

fold: 3 | epoch: 7 | train -> loss: 0.02616 | validation -> loss: 0.02617 | accuracy: 27.511417 | precision: 16.029608 | recall: 27.511417


                                                   

fold: 3 | epoch: 8 | train -> loss: 0.02615 | validation -> loss: 0.02617 | accuracy: 27.501827 | precision: 16.109535 | recall: 27.501827


                                                   

fold: 3 | epoch: 9 | train -> loss: 0.02602 | validation -> loss: 0.02599 | accuracy: 29.483908 | precision: 17.974833 | recall: 29.483908


                                                   

fold: 3 | epoch: 10 | train -> loss: 0.02596 | validation -> loss: 0.02598 | accuracy: 29.607046 | precision: 18.014488 | recall: 29.607046


                                                   

fold: 3 | epoch: 11 | train -> loss: 0.02595 | validation -> loss: 0.02597 | accuracy: 29.527071 | precision: 18.303104 | recall: 29.527071


                                                   

fold: 3 | epoch: 12 | train -> loss: 0.02594 | validation -> loss: 0.02596 | accuracy: 29.627129 | precision: 18.200047 | recall: 29.627129


                                                   

fold: 3 | epoch: 13 | train -> loss: 0.02593 | validation -> loss: 0.02596 | accuracy: 29.734716 | precision: 17.925053 | recall: 29.734716


                                                   

fold: 3 | epoch: 14 | train -> loss: 0.02592 | validation -> loss: 0.02586 | accuracy: 31.542686 | precision: 19.384331 | recall: 31.542686


                                                   

fold: 3 on Test Set | accuracy: 31.506607 | precision: 19.492065 | recall: 31.506607


                                                   

fold: 4 | epoch: 0 | train -> loss: 0.02759 | validation -> loss: 0.02717 | accuracy: 17.987080 | precision: 8.526007 | recall: 17.987080


                                                   

fold: 4 | epoch: 1 | train -> loss: 0.02707 | validation -> loss: 0.02703 | accuracy: 18.756433 | precision: 9.002375 | recall: 18.756433


                                                   

fold: 4 | epoch: 2 | train -> loss: 0.02677 | validation -> loss: 0.02661 | accuracy: 21.285067 | precision: 11.624337 | recall: 21.285067


                                                   

fold: 4 | epoch: 3 | train -> loss: 0.02627 | validation -> loss: 0.02606 | accuracy: 26.991356 | precision: 16.443027 | recall: 26.991356


                                                   

fold: 4 | epoch: 4 | train -> loss: 0.02600 | validation -> loss: 0.02591 | accuracy: 29.301399 | precision: 18.351898 | recall: 29.301399


                                                   

fold: 4 | epoch: 5 | train -> loss: 0.02591 | validation -> loss: 0.02588 | accuracy: 29.486652 | precision: 18.400372 | recall: 29.486652


                                                   

fold: 4 | epoch: 6 | train -> loss: 0.02586 | validation -> loss: 0.02575 | accuracy: 31.735485 | precision: 20.722639 | recall: 31.735485


                                                   

fold: 4 | epoch: 7 | train -> loss: 0.02574 | validation -> loss: 0.02573 | accuracy: 31.965664 | precision: 21.421337 | recall: 31.965664


                                                   

fold: 4 | epoch: 8 | train -> loss: 0.02573 | validation -> loss: 0.02571 | accuracy: 32.120365 | precision: 20.888824 | recall: 32.120365


                                                   

fold: 4 | epoch: 9 | train -> loss: 0.02572 | validation -> loss: 0.02573 | accuracy: 32.121719 | precision: 21.172560 | recall: 32.121719


                                                   

fold: 4 | epoch: 10 | train -> loss: 0.02571 | validation -> loss: 0.02571 | accuracy: 32.155289 | precision: 21.410254 | recall: 32.155289


                                                   

fold: 4 | epoch: 11 | train -> loss: 0.02570 | validation -> loss: 0.02570 | accuracy: 32.317917 | precision: 21.300310 | recall: 32.317917


                                                   

fold: 4 | epoch: 12 | train -> loss: 0.02570 | validation -> loss: 0.02570 | accuracy: 32.239120 | precision: 21.506432 | recall: 32.239120


                                                   

fold: 4 | epoch: 13 | train -> loss: 0.02569 | validation -> loss: 0.02569 | accuracy: 32.352371 | precision: 21.273664 | recall: 32.352371


                                                   

fold: 4 | epoch: 14 | train -> loss: 0.02569 | validation -> loss: 0.02568 | accuracy: 32.381840 | precision: 21.217213 | recall: 32.381840


                                                   

fold: 4 on Test Set | accuracy: 32.419716 | precision: 21.245306 | recall: 32.419716


In [39]:
torch.cuda.empty_cache() # 203.9s 
transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
num_classes = 47
# train_set = datasets.EMNIST('data/emnist', train=True, download=True, transform=transform, split="bymerge")
test_set = datasets.EMNIST('data/emnist', train=False, download=True, transform=transform, split="bymerge")
get_model = lambda: ANN(num_classes)
metrics = learn(train_set, test_set, get_model, lr=0.0001, num_classes=num_classes, n_splits=5, n_epoch=15, batch_size=128)

                                                   

fold: 0 | epoch: 0 @ 208.0s | train -> loss: 0.02678 | validation -> loss: 0.02642 | accuracy: 22.714008 | precision: 13.549365 | recall: 22.714008


                                                   

fold: 0 | epoch: 1 @ 207.3s | train -> loss: 0.02634 | validation -> loss: 0.02630 | accuracy: 24.691383 | precision: 14.776370 | recall: 24.691383


                                                   

fold: 0 | epoch: 2 @ 230.1s | train -> loss: 0.02605 | validation -> loss: 0.02595 | accuracy: 30.613312 | precision: 18.663040 | recall: 30.613312


                                                   

fold: 0 | epoch: 3 @ 238.1s | train -> loss: 0.02579 | validation -> loss: 0.02574 | accuracy: 36.713905 | precision: 23.077812 | recall: 36.713905


                                                   

fold: 0 | epoch: 4 @ 236.7s | train -> loss: 0.02558 | validation -> loss: 0.02544 | accuracy: 40.911674 | precision: 27.811420 | recall: 40.911674


                                                   

fold: 0 | epoch: 5 @ 232.0s | train -> loss: 0.02537 | validation -> loss: 0.02536 | accuracy: 41.860451 | precision: 28.400242 | recall: 41.860451


                                                   

fold: 0 | epoch: 6 @ 211.6s | train -> loss: 0.02529 | validation -> loss: 0.02528 | accuracy: 43.928158 | precision: 29.716471 | recall: 43.928158


                                                   

fold: 0 | epoch: 7 @ 214.5s | train -> loss: 0.02524 | validation -> loss: 0.02524 | accuracy: 44.469810 | precision: 30.327583 | recall: 44.469810


                                                   

fold: 0 | epoch: 8 @ 215.3s | train -> loss: 0.02521 | validation -> loss: 0.02522 | accuracy: 45.463711 | precision: 31.265331 | recall: 45.463711


                                                   

fold: 0 | epoch: 9 @ 209.4s | train -> loss: 0.02516 | validation -> loss: 0.02518 | accuracy: 46.214897 | precision: 31.832987 | recall: 46.214897


                                                   

fold: 0 | epoch: 10 @ 216.8s | train -> loss: 0.02514 | validation -> loss: 0.02516 | accuracy: 46.430813 | precision: 32.046337 | recall: 46.430813


                                                   

fold: 0 | epoch: 11 @ 203.1s | train -> loss: 0.02513 | validation -> loss: 0.02515 | accuracy: 46.579929 | precision: 31.970221 | recall: 46.579929


                                                   

fold: 0 | epoch: 12 @ 213.4s | train -> loss: 0.02511 | validation -> loss: 0.02514 | accuracy: 46.978901 | precision: 32.063545 | recall: 46.978901


                                                   

fold: 0 | epoch: 13 @ 210.5s | train -> loss: 0.02510 | validation -> loss: 0.02514 | accuracy: 46.943275 | precision: 32.024693 | recall: 46.943275


                                                   

fold: 0 | epoch: 14 @ 211.8s | train -> loss: 0.02509 | validation -> loss: 0.02513 | accuracy: 46.886600 | precision: 32.391060 | recall: 46.886600


                                                   

fold: 0 on Test Set | accuracy: 46.902596 | precision: 32.389278 | recall: 46.902596


                                                   

fold: 1 | epoch: 0 @ 217.0s | train -> loss: 0.02702 | validation -> loss: 0.02658 | accuracy: 24.061451 | precision: 13.620013 | recall: 24.061451


                                                   

fold: 1 | epoch: 1 @ 210.4s | train -> loss: 0.02645 | validation -> loss: 0.02638 | accuracy: 26.499315 | precision: 15.255926 | recall: 26.499315


                                                   

fold: 1 | epoch: 2 @ 215.0s | train -> loss: 0.02600 | validation -> loss: 0.02574 | accuracy: 32.055916 | precision: 21.265087 | recall: 32.055916


                                                   

fold: 1 | epoch: 3 @ 214.8s | train -> loss: 0.02559 | validation -> loss: 0.02554 | accuracy: 34.563793 | precision: 23.819593 | recall: 34.563793


                                                   

fold: 1 | epoch: 4 @ 211.4s | train -> loss: 0.02544 | validation -> loss: 0.02532 | accuracy: 39.222355 | precision: 29.140354 | recall: 39.222355


                                                   

fold: 1 | epoch: 5 @ 212.0s | train -> loss: 0.02525 | validation -> loss: 0.02522 | accuracy: 41.458923 | precision: 31.202591 | recall: 41.458923


                                                   

fold: 1 | epoch: 6 @ 213.1s | train -> loss: 0.02518 | validation -> loss: 0.02518 | accuracy: 42.130249 | precision: 31.497129 | recall: 42.130249


                                                   

fold: 1 | epoch: 7 @ 214.8s | train -> loss: 0.02514 | validation -> loss: 0.02514 | accuracy: 43.386513 | precision: 32.488430 | recall: 43.386513


                                                   

fold: 1 | epoch: 8 @ 226.7s | train -> loss: 0.02510 | validation -> loss: 0.02509 | accuracy: 45.104958 | precision: 34.330864 | recall: 45.104958


                                                   

fold: 1 | epoch: 9 @ 226.6s | train -> loss: 0.02502 | validation -> loss: 0.02499 | accuracy: 47.526611 | precision: 36.278828 | recall: 47.526611


                                                   

fold: 1 | epoch: 10 @ 217.9s | train -> loss: 0.02496 | validation -> loss: 0.02497 | accuracy: 47.920219 | precision: 36.355801 | recall: 47.920219


                                                   

fold: 1 | epoch: 11 @ 210.3s | train -> loss: 0.02494 | validation -> loss: 0.02495 | accuracy: 48.111282 | precision: 36.568054 | recall: 48.111282


                                                   

fold: 1 | epoch: 12 @ 219.1s | train -> loss: 0.02488 | validation -> loss: 0.02487 | accuracy: 49.615742 | precision: 38.831013 | recall: 49.615742


                                                   

fold: 1 | epoch: 13 @ 213.2s | train -> loss: 0.02484 | validation -> loss: 0.02486 | accuracy: 50.024700 | precision: 39.274654 | recall: 50.024700


                                                   

fold: 1 | epoch: 14 @ 209.0s | train -> loss: 0.02482 | validation -> loss: 0.02484 | accuracy: 50.330673 | precision: 38.951546 | recall: 50.330673


                                                   

fold: 1 on Test Set | accuracy: 50.294029 | precision: 38.854324 | recall: 50.294029


                                                   

fold: 2 | epoch: 0 @ 215.1s | train -> loss: 0.02702 | validation -> loss: 0.02663 | accuracy: 21.975485 | precision: 11.837502 | recall: 21.975485


                                                   

fold: 2 | epoch: 1 @ 210.6s | train -> loss: 0.02656 | validation -> loss: 0.02650 | accuracy: 24.023680 | precision: 12.906257 | recall: 24.023680


                                                   

fold: 2 | epoch: 2 @ 210.0s | train -> loss: 0.02632 | validation -> loss: 0.02618 | accuracy: 29.421600 | precision: 17.112799 | recall: 29.421600


                                                   

fold: 2 | epoch: 3 @ 213.2s | train -> loss: 0.02602 | validation -> loss: 0.02595 | accuracy: 34.312881 | precision: 21.490847 | recall: 34.312881


                                                   

fold: 2 | epoch: 4 @ 213.5s | train -> loss: 0.02586 | validation -> loss: 0.02581 | accuracy: 37.038651 | precision: 24.040462 | recall: 37.038651


                                                   

fold: 2 | epoch: 5 @ 215.2s | train -> loss: 0.02572 | validation -> loss: 0.02569 | accuracy: 39.553658 | precision: 25.006548 | recall: 39.553658


                                                   

fold: 2 | epoch: 6 @ 210.4s | train -> loss: 0.02565 | validation -> loss: 0.02566 | accuracy: 39.882236 | precision: 25.481499 | recall: 39.882236


                                                   

fold: 2 | epoch: 7 @ 209.8s | train -> loss: 0.02562 | validation -> loss: 0.02564 | accuracy: 40.189491 | precision: 25.397305 | recall: 40.189491


                                                   

fold: 2 | epoch: 8 @ 208.5s | train -> loss: 0.02560 | validation -> loss: 0.02563 | accuracy: 40.242256 | precision: 25.644955 | recall: 40.242256


                                                   

fold: 2 | epoch: 9 @ 210.8s | train -> loss: 0.02559 | validation -> loss: 0.02562 | accuracy: 40.405998 | precision: 25.297853 | recall: 40.405998


                                                   

fold: 2 | epoch: 10 @ 216.7s | train -> loss: 0.02557 | validation -> loss: 0.02559 | accuracy: 41.432545 | precision: 26.223480 | recall: 41.432545


                                                   

fold: 2 | epoch: 11 @ 216.0s | train -> loss: 0.02555 | validation -> loss: 0.02558 | accuracy: 41.913269 | precision: 27.510130 | recall: 41.913269


                                                   

fold: 2 | epoch: 12 @ 215.8s | train -> loss: 0.02547 | validation -> loss: 0.02550 | accuracy: 43.472664 | precision: 28.539837 | recall: 43.472664


                                                   

fold: 2 | epoch: 13 @ 215.1s | train -> loss: 0.02544 | validation -> loss: 0.02548 | accuracy: 44.016735 | precision: 28.778435 | recall: 44.016735


                                                   

fold: 2 | epoch: 14 @ 214.5s | train -> loss: 0.02543 | validation -> loss: 0.02547 | accuracy: 44.042442 | precision: 28.189171 | recall: 44.042442


                                                   

fold: 2 on Test Set | accuracy: 44.119900 | precision: 28.257048 | recall: 44.119900


                                                   

fold: 3 | epoch: 0 @ 215.9s | train -> loss: 0.02694 | validation -> loss: 0.02661 | accuracy: 22.108498 | precision: 12.065256 | recall: 22.108498


                                                   

fold: 3 | epoch: 1 @ 209.4s | train -> loss: 0.02654 | validation -> loss: 0.02649 | accuracy: 24.008297 | precision: 13.027051 | recall: 24.008297


                                                   

fold: 3 | epoch: 2 @ 208.8s | train -> loss: 0.02639 | validation -> loss: 0.02625 | accuracy: 29.307064 | precision: 16.892931 | recall: 29.307064


                                                   

fold: 3 | epoch: 3 @ 208.1s | train -> loss: 0.02611 | validation -> loss: 0.02605 | accuracy: 31.635437 | precision: 18.952728 | recall: 31.635437


                                                   

fold: 3 | epoch: 4 @ 217.9s | train -> loss: 0.02581 | validation -> loss: 0.02563 | accuracy: 38.059231 | precision: 26.357922 | recall: 38.059231


                                                   

fold: 3 | epoch: 5 @ 235.9s | train -> loss: 0.02560 | validation -> loss: 0.02557 | accuracy: 38.721703 | precision: 26.652611 | recall: 38.721703


                                                   

fold: 3 | epoch: 6 @ 230.9s | train -> loss: 0.02556 | validation -> loss: 0.02554 | accuracy: 39.060905 | precision: 26.832935 | recall: 39.060905


                                                   

fold: 3 | epoch: 7 @ 226.5s | train -> loss: 0.02549 | validation -> loss: 0.02547 | accuracy: 41.113163 | precision: 28.738707 | recall: 41.113163


                                                   

fold: 3 | epoch: 8 @ 209.3s | train -> loss: 0.02545 | validation -> loss: 0.02544 | accuracy: 41.703636 | precision: 28.880219 | recall: 41.703636


                                                   

fold: 3 | epoch: 9 @ 210.3s | train -> loss: 0.02543 | validation -> loss: 0.02542 | accuracy: 41.888695 | precision: 28.727308 | recall: 41.888695


                                                   

fold: 3 | epoch: 10 @ 209.3s | train -> loss: 0.02540 | validation -> loss: 0.02537 | accuracy: 43.738998 | precision: 30.244049 | recall: 43.738998


                                                   

fold: 3 | epoch: 11 @ 227.8s | train -> loss: 0.02530 | validation -> loss: 0.02524 | accuracy: 45.850487 | precision: 32.094646 | recall: 45.850487


                                                   

fold: 3 | epoch: 12 @ 233.4s | train -> loss: 0.02521 | validation -> loss: 0.02522 | accuracy: 46.187489 | precision: 32.137215 | recall: 46.187489


                                                   

fold: 3 | epoch: 13 @ 237.0s | train -> loss: 0.02520 | validation -> loss: 0.02520 | accuracy: 46.319874 | precision: 32.241077 | recall: 46.319874


                                                   

fold: 3 | epoch: 14 @ 225.6s | train -> loss: 0.02518 | validation -> loss: 0.02518 | accuracy: 47.784245 | precision: 33.115532 | recall: 47.784245


                                                   

fold: 3 on Test Set | accuracy: 47.775288 | precision: 33.082180 | recall: 47.775288


                                                   

fold: 4 | epoch: 0 @ 236.7s | train -> loss: 0.02700 | validation -> loss: 0.02673 | accuracy: 20.119282 | precision: 10.850564 | recall: 20.119282


                                                   

fold: 4 | epoch: 1 @ 233.9s | train -> loss: 0.02637 | validation -> loss: 0.02613 | accuracy: 27.261076 | precision: 16.182947 | recall: 27.261076


                                                   

fold: 4 | epoch: 2 @ 215.7s | train -> loss: 0.02607 | validation -> loss: 0.02605 | accuracy: 27.941441 | precision: 16.480515 | recall: 27.941441


                                                   

fold: 4 | epoch: 3 @ 206.3s | train -> loss: 0.02596 | validation -> loss: 0.02589 | accuracy: 30.078354 | precision: 18.537308 | recall: 30.078354


                                                   

fold: 4 | epoch: 4 @ 211.3s | train -> loss: 0.02575 | validation -> loss: 0.02556 | accuracy: 36.531509 | precision: 24.742508 | recall: 36.531509


                                                   

fold: 4 | epoch: 5 @ 212.0s | train -> loss: 0.02552 | validation -> loss: 0.02551 | accuracy: 37.354641 | precision: 24.954288 | recall: 37.354641


                                                   

fold: 4 | epoch: 6 @ 231.3s | train -> loss: 0.02545 | validation -> loss: 0.02542 | accuracy: 39.375614 | precision: 26.490952 | recall: 39.375614


                                                   

fold: 4 | epoch: 7 @ 232.1s | train -> loss: 0.02538 | validation -> loss: 0.02539 | accuracy: 39.579880 | precision: 27.159859 | recall: 39.579880


                                                   

fold: 4 | epoch: 8 @ 237.7s | train -> loss: 0.02536 | validation -> loss: 0.02537 | accuracy: 40.078854 | precision: 26.673702 | recall: 40.078854


                                                   

fold: 4 | epoch: 9 @ 208.8s | train -> loss: 0.02534 | validation -> loss: 0.02535 | accuracy: 40.221718 | precision: 27.136772 | recall: 40.221718


                                                   

fold: 4 | epoch: 10 @ 211.9s | train -> loss: 0.02533 | validation -> loss: 0.02535 | accuracy: 40.029297 | precision: 27.341557 | recall: 40.029297


                                                   

fold: 4 | epoch: 11 @ 207.4s | train -> loss: 0.02532 | validation -> loss: 0.02533 | accuracy: 40.395840 | precision: 26.991066 | recall: 40.395840


                                                   

fold: 4 | epoch: 12 @ 211.2s | train -> loss: 0.02531 | validation -> loss: 0.02533 | accuracy: 40.437752 | precision: 27.299591 | recall: 40.437752


                                                   

fold: 4 | epoch: 13 @ 214.6s | train -> loss: 0.02530 | validation -> loss: 0.02532 | accuracy: 40.651974 | precision: 27.162842 | recall: 40.651974


                                                   

fold: 4 | epoch: 14 @ 233.5s | train -> loss: 0.02529 | validation -> loss: 0.02532 | accuracy: 40.548527 | precision: 27.195677 | recall: 40.548527


                                                   

fold: 4 on Test Set | accuracy: 40.543793 | precision: 27.072071 | recall: 40.543793


In [87]:
class CustomImages(Dataset):
    
    def __init__(self, root, transform=lambda x: x):
        self.transform = transform
        if len(root) == 0:
            raise ValueError()
        targets = []
        data = []
        for folder in tqdm(glob(root + "*")):
            label = int(folder.split("\\")[-1])
            files = glob(folder + "\\*")
            ims = [None] * len(files)
            for i, f in tqdm(enumerate(files), leave=False, desc=f"Label '{label}'"):
                ims[i] = read_image(f)[0]
            targets.extend([label]*len(files))
            data.extend(ims)

        self.targets = torch.tensor(targets)
        self.data = torch.stack(data)
        self.data = self.data.reshape(-1, 1, *self.data.shape[1:]).to(torch.float)
    
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        return self.transform(self.data[index]), self.targets[index]

In [88]:
torch.cuda.empty_cache() # 414
transform = transforms.Compose([transforms.Resize((input_size, input_size), antialias=None), transforms.Normalize(mean=(0.5,), std=(0.5,))])
transform_test = transforms.Compose([transforms.Resize((input_size, input_size), antialias=None), transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])
train_set = CustomImages("data\\augmented\\acgan\\", transform=transform)
print("loaded train set")
test_set = datasets.EMNIST('data/emnist', train=False, download=True, transform=transform_test, split="bymerge")
print("loaded test set")
get_model = lambda: LeNet(num_classes)

metrics_cgan = learn(train_set, test_set, get_model, lr=0.0001, num_classes=num_classes, n_splits=5, n_epoch=15, batch_size=128)

 62%|██████▏   | 29/47 [03:08<04:04, 13.59s/it]