In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [2]:
with open('D:/dev/arknights_material_icon/derived/name2idx.json') as f:
    name2idx = json.load(f)
with open('D:/dev/arknights_material_icon/derived/idx2name.json') as f:
    idx2name = json.load(f)
NUM_CLASS = len(idx2name)

In [3]:
def get_data():
    images = []
    labels = []
    for i, (name, idx) in enumerate(name2idx.items()):
        if idx == NUM_CLASS - 1:
            name = 'not_support/' + name
        image = cv2.imread(f'D:/dev/arknights_material_icon/{name}.png', cv2.IMREAD_UNCHANGED)
        image = cv2.resize(image, (128, 128))

        image_aug = image

        aug_size = np.random.randint(115, 145)
        if aug_size > 128:
            pad = ((aug_size-128)//2, (aug_size - 128 - (aug_size-128)//2))
            image_aug = np.pad(image_aug, (pad, pad, (0,0)), 'constant')
        else:
            start = (128 - aug_size)//2
            image_aug = image_aug[start:start+aug_size, start:start+aug_size, :]
        image_aug = cv2.resize(image_aug, (128, 128))

        alpha_aug = image_aug[..., -1:]>50
        image_aug = image_aug[..., :-1]
        image_aug = image_aug / 255 * 2 - 1

        bg_noise = np.random.rand(*image_aug.shape) * 2 - 1
        image_aug = image_aug * alpha_aug + bg_noise * (1-alpha_aug)

        #plt.imshow(((image_aug+1)/2*255).astype(np.uint8)[..., ::-1])
        #plt.show()

        images.append(image_aug)
        labels.append(idx)
    images_np = np.transpose(np.stack(images, 0), [0, 3, 1, 2])
    labels_np = np.array(labels)
    #print(images_np.shape)
    return images_np, labels_np

In [4]:
def convbr(in_c, out_c, kernel_size, stride):
    return nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size, stride, kernel_size//2, bias=False), nn.BatchNorm2d(out_c), nn.ReLU())
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            convbr(3, 16, 3, 1),
            convbr(16, 32, 3, 2),    # 64
            convbr(32, 32, 3, 1),
            convbr(32, 64, 3, 2),   # 32
            convbr(64, 64, 3, 1),
            convbr(64, 128, 3, 2),   # 16
            convbr(128, 128, 3, 1),
            convbr(128, 256, 3, 2),  # 8
            convbr(256, 256, 3, 1),
            nn.Conv2d(256, NUM_CLASS, 3, 1, 1, bias=False)
        )
    def forward(self, x):
        out = self.model(x)
        out = out.mean((2, 3))
        return out
def compute_loss(x, label):
    loss = nn.CrossEntropyLoss()(x, label)
    prec = (x.argmax(1) == label).float().mean()
    return loss, prec

In [5]:
model = Model().cuda()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
for step in range(10000):
    images_aug_np, label_np = get_data()
    images_aug = torch.from_numpy(images_aug_np).float().cuda()
    label = torch.from_numpy(label_np).long().cuda()
    optim.zero_grad()
    score = model(images_aug)
    loss, prec = compute_loss(score, label)
    loss.backward()
    optim.step()
    if step < 10 or step % 10 == 0:
        print(step, loss.item(), prec.item())
    #plt.imshow(((np.transpose(images_aug_np[0], [1, 2, 0])+1)/2*255).astype(np.uint8)[..., ::-1])

0 4.410513877868652 0.010526316240429878
1 4.2713518142700195 0.13684211671352386
2 3.9859442710876465 0.13684211671352386
3 3.7597897052764893 0.1473684310913086
4 3.55029034614563 0.1473684310913086
5 3.3409135341644287 0.1473684310913086
6 3.1655161380767822 0.15789474546909332
7 3.010869026184082 0.1894736886024475
8 2.8437323570251465 0.2210526466369629
9 2.725564479827881 0.25263160467147827
10 2.591139316558838 0.2947368621826172
20 1.4505565166473389 0.6842105388641357
30 0.7725146412849426 0.8105263710021973
40 0.5047242045402527 0.8105263710021973
50 0.3780112862586975 0.8736842274665833
60 0.3318443298339844 0.8842105865478516
70 0.25141260027885437 0.936842143535614
80 0.15809616446495056 0.9789474010467529
90 0.1943139284849167 0.936842143535614
100 0.11455710232257843 0.9789474010467529
110 0.04278215765953064 1.0
120 0.016385745257139206 1.0
130 0.007537475321441889 1.0
140 0.0050498261116445065 1.0
150 0.003840336110442877 1.0
160 0.0030030100606381893 1.0
170 0.0024179

