In [1]:
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

from sklearn.metrics import accuracy_score

import PIL

In [19]:
IMAGE_SIZE = 100

EPOCHES = 20
BATCH_SIZE = 512
LEARNING_RATE = 0.0001
FILTER = 3

PATH_TRAIN = 'Train_data'
PATH_TESTS = 'Tests_data'

In [21]:
class CNN(nn.Module):
    def __init__(self, filter, kernel_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, filter * 4, kernel_size=kernel_size, padding=(kernel_size // 2)),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(filter * 4, filter * 8, kernel_size=kernel_size, padding=(kernel_size // 2)),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(filter * 8, filter * 16, kernel_size=kernel_size, padding=(kernel_size // 2)),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(filter * 16, filter * 32, kernel_size=kernel_size, padding=(kernel_size // 2)),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Flatten(), 
            nn.Linear(1152, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 131))
        
    def forward(self, xb):
        return self.network(xb)

In [22]:
def random_split_ratio(dataset, test_size=.2, random_state=None):
    second_part = int(len(dataset) * test_size)
    first_part = int(len(dataset)) - second_part

    if random_state:
        first_split, second_split = random_split(dataset, lengths=[first_part, second_part], 
                                                 generator=torch.Generator().manual_seed(random_state))
    else:
        first_split, second_split = random_split(dataset, lengths=[first_part, second_part])

    return first_split, second_split

In [23]:
def verify_image(fp):
    try:
        PIL.Image.open(fp).verify()
        return True
    except:
        return False

In [24]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f'CNN training on {device} type')

mps


In [25]:
model = CNN(8, 3)

In [28]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=120),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor()
])

In [29]:
dataset = ImageFolder(root=PATH_TRAIN, transform=transform, is_valid_file=verify_image)

train_dataset, valid_dataset = random_split_ratio(dataset, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)


tests_dataset = ImageFolder(root=PATH_TESTS, transform=transform, is_valid_file=verify_image)
tests_loader = DataLoader(tests_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f'Train data length: {len(train_loader.dataset)}\n'
      f'Valid data length: {len(valid_loader.dataset)}\n'
      f'Tests data length: {len(tests_loader.dataset)}')


Train data length: 54154
Valid data length: 13538
Tests data length: 22688


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
def train_network(model, data_loader, device, ):
    
    running_loss = 0
    
    bar = tqdm(data_loader, total=len(train_loader.dataset)/train_loader.batch_size)
    
    for x_batch, y_batch in bar:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward() 
        optimizer.step()
        optimizer.zero_grad()

        running_loss += (loss.item() * x_batch.size(0))

    return model, running_loss

In [12]:
for epoch in range(EPOCHES):

    model.to(device)
    
    train_running_loss = 0
    valid_running_loss = 0
    
    bar = tqdm(train_loader, total=len(train_loader.dataset)/train_loader.batch_size)
    
    for x_batch, y_batch in bar:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward() 
        optimizer.step()
        optimizer.zero_grad()

        train_running_loss += (loss.item() * x_batch.size(0))
    
    with torch.no_grad():
        
        model.eval()

        valid_predict = []  
        valid_targets = []  
        
        vbar = tqdm(valid_loader, total=len(valid_loader.dataset)/valid_loader.batch_size)
        
        for x_batch, y_batch in vbar:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            
            valid_running_loss += (loss.item() * x_batch.size(0))

            valid_predict.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            valid_targets.extend(y_batch.cpu().numpy())
            
        model.train()
        
    validation_accuracy = accuracy_score(valid_targets, valid_predict)

    print(f'Epoch: {epoch + 1} / {EPOCHES}\n'
          f'Avarage training loss: {(train_running_loss / len(train_loader.dataset)):.3f}\n'
          f'Avarage validation loss: {(valid_running_loss / len(valid_loader.dataset)):.3f}\n' 
          f'Validation accuracy: {(accuracy_score(valid_targets, valid_predict)):.3f}')

  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:59<00:00,  1.79it/s]
27it [00:12,  2.24it/s]                                                         


Epoch: 1 / 20
Avarage training loss: 0.701
Avarage validation loss: 0.577
Validation accuracy: 0.822


  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:58<00:00,  1.80it/s]
27it [00:11,  2.25it/s]                                                         


Epoch: 2 / 20
Avarage training loss: 0.670
Avarage validation loss: 0.517
Validation accuracy: 0.844


  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:59<00:00,  1.79it/s]
27it [00:11,  2.26it/s]                                                         


Epoch: 3 / 20
Avarage training loss: 0.643
Avarage validation loss: 0.488
Validation accuracy: 0.851


  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:58<00:00,  1.80it/s]
27it [00:12,  2.23it/s]                                                         


Epoch: 4 / 20
Avarage training loss: 0.612
Avarage validation loss: 0.481
Validation accuracy: 0.852


  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:59<00:00,  1.79it/s]
27it [00:12,  2.25it/s]                                                         


Epoch: 5 / 20
Avarage training loss: 0.580
Avarage validation loss: 0.447
Validation accuracy: 0.864


  full_bar = Bar(frac,
100%|████████████████████████████████| 106/105.76953125 [00:58<00:00,  1.80it/s]
27it [00:11,  2.26it/s]                                                         


Epoch: 6 / 20
Avarage training loss: 0.548
Avarage validation loss: 0.419
Validation accuracy: 0.871


 35%|███████████▌                     | 37/105.76953125 [00:20<00:39,  1.76it/s]


KeyboardInterrupt: 