In [1]:
import cv2
import json
import numpy as np
import random
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

# set seed

In [2]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7faa86c797f0>

# dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, json_path_list):
        self.annotations = []
        for json_path in json_path_list:
            with open(json_path, 'r') as f:
                annotations = json.load(f)
            self.annotations.extend(annotations['annotations'])
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        anno = self.annotations[idx]
        label = anno['label']
        label = np.eye(5)[label].astype(np.float32)
        
        # label smoothing
#         label = label * .9 + 0.1 / 5 
        
        kpts_ = np.array(anno['kpts']).reshape(-1, 2)
        kpts_ = preprocessing(kpts_)
        
        return label, kpts_.astype(np.float32)
    
    
def preprocessing(kpts):
    kpts_ = deepcopy(kpts)
    origin_point = kpts_[0]
    for i in range(21):
        kpts_[i] -= origin_point
    kpts_ /= kpts_.max(axis=0)
    return kpts_.flatten()

In [4]:
ds = CustomDataset([f'./data/annotation_{i}.json' for i in range(5)])
len(ds)

2505

In [5]:
ds[0]

(array([1., 0., 0., 0., 0.], dtype=float32),
 array([0.        , 0.        , 0.73772126, 1.        , 0.8337309 ,
        0.8687269 , 0.9228442 , 0.7656073 , 1.        , 0.70572585,
        0.7634281 , 0.61463314, 0.7689483 , 0.44302133, 0.7661475 ,
        0.33206752, 0.76127577, 0.22502469, 0.6899711 , 0.6205288 ,
        0.65875214, 0.612831  , 0.6724781 , 0.75788313, 0.6950389 ,
        0.75596666, 0.61950314, 0.6486111 , 0.58701164, 0.67132735,
        0.61510015, 0.80165964, 0.6443078 , 0.7869042 , 0.5527041 ,
        0.6827356 , 0.52903885, 0.6963169 , 0.5575288 , 0.7966719 ,
        0.5808849 , 0.81409883], dtype=float32))

In [6]:
total = len(ds)
num_train = int(total * .8)
ds_train, ds_valid = torch.utils.data.random_split(ds, [num_train, total - num_train], generator=g)
print(f'train: {len(ds_train)} / valid: {len(ds_valid)}')

train: 2004 / valid: 501


In [7]:
labels, cnt = np.unique([ds_valid.dataset.annotations[i]['label'] for i in ds_valid.indices], return_counts=True)
print({l:c for l, c in zip(labels, cnt)})

{0: 106, 1: 105, 2: 116, 3: 84, 4: 90}


# model

In [10]:
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)

In [11]:
output.shape

torch.Size([20, 100])

