In [None]:
# future
from __future__ import print_function

In [None]:
# third party
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets
from torchvision import transforms

In [None]:
# let's prepare parameters
class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.batch_size = 64
#         self.epochs = 3
        self.epochs = 14
        self.lr = 1.0
        self.gamma = 0.7
        self.no_cuda = False
        self.dry_run = False
        self.seed = 42
        self.log_interval = 100
        self.save_model = True
        self.test_batch_size = 1000
        
args = Args()

# check it
args.test_batch_size

In [None]:
# we use cuda
use_cuda = True

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")


## datasets

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)


In [None]:
len(dataset1), len(dataset2)

In [None]:
# add some other params for dataloaders

train_kwargs = {"batch_size": args.batch_size}
test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
    
train_kwargs


In [None]:
# prepare data loader

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

## architecture

In [None]:
# architecture

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


In [None]:
model = Net().to(device)

optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)


## training loop

In [None]:
# training loop

for epoch in range(1, args.epochs + 1):

    #     train(args, model, device, train_loader, optimizer, epoch)
    # training 
    
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
            if args.dry_run:
                break

    # test(model, device, test_loader)
    
    # validation
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )
    
    # update 
    scheduler.step()
    
    
    

In [None]:
# save model

if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")
    

In [None]:
# load and use it

## How to train a model with MY DATA!

In [None]:
# 
# https://github.com/myleott/mnist_png

In [None]:
# from torchaudio.datasets.utils import walk_files

from typing import Any, Iterable, List, Optional, Tuple, Union

def walk_files(root: str,
               suffix: Union[str, Tuple[str]],
               prefix: bool = False,
               remove_suffix: bool = False) -> Iterable[str]:
    """List recursively all files ending with a suffix at a given root
    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the full path to each result, otherwise
            only returns the name of the files found (Default: ``False``)
        remove_suffix (bool, optional): If true, removes the suffix to each result defined in suffix,
            otherwise will return the result as found (Default: ``False``).
    """

    root = os.path.expanduser(root)

    for dirpath, dirs, files in os.walk(root):
        dirs.sort()
        # `dirs` is the list used in os.walk function and by sorting it in-place here, we change the
        # behavior of os.walk to traverse sub directory alphabetically
        # see also
        # https://stackoverflow.com/questions/6670029/can-i-force-python3s-os-walk-to-visit-directories-in-alphabetical-order-how#comment71993866_6670926
        files.sort()
        for f in files:
            if f.endswith(suffix):

                if remove_suffix:
                    f = f[: -len(suffix)]

                if prefix:
                    f = os.path.join(dirpath, f)

                yield f


In [None]:
import os

walker = walk_files(
    "/disk2/data/mnist_png/mnist_png/training",
    suffix="png",
    prefix=True,
    remove_suffix=False
)
_walker = list(walker)

In [None]:
# _walker

In [None]:
from torch.utils.data import Dataset

In [None]:
# check datasets1
dataset1[0][0].shape, dataset1[0][1]

In [None]:
from PIL import Image

# transform = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
# )

class MyDataset(Dataset):
    def __init__(self, data_list):
        """
        MyDataset based on Dataset
        """
        super(MyDataset, self).__init__()
        self.data_list = data_list
        self.toTensor = transforms.ToTensor()
        self.normalize = transforms.Normalize((0.1307,), (0.3081,))
        
    def __getitem__(self, index):
        
        # get path
        _path = self.data_list[index]
        
        # get label
        _label = _path.split("/")[-2]
        
        # read image 
        img = Image.open(_path)
        
        # apply transforms
        img = self.toTensor(img)
        img = self.normalize(img)
         
        return img, int(_label)

    def __len__(self):
        return len(self.data_list)


In [None]:
my_dataset = MyDataset(_walker)

In [None]:
len(my_dataset)

In [None]:
my_dataset[0][0].shape, my_dataset[0][1]

In [None]:
my_train_loader = torch.utils.data.DataLoader(my_dataset, **train_kwargs)

In [None]:
# train block
def train(args, model, device, train_loader, optimizer, epoch):
    
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
            if args.dry_run:
                break
                
                

In [None]:
# test block
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )



In [None]:
# training loop 
for epoch in range(1, args.epochs + 1):
    train(args, model, device, my_train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()
