### Importing Libraries

In [1]:
import numpy as np
import pandas as pd

import torch.nn as nn 
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from itertools import product

### Creating Pytorch Dataset

In [2]:
rootdir = "flowers"
flower_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
flower_dataset = torchvision.datasets.ImageFolder(rootdir, transform=flower_transform)

### Define Parameters

In [3]:
batch_size = 8
valid_split = 0.2
shuffle_dataset = True
random_seed = 42
flower_dataset.classes

['daisy', 'dandelion', 'flowers', 'rose', 'sunflower', 'tulip']

### Define Trainloader to create batches of input

In [4]:
dataset_size = len(flower_dataset)
indices = list(range(dataset_size))
split = int(np.floor(valid_split*dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    
train_indices, val_indices = indices[split:],indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler = train_sampler)
valid_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler = valid_sampler)

### Display Image Data

In [5]:
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt

train_batch, label_train = next(iter(train_loader))
valid_batch, label_valid = next(iter(valid_loader))

def img_plotter(batch,rows=8,cols=8):
    fig,axs = plt.subplots(nrows=rows,ncols=cols,figsize=(30,18))
    for i in range(rows):
        for j in range(cols):
            axs[i,j].imshow(batch[rows*j+i].permute(1,2,0))

In [6]:
# img_plotter(train_batch, rows=2, cols=4)

In [7]:
# img_plotter(valid_batch,rows=2,cols=4)

In [8]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [23]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=7, padding=3)
        self.conv3 = nn.Conv2d(in_channels=9, out_channels=12, kernel_size=7, padding=3)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=15, kernel_size=5, padding=2)
        self.conv5 = nn.Conv2d(in_channels=15, out_channels=18, kernel_size=5, padding=2)
        self.conv6 = nn.Conv2d(in_channels=18, out_channels=24, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(in_features=24*7*7, out_features=100)
        self.fc2 = nn.Linear(in_features=100, out_features=25)
        self.out = nn.Linear(in_features=25, out_features=6)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv2(F.relu(self.conv1(x)))), kernel_size=3, stride=3)  # Layer 1
        print(x.shape)
        x = F.max_pool2d(F.relu(self.conv4(F.relu(self.conv3(x)))), kernel_size=3, stride=3)  # Layer 2
        print(x.shape)
        x = F.max_pool2d(F.relu(self.conv6(F.relu(self.conv5(x)))), kernel_size=2, stride=2)  # Layer 3
        print(x.shape)
        x = x.reshape(-1, 24*7*7)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        
        return x

In [24]:
parameters = dict(
batch_size = [100],
lr = [0.01],
shuffle = [True]
)
param_values = [v for v in parameters.values()]

for batch_size, lr, shuffle in product(*param_values):
    network = Network().to('cuda')
#     for n,p in network.named_parameters(): 
#         print(p.device)
    train_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler = train_sampler)
    optimizer = optim.Adam(network.parameters(), lr=lr)

    comment = f'FlowerNet batch_size={batch_size} lr={lr} shuffle={shuffle}'
    tb = SummaryWriter(comment=comment)

    for epoch in range(5):
        total_loss=0
        total_correct=0

        for batch in train_loader:
            images = batch[0].to('cuda')
            labels = batch[1].to('cuda')
            preds = network(images)

            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            total_loss+=loss.item() * batch_size
            total_correct+=get_num_correct(preds, labels)

        tb.add_scalar('Loss', total_loss, epoch)
        tb.add_scalar('Number Correct', total_correct, epoch)
        tb.add_scalar('Accuracy', total_correct/dataset_size, epoch)

        tb.add_histogram('conv1.bias', network.conv1.bias, epoch)
        tb.add_histogram('conv1.weight', network.conv1.weight, epoch)
        tb.add_histogram('conv1.weight.grad', network.conv1.weight.grad, epoch)

        print("epoch:", epoch, "total_correct:", total_correct, "loss:", total_loss, "Accuracy:", total_correct/dataset_size)

    tb.close()

torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([10

torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([10

torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([10

torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([100, 15, 14, 14])
torch.Size([100, 24, 7, 7])
torch.Size([100, 9, 42, 42])
torch.Size([10