In [1]:
import cv2
import json
import numpy as np
import random
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

# set seed

In [2]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7fd3b72807d0>

# dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, json_path_list):
        self.annotations = []
        for json_path in json_path_list:
            with open(json_path, 'r') as f:
                annotations = json.load(f)
            self.annotations.extend(annotations['annotations'])
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        anno = self.annotations[idx]
        label = anno['label']
        label = np.eye(5)[label].astype(np.float32)
        
        # label smoothing
        label = label * .9 + 0.1 / 5 
        
        kpts_ = np.array(anno['kpts']).reshape(-1, 2)
        kpts_ = preprocessing(kpts_)
        
        return label, kpts_.astype(np.float32)
    
    
def preprocessing(kpts):
    kpts_ = deepcopy(kpts)
#     origin_point = kpts_[0].copy()
#     for i in range(21):
#         kpts_[i] -= origin_point
    kpts_ = (kpts_ - kpts_.min(axis=0)) / (kpts_.max(axis=0) - kpts_.min(axis=0))
    return kpts_.flatten()

In [4]:
ds = CustomDataset([f'./data/annotation_{i}.json' for i in range(5)])
len(ds)

2505

In [5]:
ds[0]

(array([0.91999996, 0.02      , 0.02      , 0.02      , 0.02      ],
       dtype=float32),
 array([0.21594635, 1.        , 0.443099  , 0.9287557 , 0.6469579 ,
        0.7714337 , 0.8361737 , 0.6478518 , 1.        , 0.57608795,
        0.49768278, 0.4669195 , 0.509404  , 0.2612543 , 0.503457  ,
        0.12828362, 0.49311274, 0.        , 0.34171033, 0.47398508,
        0.27542248, 0.46475977, 0.304567  , 0.6385949 , 0.35247087,
        0.6362982 , 0.19208448, 0.5076398 , 0.12309471, 0.5348637 ,
        0.18273547, 0.69105816, 0.24475259, 0.6733748 , 0.05024889,
        0.5485357 , 0.        , 0.564812  , 0.06049323, 0.6850807 ,
        0.11008564, 0.7059657 ], dtype=float32))

In [6]:
total = len(ds)
num_train = int(total * .8)
ds_train, ds_valid = torch.utils.data.random_split(ds, [num_train, total - num_train], generator=g)
print(f'train: {len(ds_train)} / valid: {len(ds_valid)}')

train: 2004 / valid: 501


In [7]:
labels, cnt = np.unique([ds_valid.dataset.annotations[i]['label'] for i in ds_valid.indices], return_counts=True)
print({l:c for l, c in zip(labels, cnt)})

{0: 106, 1: 105, 2: 116, 3: 84, 4: 90}


# model

In [8]:
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)

In [9]:
output.shape

torch.Size([20, 100])

