In [None]:
import glob
import os
import time
import warnings
import copy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from sklearn.metrics import classification_report

from torch.backends import cudnn
from torch.backends import cuda
from torch.cuda import amp

import timm
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm import create_model

import wandb
from yacs.config import CfgNode as CN

import albumentations as A
from albumentations.pytorch import ToTensorV2
import shutil
import pickle

# Config

In [None]:
cfg = CN()

cfg.PROJECT = 'AIdea fall competition'
cfg.NAME = 'ConvNeXt_base'

cfg.SEED = 1224
cfg.USE_WANDB = False
cfg.WANDB_CONT = False
cfg.WANDB_ARTIFACT = False

cfg.SAVE_DIR = 'ConvNeXt_base' #checkpoint and best weight's output dir

cfg.DATA = CN()
cfg.DATA.DATASET = 'AIdea fall competition'
cfg.DATA.TRAIN_DIR = '../dataset/train' #train data dir
cfg.DATA.PUBLIC_TEST_DIR = '../dataset/public_test' # public test data dir
cfg.DATA.PRIVATE_TEST_DIR = '../dataset/p_test' # private test data dir
cfg.DATA.PSEUDO_DIR = '../dataset/pseudo_data' #None for no pseudo label
cfg.DATA.TARGET_CSV = '../dataset/target_merge.csv' #with public private and train coordinate data's csv
cfg.DATA.BATCH_SIZE = 8 #depend on gpu memory size
cfg.DATA.IMG_SIZE = 768
cfg.DATA.PIN_MEMORY = True
cfg.DATA.NUM_WORKERS = 8 #depend on cpu core count
cfg.DATA.CROP = True
cfg.KFOLD_NUM = 0

cfg.AUG = CN()
cfg.AUG.MIXUP_ALPHA = 0.8
cfg.AUG.CUTMIX_ALPHA = 1.0
cfg.AUG.CUTMIX_MINMAX = None
cfg.AUG.MIXUP_PROB = 0.5
cfg.AUG.MIXUP_SWITCH_PROB = 0.5
cfg.AUG.MIXUP_MODE = 'batch'
cfg.AUG.LABEL_SMOOTHING = 0.1

cfg.MODEL = CN()
cfg.MODEL.BACKBONE = 'convnext_base_in22k'
cfg.MODEL.NUM_CLASSES = 33
cfg.MODEL.PRETRAINED = True
cfg.MODEL.USE_EMA = True

cfg.TRAIN = CN()
cfg.TRAIN.EPOCHS = 10 #pseudo train: 5
cfg.TRAIN.LR = 1e-4 #pseudo train: 5e-5
cfg.TRAIN.ACCUMULATION_STEPS = 8
cfg.TRAIN.AMP_ENABLE = True

cfg.TRAIN.PRETRAINED = None #pseudo train: choose the best acc@1's checkpoint in first train and assign it
cfg.TRAIN.DEL_PRETRAINED = None
cfg.TRAIN.USE_CHECKPOINT = False
cfg.TRAIN.NEW_RUN = False

cfg.TRAIN.OPTIMIZER = CN()
cfg.TRAIN.OPTIMIZER.NAME = 'AdamW' #SGD #Adam
cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY = 1e-8
cfg.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) #AdamW only
#cfg.TRAIN.OPTIMIZER.MOMENTUM = 0.9 #SGD only
#cfg.TRAIN.OPTIMIZER.NESTEROV = True #SGD only

cfg.TRAIN.LR_SCHEDULER = CN()
cfg.TRAIN.LR_SCHEDULER.NAME = 'CosineAnnealingLR'
cfg.TRAIN.LR_SCHEDULER.T_MAX = 10 #pseudo train: 5
cfg.TRAIN.LR_SCHEDULER.ETA_MIN = 1e-6

cfg.VALID = CN()
cfg.VALID.MODEL_EMA = True

# Initialize

In [None]:
torch.cuda.set_device(0)
amp_enable = cfg.TRAIN.AMP_ENABLE

cudnn.enabled = True
cudnn.benchmark = True 

cuda.matmul.allow_tf32 = True
cudnn.allow_tf32 = True

In [None]:
seed = cfg.SEED

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

warnings.simplefilter("ignore")

In [None]:
if cfg.USE_WANDB:
    wandb.init(project=cfg.PROJECT, name=cfg.NAME, config=cfg, resume=cfg.WANDB_CONT)

