In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torchvision.transforms import ToPILImage

from dataloader import UTKFaceDS, MNISTDS, parse_utkface, parse_mnist
from model import ResNetModel

  from .collection import imread_collection_wrapper


## Dataset

### UTKFace

In [2]:
data = UTKFaceDS(dataset_folder_name="../data/UTKFace/Images",
                 parse_method=parse_utkface,
                 model_res=(224, 224),
                 batch_size=64,
                 test_size=0.3
                )
face_dataset = data.dataset

In [None]:
to_pil_img = ToPILImage()
start = 18
for i in range(start,len(face_dataset)):
    sample = face_dataset[i]
    label = sample['label']

    print(i, np.array(sample['image']).shape, label)
    display(to_pil_img(sample['image']))

    if i == start+2:
        break

### MNIST

In [2]:
data = MNISTDS(dataset_folder_name="../data/MNIST/Images",
             parse_method=parse_mnist,
             model_res=(224, 224),
             batch_size=64,
             test_size=0.3
            )
mnist_dataset = data.dataset
# mnist_dataset.df

In [None]:
to_pil_img = ToPILImage()
start = 30000
for i in range(start,len(mnist_dataset)):
    sample = mnist_dataset[i]
    label = sample['label']

    print(i, np.array(sample['image']).shape, label)
    display(to_pil_img(sample['image']))

    if i == start+2:
        break

## Model

In [3]:
resnet = ResNetModel(in_ch=1, out_f=10, n_epochs=3, use_gpu=True)

In [6]:
resnet.model

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
batch_epochs, accuracies, losses = resnet.fit(data.dataloaders)

Training Epochs:   0%|                                                   | 0/3 [00:00<?, ?it/s]

In [None]:
# Accuracy plot
plt.plot(batch_epochs, accuracies, 'r', label='Accuracy')
plt.title('Accuracy')
plt.xlabel('Batch-Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()

In [None]:
# Loss plot
plt.plot(batch_epochs, losses, 'b', label='Loss')
plt.title('Loss')
plt.xlabel('Batch-Epoch')
plt.ylabel('Loss (%)')
plt.legend()
plt.show()

In [None]:
resnet.predict(data.dataloaders # And so it dies here too :(

Batch[0]
