In [1]:
import cv2
import json
import numpy as np
import random
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

# set seed

In [2]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x7f6c2feca810>

# dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, json_path_list):
        self.annotations = []
        for json_path in json_path_list:
            with open(json_path, 'r') as f:
                annotations = json.load(f)
            self.annotations.extend(annotations['annotations'])
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        anno = self.annotations[idx]
        label = anno['label']
        label = np.eye(5)[label].astype(np.float32)
        
        # label smoothing
#         label = label * .9 + 0.1 / 5 
        
        kpts_ = np.array(anno['kpts']).reshape(-1, 2)
        kpts_ = preprocessing(kpts_)
        
        return label, kpts_.astype(np.float32)
    
    
def preprocessing(kpts):
    kpts_ = deepcopy(kpts)
    kpts_ = (kpts_ - kpts_.min(axis=0)) / (kpts_.max(axis=0) - kpts_.min(axis=0))
    return kpts_.flatten()

In [4]:
ds = CustomDataset([f'./data/annotation_{i}.json' for i in range(5)])
len(ds)

2513

In [5]:
ds[0]

(array([1., 0., 0., 0., 0.], dtype=float32),
 array([0.21594635, 1.        , 0.443099  , 0.9287557 , 0.6469579 ,
        0.7714337 , 0.8361737 , 0.6478518 , 1.        , 0.57608795,
        0.49768278, 0.4669195 , 0.509404  , 0.2612543 , 0.503457  ,
        0.12828362, 0.49311274, 0.        , 0.34171033, 0.47398508,
        0.27542248, 0.46475977, 0.304567  , 0.6385949 , 0.35247087,
        0.6362982 , 0.19208448, 0.5076398 , 0.12309471, 0.5348637 ,
        0.18273547, 0.69105816, 0.24475259, 0.6733748 , 0.05024889,
        0.5485357 , 0.        , 0.564812  , 0.06049323, 0.6850807 ,
        0.11008564, 0.7059657 ], dtype=float32))

In [6]:
total = len(ds)
num_train = int(total * .8)
ds_train, ds_valid = torch.utils.data.random_split(ds, [num_train, total - num_train], generator=g)
print(f'train: {len(ds_train)} / valid: {len(ds_valid)}')

train: 2010 / valid: 503


In [7]:
labels, cnt = np.unique([ds_valid.dataset.annotations[i]['label'] for i in ds_valid.indices], return_counts=True)
print({l:c for l, c in zip(labels, cnt)})

{0: 110, 1: 98, 2: 100, 3: 100, 4: 95}


# model

In [8]:
m = nn.BatchNorm1d(100)
# Without Learnable Parameters
m = nn.BatchNorm1d(100, affine=False)
input = torch.randn(20, 100)
output = m(input)

In [9]:
output.shape

torch.Size([20, 100])

In [10]:
simple_model = nn.Sequential(
    nn.Linear(42, 20),
    nn.BatchNorm1d(20),
    nn.ReLU(True,),
    nn.Linear(20, 10),
    nn.BatchNorm1d(10),
    nn.ReLU(True,),
    nn.Linear(10, 5),
    nn.Sigmoid(),
)

In [11]:
test = simple_model(torch.randn(2, 42))
print('shape:', test.size())

shape: torch.Size([2, 5])


# training

In [13]:
EPOCHS = 40
BS = 64

In [14]:
train_loader = DataLoader(ds_train, BS, shuffle=True, drop_last=True, num_workers=8)
valid_loader = DataLoader(ds_valid, BS, shuffle=False, drop_last=False, num_workers=8)

In [15]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(simple_model.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-3, epochs=EPOCHS, steps_per_epoch=len(train_loader))

baseline: 0.872255489021956

label smoothing: 0.8343313373253493

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
simple_model = simple_model.to(device)

all_valid_pred = []
all_valid_label = []