# Load Data

In [None]:
folderlist = ['asparagus', 'bambooshoots', 'betel', 'broccoli', 'cauliflower', 'chinesecabbage', 'chinesechives', 'custardapple', 'grape', 'greenhouse', 'greenonion', 'kale', 'lemon', 'lettuce', 'litchi', 'longan', 'loofah', 'mango', 'onion', 'others', 'papaya', 'passionfruit', 'pear', 'pennisetum', 'redbeans', 'roseapple', 'sesbania', 'soybeans', 'sunhemp', 'sweetpotato', 'taro', 'tea', 'waterbamboo']
class_list = sorted(folderlist)

target_df = pd.read_csv(cfg.DATA.TARGET_CSV)

train_images = []
for folder in folderlist:
    train_images += glob.glob(os.path.join('{}/{}'.format(cfg.DATA.TRAIN_DIR, folder), "*"))

train_images = sorted(train_images)
train_labels = [os.path.split(path)[0].split('/')[-1] for path in train_images]
    
print(f'Size of dataset: {len(train_images)}')

In [None]:
if cfg.DATA.PSEUDO_DIR is not None:
    pseudo_images = []

    for folder in folderlist:
        pseudo_images += glob.glob(os.path.join('{}/{}'.format(cfg.DATA.PSEUDO_DIR, folder), "*"))

In [None]:
train_folds = []
val_folds = []

sfolder = StratifiedKFold(n_splits=25, random_state=cfg.SEED, shuffle=True)

for train_idx, val_idx in sfolder.split(train_images, train_labels):
    train_folds.append(train_idx)
    val_folds.append(val_idx)

train_list = [train_images[idx] for idx in train_folds[cfg.KFOLD_NUM]]
valid_list = [train_images[idx] for idx in val_folds[cfg.KFOLD_NUM]]

train_list = train_list + pseudo_images

print('Number of training Data: {}'.format(len(train_list)))
print('Number of validation Data: {}'.format(len(valid_list)))

# Data augmentation

In [None]:
train_transform =  A.Compose([
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20, border_mode=0, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Resize(cfg.DATA.IMG_SIZE, cfg.DATA.IMG_SIZE),
    A.Cutout(max_h_size=int(cfg.DATA.IMG_SIZE * 0.25), max_w_size=int(cfg.DATA.IMG_SIZE * 0.25), num_holes=1, p=0.25),
    A.Normalize(),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Resize(cfg.DATA.IMG_SIZE, cfg.DATA.IMG_SIZE),
    A.Normalize(),
    ToTensorV2(),
])

In [None]:
mixup_fn = Mixup(
    mixup_alpha=cfg.AUG.MIXUP_ALPHA,
    cutmix_alpha=cfg.AUG.CUTMIX_ALPHA, 
    cutmix_minmax=cfg.AUG.CUTMIX_MINMAX,
    prob=cfg.AUG.MIXUP_PROB, 
    switch_prob=cfg.AUG.MIXUP_SWITCH_PROB, 
    mode=cfg.AUG.MIXUP_MODE,
    label_smoothing=cfg.AUG.LABEL_SMOOTHING, 
    num_classes=cfg.MODEL.NUM_CLASSES
)

# Load Datasets

In [None]:
class AIdeaDataset(Dataset):
    def __init__(self, file_list, class_list, transform, crop=False):
        self.file_list = file_list
        self.class_list = class_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img = Image.open(self.file_list[idx]).copy()
        W, H = img.size
        
        img = np.array(img)

        if cfg.DATA.CROP:
            target_x = float(target_df.loc[target_df['filename'] == os.path.split(self.file_list[idx])[1]]['target_x'].to_string(index=False))
            target_y = float(target_df.loc[target_df['filename'] == os.path.split(self.file_list[idx])[1]]['target_y'].to_string(index=False))
            
            if W > H:
                if W/2 - H/2 + target_x * W < 0:
                    x_min = 0
                elif W/2 - H/2 + target_x * W > W - H:
                    x_min = int(W - H)
                else:
                    x_min = int(W/2 - H/2 + target_x * W)
                x_max = x_min + H
                y_min = 0
                y_max = H
            else:
                x_min = 0
                x_max = W
                if H/2 - W/2 + target_y * H < 0:
                    y_min = 0
                elif H/2 - W/2 + target_y * H > H - W:
                    y_min = int(H - W)
                else:
                    y_min = int(H/2 - W/2 + target_y * H)
                y_max = y_min + W
            
            img_crop = A.Crop(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max)(image=img)['image']
        else:
            img_crop = img
        
        img_transformed = self.transform(image=img_crop)
        img_transformed = img_transformed['image']
        
        label_name = os.path.split(self.file_list[idx])[0].split('/')[-1]
        label = self.class_list.index(label_name)
        
        return img_transformed, label