In [8]:
simple_model = nn.Sequential(
    nn.Linear(42, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(True,),
    nn.Linear(20, 10),
    nn.BatchNorm1d(10),
    nn.ReLU(True,),
    nn.Linear(10, 5),
    nn.Sigmoid(),
)

In [13]:
test = simple_model(torch.randn(2, 42))
print('shape:', test.size())

shape: torch.Size([2, 5])


# training

In [14]:
EPOCHS = 60
BS = 64

In [15]:
train_loader = DataLoader(ds_train, BS, shuffle=True, drop_last=True, num_workers=8)
valid_loader = DataLoader(ds_valid, BS, shuffle=False, drop_last=False, num_workers=8)

In [16]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(simple_model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-3, epochs=EPOCHS, steps_per_epoch=len(train_loader))

baseline: 0.872255489021956

label smoothing: 0.8343313373253493

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
simple_model = simple_model.to(device)

all_valid_pred = []
all_valid_label = []

for epoch in range(EPOCHS):
    simple_model.train()
    epoch_loss = 0.
    for label, input_ in tqdm(train_loader):
        label, input_ = label.to(device), input_.to(device)
        
        pred = simple_model(input_)
        loss = criterion(pred, label)
        
        epoch_loss += loss.detach().item() * label.size()[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f'[epoch {epoch}] loss: {epoch_loss / len(ds_train)}, lr: {scheduler.get_lr()}')
    
    simple_model.eval()
    
    correct = 0
    for label, input_ in tqdm(valid_loader):
        label, input_ = label.to(device), input_.to(device)
        
        with torch.no_grad():
            pred = simple_model(input_)
            cls_pred = pred.argmax(dim=-1)
            cls_true = label.argmax(dim=-1)
            
            correct += (cls_pred == cls_true).sum().item()
            
            if epoch == EPOCHS - 1:
                all_valid_pred.extend(cls_pred.detach().cpu().numpy().tolist())
                all_valid_label.extend(cls_true.detach().cpu().numpy().tolist())
            
    print(f'--val acc: {correct / len(ds_valid)}')

100%|██████████| 31/31 [00:00<00:00, 53.86it/s]


[epoch 0] loss: 0.6424797594904186, lr: [9.463683876508033e-05]


100%|██████████| 8/8 [00:00<00:00, 43.03it/s]


--val acc: 0.19161676646706588


100%|██████████| 31/31 [00:00<00:00, 104.43it/s]


[epoch 1] loss: 0.6265416744940295, lr: [0.00013810102787483098]


100%|██████████| 8/8 [00:00<00:00, 46.14it/s]


--val acc: 0.09580838323353294


100%|██████████| 31/31 [00:00<00:00, 108.69it/s]


[epoch 2] loss: 0.6085849731506225, lr: [0.00020906719581247307]


100%|██████████| 8/8 [00:00<00:00, 47.49it/s]


--val acc: 0.1377245508982036


100%|██████████| 31/31 [00:00<00:00, 114.41it/s]


[epoch 3] loss: 0.5846265099957556, lr: [0.0003053713418324246]


100%|██████████| 8/8 [00:00<00:00, 44.21it/s]


--val acc: 0.27944111776447106


100%|██████████| 31/31 [00:00<00:00, 109.71it/s]


[epoch 4] loss: 0.5534284557411057, lr: [0.00042407682373113865]


100%|██████████| 8/8 [00:00<00:00, 47.04it/s]


--val acc: 0.3972055888223553


100%|██████████| 31/31 [00:00<00:00, 115.49it/s]


[epoch 5] loss: 0.5207606665864438, lr: [0.0005615639060938596]


100%|██████████| 8/8 [00:00<00:00, 43.05it/s]


--val acc: 0.5209580838323353


100%|██████████| 31/31 [00:00<00:00, 119.17it/s]


[epoch 6] loss: 0.48645770192860127, lr: [0.000713640138385537]


100%|██████████| 8/8 [00:00<00:00, 42.78it/s]


--val acc: 0.6087824351297405


100%|██████████| 31/31 [00:00<00:00, 116.55it/s]


[epoch 7] loss: 0.44484477985404924, lr: [0.0008756681970810574]


100%|██████████| 8/8 [00:00<00:00, 44.12it/s]


--val acc: 0.6187624750499002


100%|██████████| 31/31 [00:00<00:00, 114.56it/s]


[epoch 8] loss: 0.3951479232239866, lr: [0.0010427072934917855]


100%|██████████| 8/8 [00:00<00:00, 45.01it/s]


--val acc: 0.5768463073852296


100%|██████████| 31/31 [00:00<00:00, 120.79it/s]


[epoch 9] loss: 0.34314741393525205, lr: [0.001209663835280994]


100%|██████████| 8/8 [00:00<00:00, 45.17it/s]


--val acc: 0.562874251497006


100%|██████████| 31/31 [00:00<00:00, 116.81it/s]


[epoch 10] loss: 0.29136150706551983, lr: [0.0013714467474842179]


100%|██████████| 8/8 [00:00<00:00, 46.55it/s]


--val acc: 0.812375249500998


100%|██████████| 31/31 [00:00<00:00, 113.30it/s]


[epoch 11] loss: 0.2404968438748114, lr: [0.001523122716766323]


100%|██████████| 8/8 [00:00<00:00, 47.11it/s]


--val acc: 0.812375249500998


100%|██████████| 31/31 [00:00<00:00, 118.63it/s]


[epoch 12] loss: 0.1921576102098781, lr: [0.0016600666249878945]


100%|██████████| 8/8 [00:00<00:00, 43.59it/s]


--val acc: 0.7485029940119761


100%|██████████| 31/31 [00:00<00:00, 119.20it/s]


[epoch 13] loss: 0.1526197214088516, lr: [0.0017781025848478931]


100%|██████████| 8/8 [00:00<00:00, 44.60it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 115.61it/s]


[epoch 14] loss: 0.12159512642614856, lr: [0.001873631276944334]


100%|██████████| 8/8 [00:00<00:00, 45.52it/s]


--val acc: 0.7325349301397206


100%|██████████| 31/31 [00:00<00:00, 117.95it/s]


[epoch 15] loss: 0.0957515775086637, lr: [0.0019437397053112704]


100%|██████████| 8/8 [00:00<00:00, 44.90it/s]


--val acc: 0.8323353293413174


100%|██████████| 31/31 [00:00<00:00, 114.40it/s]


[epoch 16] loss: 0.07724306707134741, lr: [0.0019862900246110362]


100%|██████████| 8/8 [00:00<00:00, 44.10it/s]


--val acc: 0.9860279441117764


100%|██████████| 31/31 [00:00<00:00, 121.53it/s]


[epoch 17] loss: 0.06345497098511564, lr: [0.0019999970889756826]


100%|██████████| 8/8 [00:00<00:00, 46.69it/s]


--val acc: 0.9041916167664671


100%|██████████| 31/31 [00:00<00:00, 118.07it/s]


[epoch 18] loss: 0.05048778409253575, lr: [0.0019970205903154987]


100%|██████████| 8/8 [00:00<00:00, 46.46it/s]


--val acc: 0.7544910179640718


100%|██████████| 31/31 [00:00<00:00, 117.32it/s]


[epoch 19] loss: 0.044440038189916556, lr: [0.001988468370454814]


100%|██████████| 8/8 [00:00<00:00, 45.67it/s]


--val acc: 0.9001996007984032


100%|██████████| 31/31 [00:00<00:00, 115.89it/s]


[epoch 20] loss: 0.035689741016148094, lr: [0.0019743882568761923]


100%|██████████| 8/8 [00:00<00:00, 44.43it/s]


--val acc: 0.7664670658682635


100%|██████████| 31/31 [00:00<00:00, 116.55it/s]


[epoch 21] loss: 0.033431134954422057, lr: [0.0019548589912861884]


100%|██████████| 8/8 [00:00<00:00, 44.61it/s]


--val acc: 0.8043912175648703


100%|██████████| 31/31 [00:00<00:00, 115.03it/s]


[epoch 22] loss: 0.029249417627167082, lr: [0.0019299897892597869]


100%|██████████| 8/8 [00:00<00:00, 44.66it/s]


--val acc: 0.9880239520958084


100%|██████████| 31/31 [00:00<00:00, 118.96it/s]


[epoch 23] loss: 0.024448536232322036, lr: [0.0018999197294626046]


100%|██████████| 8/8 [00:00<00:00, 43.53it/s]


--val acc: 0.8622754491017964


100%|██████████| 31/31 [00:00<00:00, 116.35it/s]


[epoch 24] loss: 0.024285840774010754, lr: [0.0018648169758665749]


100%|██████████| 8/8 [00:00<00:00, 44.99it/s]


--val acc: 0.9740518962075848


100%|██████████| 31/31 [00:00<00:00, 115.17it/s]


[epoch 25] loss: 0.021858429748141124, lr: [0.001824877837308805]


100%|██████████| 8/8 [00:00<00:00, 42.69it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 121.26it/s]


[epoch 26] loss: 0.02071501386022853, lr: [0.0017803256696529279]


100%|██████████| 8/8 [00:00<00:00, 41.25it/s]


--val acc: 0.8263473053892215


100%|██████████| 31/31 [00:00<00:00, 117.51it/s]


[epoch 27] loss: 0.01894404669245798, lr: [0.0017314096266925114]


100%|██████████| 8/8 [00:00<00:00, 43.13it/s]


--val acc: 0.7544910179640718


100%|██████████| 31/31 [00:00<00:00, 119.63it/s]


[epoch 28] loss: 0.02012219537280039, lr: [0.0016784032667819782]


100%|██████████| 8/8 [00:00<00:00, 43.62it/s]


--val acc: 0.9880239520958084


100%|██████████| 31/31 [00:00<00:00, 120.67it/s]


[epoch 29] loss: 0.015399014729701592, lr: [0.0016216030229873227]


100%|██████████| 8/8 [00:00<00:00, 47.39it/s]


--val acc: 0.9740518962075848


100%|██████████| 31/31 [00:00<00:00, 119.20it/s]


[epoch 30] loss: 0.014803384116547788, lr: [0.0015613265453121616]


100%|██████████| 8/8 [00:00<00:00, 45.21it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 117.76it/s]


[epoch 31] loss: 0.013804505476932565, lr: [0.0014979109242700627]


100%|██████████| 8/8 [00:00<00:00, 46.79it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 109.54it/s]


[epoch 32] loss: 0.012414205440147194, lr: [0.0014317108057376557]


100%|██████████| 8/8 [00:00<00:00, 47.37it/s]


--val acc: 0.7764471057884231


100%|██████████| 31/31 [00:00<00:00, 115.64it/s]


[epoch 33] loss: 0.01331638809509144, lr: [0.0013630964076310345]


100%|██████████| 8/8 [00:00<00:00, 45.36it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 114.50it/s]


[epoch 34] loss: 0.011496795731747222, lr: [0.001292451449496993]


100%|██████████| 8/8 [00:00<00:00, 45.03it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 122.18it/s]


[epoch 35] loss: 0.011354021906495808, lr: [0.0012201710065976711]


100%|██████████| 8/8 [00:00<00:00, 45.97it/s]


--val acc: 0.8423153692614771


100%|██████████| 31/31 [00:00<00:00, 108.64it/s]


[epoch 36] loss: 0.009811937496690694, lr: [0.0011466593004894302]


100%|██████████| 8/8 [00:00<00:00, 43.43it/s]


--val acc: 0.9900199600798403


100%|██████████| 31/31 [00:00<00:00, 120.47it/s]


[epoch 37] loss: 0.00976536077951005, lr: [0.0010723274384519426]


100%|██████████| 8/8 [00:00<00:00, 44.95it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 110.62it/s]


[epoch 38] loss: 0.009628876359638815, lr: [0.0009975911144095228]


100%|██████████| 8/8 [00:00<00:00, 44.87it/s]


--val acc: 0.8642714570858283


100%|██████████| 31/31 [00:00<00:00, 115.30it/s]


[epoch 39] loss: 0.007932956079523006, lr: [0.0009228682842020823]


100%|██████████| 8/8 [00:00<00:00, 45.28it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 110.82it/s]


[epoch 40] loss: 0.008783913420108026, lr: [0.0008485768282065336]


100%|██████████| 8/8 [00:00<00:00, 45.35it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 110.99it/s]


[epoch 41] loss: 0.008713963785452281, lr: [0.000775132214380214]


100%|██████████| 8/8 [00:00<00:00, 43.34it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 119.67it/s]


[epoch 42] loss: 0.007424138768942295, lr: [0.0007029451747955407]


100%|██████████| 8/8 [00:00<00:00, 44.65it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 117.73it/s]


[epoch 43] loss: 0.0074828604529955664, lr: [0.0006324194086596515]


100%|██████████| 8/8 [00:00<00:00, 43.95it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 120.19it/s]


[epoch 44] loss: 0.00678955925676875, lr: [0.0005639493246646838]


100%|██████████| 8/8 [00:00<00:00, 42.72it/s]


--val acc: 0.9920159680638723


100%|██████████| 31/31 [00:00<00:00, 118.87it/s]


[epoch 45] loss: 0.006943658633741314, lr: [0.0004979178352943804]


100%|██████████| 8/8 [00:00<00:00, 42.37it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 121.37it/s]


[epoch 46] loss: 0.006648941802050539, lr: [0.0004346942154221573]


100%|██████████| 8/8 [00:00<00:00, 43.91it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 123.83it/s]


[epoch 47] loss: 0.006623165648497507, lr: [0.0003746320371762205]


100%|██████████| 8/8 [00:00<00:00, 44.59it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 119.45it/s]


[epoch 48] loss: 0.0067759889834179375, lr: [0.00031806719262080115]


100%|██████████| 8/8 [00:00<00:00, 42.66it/s]


--val acc: 0.9960079840319361


100%|██████████| 31/31 [00:00<00:00, 115.02it/s]


[epoch 49] loss: 0.0068168443833996435, lr: [0.0002653160153114834]


100%|██████████| 8/8 [00:00<00:00, 46.67it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 116.36it/s]


[epoch 50] loss: 0.006946605747331402, lr: [0.00021667351122964338]


100%|██████████| 8/8 [00:00<00:00, 46.04it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 115.60it/s]


[epoch 51] loss: 0.006208702683924677, lr: [0.00017241170898933824]


100%|██████████| 8/8 [00:00<00:00, 46.75it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 114.89it/s]


[epoch 52] loss: 0.006625692399140603, lr: [0.00013277813854294793]


100%|██████████| 8/8 [00:00<00:00, 42.52it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 115.61it/s]


[epoch 53] loss: 0.006256418492265804, lr: [9.799444689327728e-05]


100%|██████████| 8/8 [00:00<00:00, 42.75it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 117.08it/s]


[epoch 54] loss: 0.006575963930217568, lr: [6.82551585536054e-05]


100%|██████████| 8/8 [00:00<00:00, 40.92it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 113.97it/s]


[epoch 55] loss: 0.006724094477539767, lr: [4.372658768770261e-05]


100%|██████████| 8/8 [00:00<00:00, 41.97it/s]


--val acc: 0.998003992015968


100%|██████████| 31/31 [00:00<00:00, 116.99it/s]


[epoch 56] loss: 0.006495536451865575, lr: [2.454590801356266e-05]


100%|██████████| 8/8 [00:00<00:00, 44.14it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 114.17it/s]


[epoch 57] loss: 0.0065240481061611824, lr: [1.0820385672328905e-05]


100%|██████████| 8/8 [00:00<00:00, 44.95it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 118.77it/s]


[epoch 58] loss: 0.005875594735502483, lr: [2.6267793525220675e-06]


100%|██████████| 8/8 [00:00<00:00, 44.38it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 117.08it/s]


[epoch 59] loss: 0.006377458393930675, lr: [1.0911024317473058e-08]


100%|██████████| 8/8 [00:00<00:00, 43.76it/s]

--val acc: 1.0





In [18]:
from sklearn.metrics import confusion_matrix

In [19]:
confusion_matrix(all_valid_label, all_valid_pred)

array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])

