In [1]:
import torch 
import torch.nn as nn

In [2]:
import torchvision
import torchvision.transforms as transforms
from PIL import Image

Image.MAX_IMAGE_PIXELS = None
def get_data_loader(data_dir= "data", batch_size=1, train = True):
    """
    Define the way we compose the batch dataset including the augmentation for increasing the number of data
    and return the augmented batch-dataset
    :param data_dir: root directory where the either train or test dataset is
    :param batch_size: size of the batch
    :param train: true if current phase is training, else false
    :return: augmented batch dataset
    """

    # define how we augment the data for composing the batch-dataset in train and test step
    transform = {
        'train': transforms.Compose([
            transforms.Resize([128, 128]), # Resizing the image as the VGG only take 224 x 244 as input size
            transforms.RandomHorizontalFlip(), # Flip the data horizontally
            #TODO if it is needed, add the random crop
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5), std=(0.5))
        ]),
        'test': transforms.Compose([
            transforms.Resize([128, 128]),
            transforms.RandomHorizontalFlip(),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ])
    }

    # ImageFloder with root directory and defined transformation methods for batch as well as data augmentation
    data = torchvision.datasets.ImageFolder(root=data_dir, transform=transform['train'] if train else 'test')
    print(len(data))
    train_size = int(0.9* len(data))
    test_size = len(data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
    data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    return data_loader, test_data_loader

In [3]:
data_loader, test_data_loader = get_data_loader()
print(len(data_loader))

264
237


In [4]:
import torch.nn as nn
import torch.nn.functional as F


class Discriminator(nn.Module):
    def __init__(self, in_feature_size = 118, batch_size = 1, devices = 1):
        
        super(Discriminator, self).__init__()
        self.batch_size = batch_size
        self.devices = devices
        self.conv1 = nn.Conv2d(1, 64, 3)
        self.bn1 = nn.BatchNorm2d(64, affine=False)
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.bn2 = nn.BatchNorm2d(128, affine=False)
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.bn3 = nn.BatchNorm2d(256, affine=False)
        self.conv4 = nn.Conv2d(256, 512, 3)
        self.bn4 = nn.BatchNorm2d(512, affine=False)
        self.conv5 = nn.Conv2d(512, 1, 3)
        self.fc1 = nn.Linear(in_feature_size * in_feature_size , 64)
        self.fc2 = nn.Linear(64, 8)
        self.fc3 = nn.Linear(8, 2)
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.conv5(x)
        x = x.view(self.batch_size , -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

device = "cuda"
net = Discriminator()
net = nn.DataParallel(net, device_ids=[0, 1, 2])
net.to(device)


DataParallel(
  (module): Discriminator(
    (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (conv5): Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=13924, out_features=64, bias=True)
    (fc2): Linear(in_features=64, out_features=8, bias=True)
    (fc3): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [5]:
import wandb
wandb.init()
epochs = 50
lr = 1e-3
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
data_len = len(data_loader)
wandb.watch(net, log_freq=100)

for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        try:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if (i + 1 )% 2 == 0:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}/{data_len:5d}] loss: {running_loss :.8f}')
                wandb.log({"loss": running_loss/5})
                running_loss = 0.0
        except:
            print("Asdf")
        
    torch.save(net.state_dict(), "model-v5.ckpt")
print('Finished Training')

[34m[1mwandb[0m: Currently logged in as: [33mgradai[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[1,     2/   10] loss: 1.34688342
[1,     4/   10] loss: 1.28042191
[1,     6/   10] loss: 1.23096341
[1,     8/   10] loss: 1.29556310
Asdf
[2,     2/   10] loss: 1.21664989
[2,     4/   10] loss: 1.40906978
[2,     6/   10] loss: 1.21670103
[2,     8/   10] loss: 1.06725046
Asdf
[3,     2/   10] loss: 0.95000488
[3,     4/   10] loss: 0.89417273
[3,     6/   10] loss: 0.77991596
[3,     8/   10] loss: 0.56906950
Asdf
[4,     2/   10] loss: 0.75509337
[4,     4/   10] loss: 0.88606182
[4,     6/   10] loss: 0.95910832
[4,     8/   10] loss: 0.59957427
Asdf
[5,     2/   10] loss: 0.63418394
[5,     4/   10] loss: 0.89843658
[5,     6/   10] loss: 0.95187071
[5,     8/   10] loss: 1.02480689
Asdf
[6,     2/   10] loss: 1.15355724
[6,     4/   10] loss: 0.59867199
[6,     6/   10] loss: 0.77441412
[6,     8/   10] loss: 0.82494447
Asdf
[7,     2/   10] loss: 0.91472188
[7,     4/   10] loss: 0.90381186
[7,     6/   10] loss: 0.62808198
[7,     8/   10] loss: 0.80861896
Asdf
[8,     2/   

In [5]:
net.load_state_dict(torch.load("model-v5.ckpt"))
net.eval()
correct = 0
total = 0
for i, data in enumerate(test_data_loader, 0):
        print(i)
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients

        # forward + backward + optimize
        outputs = net(inputs)
        outputs = torch.nn.functional.softmax(outputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total: .8f} %')

0
1
2
3
4
5
6
7
8
9
10
11
12


  outputs = torch.nn.functional.softmax(outputs)


13
14
15
16
17
18
19
20
21
22
23
24
25
26
Accuracy of the network on the 10000 test images:  100.00000000 %
