In [1]:
import numpy as np
import json
import pandas as pd
from tqdm import trange
import argparse
from modules.gem import GeM
# from modules.swin.swin_transformer import SwinTransformer
from modules.swin.build import build_model
from modules.swin.config import get_config
from modules.gswin import SwinFM
from utils.train_util import set_seed
from torch.utils.data import DataLoader
from datasets.dl import GeMData, GeMClass
from datasets.config import GeMConfig, SwinConfig
import torch
import os
from tqdm import tqdm
# import yaml
import gc
import torch.nn.functional as F
from matplotlib import pyplot as plt
from IPython.core.interactiveshell import InteractiveShell
from sklearn.metrics import roc_auc_score, accuracy_score
# InteractiveShell.ast_node_interactivity = "all"

In [2]:
class tmp_config:
    def __init__(self):
        self.cfg = './para/swin_large_patch4_window7_224_22k.yaml'
        self.opts = None
        self.batch_size = 32
        self.data_path = None
        self.zip = True
        self.cache_mode = 'part'
        self.pretrained = './para/swin_large_patch4_window7_224_22k.pth'
        self.resume = None
        self.accumulation_steps = None
        self.use_checkpoint = True
        self.amp_opt_level = 'O1'
        self.output = 'output'
        self.tag = None
        self.eval = True
        self.throughput = True
        self.local_rank=0
        
tc = tmp_config()

In [3]:
config = get_config(tc)
config

=> merge config from ./para/swin_large_patch4_window7_224_22k.yaml


