In [1]:
import torch
from torch import nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils, datasets, models

from PIL import Image
import pandas as pd
from datetime import datetime

from data_preporation import build_dataset

#### Build dataset

In [2]:
dataset = build_dataset("idao_dataset/train/")

# 80% is train, 20% is test
train, test = random_split(
    dataset,
    [8*len(dataset)//10, len(dataset)-8*len(dataset)//10]
    # torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train, batch_size=10, shuffle=True, num_workers=16)
test_loader = DataLoader(test, batch_size=10, shuffle=True, num_workers=16)

In [3]:
print(
    f"""\
dataset: {len(dataset)}
train: {len(train)}
test: {len(test)}
"""
)

dataset: 13404
train: 10723
test: 2681



#### Build the model

In [4]:
class Meow(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.convolution1 = nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1, out_channels=8,
                kernel_size=3, stride=1,
                padding=0, bias=False
            ),
            
            nn.BatchNorm2d(8),
            nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        
        self.convolution2 = nn.Sequential(
            torch.nn.Conv2d(
                in_channels=8, out_channels=16,
                kernel_size=3, stride=1,
                padding=0, bias=False
            ),
            nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        
        self.convolution3 = nn.Sequential(
            torch.nn.Conv2d(
                in_channels=16, out_channels=32,
                kernel_size=3, stride=1,
                padding=0, bias=False
            ),
            nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        
        self.convolution4 = nn.Sequential(
            torch.nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, stride=1,
                padding=0, bias=False
            ),
            nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3)
        )
        
        self.convolution5 = nn.Sequential(
            torch.nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, stride=1,
                padding=0, bias=False
            ),
            nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3),
            nn.BatchNorm2d(128)
        )
        
        
        self.ff = nn.Sequential(
            nn.Linear(64*6*6, 64),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(64, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(1024, 1)
        )
        
        
    def forward(self, x):
        x = self.convolution1(x)
        x = self.convolution2(x)
        x = self.convolution3(x)
        x = self.convolution4(x)
        
        x = x.view(-1, 64*6*6)
        x = self.ff(x)
        x = torch.sigmoid(x)
        
        return x
        
    
    def predict(self):
        pass

In [5]:
meow = Meow()
if torch.cuda.is_available():
    meow.cuda()

In [6]:
loss_function = nn.BCELoss()
optimizer = optim.Adam(meow.parameters(), lr=1e-3)

In [7]:
train_losses, test_losses = [], []
epoches = 1
printing_gap = 100

In [9]:
crepr = lambda tsr: " ".join([str(round(i[0], 5)) for i in tsr.tolist()])

In [10]:
%%time
morning = datetime.now()
for epoch in range(epoches):  # loop over the dataset multiple times
    print(f"Epoch: {epoch}")
    running_loss = 0
    meow.train()
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        images, labels = data
        labels = labels.unsqueeze(1)
        
        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()

        outputs = meow(images) # forward
        loss = loss_function(outputs.float(), labels.float())
        loss.backward() # backward
        optimizer.step() # optimize

        # print statistics
        running_loss += loss.item()
        if i % printing_gap == 0 and i != 0:
            print(f"sample: {i}, loss: {running_loss / (i+1e-3)}, now: {str(datetime.now()-morning)}, last: {crepr(outputs)}")

# Validate

rights = torch.tensor(0)

with torch.no_grad():
    for i, data in enumerate(test_loader, 0):
        images, labels = data
        labels = labels.unsqueeze(1)
        
        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()
        outputs = meow(images) # forward
        
        outputs = outputs > 0.5
        outputs = outputs.float()
        
        rights += (outputs == labels).int().sum()
        
        if i % printing_gap == 0:
            print(f"sample: {i}, rights: {rights}, now: {str(datetime.now()-morning)}")

print("Accuracy: ", round(int(rights) / (len(test_loader)*10), 5))

Epoch: 0
sample: 100, loss: 0.5596102313985062, now: 0:00:43.602865, last: 0.23742 0.18219 0.62197 0.59299 0.63011 0.62747 0.50099 0.61645 0.51993 0.56204
sample: 200, loss: 0.3865673248446338, now: 0:01:23.812954, last: 0.01269 0.9509 0.0 0.99755 0.0 0.98713 2e-05 0.00041 0.0 0.0
sample: 300, loss: 0.2821921862087341, now: 0:02:03.969084, last: 0.98869 0.99544 0.99955 0.0 0.9853 0.9976 0.0 0.99913 0.99825 0.00214
sample: 400, loss: 0.22863476563032392, now: 0:02:45.259113, last: 0.97222 0.98271 0.0 0.99836 0.0 0.0 0.99256 0.0 0.94916 0.0
sample: 500, loss: 0.1935070155001434, now: 0:03:25.602689, last: 0.0 0.0 0.0 0.0 0.98536 0.0 0.0 0.96402 0.0 0.0
sample: 600, loss: 0.18025300128012822, now: 0:04:06.760652, last: 0.0 0.00124 0.96423 0.98413 4e-05 0.9891 0.99155 0.0 0.99404 0.0
sample: 700, loss: 0.1598261215233622, now: 0:04:48.014422, last: 0.01132 0.99278 0.0 0.0 0.98634 0.0 0.9998 0.99499 0.0 0.9878
sample: 800, loss: 0.14408442061284077, now: 0:05:28.690889, last: 0.98576 0.0044