In [None]:
train_data = AIdeaDataset(train_list, class_list, transform=train_transform)
valid_data = AIdeaDataset(valid_list, class_list, transform=valid_transform)
query_data = AIdeaDataset(train_list, class_list, transform=valid_transform)

In [None]:
train_loader = DataLoader(
    dataset=train_data,
    batch_size=cfg.DATA.BATCH_SIZE,
    shuffle=True,
    pin_memory=cfg.DATA.PIN_MEMORY,
    num_workers=cfg.DATA.NUM_WORKERS,
    drop_last=True
)

valid_loader = DataLoader(
    dataset=valid_data,
    batch_size=cfg.DATA.BATCH_SIZE,
    shuffle=False,
    pin_memory=cfg.DATA.PIN_MEMORY,
    num_workers=cfg.DATA.NUM_WORKERS,
    drop_last=False
)

query_loader = DataLoader(
    dataset=query_data,
    batch_size=cfg.DATA.BATCH_SIZE,
    shuffle=False,
    pin_memory=cfg.DATA.PIN_MEMORY,
    num_workers=cfg.DATA.NUM_WORKERS,
    drop_last=False
)

# Model

In [None]:
from torch.nn.parameter import Parameter

def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM,self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1)*p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        ret = gem(x, p=self.p, eps=self.eps)   
        return ret
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.n_classes = cfg.MODEL.NUM_CLASSES
        
        self.backbone = timm.create_model(cfg.MODEL.BACKBONE, 
                                          pretrained=True, 
                                          num_classes=33,
                                          drop_path_rate=0.2,
                                          head_init_scale=0.001)
        
        self.backbone.head.global_pool = GeM(p_trainable=True)

    def forward(self, img): 
        x = self.backbone.forward_features(img)
        
        logits = self.backbone.forward_head(x)
        x_emb = self.backbone.forward_head(x, pre_logits=True)
        
        return logits, x_emb

In [None]:
model = Net()
model.cuda()
if cfg.MODEL.USE_EMA:
    model_ema = timm.utils.ModelEmaV2(model, decay=0.999, device='cpu')

# Update function

In [None]:
# loss function
criterion_STCE = SoftTargetCrossEntropy()
criterion_CE = nn.CrossEntropyLoss()

# optimizer
if cfg.TRAIN.OPTIMIZER.NAME is 'Adam':
    optimizer = optim.Adam(
        model.parameters(), 
        lr=cfg.TRAIN.LR,
        weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
    )
elif cfg.TRAIN.OPTIMIZER.NAME is 'SGD':
    optimizer = optim.SGD(
        model.parameters(), 
        lr=cfg.TRAIN.LR,
        nesterov=cfg.TRAIN.OPTIMIZER.NESTEROV,
        momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
        weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
    )
elif cfg.TRAIN.OPTIMIZER.NAME is 'AdamW':
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=cfg.TRAIN.LR,
        weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
        betas=cfg.TRAIN.OPTIMIZER.BETAS,
    )
else:
    print('WARNING: No optimizer chosen')

# scheduler
if cfg.TRAIN.LR_SCHEDULER.NAME is 'CosineAnnealingLR':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer,
        T_max=cfg.TRAIN.LR_SCHEDULER.T_MAX,
        eta_min=cfg.TRAIN.LR_SCHEDULER.ETA_MIN,
    )
elif cfg.TRAIN.LR_SCHEDULER.NAME is 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
    )
    
    
# gradscaler
scaler = amp.GradScaler(enabled=amp_enable)

# Load model