for epoch in range(EPOCHS):
    simple_model.train()
    epoch_loss = 0.
    for label, input_ in tqdm(train_loader):
        label, input_ = label.to(device), input_.to(device)
        
        pred = simple_model(input_)
        loss = criterion(pred, label)
        
        epoch_loss += loss.detach().item() * label.size()[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f'[epoch {epoch}] loss: {epoch_loss / len(ds_train)}, lr: {scheduler.get_lr()}')
    
    simple_model.eval()
    
    correct = 0
    for label, input_ in tqdm(valid_loader):
        label, input_ = label.to(device), input_.to(device)
        
        with torch.no_grad():
            pred = simple_model(input_)
            cls_pred = pred.argmax(dim=-1)
            cls_true = label.argmax(dim=-1)
            
            correct += (cls_pred == cls_true).sum().item()
            
            if epoch == EPOCHS - 1:
                all_valid_pred.extend(cls_pred.detach().cpu().numpy().tolist())
                all_valid_label.extend(cls_true.detach().cpu().numpy().tolist())
            
    print(f'--val acc: {correct / len(ds_valid)}')

100%|██████████| 31/31 [00:00<00:00, 61.00it/s]


[epoch 0] loss: 0.7094271133195108, lr: [0.00011288677002285622]


100%|██████████| 8/8 [00:00<00:00, 55.51it/s]


--val acc: 0.24055666003976142


100%|██████████| 31/31 [00:00<00:00, 134.50it/s]


[epoch 1] loss: 0.688012661151032, lr: [0.00020929387250280779]


100%|██████████| 8/8 [00:00<00:00, 55.15it/s]


--val acc: 0.3399602385685885


100%|██████████| 31/31 [00:00<00:00, 137.83it/s]


[epoch 2] loss: 0.6566177463057029, lr: [0.0003626160611735461]


100%|██████████| 8/8 [00:00<00:00, 56.84it/s]


--val acc: 0.4532803180914513


100%|██████████| 31/31 [00:00<00:00, 143.83it/s]


[epoch 3] loss: 0.6129031261994471, lr: [0.0005623486036221222]


100%|██████████| 8/8 [00:00<00:00, 59.80it/s]


--val acc: 0.5944333996023857


100%|██████████| 31/31 [00:00<00:00, 110.57it/s]


[epoch 4] loss: 0.5627877192710763, lr: [0.0007948070036202307]


100%|██████████| 8/8 [00:00<00:00, 57.51it/s]


--val acc: 0.7037773359840954


100%|██████████| 31/31 [00:00<00:00, 141.75it/s]


[epoch 5] loss: 0.5090325179977797, lr: [0.0010440645821249144]


100%|██████████| 8/8 [00:00<00:00, 53.00it/s]


--val acc: 0.7395626242544732


100%|██████████| 31/31 [00:00<00:00, 138.22it/s]


[epoch 6] loss: 0.4503758814797473, lr: [0.0012930436794263047]


100%|██████████| 8/8 [00:00<00:00, 56.10it/s]


--val acc: 0.805168986083499


100%|██████████| 31/31 [00:00<00:00, 136.29it/s]


[epoch 7] loss: 0.38573763928010096, lr: [0.0015246857157047156]


100%|██████████| 8/8 [00:00<00:00, 55.03it/s]


--val acc: 0.8946322067594433


100%|██████████| 31/31 [00:00<00:00, 134.95it/s]


[epoch 8] loss: 0.3185216362796613, lr: [0.0017231199443461585]


100%|██████████| 8/8 [00:00<00:00, 57.32it/s]


--val acc: 0.9562624254473161


100%|██████████| 31/31 [00:00<00:00, 140.22it/s]


[epoch 9] loss: 0.25434332795404085, lr: [0.0018747508219298132]


100%|██████████| 8/8 [00:00<00:00, 53.76it/s]


--val acc: 0.952286282306163


100%|██████████| 31/31 [00:00<00:00, 135.72it/s]


[epoch 10] loss: 0.1994437421732281, lr: [0.0019691894947068095]


100%|██████████| 8/8 [00:00<00:00, 56.79it/s]


--val acc: 0.9801192842942346


100%|██████████| 31/31 [00:00<00:00, 146.79it/s]


[epoch 11] loss: 0.1555576343441484, lr: [0.001999993450199258]


100%|██████████| 8/8 [00:00<00:00, 54.75it/s]


--val acc: 0.9662027833001988


100%|██████████| 31/31 [00:00<00:00, 139.21it/s]


[epoch 12] loss: 0.12022176832702029, lr: [0.0019933004907124047]


100%|██████████| 8/8 [00:00<00:00, 54.91it/s]


