In [1]:
import cv2
import json
import numpy as np
import random
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

# set seed

In [2]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7fa74ef367d0>

# dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, json_path_list):
        self.annotations = []
        for json_path in json_path_list:
            with open(json_path, 'r') as f:
                annotations = json.load(f)
            self.annotations.extend(annotations['annotations'])
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        anno = self.annotations[idx]
        label = anno['label']
        label = np.eye(5)[label].astype(np.float32)
        
        # label smoothing
#         label = label * .9 + 0.1 / 5 
        
        kpts_ = np.array(anno['kpts']).reshape(-1, 2)
        kpts_ = preprocessing(kpts_)
        
        return label, kpts_.astype(np.float32)
    
    
def preprocessing(kpts):
    kpts_ = deepcopy(kpts)
    kpts_ = (kpts_ - kpts_.min(axis=0)) / (kpts_.max(axis=0) - kpts_.min(axis=0))
    return kpts_.flatten()

In [4]:
ds = CustomDataset([f'./data/annotation_{i}.json' for i in range(5)])
len(ds)

2505

In [5]:
ds[0]

(array([1., 0., 0., 0., 0.], dtype=float32),
 array([0.21594635, 1.        , 0.443099  , 0.9287557 , 0.6469579 ,
        0.7714337 , 0.8361737 , 0.6478518 , 1.        , 0.57608795,
        0.49768278, 0.4669195 , 0.509404  , 0.2612543 , 0.503457  ,
        0.12828362, 0.49311274, 0.        , 0.34171033, 0.47398508,
        0.27542248, 0.46475977, 0.304567  , 0.6385949 , 0.35247087,
        0.6362982 , 0.19208448, 0.5076398 , 0.12309471, 0.5348637 ,
        0.18273547, 0.69105816, 0.24475259, 0.6733748 , 0.05024889,
        0.5485357 , 0.        , 0.564812  , 0.06049323, 0.6850807 ,
        0.11008564, 0.7059657 ], dtype=float32))

In [6]:
total = len(ds)
num_train = int(total * .8)
ds_train, ds_valid = torch.utils.data.random_split(ds, [num_train, total - num_train], generator=g)
print(f'train: {len(ds_train)} / valid: {len(ds_valid)}')

train: 2004 / valid: 501


In [7]:
labels, cnt = np.unique([ds_valid.dataset.annotations[i]['label'] for i in ds_valid.indices], return_counts=True)
print({l:c for l, c in zip(labels, cnt)})

{0: 106, 1: 105, 2: 116, 3: 84, 4: 90}


# model

In [8]:
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)

In [9]:
output.shape

torch.Size([20, 100])

