In [1]:
#import relavant libraries / dependencies
import numpy as np # 
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
import copy
from torch import optim
import tqdm
from PIL import Image

class CatDogDataset(Dataset):
  def __init__(self, directory, transform = None):
    super().__init__()
    self.directory = directory
    self.transform = transform
    self.files = os.listdir(directory)
  def __len__(self):
    return len(self.files)
  def __getitem__(self, idx):
    filename = self.files[idx]
    if 'dog' in filename.lower():
      label = 1
    else:
      label = 0
    img = Image.open(os.path.join(self.directory, filename))
    if self.transform is not None:
      img = self.transform(img)
    return img, label
    

train_transforms = transforms.Compose([transforms.Resize(64),
                                       transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(48),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.Resize(64),
                                      transforms.CenterCrop(48),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

data_dir = 'splitted'
# Pass transforms in here, then run the next cell to see how the transforms look
train_data = CatDogDataset(data_dir + '/train', transform=train_transforms)
test_data = CatDogDataset(data_dir + '/valid', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding = 1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding = 1)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.pool3 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(6 * 6 * 128, 256)
        self.fc2 = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        x = self.pool1(self.conv1(inputs))
        x = self.pool2(self.conv2(x))
        x = self.pool3(self.conv3(x))

        x = x.view(-1, 6 * 6 * 128)
        x = self.fc1(x)
        x = self.fc2(x)

        x = self.sigmoid(x)#
        return x


model = SimpleModel()

optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.BCELoss()

epochs = 2
itr = 1
p_itr = 500
model.train()
total_loss = 0
loss_list = []
acc_list = []
for epoch in range(epochs):
    for samples, labels in trainloader:
        optimizer.zero_grad()
        output = model(samples)
        labels = labels.float().view(-1, 1)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        
        if itr%p_itr == 0:
            pred = torch.round(output) 
            label = labels.long()
            correct = pred.eq(labels)
            acc = torch.mean(correct.float())
            print('[Epoch {}/{}] Iteration {} -> Train Loss: {:.4f}, Accuracy: {:.3f}'.format(epoch+1, epochs, itr, total_loss/p_itr, acc))
            loss_list.append(total_loss/p_itr)
            acc_list.append(acc)
            total_loss = 0
            
        itr += 1

plt.plot(loss_list, label='loss')
plt.plot(acc_list, label='accuracy')
plt.legend()
plt.title('training loss and accuracy')
plt.show()
        

FileNotFoundError: [Errno 2] No such file or directory: 'splitted/train'