## Usage

1. 해당 분류의 img path를 img_root에 넣는다.
2. 몇 개의 class만 확인하고 싶은 경우 classes에 확인 하고 싶은 소분류 class를 넣는다.
3. imgs에 너무 많은 image가 들어간 경우 random.choices(imgs, k=?)로 대충 sampling해서 돌린다.
4. batch size는 16이상으로 지정하는 것이 좋다.
5. epoch은 20~30만해도 될 것 같다.

In [1]:
import os
from os.path import splitext, join
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm

img_root = r'/opt/ml/data/stew/train'
# label_root = r'/opt/ml/data/labels/train'

# for img_path in os.listdir(img_root):
#     imgs.append(join(img_root, img_path))

classes = [
# '02011019',
# '02011038',
# '02011039',
# '04011001',
# '04011002',
# '04011003',
# '04011004',
# '04011005',
# '04011006',
# '04011007',
# '04011008',
# '04011010',
# '04011011',
# '04011012',
# '04011013',
# '04011014',
# '04011015',
# '04011016',
# '04012001',
# '04012002',
# '04012003',
# '04012004',
# '04012005',
# '04012006',
# '04012007',
# '04012008',
# '04012009',
# '04012010',
# '04012011',
# '04012012',
# '04012013',
# '04013002',
# '04013003',
# '04013004',
# '04013005',
# '04013006',
# '04013007',
# '04013008',
# '04013009',
# '04013010',
# '04013011',
# '04013012',
# '04013013',
# '04013014',
# '04013015',
# '04013017',
# '04013018',
# '04013019',
# '04013020',
# '04013021',
# '04013022',
# '04013023',
# '04013024',
# '04014001',
'04015001',
'04015002',
'04015003',
'04016001',
'04017001',
'04017002',
'04018001',
'04018002',
'04018003',
'04018004',
'04019001',
'04019002',
'04019003',
'04019004',
'04019005',
'04019006',
'04019007',
'04019008'
]

imgs = []
labels = []

# -- label data 2 dataframe & xywh 2 xyxy
df = pd.DataFrame([], columns=['class', 'mean1', 'mean2', 'mean3', 'std1', 'std2', 'std3'])

img_dirs = os.listdir(img_root)
for img_dir in tqdm(img_dirs):
    for img_file in os.listdir(join(img_root, img_dir)):
        filename = splitext(img_file)[0]

        small_label = filename.split('_')[2]
        if small_label not in classes:
            continue
        labels.append(int(small_label))
        imgs.append(np.array(Image.open(join(img_root, img_dir, img_file))))

100%|██████████| 72/72 [00:42<00:00,  1.68it/s]


In [2]:
print(len(imgs))
print(len(labels))
num_classes = len(np.unique(labels))
print(num_classes)

28725
28725
18


In [3]:
import torch
import random
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import *

class TripletDataset(Dataset):
    def __init__(self, img_paths, labels, train=True):
        self.img_paths = np.array(img_paths)
        self.labels = np.array(labels)
        self.transform = transforms.Compose([
            ToTensor(),
            Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
        ])
        self.train = train

    def __getitem__(self, idx: torch.Tensor) -> torch.Tensor:
        anchor_img, anchor_label = self.img_paths[idx], self.labels[idx]
        if self.transform:
            anchor_img = self.transform(anchor_img)

        if self.train:
            positive_idx = np.where(self.labels == anchor_label)[0]
            positive_idx = [p_idx for p_idx in positive_idx if p_idx != idx] # anchor index 제외
            negative_idx = np.where(self.labels != anchor_label)[0]

            positive_img = self.img_paths[random.choice(positive_idx)]
            negative_img = self.img_paths[random.choice(negative_idx)]

            if self.transform:
                positive_img = self.transform(positive_img)
                negative_img = self.transform(negative_img)
            
            return anchor_img, positive_img, negative_img, anchor_label

        return anchor_img, anchor_label

    def __len__(self):
        return len(self.labels)

In [4]:
import torch
import torch.nn as nn


class TripletLoss(nn.Module):
    '''
    Compute normal triplet loss or soft margin triplet loss given triplets
    '''
    def __init__(self, margin = None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin = margin, p = 2)

    def forward(self, anchor, pos, neg):
        if self.margin is None:
            num_samples = anchor.shape[0]
            y = torch.ones((num_samples, 1)).view(-1)
            if anchor.is_cuda: y = y.cuda()
            ap_dist = torch.norm(anchor - pos, 2, dim = 1).view(-1)
            an_dist = torch.norm(anchor - neg, 2, dim = 1).view(-1)
            loss = self.Loss(an_dist - ap_dist, y)
        else:
            loss = self.Loss(anchor, pos, neg)

        return loss

In [5]:
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import multiprocessing
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

train_set = TripletDataset(imgs, labels)
dataloader = DataLoader(train_set,
                        batch_size=16,
                        num_workers=multiprocessing.cpu_count()//2,
                        shuffle=False,
                        pin_memory=use_cuda)
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=num_classes)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-4)
criterion = TripletLoss(margin=2.0)
NUM_EPOCH = 30

Loaded pretrained weights for efficientnet-b0


In [6]:
model = model.to(device)

for epoch in range(NUM_EPOCH):
    losses = []
    
    model.train()
    for idx, (anchor_img, positive_img, negative_img, anchor_label) in enumerate(tqdm(dataloader)):
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)
        anchor_label = anchor_label.to(device)

        optimizer.zero_grad()

        a_embeds = model(anchor_img)
        p_embeds = model(positive_img)
        n_embeds = model(negative_img)

        loss = criterion(a_embeds, p_embeds, n_embeds)

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        losses.append(loss.item())
    print('loss: ', loss.item(), 'epoch: ', epoch)

100%|██████████| 1796/1796 [09:08<00:00,  3.27it/s]

loss:  0.0 epoch:  0





In [7]:
train_results = []
labels = []

model.eval()

with torch.no_grad():
    for img, _, _, label in tqdm(dataloader):
        img = img.to(device)
        label = label.to(device)
        train_results.append(model(img).detach().cpu().numpy())
        labels.append(label.detach().cpu().numpy())

train_results = np.concatenate(train_results)
labels = np.concatenate(labels)

100%|██████████| 1796/1796 [00:55<00:00, 32.53it/s]


In [8]:
from matplotlib import colors as mcolors
import random
colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
color_key = list(colors.keys())

use_colors = [colors[color_key[i]] for i in range(num_classes)]

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "vscode"

import numpy as np
data = []
for idx, label in enumerate(np.unique(labels)[:]):
    tmp = train_results[np.where(labels == label)[0]]
    data.append(go.Scatter3d(x=tmp[:, 1], y=tmp[:, 0], z=tmp[:, 2], mode='markers', marker=dict(size=2), name=str(label)))
fig = go.Figure(data=data)
fig.update_layout(
    autosize=False,
    width=1000,
    height=1000,
    )
fig.show()