In [10]:
simple_model = nn.Sequential(
    nn.Linear(42, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(True,),
    nn.Linear(20, 10),
    nn.BatchNorm1d(10),
    nn.ReLU(True,),
    nn.Linear(10, 5),
    nn.Sigmoid(),
)

In [11]:
test = simple_model(torch.randn(2, 42))
print('shape:', test.size())

shape: torch.Size([2, 5])


# training

In [12]:
EPOCHS = 60
BS = 64

In [13]:
train_loader = DataLoader(ds_train, BS, shuffle=True, drop_last=True, num_workers=8)
valid_loader = DataLoader(ds_valid, BS, shuffle=False, drop_last=False, num_workers=8)

In [14]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(simple_model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-3, epochs=EPOCHS, steps_per_epoch=len(train_loader))

baseline: 0.872255489021956

label smoothing: 0.8343313373253493

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
simple_model = simple_model.to(device)

all_valid_pred = []
all_valid_label = []

for epoch in range(EPOCHS):
    simple_model.train()
    epoch_loss = 0.
    for label, input_ in tqdm(train_loader):
        label, input_ = label.to(device), input_.to(device)
        
        pred = simple_model(input_)
        loss = criterion(pred, label)
        
        epoch_loss += loss.detach().item() * label.size()[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f'[epoch {epoch}] loss: {epoch_loss / len(ds_train)}, lr: {scheduler.get_lr()}')
    
    simple_model.eval()
    
    correct = 0
    for label, input_ in tqdm(valid_loader):
        label, input_ = label.to(device), input_.to(device)
        
        with torch.no_grad():
            pred = simple_model(input_)
            cls_pred = pred.argmax(dim=-1)
            cls_true = label.argmax(dim=-1)
            
            correct += (cls_pred == cls_true).sum().item()
            
            if epoch == EPOCHS - 1:
                all_valid_pred.extend(cls_pred.detach().cpu().numpy().tolist())
                all_valid_label.extend(cls_true.detach().cpu().numpy().tolist())
            
    print(f'--val acc: {correct / len(ds_valid)}')

100%|██████████| 31/31 [00:00<00:00, 46.49it/s]


[epoch 0] loss: 0.7056712951964723, lr: [9.463683876508033e-05]


100%|██████████| 8/8 [00:00<00:00, 34.82it/s]


--val acc: 0.22355289421157684


100%|██████████| 31/31 [00:00<00:00, 102.53it/s]


[epoch 1] loss: 0.6867960817561654, lr: [0.00013810102787483098]


100%|██████████| 8/8 [00:00<00:00, 37.26it/s]


--val acc: 0.35728542914171657


100%|██████████| 31/31 [00:00<00:00, 110.92it/s]


[epoch 2] loss: 0.6646177954302577, lr: [0.00020906719581247307]


100%|██████████| 8/8 [00:00<00:00, 37.79it/s]


--val acc: 0.44510978043912175


100%|██████████| 31/31 [00:00<00:00, 109.32it/s]


[epoch 3] loss: 0.6343441256981885, lr: [0.0003053713418324246]


100%|██████████| 8/8 [00:00<00:00, 38.66it/s]


--val acc: 0.5728542914171657


100%|██████████| 31/31 [00:00<00:00, 108.44it/s]


[epoch 4] loss: 0.6006010002242829, lr: [0.00042407682373113865]


100%|██████████| 8/8 [00:00<00:00, 38.46it/s]


--val acc: 0.7065868263473054


100%|██████████| 31/31 [00:00<00:00, 100.87it/s]


[epoch 5] loss: 0.565437787069294, lr: [0.0005615639060938596]


100%|██████████| 8/8 [00:00<00:00, 36.23it/s]


--val acc: 0.7065868263473054


100%|██████████| 31/31 [00:00<00:00, 105.58it/s]


[epoch 6] loss: 0.5275637822712729, lr: [0.000713640138385537]


100%|██████████| 8/8 [00:00<00:00, 37.57it/s]


--val acc: 0.7285429141716567


100%|██████████| 31/31 [00:00<00:00, 108.73it/s]


[epoch 7] loss: 0.4856011463020614, lr: [0.0008756681970810574]


100%|██████████| 8/8 [00:00<00:00, 39.51it/s]


--val acc: 0.7345309381237525


100%|██████████| 31/31 [00:00<00:00, 110.71it/s]


[epoch 8] loss: 0.439767643363176, lr: [0.0010427072934917855]


100%|██████████| 8/8 [00:00<00:00, 40.47it/s]


--val acc: 0.7524950099800399


100%|██████████| 31/31 [00:00<00:00, 103.37it/s]


[epoch 9] loss: 0.38745060557138894, lr: [0.001209663835280994]


100%|██████████| 8/8 [00:00<00:00, 37.57it/s]


--val acc: 0.7724550898203593


100%|██████████| 31/31 [00:00<00:00, 102.12it/s]


[epoch 10] loss: 0.3349252270605274, lr: [0.0013714467474842179]


100%|██████████| 8/8 [00:00<00:00, 39.74it/s]


--val acc: 0.7485029940119761


100%|██████████| 31/31 [00:00<00:00, 112.42it/s]


[epoch 11] loss: 0.2848939686240312, lr: [0.001523122716766323]


100%|██████████| 8/8 [00:00<00:00, 39.48it/s]


--val acc: 0.8303393213572854


100%|██████████| 31/31 [00:00<00:00, 105.92it/s]


[epoch 12] loss: 0.240218991528966, lr: [0.0016600666249878945]


100%|██████████| 8/8 [00:00<00:00, 39.58it/s]


--val acc: 0.9600798403193613


100%|██████████| 31/31 [00:00<00:00, 113.09it/s]


[epoch 13] loss: 0.20039719212316942, lr: [0.0017781025848478931]


100%|██████████| 8/8 [00:00<00:00, 40.59it/s]


--val acc: 0.9520958083832335


100%|██████████| 31/31 [00:00<00:00, 108.98it/s]


[epoch 14] loss: 0.1619695555902051, lr: [0.001873631276944334]


100%|██████████| 8/8 [00:00<00:00, 39.12it/s]


--val acc: 0.9840319361277445


100%|██████████| 31/31 [00:00<00:00, 112.33it/s]


[epoch 15] loss: 0.13108008826326228, lr: [0.0019437397053112704]


100%|██████████| 8/8 [00:00<00:00, 38.45it/s]


--val acc: 0.9101796407185628


100%|██████████| 31/31 [00:00<00:00, 106.37it/s]


[epoch 16] loss: 0.10249905267399466, lr: [0.0019862900246110362]


100%|██████████| 8/8 [00:00<00:00, 38.38it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 110.61it/s]


[epoch 17] loss: 0.08248029378597846, lr: [0.0019999970889756826]


100%|██████████| 8/8 [00:00<00:00, 40.62it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 108.13it/s]


[epoch 18] loss: 0.06701416074634788, lr: [0.0019970205903154987]


100%|██████████| 8/8 [00:00<00:00, 39.57it/s]


--val acc: 0.9740518962075848


100%|██████████| 31/31 [00:00<00:00, 104.82it/s]


[epoch 19] loss: 0.05924954826008536, lr: [0.001988468370454814]


100%|██████████| 8/8 [00:00<00:00, 41.82it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 107.10it/s]


[epoch 20] loss: 0.0488782943841702, lr: [0.0019743882568761923]


100%|██████████| 8/8 [00:00<00:00, 38.69it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 115.16it/s]


[epoch 21] loss: 0.043557777257260684, lr: [0.0019548589912861884]


100%|██████████| 8/8 [00:00<00:00, 39.46it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 106.62it/s]


[epoch 22] loss: 0.03556790394697361, lr: [0.0019299897892597869]


100%|██████████| 8/8 [00:00<00:00, 42.32it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 97.85it/s]


[epoch 23] loss: 0.0326500421988512, lr: [0.0018999197294626046]


100%|██████████| 8/8 [00:00<00:00, 40.44it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 102.82it/s]


[epoch 24] loss: 0.0301766108253045, lr: [0.0018648169758665749]


100%|██████████| 8/8 [00:00<00:00, 38.95it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 109.60it/s]


[epoch 25] loss: 0.026675023123651682, lr: [0.001824877837308805]


100%|██████████| 8/8 [00:00<00:00, 41.08it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 111.62it/s]


[epoch 26] loss: 0.024678267285733402, lr: [0.0017803256696529279]


100%|██████████| 8/8 [00:00<00:00, 42.15it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 108.33it/s]


[epoch 27] loss: 0.020895775861130977, lr: [0.0017314096266925114]


100%|██████████| 8/8 [00:00<00:00, 39.71it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 108.03it/s]


[epoch 28] loss: 0.02094625352980372, lr: [0.0016784032667819782]


100%|██████████| 8/8 [00:00<00:00, 40.78it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.52it/s]


[epoch 29] loss: 0.01849262312262834, lr: [0.0016216030229873227]


100%|██████████| 8/8 [00:00<00:00, 40.71it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 113.15it/s]


[epoch 30] loss: 0.015054369014418291, lr: [0.0015613265453121616]


100%|██████████| 8/8 [00:00<00:00, 41.94it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 107.11it/s]


[epoch 31] loss: 0.013047389879436074, lr: [0.0014979109242700627]


100%|██████████| 8/8 [00:00<00:00, 41.86it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 101.54it/s]


[epoch 32] loss: 0.01501890799837436, lr: [0.0014317108057376557]


100%|██████████| 8/8 [00:00<00:00, 42.47it/s]


--val acc: 0.9860279441117764


100%|██████████| 31/31 [00:00<00:00, 108.55it/s]


[epoch 33] loss: 0.016018921327210233, lr: [0.0013630964076310345]


100%|██████████| 8/8 [00:00<00:00, 40.21it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 112.41it/s]


[epoch 34] loss: 0.013742260798484741, lr: [0.001292451449496993]


100%|██████████| 8/8 [00:00<00:00, 40.36it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 104.14it/s]


[epoch 35] loss: 0.012339821921851107, lr: [0.0012201710065976711]


100%|██████████| 8/8 [00:00<00:00, 39.82it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.75it/s]


[epoch 36] loss: 0.01186272354718454, lr: [0.0011466593004894302]


100%|██████████| 8/8 [00:00<00:00, 39.78it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 109.15it/s]


[epoch 37] loss: 0.010275754891707749, lr: [0.0010723274384519426]


100%|██████████| 8/8 [00:00<00:00, 39.41it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 108.10it/s]


[epoch 38] loss: 0.010964557930024084, lr: [0.0009975911144095228]


100%|██████████| 8/8 [00:00<00:00, 40.50it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.21it/s]


[epoch 39] loss: 0.01097549600693994, lr: [0.0009228682842020823]


100%|██████████| 8/8 [00:00<00:00, 38.96it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.10it/s]


[epoch 40] loss: 0.00901201963722111, lr: [0.0008485768282065336]


100%|██████████| 8/8 [00:00<00:00, 39.60it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 102.07it/s]


[epoch 41] loss: 0.00881183663468637, lr: [0.000775132214380214]


100%|██████████| 8/8 [00:00<00:00, 40.38it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 108.19it/s]


[epoch 42] loss: 0.008491786074138688, lr: [0.0007029451747955407]


100%|██████████| 8/8 [00:00<00:00, 41.12it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 101.60it/s]


[epoch 43] loss: 0.008955258139950073, lr: [0.0006324194086596515]


100%|██████████| 8/8 [00:00<00:00, 42.33it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.19it/s]


[epoch 44] loss: 0.008516111535702399, lr: [0.0005639493246646838]


100%|██████████| 8/8 [00:00<00:00, 38.62it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 108.49it/s]


[epoch 45] loss: 0.008188486114114582, lr: [0.0004979178352943804]


100%|██████████| 8/8 [00:00<00:00, 40.16it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.41it/s]


[epoch 46] loss: 0.008031576202658122, lr: [0.0004346942154221573]


100%|██████████| 8/8 [00:00<00:00, 41.54it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.67it/s]


[epoch 47] loss: 0.007773446168252332, lr: [0.0003746320371762205]


100%|██████████| 8/8 [00:00<00:00, 39.90it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.63it/s]


[epoch 48] loss: 0.007679773453704849, lr: [0.00031806719262080115]


100%|██████████| 8/8 [00:00<00:00, 40.14it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 105.78it/s]


[epoch 49] loss: 0.007412609270589794, lr: [0.0002653160153114834]


100%|██████████| 8/8 [00:00<00:00, 39.32it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 111.70it/s]


[epoch 50] loss: 0.007335054333100538, lr: [0.00021667351122964338]


100%|██████████| 8/8 [00:00<00:00, 41.43it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.93it/s]


[epoch 51] loss: 0.0065137056741647855, lr: [0.00017241170898933824]


100%|██████████| 8/8 [00:00<00:00, 40.08it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.44it/s]


[epoch 52] loss: 0.006849295112068306, lr: [0.00013277813854294793]


100%|██████████| 8/8 [00:00<00:00, 40.80it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 106.51it/s]


[epoch 53] loss: 0.007761714433481593, lr: [9.799444689327728e-05]


100%|██████████| 8/8 [00:00<00:00, 40.25it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 111.74it/s]


[epoch 54] loss: 0.007002502099839513, lr: [6.82551585536054e-05]


100%|██████████| 8/8 [00:00<00:00, 40.81it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 113.58it/s]


[epoch 55] loss: 0.007269744164572505, lr: [4.372658768770261e-05]


100%|██████████| 8/8 [00:00<00:00, 39.40it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.18it/s]


[epoch 56] loss: 0.0066706498344977225, lr: [2.454590801356266e-05]


100%|██████████| 8/8 [00:00<00:00, 40.46it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 108.34it/s]


[epoch 57] loss: 0.006940548945805746, lr: [1.0820385672328905e-05]


100%|██████████| 8/8 [00:00<00:00, 40.49it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 110.19it/s]


[epoch 58] loss: 0.0062647277485586685, lr: [2.6267793525220675e-06]


100%|██████████| 8/8 [00:00<00:00, 42.23it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 104.25it/s]


[epoch 59] loss: 0.006663305807613327, lr: [1.0911024317473058e-08]


100%|██████████| 8/8 [00:00<00:00, 41.31it/s]

--val acc: 1.0





In [16]:
from sklearn.metrics import confusion_matrix

In [17]:
confusion_matrix(all_valid_label, all_valid_pred)

array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])

