In [None]:
from google.colab import files,drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tensorflow as tf

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


'''
print("Tensorflow version " + tf.__version__)


#Using TPU --> Please Enable TPU by going into Edit-> Notebook settings -> Select TPU as hardware accelerator
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the comment above for instructions!')
'''


##Define NN Model

In [None]:
class Net(nn.Module):
  
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels = 1,out_channels = 6, kernel_size=(5,5))     # Convolution Layer 1
    self.pool = nn.MaxPool2d(2,2)                                                   # Pooling layer
    self.conv2 = nn.Conv2d(in_channels = 6,out_channels = 16, kernel_size=(5,5))    # Convolution Layer 2
    self.fc1 = nn.Linear(16*4*4,120)                                                # FFN Model - Layer 1
    self.fc2 = nn.Linear(120,84)                                                    # FFN Model - Layer 2  
    self.fc3 = nn.Linear(84,10)                                                     # FFN Model - Ouput Layer
    self.do1 = nn.Dropout(p=0.7, inplace=False)                                     # Dropout Layer 1
    self.do2 = nn.Dropout(p=0.7, inplace=False)                                     # Dropout Layer 2

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = self.do1(x)   
    x = F.relu(self.fc2(x))
    x = self.do2(x) 
    x = F.log_softmax(self.fc3(x))
    return x

  def num_flat_features(self,x):
    
    size = x.size()[1:]
    num_features = 1
    
    for s in size:
      num_features *= s
    
    return num_features


#My NN Model
net = Net().to(device)
print(net)

##Define function to create custom dataset and override inbuilt functions

In [None]:
# create customized dataset

import os
import glob
import numpy as np
from skimage import io

from torch.utils.data import Dataset, DataLoader

##Override 

class MNISTDataset(Dataset):

  def __init__(self,dir,transform = None):
    self.dir = dir
    self.transform = transform

  def __len__(self):
    files = glob.glob(self.dir+'/*.jpg')[:100]
    return len(files)

  def __getitem__(self,idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    all_instances = glob.glob(self.dir+'/*.jpg')[:100] # returns list of file names 
    img_fname = os.path.join(self.dir, all_instances[idx])
    image = io.imread(img_fname)
    digit = int(self.dir.split('/')[-1].strip())
    label = np.array(digit)

    instance = {'image':image, 'label':label}

    if self.transform:
      instance = self.transform(instance)

    return instance


In [None]:
# create or define a customized transformation for each instance in the dataset

from skimage import transform
from torchvision import transforms, utils

class Rescale(object):
  def __init__(self,output_size):
    assert isinstance(output_size, (int,tuple))
    self.output_size=output_size

  def __call__(self,sample):
    image, label = sample['image'], sample['label']

    h, w = image.shape[-2:]
    if isinstance(self.output_size, int):
      if h > w:
        new_h, new_w = self.output_size*h/w, self.output_size
      else:
        new_h, new_w = self.output_size, self.output_size*w/h
    else:
      new_h, new_w = self.output_size

    new_h, new_w = int(new_h), int(new_w)

    new_image = transform.resize(image, (new_h, new_w))

    return {'image': new_image, 'label': label}


class ToTensor(object):
  def __call__(self, sample):
    image, label = sample['image'], sample['label']

    image = image.reshape((1,image.shape[0], image.shape[1]))

    return {'image': torch.from_numpy(image), 'label': torch.from_numpy(label)}

## Import MNIST data and concat together in one dataset

In [None]:
#create train/val dataloader
from keras.datasets import mnist
from torch.utils.data import random_split
from torchvision import transforms,utils

batch_size = 32
list_datasets = []

for i in range(10):
  cur_ds = MNISTDataset('/content/drive/My Drive/MNIST/trainingset/'+str(i), transform = transforms.Compose([Rescale(28),ToTensor()]))
  list_datasets.append(cur_ds)    

dataset = torch.utils.data.ConcatDataset(list_datasets)
print(len(dataset))

##Create Dataset and Dataloader objects


In [None]:
train_size = int(len(dataset)*0.7)
val_size = len(dataset)-train_size

train_dataset, val_dataset = random_split(dataset,[train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 0)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True, num_workers = 0)

## Training and Validation Function

In [None]:
#training

epochs = 10
lr  = 1e-3
optimizer = optim.Adam(net.parameters(1), lr=lr, weight_decay = 1e-5)
criterion = nn.CrossEntropyLoss()

running_loss = 0.0

for epoch in range(epochs):
  for batch_idx, batch in enumerate(train_dataloader):
    inputs, targets = batch['image'].to(device, dtype=torch.float), batch['label'].to(device, dtype=torch.long)

    optimizer.zero_grad()
    predicted_output = net(inputs)
    loss = criterion(predicted_output, targets)
    loss.backward()
    optimizer.step()
    

    running_loss += loss.item()
    if(batch_idx+1)%10 == 0:
      print('epoch %d batch %d, training_loss:%.3f'%(epoch+1, batch_idx+1, running_loss/10))
      running_loss = 0.0

  #validation

  net.eval()

  correct = [0.0] * 10
  total = [0.0] * 10


  with torch.no_grad():
    for batch_idx, batch in enumerate(val_dataloader):
      images, labels = batch['image'].to(device, dtype=torch.float), batch['label'].to(device, dtype=torch.long)
      predicted_outputs = net(images)

      _, predicted_labels = torch.max(predicted_outputs, 1)
      c = (predicted_labels==labels)
      
      for i in range(len(labels)):
        label = labels[i]
        correct[label] += c[i].item()
        total[label] += 1
      
    for i in range(10):
      print('\t Validation accuracy for digit %d: %.2f' %(i, 100*correct[i]/total[i]))