In [10]:
simple_model = nn.Sequential(
    nn.Linear(42, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(True,),
    nn.Linear(20, 10),
    nn.BatchNorm1d(10),
    nn.ReLU(True,),
    nn.Linear(10, 5),
    nn.Sigmoid(),
)

In [11]:
test = simple_model(torch.randn(2, 42))
print('shape:', test.size())

shape: torch.Size([2, 5])


# training

In [12]:
EPOCHS = 60
BS = 64

In [13]:
train_loader = DataLoader(ds_train, BS, shuffle=True, drop_last=True, num_workers=8)
valid_loader = DataLoader(ds_valid, BS, shuffle=False, drop_last=False, num_workers=8)

In [14]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(simple_model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-3, epochs=EPOCHS, steps_per_epoch=len(train_loader))

baseline: 0.872255489021956

label smoothing: 0.8343313373253493

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
simple_model = simple_model.to(device)

all_valid_pred = []
all_valid_label = []

for epoch in range(EPOCHS):
    simple_model.train()
    epoch_loss = 0.
    for label, input_ in tqdm(train_loader):
        label, input_ = label.to(device), input_.to(device)
        
        pred = simple_model(input_)
        loss = criterion(pred, label)
        
        epoch_loss += loss.detach().item() * label.size()[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f'[epoch {epoch}] loss: {epoch_loss / len(ds_train)}, lr: {scheduler.get_lr()}')
    
    simple_model.eval()
    
    correct = 0
    for label, input_ in tqdm(valid_loader):
        label, input_ = label.to(device), input_.to(device)
        
        with torch.no_grad():
            pred = simple_model(input_)
            cls_pred = pred.argmax(dim=-1)
            cls_true = label.argmax(dim=-1)
            
            correct += (cls_pred == cls_true).sum().item()
            
            if epoch == EPOCHS - 1:
                all_valid_pred.extend(cls_pred.detach().cpu().numpy().tolist())
                all_valid_label.extend(cls_true.detach().cpu().numpy().tolist())
            
    print(f'--val acc: {correct / len(ds_valid)}')

100%|██████████| 31/31 [00:00<00:00, 50.34it/s]


[epoch 0] loss: 0.6894816419559563, lr: [9.463683876508033e-05]


100%|██████████| 8/8 [00:00<00:00, 45.52it/s]


--val acc: 0.21157684630738524


100%|██████████| 31/31 [00:00<00:00, 125.07it/s]


[epoch 1] loss: 0.68439268542383, lr: [0.00013810102787483098]


100%|██████████| 8/8 [00:00<00:00, 47.30it/s]


--val acc: 0.21157684630738524


100%|██████████| 31/31 [00:00<00:00, 124.47it/s]


[epoch 2] loss: 0.6770604527639058, lr: [0.00020906719581247307]


100%|██████████| 8/8 [00:00<00:00, 45.08it/s]


--val acc: 0.21157684630738524


100%|██████████| 31/31 [00:00<00:00, 127.28it/s]


[epoch 3] loss: 0.663419965260519, lr: [0.0003053713418324246]


100%|██████████| 8/8 [00:00<00:00, 43.23it/s]


--val acc: 0.21157684630738524


100%|██████████| 31/31 [00:00<00:00, 118.56it/s]


[epoch 4] loss: 0.6369798835404142, lr: [0.00042407682373113865]


100%|██████████| 8/8 [00:00<00:00, 43.67it/s]


--val acc: 0.21157684630738524


100%|██████████| 31/31 [00:00<00:00, 131.85it/s]


[epoch 5] loss: 0.581941439006143, lr: [0.0005615639060938596]


100%|██████████| 8/8 [00:00<00:00, 46.51it/s]


--val acc: 0.23952095808383234


100%|██████████| 31/31 [00:00<00:00, 121.96it/s]


[epoch 6] loss: 0.5117321157169913, lr: [0.000713640138385537]


100%|██████████| 8/8 [00:00<00:00, 42.65it/s]


--val acc: 0.20958083832335328


100%|██████████| 31/31 [00:00<00:00, 118.45it/s]


[epoch 7] loss: 0.48718815744518046, lr: [0.0008756681970810574]


100%|██████████| 8/8 [00:00<00:00, 48.09it/s]


--val acc: 0.3393213572854291


100%|██████████| 31/31 [00:00<00:00, 132.59it/s]


[epoch 8] loss: 0.4764466200046197, lr: [0.0010427072934917855]


100%|██████████| 8/8 [00:00<00:00, 45.02it/s]


--val acc: 0.28542914171656686


100%|██████████| 31/31 [00:00<00:00, 121.66it/s]


[epoch 9] loss: 0.45954243722790017, lr: [0.001209663835280994]


100%|██████████| 8/8 [00:00<00:00, 46.50it/s]


--val acc: 0.4171656686626746


100%|██████████| 31/31 [00:00<00:00, 132.40it/s]


[epoch 10] loss: 0.43578153598808245, lr: [0.0013714467474842179]


100%|██████████| 8/8 [00:00<00:00, 44.62it/s]


--val acc: 0.5029940119760479


100%|██████████| 31/31 [00:00<00:00, 124.86it/s]


[epoch 11] loss: 0.4092585702618201, lr: [0.001523122716766323]


100%|██████████| 8/8 [00:00<00:00, 45.10it/s]


--val acc: 0.6526946107784432


100%|██████████| 31/31 [00:00<00:00, 121.63it/s]


[epoch 12] loss: 0.38351643918279166, lr: [0.0016600666249878945]


100%|██████████| 8/8 [00:00<00:00, 43.56it/s]


--val acc: 0.7305389221556886


100%|██████████| 31/31 [00:00<00:00, 124.17it/s]


[epoch 13] loss: 0.35646161918868563, lr: [0.0017781025848478931]


100%|██████████| 8/8 [00:00<00:00, 45.57it/s]


--val acc: 0.7624750499001997


100%|██████████| 31/31 [00:00<00:00, 125.19it/s]


[epoch 14] loss: 0.3266128960721745, lr: [0.001873631276944334]


100%|██████████| 8/8 [00:00<00:00, 45.51it/s]


--val acc: 0.7784431137724551


100%|██████████| 31/31 [00:00<00:00, 125.88it/s]


[epoch 15] loss: 0.3005433501359707, lr: [0.0019437397053112704]


100%|██████████| 8/8 [00:00<00:00, 46.80it/s]


--val acc: 0.7944111776447106


100%|██████████| 31/31 [00:00<00:00, 125.64it/s]


[epoch 16] loss: 0.28625380445621207, lr: [0.0019862900246110362]


100%|██████████| 8/8 [00:00<00:00, 45.71it/s]


--val acc: 0.8083832335329342


100%|██████████| 31/31 [00:00<00:00, 128.31it/s]


[epoch 17] loss: 0.278536624299314, lr: [0.0019999970889756826]


100%|██████████| 8/8 [00:00<00:00, 45.49it/s]


--val acc: 0.8063872255489022


100%|██████████| 31/31 [00:00<00:00, 129.36it/s]


[epoch 18] loss: 0.27161041562428734, lr: [0.0019970205903154987]


100%|██████████| 8/8 [00:00<00:00, 44.37it/s]


--val acc: 0.8183632734530938


100%|██████████| 31/31 [00:00<00:00, 120.51it/s]


[epoch 19] loss: 0.2666293166116802, lr: [0.001988468370454814]


100%|██████████| 8/8 [00:00<00:00, 43.67it/s]


--val acc: 0.7784431137724551


100%|██████████| 31/31 [00:00<00:00, 130.82it/s]


[epoch 20] loss: 0.261663734793901, lr: [0.0019743882568761923]


100%|██████████| 8/8 [00:00<00:00, 43.98it/s]


--val acc: 0.812375249500998


100%|██████████| 31/31 [00:00<00:00, 128.99it/s]


[epoch 21] loss: 0.2570558484204991, lr: [0.0019548589912861884]


100%|██████████| 8/8 [00:00<00:00, 49.64it/s]


--val acc: 0.8702594810379242


100%|██████████| 31/31 [00:00<00:00, 128.66it/s]


[epoch 22] loss: 0.25311145239961363, lr: [0.0019299897892597869]


100%|██████████| 8/8 [00:00<00:00, 44.48it/s]


--val acc: 0.8922155688622755


100%|██████████| 31/31 [00:00<00:00, 126.24it/s]


[epoch 23] loss: 0.24915134692620375, lr: [0.0018999197294626046]


100%|██████████| 8/8 [00:00<00:00, 48.01it/s]


--val acc: 0.8862275449101796


100%|██████████| 31/31 [00:00<00:00, 132.28it/s]


[epoch 24] loss: 0.24603850494125884, lr: [0.0018648169758665749]


100%|██████████| 8/8 [00:00<00:00, 45.49it/s]


--val acc: 0.8942115768463074


100%|██████████| 31/31 [00:00<00:00, 117.55it/s]


[epoch 25] loss: 0.2437073979787008, lr: [0.001824877837308805]


100%|██████████| 8/8 [00:00<00:00, 45.65it/s]


--val acc: 0.8922155688622755


100%|██████████| 31/31 [00:00<00:00, 124.28it/s]


[epoch 26] loss: 0.24129995566880155, lr: [0.0017803256696529279]


100%|██████████| 8/8 [00:00<00:00, 40.27it/s]


--val acc: 0.8882235528942116


100%|██████████| 31/31 [00:00<00:00, 114.82it/s]


[epoch 27] loss: 0.23948174132082514, lr: [0.0017314096266925114]


100%|██████████| 8/8 [00:00<00:00, 44.29it/s]


--val acc: 0.9021956087824351


100%|██████████| 31/31 [00:00<00:00, 126.26it/s]


[epoch 28] loss: 0.23713299661815285, lr: [0.0016784032667819782]


100%|██████████| 8/8 [00:00<00:00, 43.22it/s]


--val acc: 0.906187624750499


100%|██████████| 31/31 [00:00<00:00, 122.16it/s]


[epoch 29] loss: 0.23454962090817755, lr: [0.0016216030229873227]


100%|██████████| 8/8 [00:00<00:00, 45.16it/s]


--val acc: 0.9101796407185628


100%|██████████| 31/31 [00:00<00:00, 126.52it/s]


[epoch 30] loss: 0.2316515279149343, lr: [0.0015613265453121616]


100%|██████████| 8/8 [00:00<00:00, 43.46it/s]


--val acc: 0.9101796407185628


100%|██████████| 31/31 [00:00<00:00, 120.45it/s]


[epoch 31] loss: 0.22953037159171646, lr: [0.0014979109242700627]


100%|██████████| 8/8 [00:00<00:00, 45.98it/s]


--val acc: 0.9121756487025948


100%|██████████| 31/31 [00:00<00:00, 119.00it/s]


[epoch 32] loss: 0.22717345117808815, lr: [0.0014317108057376557]


100%|██████████| 8/8 [00:00<00:00, 45.53it/s]


--val acc: 0.9121756487025948


100%|██████████| 31/31 [00:00<00:00, 119.81it/s]


[epoch 33] loss: 0.2255994516932322, lr: [0.0013630964076310345]


100%|██████████| 8/8 [00:00<00:00, 45.73it/s]


--val acc: 0.9181636726546906


100%|██████████| 31/31 [00:00<00:00, 127.03it/s]


[epoch 34] loss: 0.22279949054984513, lr: [0.001292451449496993]


100%|██████████| 8/8 [00:00<00:00, 43.51it/s]


--val acc: 0.9141716566866267


100%|██████████| 31/31 [00:00<00:00, 128.21it/s]


[epoch 35] loss: 0.22066681161374152, lr: [0.0012201710065976711]


100%|██████████| 8/8 [00:00<00:00, 42.92it/s]


--val acc: 0.9201596806387226


100%|██████████| 31/31 [00:00<00:00, 126.83it/s]


[epoch 36] loss: 0.21905661961751546, lr: [0.0011466593004894302]


100%|██████████| 8/8 [00:00<00:00, 46.02it/s]


--val acc: 0.9321357285429142


100%|██████████| 31/31 [00:00<00:00, 116.06it/s]


[epoch 37] loss: 0.216508431348972, lr: [0.0010723274384519426]


100%|██████████| 8/8 [00:00<00:00, 47.20it/s]


--val acc: 0.9301397205588823


100%|██████████| 31/31 [00:00<00:00, 129.23it/s]


[epoch 38] loss: 0.21511101627540208, lr: [0.0009975911144095228]


100%|██████████| 8/8 [00:00<00:00, 44.52it/s]


--val acc: 0.9341317365269461


100%|██████████| 31/31 [00:00<00:00, 134.41it/s]


[epoch 39] loss: 0.2133302902747057, lr: [0.0009228682842020823]


100%|██████████| 8/8 [00:00<00:00, 44.82it/s]


--val acc: 0.9321357285429142


100%|██████████| 31/31 [00:00<00:00, 119.05it/s]


[epoch 40] loss: 0.21166707703215396, lr: [0.0008485768282065336]


100%|██████████| 8/8 [00:00<00:00, 44.95it/s]


--val acc: 0.9401197604790419


100%|██████████| 31/31 [00:00<00:00, 132.80it/s]


[epoch 41] loss: 0.21055200047597675, lr: [0.000775132214380214]


100%|██████████| 8/8 [00:00<00:00, 46.49it/s]


--val acc: 0.936127744510978


100%|██████████| 31/31 [00:00<00:00, 125.11it/s]


[epoch 42] loss: 0.2090388267578003, lr: [0.0007029451747955407]


100%|██████████| 8/8 [00:00<00:00, 44.90it/s]


--val acc: 0.9341317365269461


100%|██████████| 31/31 [00:00<00:00, 126.92it/s]


[epoch 43] loss: 0.20787220800708153, lr: [0.0006324194086596515]


100%|██████████| 8/8 [00:00<00:00, 49.61it/s]


--val acc: 0.936127744510978


100%|██████████| 31/31 [00:00<00:00, 128.51it/s]


[epoch 44] loss: 0.20623848205079098, lr: [0.0005639493246646838]


100%|██████████| 8/8 [00:00<00:00, 43.28it/s]


--val acc: 0.93812375249501


100%|██████████| 31/31 [00:00<00:00, 122.19it/s]


[epoch 45] loss: 0.20572065878771023, lr: [0.0004979178352943804]


100%|██████████| 8/8 [00:00<00:00, 46.38it/s]


--val acc: 0.9401197604790419


100%|██████████| 31/31 [00:00<00:00, 124.34it/s]


[epoch 46] loss: 0.2050641149341941, lr: [0.0004346942154221573]


100%|██████████| 8/8 [00:00<00:00, 45.01it/s]


--val acc: 0.9441117764471058


100%|██████████| 31/31 [00:00<00:00, 125.97it/s]


[epoch 47] loss: 0.20436975199305368, lr: [0.0003746320371762205]


100%|██████████| 8/8 [00:00<00:00, 46.24it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 125.04it/s]


[epoch 48] loss: 0.20359871201886387, lr: [0.00031806719262080115]


100%|██████████| 8/8 [00:00<00:00, 43.22it/s]


--val acc: 0.9441117764471058


100%|██████████| 31/31 [00:00<00:00, 124.80it/s]


[epoch 49] loss: 0.20308341951427344, lr: [0.0002653160153114834]


100%|██████████| 8/8 [00:00<00:00, 44.59it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 129.40it/s]


[epoch 50] loss: 0.20261712464505802, lr: [0.00021667351122964338]


100%|██████████| 8/8 [00:00<00:00, 46.61it/s]


--val acc: 0.9441117764471058


100%|██████████| 31/31 [00:00<00:00, 122.37it/s]


[epoch 51] loss: 0.20213637808839718, lr: [0.00017241170898933824]


100%|██████████| 8/8 [00:00<00:00, 45.76it/s]


--val acc: 0.9461077844311377


100%|██████████| 31/31 [00:00<00:00, 127.27it/s]


[epoch 52] loss: 0.2018530425911178, lr: [0.00013277813854294793]


100%|██████████| 8/8 [00:00<00:00, 46.61it/s]


--val acc: 0.9461077844311377


100%|██████████| 31/31 [00:00<00:00, 126.22it/s]


[epoch 53] loss: 0.20158510151023637, lr: [9.799444689327728e-05]


100%|██████████| 8/8 [00:00<00:00, 44.44it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 122.19it/s]


[epoch 54] loss: 0.20131941565020595, lr: [6.82551585536054e-05]


100%|██████████| 8/8 [00:00<00:00, 44.03it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 127.39it/s]


[epoch 55] loss: 0.20153347365632504, lr: [4.372658768770261e-05]


100%|██████████| 8/8 [00:00<00:00, 44.98it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 120.64it/s]


[epoch 56] loss: 0.20131207226279252, lr: [2.454590801356266e-05]


100%|██████████| 8/8 [00:00<00:00, 43.93it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 120.53it/s]


[epoch 57] loss: 0.20133521076210006, lr: [1.0820385672328905e-05]


100%|██████████| 8/8 [00:00<00:00, 45.44it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 126.47it/s]


[epoch 58] loss: 0.20116524115769924, lr: [2.6267793525220675e-06]


100%|██████████| 8/8 [00:00<00:00, 46.94it/s]


--val acc: 0.9481037924151696


100%|██████████| 31/31 [00:00<00:00, 118.14it/s]


[epoch 59] loss: 0.200939824719153, lr: [1.0911024317473058e-08]


100%|██████████| 8/8 [00:00<00:00, 44.25it/s]

--val acc: 0.9481037924151696





In [16]:
from sklearn.metrics import confusion_matrix

In [17]:
confusion_matrix(all_valid_label, all_valid_pred)

array([[101,   4,   1,   0,   0],
       [  1,  99,   5,   0,   0],
       [  0,  15, 101,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])

### base

acc: 0.872255489021956


```python
array([[ 98,   8,   0,   0,   0],
       [ 14,  57,  34,   0,   0],
       [  0,   8, 108,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### label smoothing / 0.1

acc: 0.8343313373253493

```python
array([[ 97,   8,   1,   0,   0],
       [ 19,  39,  47,   0,   0],
       [  2,   6, 108,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### batchnorm

acc: 1.0

```python
array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

In [18]:
import onnxruntime as ort
from tqdm import tqdm
import numpy as np

ort_session = ort.InferenceSession("simple_model.onnx")

for _ in tqdm(range(100)):
    outputs = ort_session.run(
        ['output'],
        {"input": np.random.randn(1, 42).astype(np.float32)},
    )
    

100%|██████████| 100/100 [00:00<00:00, 32246.51it/s]