### base

acc: 0.872255489021956


```python
array([[ 98,   8,   0,   0,   0],
       [ 14,  57,  34,   0,   0],
       [  0,   8, 108,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### label smoothing / 0.1

acc: 0.8343313373253493

```python
array([[ 97,   8,   1,   0,   0],
       [ 19,  39,  47,   0,   0],
       [  2,   6, 108,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### batchnorm

acc: 1.0

```python
array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

In [21]:
torch.onnx.export(simple_model, input_[:1], "simple_model.onnx", verbose=True, input_names=['input'], output_names=['output'])

graph(%input : Float(1, 42, strides=[42, 1], requires_grad=0, device=cuda:0),
      %0.weight : Float(20, 42, strides=[42, 1], requires_grad=1, device=cuda:0),
      %0.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.weight : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.running_mean : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %1.running_var : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %3.weight : Float(10, 20, strides=[20, 1], requires_grad=1, device=cuda:0),
      %3.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.weight : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.running_mean : Float(10, strides=[1], requires_grad=0, device=cuda:0),
      %4.running_var : Float(10, strides=[1], requires_grad=0, device=cuda:0),
    

In [3]:
import onnxruntime as ort
from tqdm import tqdm
import numpy as np

ort_session = ort.InferenceSession("simple_model.onnx")

for _ in tqdm(range(100)):
    outputs = ort_session.run(
        ['output'],
        {"input": np.random.randn(1, 42).astype(np.float32)},
    )
    

100%|██████████| 100/100 [00:00<00:00, 56043.61it/s]