### base

acc: 0.9620758483033932


```python
array([[100,   4,   2,   0,   0],
       [  0, 102,   3,   0,   0],
       [  0,  10, 106,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### label smoothing / 0.1

acc: 0.9481037924151696

```python
array([[101,   4,   1,   0,   0],
       [  1,  99,   5,   0,   0],
       [  0,  15, 101,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### batchnorm

acc: 1.0

```python
array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

In [18]:
torch.onnx.export(simple_model, input_[:1], "simple_model.onnx", verbose=True, input_names=['input'], output_names=['output'])

graph(%input : Float(1, 42, strides=[42, 1], requires_grad=0, device=cuda:0),
      %0.weight : Float(20, 42, strides=[42, 1], requires_grad=1, device=cuda:0),
      %0.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.weight : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.running_mean : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %1.running_var : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %3.weight : Float(10, 20, strides=[20, 1], requires_grad=1, device=cuda:0),
      %3.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.weight : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.running_mean : Float(10, strides=[1], requires_grad=0, device=cuda:0),
      %4.running_var : Float(10, strides=[1], requires_grad=0, device=cuda:0),
    

In [19]:
import onnxruntime as ort
from tqdm import tqdm
import numpy as np

ort_session = ort.InferenceSession("simple_model.onnx")

for _ in tqdm(range(100)):
    outputs = ort_session.run(
        ['output'],
        {"input": np.random.randn(1, 42).astype(np.float32)},
    )
    

100%|██████████| 100/100 [00:00<00:00, 34617.89it/s]
