In [None]:
!pip install ppim

In [None]:
from ppim import deit_b_distilled_384
import paddle
import numpy as np
import paddle.nn as nn
from paddle.io import DataLoader
from paddle.vision.datasets import Cifar100
import paddle.vision.transforms as T
import paddle.nn.functional as F
import matplotlib.pyplot as plt
from utils import show_imgs
import os

In [None]:
BATCH_SIZE = 24
DROP_RATIO = 0.3
LEARNING_RATE = 3e-4
EPOCH_NUM = 100
ALPHA = 5
PARAM_PATH = "./deit.pdparams"

In [None]:
deit = deit_b_distilled_384(pretrained=True)
train_transforms = T.Compose([
    T.Resize(384, interpolation='bicubic'),
    T.CenterCrop(384),
    T.RandomHorizontalFlip(0.5),
    T.RandomRotation(15),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    np.array
])
val_transforms = T.Compose([
    T.Resize(384, interpolation='bicubic'),
    T.CenterCrop(384),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    np.array
])

In [None]:
train_dataset = Cifar100(mode="train", transform=train_transforms, backend="pil")
val_dataset = Cifar100(mode="test", transform=val_transforms, backend="pil")
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
data, label = next(train_dataloader())
show_imgs([data[0], data[1],data[2]])

In [None]:
def reader(dataset, mode = "train"):
    if mode == "train":
        parti = np.arange(0, 1000)
    else:
        parti = np.arange(2000, 2100)
    for k in range(len(parti) // BATCH_SIZE):
        idxs = np.random.choice(parti, BATCH_SIZE)
        data = []
        labels = []
        for i in range(BATCH_SIZE):
            img, label = dataset[idxs[i]]
            data.append(img)
            labels.append(label)
        data = paddle.to_tensor(data)
        labels = paddle.to_tensor(labels)
        yield data, labels

In [None]:
class ViT(nn.Layer):
    def __init__(self, deit):
        super(ViT, self).__init__()
        self.deit = deit
        self.nets = nn.Sequential(
            self.deit,
            nn.Linear(1000, 512),
            nn.Tanh(),
            nn.Dropout(DROP_RATIO),
            nn.Linear(512, 100),
            nn.Dropout(DROP_RATIO)
        )
    
    def forward(self, x):
        return self.nets(x)

x = paddle.rand([1,3,384,384])
model = ViT(deit)
if os.path.exists(PARAM_PATH):
    state_dict = paddle.load(PARAM_PATH)
    model.load_dict(state_dict)
    print("load params")
y = model(x)
print(y.shape)

load params
[1, 100]


In [7]:
'''
开局使用AdamW加速收敛
当准确率达到一定程度后切换Momentum调优
'''
# optimizer = paddle.optimizer.AdamW(learning_rate=LEARNING_RATE, parameters=model.parameters())
scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=LEARNING_RATE,decay_steps=20,end_lr=LEARNING_RATE/10)
optimizer = paddle.optimizer.Momentum(learning_rate=scheduler,parameters=model.parameters(),weight_decay=1e-2)

# 交叉熵损失
ce_loss = paddle.nn.CrossEntropyLoss()
accuracy = paddle.metric.Accuracy()
# 最优准确率
max_score = 0.
# 最优准确率对应轮数
ex_epoch = 0
for epoch in range(EPOCH_NUM):
    model.train()
    for i, (data, label) in enumerate(train_dataloader()):
        summary = []
        # 前向传播两次
        label_hat_A = model(data)
        label_hat_B = model(data)
        # cross entropy loss
        CE_loss = ce_loss(label_hat_A, label) + ce_loss(label_hat_B, label)
        # KL divergence loss
        KL_loss = 0.5 * (F.kl_div(F.softmax(label_hat_A, axis=-1), F.softmax(label_hat_B, axis=-1)) + \
                         F.kl_div(F.softmax(label_hat_B, axis=-1), F.softmax(label_hat_A, axis=-1)))
        # 损失加权求和
        loss = CE_loss + ALPHA * KL_loss
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 清除梯度
        optimizer.clear_gradients()
        if i % 30 == 0:
            print("[train]epoch:%d,i:%d,loss:%f" % (epoch, i, loss))

    model.eval()
    with paddle.no_grad():
        for j, (eval_data, eval_label) in enumerate(val_dataloader()):
            summary = []
            eval_label_hat = model(eval_data)
            eval_indexs = eval_label_hat.argmax(-1)
            eval_loss = ce_loss(eval_label_hat, eval_label)
            correct = accuracy.compute(eval_label_hat, eval_label)
            accuracy.update(correct)
            acc = accuracy.accumulate()
            summary.append(acc)
            accuracy.reset()

    print("[eval]epoch:%d,loss:%f,acc:%f" % (epoch, eval_loss, sum(summary) / len(summary)))
    if sum(summary) / len(summary) >= max_score:
        max_score = sum(summary) / len(summary)
        ex_epoch = epoch
        paddle.save(model.state_dict(), "./deit_.pdparams")
        print("[eval]saved params")
    print("[eval]ex_epoch:%d,best acc:%f" % (ex_epoch, max_score))

[train]epoch:0,i:0,loss:2.511044
[train]epoch:0,i:30,loss:2.172033
[train]epoch:0,i:60,loss:1.895260
[train]epoch:0,i:90,loss:2.801440
[train]epoch:0,i:120,loss:1.813140
[train]epoch:0,i:150,loss:1.573447
[train]epoch:0,i:180,loss:2.255599
[train]epoch:0,i:210,loss:2.508465
[train]epoch:0,i:240,loss:2.730235
[train]epoch:0,i:270,loss:2.229047
[train]epoch:0,i:300,loss:3.265573
[train]epoch:0,i:330,loss:2.896493
[train]epoch:0,i:360,loss:2.090678
[train]epoch:0,i:390,loss:2.520741
[train]epoch:0,i:420,loss:1.966374
[train]epoch:0,i:450,loss:1.975958
[train]epoch:0,i:480,loss:2.574533
[train]epoch:0,i:510,loss:2.763706
[train]epoch:0,i:540,loss:2.167204
[train]epoch:0,i:570,loss:1.991791
[train]epoch:0,i:600,loss:2.480460
[train]epoch:0,i:630,loss:2.614878
[train]epoch:0,i:660,loss:2.481817
[train]epoch:0,i:690,loss:2.692974
[train]epoch:0,i:720,loss:2.776292
[train]epoch:0,i:750,loss:3.166075
[train]epoch:0,i:780,loss:1.514883
[train]epoch:0,i:810,loss:2.378160
[train]epoch:0,i:840,loss

KeyboardInterrupt: 