CfgNode({'BASE': [''], 'DATA': CfgNode({'BATCH_SIZE': 32, 'DATA_PATH': '', 'DATASET': 'imagenet22K', 'IMG_SIZE': 224, 'INTERPOLATION': 'bicubic', 'ZIP_MODE': True, 'CACHE_MODE': 'part', 'PIN_MEMORY': True, 'NUM_WORKERS': 8}), 'MODEL': CfgNode({'TYPE': 'swin', 'NAME': 'swin_large_patch4_window7_224_22k', 'PRETRAINED': './para/swin_large_patch4_window7_224_22k.pth', 'RESUME': '', 'NUM_CLASSES': 1000, 'DROP_RATE': 0.0, 'DROP_PATH_RATE': 0.2, 'LABEL_SMOOTHING': 0.1, 'SWIN': CfgNode({'PATCH_SIZE': 4, 'IN_CHANS': 3, 'EMBED_DIM': 192, 'DEPTHS': [2, 2, 18, 2], 'NUM_HEADS': [6, 12, 24, 48], 'WINDOW_SIZE': 7, 'MLP_RATIO': 4.0, 'QKV_BIAS': True, 'QK_SCALE': None, 'APE': False, 'PATCH_NORM': True}), 'SWIN_MLP': CfgNode({'PATCH_SIZE': 4, 'IN_CHANS': 3, 'EMBED_DIM': 96, 'DEPTHS': [2, 2, 6, 2], 'NUM_HEADS': [3, 6, 12, 24], 'WINDOW_SIZE': 7, 'MLP_RATIO': 4.0, 'APE': False, 'PATCH_NORM': True})}), 'TRAIN': CfgNode({'START_EPOCH': 0, 'EPOCHS': 90, 'WARMUP_EPOCHS': 5, 'WEIGHT_DECAY': 0.05, 'BASE_LR': 0.0

In [4]:
st = build_model(config)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
def load_pretrained(config, model):
    print(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
    state_dict = checkpoint['model']

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete relative_coords_table since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del state_dict[k]

    # bicubic interpolate relative_position_bias_table if not match
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
    for k in relative_position_bias_table_keys:
        relative_position_bias_table_pretrained = state_dict[k]
        relative_position_bias_table_current = model.state_dict()[k]
        L1, nH1 = relative_position_bias_table_pretrained.size()
        L2, nH2 = relative_position_bias_table_current.size()
        if nH1 != nH2:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                # bicubic interpolate relative_position_bias_table if not match
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
                    mode='bicubic')
                state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)

    # bicubic interpolate absolute_pos_embed if not match
    absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
    for k in absolute_pos_embed_keys:
        # dpe
        absolute_pos_embed_pretrained = state_dict[k]
        absolute_pos_embed_current = model.state_dict()[k]
        _, L1, C1 = absolute_pos_embed_pretrained.size()
        _, L2, C2 = absolute_pos_embed_current.size()
        if C1 != C1:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
                state_dict[k] = absolute_pos_embed_pretrained_resized

    # check classifier, if not match, then re-init classifier to zero
    head_bias_pretrained = state_dict['head.bias']
    Nc1 = head_bias_pretrained.shape[0]
    Nc2 = model.head.bias.shape[0]
    if (Nc1 != Nc2):
        if Nc1 == 21841 and Nc2 == 1000:
            print("loading ImageNet-22K weight to ImageNet-1K ......")
            map22kto1k_path = f'para/map22kto1k.txt'
            with open(map22kto1k_path) as f:
                map22kto1k = f.readlines()
            map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
            state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
            state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
        else:
            torch.nn.init.constant_(model.head.bias, 0.)
            torch.nn.init.constant_(model.head.weight, 0.)
            del state_dict['head.weight']
            del state_dict['head.bias']
            print(f"Error in loading classifier head, re-init classifier head to 0")

    msg = model.load_state_dict(state_dict, strict=False)
    print(msg)

    print(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")

    del checkpoint

In [6]:
load_pretrained(config, st)

loading ImageNet-22K weight to ImageNet-1K ......
_IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.relative_position_index', 'layers.0.blocks.1.attn_mask', 'layers.0.blocks.1.attn.relative_position_index', 'layers.1.blocks.0.attn.relative_position_index', 'layers.1.blocks.1.attn_mask', 'layers.1.blocks.1.attn.relative_position_index', 'layers.2.blocks.0.attn.relative_position_index', 'layers.2.blocks.1.attn_mask', 'layers.2.blocks.1.attn.relative_position_index', 'layers.2.blocks.2.attn.relative_position_index', 'layers.2.blocks.3.attn_mask', 'layers.2.blocks.3.attn.relative_position_index', 'layers.2.blocks.4.attn.relative_position_index', 'layers.2.blocks.5.attn_mask', 'layers.2.blocks.5.attn.relative_position_index', 'layers.2.blocks.6.attn.relative_position_index', 'layers.2.blocks.7.attn_mask', 'layers.2.blocks.7.attn.relative_position_index', 'layers.2.blocks.8.attn.relative_position_index', 'layers.2.blocks.9.attn_mask', 'layers.2.blocks.9.attn.relative_position_index', '

In [2]:
mi = GeMConfig('cifar100')
mi.set_arch('swin')
model = GeM(mi)
# model = SwinFM()

=> merge config from ./swin_para/swin_large_patch4_window7_224_22k.yaml


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


loading ImageNet-22K weight to ImageNet-1K ......
_IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.relative_position_index', 'layers.0.blocks.1.attn_mask', 'layers.0.blocks.1.attn.relative_position_index', 'layers.1.blocks.0.attn.relative_position_index', 'layers.1.blocks.1.attn_mask', 'layers.1.blocks.1.attn.relative_position_index', 'layers.2.blocks.0.attn.relative_position_index', 'layers.2.blocks.1.attn_mask', 'layers.2.blocks.1.attn.relative_position_index', 'layers.2.blocks.2.attn.relative_position_index', 'layers.2.blocks.3.attn_mask', 'layers.2.blocks.3.attn.relative_position_index', 'layers.2.blocks.4.attn.relative_position_index', 'layers.2.blocks.5.attn_mask', 'layers.2.blocks.5.attn.relative_position_index', 'layers.2.blocks.6.attn.relative_position_index', 'layers.2.blocks.7.attn_mask', 'layers.2.blocks.7.attn.relative_position_index', 'layers.2.blocks.8.attn.relative_position_index', 'layers.2.blocks.9.attn_mask', 'layers.2.blocks.9.attn.relative_position_index', '

In [3]:
# pic_matrix = torch.ByteTensor(np.load("data/imageset_small.npy"))
pic_matrix = torch.ByteTensor(np.random.randint(low=0, high=255, size=(256, 3, 224, 224)))
dataset_img = torch.LongTensor(np.random.randint(low=0, high=256, size=(1280, 1)))
dataset_label = torch.LongTensor(np.random.randint(low=0, high=100, size=(1280, 1)))

In [4]:
train_dataset = GeMClass(pic_matrix, torch.cat([dataset_img, dataset_label], dim=-1))
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# input_label = torch.zeros((32), dtype=torch.long).to(0)
input_label = torch.randint(low=0, high=100, size=(32, 1)).reshape(-1)

In [5]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=5e-4)

In [7]:
for epoch in range(2):
    model.train()
    model.zero_grad()
    index = 0
    steps_one_epoch = len(train_data_loader)
    enum_dataloader = tqdm(train_data_loader, total=steps_one_epoch, desc="EP-{} train".format(epoch))
    loss_list = []
    for data in enum_dataloader:
    #     if index >= steps_one_epoch:
    #         break

        model_in = data[:, :-1] / 255.0
#         pred = model.st(model_in.reshape(-1, 3, 224, 224))
#         print(pred.size())
        pred = model.predict_class(model_in, 224, scale=1, encoder='gem')
        loss = F.cross_entropy(pred, data[:, -1])
        loss_list.append(loss)

        loss.backward()
        optimizer.step()
        model.zero_grad()

        enum_dataloader.set_description("EP-{} train loss: {}".format(epoch, loss))
        enum_dataloader.refresh()
        index += 1
    
    print('epoch {} end'.format(epoch))


EP-0 train:   0%|                                                                                | 0/40 [00:00<?, ?it/s]

torch.Size([32, 1536, 7, 7])


EP-0 train loss: nan:   2%|█▌                                                            | 1/40 [00:52<34:00, 52.33s/it]


KeyboardInterrupt: 