In [None]:
def load_pretrained(model, pretrained_dir=None):
    pretrained = torch.load(pretrained_dir, map_location='cpu')

    if cfg.TRAIN.DEL_PRETRAINED is not None:
        del_list = [word for word in pretrained['model_state_dict'].keys() if cfg.MODEL.DEL_PRETRAINED  in word]
        for key in del_list:
            del pretrained['model_state_dict'][key]
        strict = False
    else:
        strict = True
    
    model.load_state_dict(pretrained['model_state_dict'], strict=strict)
    print('Loaded pretrained: {}\n'.format(pretrained_dir))

def load_checkpoint(model, optimizer=None, scheduler=None, specific_checkpoint=None):
    last_epochs = 0
    best_acc = 0.0
    save_dir = cfg.SAVE_DIR
    
    checkpoint_list = []
    
    checkpoint_files = os.listdir('{}/checkpoint'.format(save_dir))
    checkpoint_list = [os.path.join('{}/checkpoint'.format(save_dir), name) for name in checkpoint_files]
    
    if checkpoint_list != [] and cfg.TRAIN.USE_CHECKPOINT:
        if specific_checkpoint is not None:
            checkpoint_path = os.path.join('{}/checkpoint'.format(save_dir), specific_checkpoint)
        else:
            checkpoint_path = max(checkpoint_list, key=os.path.getctime)
            
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if not cfg.TRAIN.NEW_RUN:
            last_epochs = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            except:
                pass
            try:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            except:
                pass
            try:
                scaler.load_state_dict(checkpoint['scaler_state_dict'])
            except:
                pass
            try:
                model_ema.load_state_dict(checkpoint['model_ema_state_dict'])
            except:
                pass
        
        print('Loaded checkpoint: {}\n'.format(checkpoint_path))
        
    return last_epochs, best_acc

def load_data(model, optimizer=None, scheduler=None, specific_checkpoint=None):
    last_epochs = 0
    best_acc = 0.0
    
    try:
        save_dir = cfg.SAVE_DIR
        os.makedirs('{}/checkpoint'.format(save_dir))
        os.makedirs('{}/best_acc'.format(save_dir))
    except:
        pass
    
    if cfg.TRAIN.PRETRAINED is not None:
        load_pretrained(model, pretrained_dir=cfg.TRAIN.PRETRAINED)
    elif cfg.TRAIN.USE_CHECKPOINT:
        last_epochs, best_acc = load_checkpoint(model, optimizer, scheduler)

    return last_epochs, best_acc

# Train