--val acc: 0.9284294234592445


100%|██████████| 31/31 [00:00<00:00, 130.66it/s]


[epoch 13] loss: 0.09706383083590228, lr: [0.001974116251530795]


100%|██████████| 8/8 [00:00<00:00, 55.31it/s]


--val acc: 0.9880715705765407


100%|██████████| 31/31 [00:00<00:00, 137.16it/s]


[epoch 14] loss: 0.07927090848856304, lr: [0.001942681985593092]


100%|██████████| 8/8 [00:00<00:00, 57.36it/s]


--val acc: 0.9940357852882704


100%|██████████| 31/31 [00:00<00:00, 136.89it/s]


[epoch 15] loss: 0.06536436934969318, lr: [0.001899392997032049]


100%|██████████| 8/8 [00:00<00:00, 56.56it/s]


--val acc: 0.9681908548707754


100%|██████████| 31/31 [00:00<00:00, 142.00it/s]


[epoch 16] loss: 0.05886310997293956, lr: [0.0018447936699956763]


100%|██████████| 8/8 [00:00<00:00, 58.26it/s]


--val acc: 0.9840954274353877


100%|██████████| 31/31 [00:00<00:00, 140.69it/s]


[epoch 17] loss: 0.04699057477030588, lr: [0.0017795706227007246]


100%|██████████| 8/8 [00:00<00:00, 57.74it/s]


--val acc: 0.9920477137176938


100%|██████████| 31/31 [00:00<00:00, 141.09it/s]


[epoch 18] loss: 0.04517376244957767, lr: [0.0017045440728102209]


100%|██████████| 8/8 [00:00<00:00, 57.74it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 142.54it/s]


[epoch 19] loss: 0.039143166613222945, lr: [0.001620657522720457]


100%|██████████| 8/8 [00:00<00:00, 55.63it/s]


--val acc: 0.974155069582505


100%|██████████| 31/31 [00:00<00:00, 141.13it/s]


[epoch 20] loss: 0.03454575633528221, lr: [0.001528965894470921]


100%|██████████| 8/8 [00:00<00:00, 55.26it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 143.03it/s]


[epoch 21] loss: 0.03214210538721796, lr: [0.0014306222634875734]


100%|██████████| 8/8 [00:00<00:00, 56.21it/s]


--val acc: 0.9960238568588469


100%|██████████| 31/31 [00:00<00:00, 139.30it/s]


[epoch 22] loss: 0.03122433471442455, lr: [0.0013268633579903333]


100%|██████████| 8/8 [00:00<00:00, 56.75it/s]


--val acc: 0.9880715705765407


100%|██████████| 31/31 [00:00<00:00, 137.53it/s]


[epoch 23] loss: 0.02572015185854328, lr: [0.0012189940064181476]


100%|██████████| 8/8 [00:00<00:00, 56.65it/s]


--val acc: 0.9940357852882704


100%|██████████| 31/31 [00:00<00:00, 137.60it/s]


[epoch 24] loss: 0.024997190634409586, lr: [0.0011083707284542928]


100%|██████████| 8/8 [00:00<00:00, 54.78it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 140.27it/s]


[epoch 25] loss: 0.02243902982763983, lr: [0.0009963846760042848]


100%|██████████| 8/8 [00:00<00:00, 58.96it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 140.25it/s]


[epoch 26] loss: 0.021705632126746487, lr: [0.0008844441386535036]


100%|██████████| 8/8 [00:00<00:00, 57.08it/s]


--val acc: 0.9960238568588469


100%|██████████| 31/31 [00:00<00:00, 139.54it/s]


[epoch 27] loss: 0.018904429199683725, lr: [0.0007739568336085474]


100%|██████████| 8/8 [00:00<00:00, 55.77it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 139.76it/s]


[epoch 28] loss: 0.018403756084726816, lr: [0.000666312202836585]


100%|██████████| 8/8 [00:00<00:00, 57.51it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 142.21it/s]


[epoch 29] loss: 0.018035340961532212, lr: [0.0005628639400264434]


100%|██████████| 8/8 [00:00<00:00, 58.16it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 141.74it/s]


[epoch 30] loss: 0.017969393522585208, lr: [0.0004649129671050396]