2500 1.9675806470331736e-05 1.0
2510 1.9234104911447503e-05 1.0
2520 1.9776192857534625e-05 1.0
2530 1.8471166185918264e-05 1.0
2540 1.9384684492251836e-05 1.0
2550 1.77483816514723e-05 1.0
2560 2.0167701222817414e-05 1.0
2570 1.8471166185918264e-05 1.0
2580 1.898313894344028e-05 1.0
2590 1.840089498728048e-05 1.0
2600 1.8461127183400095e-05 1.0
2610 1.77483816514723e-05 1.0
2620 1.7888922229758464e-05 1.0
2630 1.8531398382037878e-05 1.0
2640 1.841093398979865e-05 1.0
2650 1.772830364643596e-05 1.0
2660 1.7637956261751242e-05 1.0
2670 1.6985441106953658e-05 1.0
2680 1.686497671471443e-05 1.0
2690 1.7517491869512014e-05 1.0
2700 1.6694319128873758e-05 1.0
2710 1.619238537386991e-05 1.0
2720 1.7336795281153172e-05 1.0
2730 1.632288876862731e-05 1.0
2740 1.5800877008587122e-05 1.0
2750 1.6363042959710583e-05 1.0
2760 1.5590065231663175e-05 1.0
2770 1.7477335859439336e-05 1.0
2780 1.4797009498579428e-05 1.0
2790 1.6061883798101917e-05 1.0
2800 1.4897396795277018e-05 1.0
2810 1.594141940586

5120 3.051757857974735e-06 1.0
5130 2.991525661855121e-06 1.0
5140 3.413150125197717e-06 1.0
5150 2.871061724363244e-06 1.0
5160 3.051757857974735e-06 1.0
5170 3.1722220228402875e-06 1.0
5180 2.9011775950493757e-06 1.0
5190 2.9614097911689896e-06 1.0
5200 2.7505975594976917e-06 1.0
5210 2.9814871140843024e-06 1.0
5220 2.881100272134063e-06 1.0
5230 2.871061724363244e-06 1.0
5240 2.871061724363244e-06 1.0
5250 2.670288040462765e-06 1.0
5260 2.891139047278557e-06 1.0
5270 3.0015644369996153e-06 1.0
5280 2.9011775950493757e-06 1.0
5290 2.680326815607259e-06 1.0
5300 2.7405587843531976e-06 1.0
5310 2.8007909804728115e-06 1.0
5320 2.5699014258862007e-06 1.0
5330 2.650210717547452e-06 1.0
5340 2.680326815607259e-06 1.0
5350 2.7405587843531976e-06 1.0
5360 2.6100560717168264e-06 1.0
5370 2.81082952824363e-06 1.0
5380 2.5799399736570194e-06 1.0
5390 2.3691277419857215e-06 1.0
5400 2.5899787488015136e-06 1.0
5410 2.489591906851274e-06 1.0
5420 2.489591906851274e-06 1.0
5430 2.509669229766587e-0

7740 6.223979767128185e-07 1.0
7750 6.926687206032511e-07 1.0
7760 6.023205969540868e-07 1.0
7770 7.227847618196392e-07 1.0
7780 6.926687206032511e-07 1.0
7790 5.119725301483413e-07 1.0
7800 5.32049909907073e-07 1.0
7810 7.027073820609075e-07 1.0
7820 5.119725301483413e-07 1.0
7830 4.818564889319532e-07 1.0
7840 7.729781259513402e-07 1.0
7850 4.818564889319532e-07 1.0
7860 5.922819354964304e-07 1.0
7870 5.621658942800423e-07 1.0
7880 4.918951503896096e-07 1.0
7890 4.3166312480025226e-07 1.0
7900 4.6177913759493094e-07 1.0
7910 5.220111916059977e-07 1.0
7920 4.918951503896096e-07 1.0
7930 3.4131502957279736e-07 1.0
7940 4.718178274742968e-07 1.0
7950 3.915083937044983e-07 1.0
7960 5.019338686906849e-07 1.0
7970 5.722046125811175e-07 1.0
7980 5.019338686906849e-07 1.0
7990 4.6177913759493094e-07 1.0
8000 5.220111916059977e-07 1.0
8010 5.32049909907073e-07 1.0
8020 5.922819354964304e-07 1.0
8030 6.223979767128185e-07 1.0
8040 5.32049909907073e-07 1.0
8050 5.119725301483413e-07 1.0
8060 5.

In [7]:
torch.save(model.state_dict(), 'D:/dev/arknights_material_icon/derived/model.bin')