In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
from data_utils import WoodDataset
from torch.utils.data import DataLoader
import torchvision.models as models
from torch import nn
import torch
from torch.optim import lr_scheduler
import torch.optim as optim

from datetime import datetime
import pandas as pd
from sklearn.metrics import classification_report

from os import path
from tqdm.notebook import tqdm
import joblib

In [3]:
train_val_test_split_params = {
    'test_size': 0.25,
    'valid_size': 0.25,
    'random_state': 42,
    'stratify': 'target'
}
train_dataset = WoodDataset(
    img_dir='data', 
    is_test=False, 
    task_type='classification', 
    dataset_role='train', 
    train_val_test_split_params=train_val_test_split_params
)
valid_dataset = WoodDataset(
    img_dir='data', 
    is_test=False, 
    task_type='classification', 
    dataset_role='valid', 
    train_val_test_split_params=train_val_test_split_params
)
test_dataset = WoodDataset(
    img_dir='data', 
    is_test=False, 
    task_type='classification', 
    dataset_role='test', 
    train_val_test_split_params=train_val_test_split_params
)
submission_dataset = WoodDataset(
    img_dir='data', 
    is_test=True, 
    task_type='classification'
)

In [4]:
batch_size = 32

train_preset = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 0, 'drop_last': True}
test_preset = {'batch_size': 1, 'shuffle': False, 'num_workers': 0, 'drop_last': False}

train_dataloader = DataLoader(train_dataset, **train_preset)
valid_dataloader = DataLoader(valid_dataset, **train_preset)
test_dataloader = DataLoader(test_dataset, **test_preset)
submission_dataloader = DataLoader(submission_dataset, **test_preset)

In [5]:
def layers_freeze(model):
    for name, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = False
        layers_freeze(child)
        
        
def layers_unfreeze(model):
    for name, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = True
        layers_unfreeze(child)

In [6]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f028410b650>

In [7]:
model = models.mobilenet_v2(pretrained=True)
layers_freeze(model)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 3)

In [8]:
lr = 1e-3

loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 2, gamma=0.1, last_epoch=-1)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epochs = 30

train_epoch_losses = []
valid_epoch_losses = []
es_counter = 0
for epoch in range(n_epochs):
    if epoch == 6:
        layers_unfreeze(model)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, 5, gamma=0.1, last_epoch=-1)
    with tqdm(train_dataloader, unit="batch") as tqdm_train_dataloader:
        model.train()
        train_loss = 0
        train_cnt = 0
        for batch in tqdm_train_dataloader:
            tqdm_train_dataloader.set_description(f"train Epoch {epoch}")
            model.zero_grad()
            images = batch['image'].to(device)
            targets = batch['target'].to(device)
            bath_shape = images.shape[0]
            train_cnt += bath_shape
            
            output = model(images)
            batch_loss = loss(output, targets)
            batch_loss.backward()
            optimizer.step()
            batch_loss_val = batch_loss.detach().item()
            train_loss += batch_loss_val
            tqdm_train_dataloader.set_postfix(
                batch_loss=batch_loss_val / bath_shape,
                epoch_loss=train_loss / train_cnt)
    with torch.no_grad():
        model.eval()
        with tqdm(valid_dataloader, unit="batch") as tqdm_valid_dataloader:
            valid_loss = 0
            valid_cnt = 0
            for batch in tqdm_valid_dataloader:
                tqdm_valid_dataloader.set_description(f"valid Epoch {epoch}")
                images = batch['image'].to(device)
                targets = batch['target'].to(device)
                bath_shape = images.shape[0]
                valid_cnt += bath_shape
            
                output = model(images)
                batch_loss = loss(output, targets)
                
                batch_loss_val = batch_loss.detach().item()
                valid_loss += batch_loss_val
                
                tqdm_valid_dataloader.set_postfix(
                batch_loss=batch_loss_val / bath_shape,
                epoch_loss=valid_loss / valid_cnt)
    train_epoch_losses.append(train_loss / train_cnt)
    valid_epoch_losses.append(valid_loss / valid_cnt)
    now_time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    torch.save(model.state_dict(),
               path.join(
                   'data',
                   'nn_chpt',
                   f'model_{epoch}_{now_time_str}_{round(train_epoch_losses[-1],5)}_{round(valid_epoch_losses[-1],5)}'))
    if (len(valid_epoch_losses) > 1) and (valid_epoch_losses[-1] >= min(valid_epoch_losses[:-1])):
        es_counter += 1
    elif (len(valid_epoch_losses) > 1) and (valid_epoch_losses[-1] < min(valid_epoch_losses[:-1])):
        es_counter = 0
    if es_counter == 5:
        break

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/8 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

In [23]:
with torch.no_grad():
    model.eval()
    with tqdm(test_dataloader, unit="batch") as tqdm_test_dataloader:
        valid_loss = 0
        valid_cnt = 0
        pred = []
        true = []
        img_path = []
        for batch in tqdm_test_dataloader:
            images = batch['image'].to(device)
            targets = batch['target'].to(device)
            output = model(images)
            pred.extend(torch.argmax(nn.functional.softmax(output, dim=1), dim=1).detach().numpy().tolist())
            true.extend(targets.detach().numpy().tolist())
            img_path.extend(batch['img_path'])
test_df = pd.DataFrame({'true': true, 'pred': pred, 'img_path': img_path})

  0%|          | 0/134 [00:00<?, ?batch/s]

In [25]:
print(classification_report(test_df.true, test_df.pred))

              precision    recall  f1-score   support

           0       0.95      0.92      0.94        64
           1       0.96      0.94      0.95        53
           2       0.75      0.88      0.81        17

    accuracy                           0.93       134
   macro avg       0.89      0.92      0.90       134
weighted avg       0.93      0.93      0.93       134



In [26]:
with torch.no_grad():
    model.eval()
    with tqdm(submission_dataloader, unit="batch") as tqdm_submission_dataloader:
        valid_loss = 0
        valid_cnt = 0
        pred = []
        img_path = []
        for batch in tqdm_submission_dataloader:
            images = batch['image'].to(device)
            output = model(images)
            pred.extend(torch.argmax(nn.functional.softmax(output, dim=1), dim=1).detach().numpy().tolist())
            img_path.extend(batch['img_path'])
test_df = pd.DataFrame({'pred': pred, 'img_path': img_path})

  0%|          | 0/249 [00:00<?, ?batch/s]

In [37]:
backward_mapping = {0: 1, 1: 3, 2: 0}

test_df['id'] = test_df.img_path.map(lambda x: path.split(x)[-1].split('.')[0])
test_df['class'] = test_df.pred.map(lambda x: backward_mapping[x])

In [40]:
test_df[['id', 'class']].to_csv(path.join('data', 'mn_cls_9.csv'), index=False)