100%|██████████| 8/8 [00:00<00:00, 57.66it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 140.72it/s]


[epoch 31] loss: 0.016999854051058565, lr: [0.00037369107438933876]


100%|██████████| 8/8 [00:00<00:00, 57.02it/s]


--val acc: 0.9980119284294234


100%|██████████| 31/31 [00:00<00:00, 136.85it/s]


[epoch 32] loss: 0.016724655491795706, lr: [0.0002903454301084161]


100%|██████████| 8/8 [00:00<00:00, 57.74it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 138.10it/s]


[epoch 33] loss: 0.017337346788662583, lr: [0.00021592415409737317]


100%|██████████| 8/8 [00:00<00:00, 58.85it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 141.28it/s]


[epoch 34] loss: 0.015983070396072236, lr: [0.00015136313708227894]


100%|██████████| 8/8 [00:00<00:00, 59.02it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 137.26it/s]


[epoch 35] loss: 0.01579858909791975, lr: [9.747427131127321e-05]


100%|██████████| 8/8 [00:00<00:00, 57.83it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 139.30it/s]


[epoch 36] loss: 0.016537746386741523, lr: [5.493524053847441e-05]


100%|██████████| 8/8 [00:00<00:00, 55.74it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 140.15it/s]


[epoch 37] loss: 0.015507089765510748, lr: [2.4280997757570817e-05]


100%|██████████| 8/8 [00:00<00:00, 55.64it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 141.43it/s]


[epoch 38] loss: 0.015662447314950364, lr: [5.897037857538436e-06]


100%|██████████| 8/8 [00:00<00:00, 57.17it/s]


--val acc: 1.0


100%|██████████| 31/31 [00:00<00:00, 138.32it/s]


[epoch 39] loss: 0.016058651665549966, lr: [1.4549800742174329e-08]


100%|██████████| 8/8 [00:00<00:00, 54.81it/s]

--val acc: 1.0





In [17]:
from sklearn.metrics import confusion_matrix

In [18]:
confusion_matrix(all_valid_label, all_valid_pred)

array([[110,   0,   0,   0,   0],
       [  0,  98,   0,   0,   0],
       [  0,   0, 100,   0,   0],
       [  0,   0,   0, 100,   0],
       [  0,   0,   0,   0,  95]])

### base

acc: 0.9620758483033932


```python
array([[100,   4,   2,   0,   0],
       [  0, 102,   3,   0,   0],
       [  0,  10, 106,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### label smoothing / 0.1

acc: 0.9481037924151696

```python
array([[101,   4,   1,   0,   0],
       [  1,  99,   5,   0,   0],
       [  0,  15, 101,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

### batchnorm

acc: 1.0

```python
array([[106,   0,   0,   0,   0],
       [  0, 105,   0,   0,   0],
       [  0,   0, 116,   0,   0],
       [  0,   0,   0,  84,   0],
       [  0,   0,   0,   0,  90]])
```

In [19]:
torch.onnx.export(simple_model, input_[:1], "simple_model.onnx", verbose=True, input_names=['input'], output_names=['output'])

graph(%input : Float(1, 42, strides=[42, 1], requires_grad=0, device=cuda:0),
      %0.weight : Float(20, 42, strides=[42, 1], requires_grad=1, device=cuda:0),
      %0.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.weight : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.bias : Float(20, strides=[1], requires_grad=1, device=cuda:0),
      %1.running_mean : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %1.running_var : Float(20, strides=[1], requires_grad=0, device=cuda:0),
      %3.weight : Float(10, 20, strides=[20, 1], requires_grad=1, device=cuda:0),
      %3.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.weight : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.bias : Float(10, strides=[1], requires_grad=1, device=cuda:0),
      %4.running_mean : Float(10, strides=[1], requires_grad=0, device=cuda:0),
      %4.running_var : Float(10, strides=[1], requires_grad=0, device=cuda:0),
    

In [20]:
import onnxruntime as ort
from tqdm import tqdm
import numpy as np

ort_session = ort.InferenceSession("simple_model.onnx")

for _ in tqdm(range(100)):
    outputs = ort_session.run(
        ['output'],
        {"input": np.random.randn(1, 42).astype(np.float32)},
    )
    

100%|██████████| 100/100 [00:00<00:00, 45669.69it/s]