In [None]:
def train(model, optimizer, scheduler):
    since = time.time()
    
    gradient_accumulation = cfg.TRAIN.ACCUMULATION_STEPS
    save_dir = cfg.SAVE_DIR
    last_epochs = 0
    best_acc = 0.0
    
    last_epochs, best_acc = load_data(model, optimizer, scheduler)
    
    if cfg.USE_WANDB:
        wandb.watch(model, criterion=criterion_STCE, log='all', log_freq=500)
    
    print('Training start from {} epoch\n'.format(last_epochs + 1))
    
    for epoch in range(last_epochs, cfg.TRAIN.EPOCHS):
        print('train start - epoch : {} - lr : {:.4e}\n'.format(epoch + 1, optimizer.param_groups[0]['lr']))

        #train
        epoch_loss = 0
        
        model.train()
        
        for idx, (datas, labels) in enumerate(tqdm(train_loader)):
            datas = datas.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            
            datas, labels = mixup_fn(datas, labels)
            
            with amp.autocast(enabled=amp_enable):
                output,_ = model(datas)
                
                loss = criterion_STCE(output, labels)
                loss = loss / gradient_accumulation
            
            scaler.scale(loss).backward()
            
            if (idx + 1) % gradient_accumulation == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                if cfg.MODEL.USE_EMA:
                    model_ema.update(model)

            epoch_loss += loss.item() * gradient_accumulation / len(train_loader)
                
            #wandb
            if cfg.USE_WANDB:
                wandb.log({
                    'iteration': idx + epoch * len(train_loader),
                    'train loss': loss.item() * gradient_accumulation
                })
        
        print('train end - epoch : {} - loss : {:.4f} - lr : {:.4e}\n'.format(epoch + 1, epoch_loss, optimizer.param_groups[0]['lr']))
        
        #evaluate
        epoch_val_top1 = 0
        epoch_val_top2 = 0
        epoch_val_top3 = 0
        epoch_val_top4 = 0
        epoch_val_top5 = 0

        epoch_val_loss = 0

        pred_list = []
        label_list = []
        
        if cfg.VALID.MODEL_EMA:
            model_ema.cuda()
            model_ema.eval()
        else:
            model.eval()
        
        with torch.no_grad():
            for idx, (datas, labels) in enumerate(tqdm(valid_loader)):
                datas = datas.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                
                with amp.autocast(enabled=amp_enable):
                    if cfg.VALID.MODEL_EMA:
                        val_output,_ = model_ema.module(datas)
                    else:
                        val_output,_ = model(datas)
                    
                    val_loss = criterion_CE(val_output, labels)

                maxk = torch.topk(val_output, 5)[1]
                top1 = (maxk[:, 0] == labels).sum() / labels.size(0)
                top2 = (maxk[:, 0:2] == labels.view(-1, 1)).sum() / labels.size(0)
                top3 = (maxk[:, 0:3] == labels.view(-1, 1)).sum() / labels.size(0)
                top4 = (maxk[:, 0:4] == labels.view(-1, 1)).sum() / labels.size(0)
                top5 = (maxk == labels.view(-1, 1)).sum() / labels.size(0)

                epoch_val_top1 += top1 / len(valid_loader)
                epoch_val_top2 += top2 / len(valid_loader)
                epoch_val_top3 += top3 / len(valid_loader)
                epoch_val_top4 += top4 / len(valid_loader)
                epoch_val_top5 += top5 / len(valid_loader)

                epoch_val_loss += val_loss.item() / len(valid_loader)
            
                for lb in labels:
                    label_list.append(lb.data.item())

                for top1 in maxk[:, 0]:
                    pred_list.append(top1.data.item())
                
        if cfg.VALID.MODEL_EMA:
            model_ema.cpu()
                    
        if cfg.USE_WANDB:
            wandb.log({
                'epoch': epoch + 1,
                'epoch train loss': epoch_loss,
                'learning rate': optimizer.param_groups[0]['lr'],
                'epoch valid loss': epoch_val_loss,
                'epoch valid acc@1': epoch_val_top1,
                'epoch valid acc@2': epoch_val_top2,
                'epoch valid acc@3': epoch_val_top3,
                'epoch valid acc@4': epoch_val_top4,
                'epoch valid acc@5': epoch_val_top5,
            })
        
        print(f'validation - epoch : {epoch + 1} - loss : {epoch_val_loss:.4f} - top1: {epoch_val_top1:.4f} - top2: {epoch_val_top2:.4f} - top3 : {epoch_val_top3:.4f} - top4: {epoch_val_top4:.4f} - top5: {epoch_val_top5:.4f}\n')
        print(classification_report(label_list, pred_list, target_names=class_list, digits=4))

        #lr schedule
        scheduler.step(epoch + 1)
        
        checkpoint = {
            'epoch': epoch + 1,
            'loss': epoch_loss,
            'val_loss': epoch_val_loss,
            'val_acc' : epoch_val_top1,
            'best_acc' : best_acc,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }
        
        if cfg.TRAIN.AMP_ENABLE:
            checkpoint['scaler_state_dict'] = scaler.state_dict()
        if cfg.MODEL.USE_EMA:
            checkpoint['model_ema_state_dict'] = model_ema.state_dict()
        
        if epoch_val_top1 > best_acc:
            best_acc = epoch_val_top1
            torch.save(checkpoint, '{}/best_acc/{}_acc_{:.2f}_{}.pth'.format(save_dir, cfg.NAME, best_acc*100, epoch + 1))
            print('Best accuracy model saved - epoch: {} - best_acc: {}\n'.format(epoch + 1, best_acc))
        
        torch.save(checkpoint, '{}/checkpoint/{}_ckpt_{}.pth'.format(save_dir, cfg.NAME, epoch + 1))
           
    #print final result
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(time_elapsed // 3600,
        (time_elapsed % 3600) // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))

# Start Training

In [None]:
train(model, optimizer, scheduler)

# Get training and validation embedding (for xgboost ensemble )

In [None]:
def get_emb(model, data_loader, specific_checkpoint=None):
    emb_list = []
    label_list = []
    
    load_data(model, specific_checkpoint=specific_checkpoint)
    
    if cfg.VALID.MODEL_EMA:
        model_ema.cuda()
        model_ema.eval()
    else:
        model.eval()
    
    with torch.no_grad():
        for idx, (datas, labels) in enumerate(tqdm(data_loader)):
            datas = datas.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            
            with amp.autocast(enabled=amp_enable):
                if cfg.VALID.MODEL_EMA:
                    _, embs = model_ema.module(datas)
                else:
                    _, embs = model(datas)
            
            emb_list.append(embs.cpu())
            label_list.append(labels.cpu())
            
        emb_list = torch.cat(emb_list)
        label_list = torch.cat(label_list)
    
    if cfg.VALID.MODEL_EMA:
        model_ema.cpu()
            
    return emb_list, label_list

In [None]:
train_emb_list, train_label_list = get_emb(model, query_loader, specific_checkpoint=None)
valid_emb_list, valid_label_list = get_emb(model, valid_loader, specific_checkpoint=None)

In [None]:
#save embedding data array for later ensemble
l = [train_emb_list.numpy(), train_label_list.numpy(), valid_emb_list.numpy(), valid_label_list.numpy()]
with open("train_emb_convnext_base", "wb") as fp:
    pickle.dump(l, fp)

# Get testing embedding (for xgboost ensemble )

In [None]:
test_list = glob.glob(os.path.join(cfg.DATA.PUBLIC_TEST_DIR, "*")) + glob.glob(os.path.join(cfg.DATA.PRIVATE_TEST_DIR, "*"))
test_list = sorted(test_list)

print(len(test_list))

class AIdeaTestDataset(Dataset):
    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img = Image.open(self.file_list[idx]).copy()
        W, H = img.size
        
        img = np.array(img)

        if cfg.DATA.CROP:
            target_x = float(target_df.loc[target_df['filename'] == os.path.split(self.file_list[idx])[1]]['target_x'].to_string(index=False))
            target_y = float(target_df.loc[target_df['filename'] == os.path.split(self.file_list[idx])[1]]['target_y'].to_string(index=False))
            
            if W > H:
                if W/2 - H/2 + target_x * W < 0:
                    x_min = 0
                elif W/2 - H/2 + target_x * W > W - H:
                    x_min = int(W - H)
                else:
                    x_min = int(W/2 - H/2 + target_x * W)
                x_max = x_min + H
                y_min = 0
                y_max = H
            else:
                x_min = 0
                x_max = W
                if H/2 - W/2 + target_y * H < 0:
                    y_min = 0
                elif H/2 - W/2 + target_y * H > H - W:
                    y_min = int(H - W)
                else:
                    y_min = int(H/2 - W/2 + target_y * H)
                y_max = y_min + W
            
            img = A.Crop(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max)(image=img)['image']
        img_transformed = self.transform(image=img)
        img_transformed = img_transformed['image']
        
        return img_transformed, os.path.split(self.file_list[idx])[1]

test_data = AIdeaTestDataset(test_list, transform=valid_transform)

test_loader = DataLoader(
    dataset=test_data,
    batch_size=cfg.DATA.BATCH_SIZE,
    shuffle=False,
    pin_memory=cfg.DATA.PIN_MEMORY,
    num_workers=cfg.DATA.NUM_WORKERS,
    drop_last=False
)

In [None]:
def test_emb(model, test_loader, specific_checkpoint):
    test_emb_list = []
    test_file_list = []
    
    load_data(model, specific_checkpoint=specific_checkpoint)
    
    if cfg.VALID.MODEL_EMA:
        model_ema.cuda()
        model_ema.eval()
    else:
        model.eval()
        
    with torch.no_grad():
        for idx, (datas, file_name) in enumerate(tqdm(test_loader)):
            datas = datas.cuda(non_blocking=True)
            
            with amp.autocast(enabled=amp_enable):
                if cfg.VALID.MODEL_EMA:
                    _, emb = model_ema.module(datas)
                else:
                    _, emb = model(datas)
                    
            test_emb_list.append(emb.cpu())
            test_file_list += file_name
            
    if cfg.VALID.MODEL_EMA:
        model_ema.cpu()
        
    return torch.cat(test_emb_list), test_file_list

test_emb_list, test_file_list = test_emb(model, test_loader, specific_checkpoint=None)

In [None]:
#save embedding data array for later ensemble
l = [test_emb_list.numpy(), test_file_list]
with open("test_emb_convnext_base", "wb") as fp:
    pickle.dump(l, fp)