In [None]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

import torch
from torch.utils.data import Dataset, DataLoader
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

In [None]:
class CustomDataset(Dataset):
    def __init__(self, name, phase):
        self.n_classes = 10
        if name == 'Cells': self.n_classes = 4
        if phase == 'train':
            self.n_imgs_per_class = 1000
            if name == 'Cells': self.n_imgs_per_class = 500
            self.root_path = '/home/cicconet/Development/MachineLearning/%s/Train' % name
        elif phase == 'test':
            self.n_imgs_per_class = 100
            if name == 'Cells': self.n_imgs_per_class = 200
            self.root_path = '/home/cicconet/Development/MachineLearning/%s/Test' % name
        
    def __len__(self):
        return self.n_classes*self.n_imgs_per_class

    def __getitem__(self, idx):
        folder_idx = int(idx/self.n_imgs_per_class)
        img_idx = idx-folder_idx*self.n_imgs_per_class
        img_path = self.root_path+'/{}/Image{:05d}.png'.format(folder_idx, img_idx)
        
        image = io.imread(img_path)
        image = torch.unsqueeze(torch.from_numpy(image), 0).float()/255
        sample = {'image': image, 'label': folder_idx}

        return sample

In [None]:
ds_name = 'Cells'

train_dataset = CustomDataset(ds_name, 'train')
test_dataset = CustomDataset(ds_name, 'test')
rand_i = np.random.randint(len(train_dataset), size=10)
for i in rand_i:
    item = train_dataset[i]
    img, lbl = item['image'], item['label']
    print(i, img.shape, img.dtype, img.max(), lbl)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4,  shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4,  shuffle=True, num_workers=4)

In [None]:
fig = plt.figure()
for idx, batch in enumerate(train_loader):
    imgs, lbls = batch['image'], batch['label']
    print(idx, imgs[0].numpy().squeeze().max(), lbls[0])
    ax = plt.subplot(1,5,idx+1)
    ax.imshow(imgs[0].numpy().squeeze())
    ax.axis('off')
    plt.tight_layout()
    plt.title(batch['label'][0].item())
    if idx == 4:
        break

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) # n chan in, n chan out, kernel size
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
data_iter = iter(train_loader)
batch = data_iter.next()

imgs, lbls = batch['image'], batch['label']
print(imgs.dtype, imgs.shape, lbls)

net = Net()
print(net(imgs).shape)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
running_loss = 0.0
for epoch in range(20):
    for i, batch in enumerate(train_loader):
        imgs, lbls = batch['image'], batch['label']

        optimizer.zero_grad()

        pred = net(imgs)
        loss = criterion(pred, lbls)
        loss.backward()
        optimizer.step()

        running_loss = 0.5*running_loss+0.5*loss.item()
        if i % 500 == 499:
            print('epoch', epoch+1, 'batch', i+1, 'loss', running_loss)

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        imgs, lbls = batch['image'], batch['label']
        pred = net(imgs)
        mx, imx = torch.max(pred,1)
        total += len(lbls)
        correct += (imx == lbls).sum().item()

print('test accuracy', correct / total)