In [1]:
import torch
from torch import nn
import wandb
from typing import Tuple
from tqdm import tqdm


PRJ_NAME = 'intel_image_cls'
kls = [
    "buildings",
    "forest",
    "glacier",
    "mountain",
    "sea",
    "street"
]


def accuracy(y_hat: torch.Tensor, y: torch.Tensor):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    return float((y_hat.type(y.dtype) == y).sum())


def train_epoch(net, train_data, loss, updater, device):
    net.train()
    acc_sum, lsum, numel = 0, 0, 0
    for X, y in tqdm(train_data):
        X = X.to(device)
        y = y.to(device)
        updater.zero_grad()
        y_hat = net(X)
        l = loss(y_hat, y)
        l.mean().backward()
        updater.step()
        lsum += l.sum().detach()
        numel += y.numel()
        acc_sum += accuracy(y_hat, y)

    return (lsum / numel), (acc_sum / numel)

def init_weight(m):
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        print(f'initialize weight : {m}')

def train(net: nn.Module, train_data, val_data, device, config):
    with wandb.init(project=PRJ_NAME, job_type='training') as run:
        lr, num_epochs = config['lr'], config['epochs']
        net = net.to(device)
        net.apply(init_weight)
        print(f'network : {net.__class__.__name__}')
        
        loss = nn.CrossEntropyLoss()
        updater = torch.optim.SGD(net.parameters(), lr)
        loss = loss.to(device)
        for epoch in range(num_epochs):
            tloss, tacc = train_epoch(net, train_data, loss, updater, device)
            val_loss, val_acc = evaluate(net, val_data, loss, device)
            metric = {
                'train_loss': tloss,
                'train_accuracy': tacc,
                'validation_loss': val_loss,
                'validation_accuracy': val_acc
            }
            wandb.log(metric)
            print(f'Ep {epoch} : {metric}')

        torch.save(net.state_dict(), f'{net.__class__.__name__}.pt')
        trained_model = wandb.Artifact(f'{net.__class__.__name__}', type='model', description=f'{str(net)}')
        trained_model.add_file(f'{net.__class__.__name__}.pt')
        run.log_artifact(trained_model)

    
def evaluate(net, val_data, loss, device):
    net.eval()
    with torch.no_grad():
        acc_sum, lsum, numel = 0, 0, 0
        for X, y in val_data:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            lsum += loss(y_hat, y).sum()
            acc_sum += accuracy(y_hat, y)
            numel += y.numel()
    return (lsum / numel), (acc_sum / numel)


def try_gpu():
    for i in range(torch.cuda.device_count()):
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')


from torchvision.transforms import functional as F
from PIL.Image import Image

class Padding(object):
    
    def __init__(self, size: Tuple):
        super(Padding).__init__()
        self.size = size

    def __call__(self, X: Image):
        ## torch compatible image fromat assumed (N,C,H,W)
        h, w = self.size
        iw, ih = X.size
        if ih == h and iw == w:
            return X
        else:
            tp = (h - ih) // 2
            bp = h - tp - ih
            lp = (w - iw) // 2
            rp = w - lp - iw
            return F.pad(X, [lp, tp, rp, bp])


## Define Model
> Mainly based on idea of Inception block from googLeNet while adopting the concept of DenseNet, which has showed better resilience vanishing gradient

In [2]:
## Simple network test

import torch
from torch import nn

class InceptionBlock(nn.Module):
    
    def __init__(self, input_channels, output_channels, kernels=[3, 5]):
        super().__init__()
        self.kerenls = kernels
        for idx, k in enumerate(kernels):
            if k % 2 == 0:
                raise 'kernel size should be odd number'
            self._modules[f'Conv Ch.{idx}_K_{k}X{k}'] = nn.Conv2d(input_channels, output_channels, kernel_size=k, padding=(k // 2))
        self._modules[f'MxPool_Kx3'] = nn.Sequential(
            nn.MaxPool2d(3, padding=1, stride=1), 
            nn.Conv2d(input_channels, output_channels, kernel_size=1))
        self._modules[f'Conv1X1'] = nn.Conv2d(input_channels, output_channels, kernel_size=1)
    
    def forward(self, X):
        Y = []
        for idx, k in enumerate(self.kerenls):
            Y.append(self._modules[f'Conv Ch.{idx}_K_{k}X{k}'](X))
        Y.append(self._modules[f'MxPool_Kx3'](X))
        Y.append(self._modules[f'Conv1X1'](X))
        Y.append(X)
        return torch.cat(Y,axis=1)
            

class DenseInception_V0(nn.Sequential):
    def __init__(self):
        super().__init__(InceptionBlock(3, 8),
            nn.BatchNorm2d(8 * 4 + 3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),       ## (n, 35, 75,75)
            nn.Conv2d(8 * 4 + 3, 8 * 2, kernel_size=1),  ## (n, 16, 75, 75)
            InceptionBlock(8 * 2, 8 * 2),                ## (n, 80, 75, 75)  
            nn.BatchNorm2d(8 * 2 * 5),                       
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=3),       ## (n, 80, 25, 25)
            nn.Conv2d(8 * 2 * 5, 8 * 5, kernel_size=1),  ## (n, 40, 25, 25)
            InceptionBlock(8 * 5, 8 * 5),                ## (n, 200, 25, 25)
            nn.BatchNorm2d(8 * 5 * 5),                   
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=3, padding=1),      ## (n, 200, 9, 9)
            nn.Flatten(),
            nn.Linear(8 * 5 * 5 * 9 * 9, 2048),
            nn.ReLU(),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, len(kls)))

In [4]:
# training 
from torch import nn
import torch
from torchvision import datasets
from torchvision import transforms as T
from torch.utils import data
from kaggle_secrets import UserSecretsClient

INPUT_DATA_BASE = '/kaggle/input/intel-image-classification/'
project_name = 'intel_image_cls'
user_secrets = UserSecretsClient()

wandb_api = user_secrets.get_secret("wandb")

wandb.login(key=wandb_api)
wandb.init(project=project_name, entity="dwidlee")

lr, epochs, batch = 0.001, 10, 8

def run_train():
    with wandb.init(project=PRJ_NAME, job_type='training') as run:
        device = try_gpu()
        config = wandb.config
        print(config)

        train_preprocess = torch.jit.script_if_tracing(T.Compose([
            Padding((150, 150)),
            T.RandomAffine(degrees=(-10,10), translate=(0.1, 0.1), scale=(0.9,1.1)),
            T.RandomHorizontalFlip(),
            T.ToTensor()
        ]))

        val_preprocess = torch.jit.script_if_tracing(T.Compose([
            Padding((150, 150)),
            T.ToTensor()
        ]))

        validate_set = datasets.ImageFolder(root='data/seg_test', transform=val_preprocess)
        train_set = datasets.ImageFolder(root='data/seg_train', transform=train_preprocess)
        train_data = data.DataLoader(train_set, batch_size=config['batch'], shuffle=True, num_workers=5)
        val_data = data.DataLoader(validate_set, batch_size=config['batch'], shuffle=True, num_workers=5)
        net = DenseInception_V0()
        train(net, train_data, val_data, device, config)
        torch.save(net.state_dict(), f'{net.__class__.__name__}.pt')
        trained_model = wandb.Artifact(f'{net.__class__.__name__}', type='model', description=f'{str(net)}')
        trained_model.add_file(f'{net.__class__.__name__}.pt')
        run.log_artifact(trained_model)


sweep_id = '3gfg64vc'
count = 5
wandb.agent(sweep_id, project=project_name, function=run_train, count=count)