# **匯入相關套件**

In [1]:
!pip install torch-ema
!pip install transformers
!pip install ptflops



In [2]:
# Pytorch related
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ExponentialLR

# TorchVision
from torchvision import transforms
from torchvision.datasets import ImageFolder

# 3-rd party lib for pytorch
## EMA 
from torch_ema import ExponentialMovingAverage

## Flops Calculation
from ptflops import get_model_complexity_info

## Hugging face API
from transformers import AutoFeatureExtractor, SwinForImageClassification   
from transformers import ConvNextFeatureExtractor, ConvNextForImageClassification
from transformers import ViTFeatureExtractor, ViTModel

# Other ML related lib that is helpful
## Numpy
import numpy as np
## sklearn
from sklearn.model_selection import train_test_split

# Utils lib
## tqdm
from tqdm import tqdm 

# csv
import pandas as pd

# OS 
import os
import shutil
import argparse

# #
# import models
# from data.dataset import OrchidDataSet
# from config import DefualtConfig
# from utils import get_confidence_score
# from utils import mixup_data, mixup_criterion
# from utils import rand_bbox
# from utils.self_supervised import get_pseudo_labels
# from optim.scheduler import GradualWarmupScheduler

In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

# import os
# os.chdir('/content/drive/My Drive/AICUP2022 - OrchidClassifier') #切換該目錄
# os.listdir() #確認目錄內容

# **Configuration 設置擋**

In [4]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    # model_name = 'ConvNeXt'
    # # pretrained_model = 'facebook/convnext-base-224'
    # pretrained_model = 'facebook/convnext-base-384'

    model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    pretrained_model = 'microsoft/swin-base-patch4-window12-384'

    # 
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219
    
    ###################################################################
    # Mix-based Augmentation
    ## Mixup
    do_MixUp = True

    ## CutMix
    do_cutMix = True
    beta = 1.0
    
    # the probability to use mix-based augmentation
    mix_prob = 0.2

    ###################################################################
    # Training
    ## Epochs
    start_epoch = 0
    num_epochs = 105            # Total epoch
    earlyStop_interval = 600    # 
    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5         # warmup epoch
    cosine_tmax = 101           # The Tmax for cosine annealing learning rate scheduler
    
    # Do semi-supervised training (if there is additional data for training)
    do_semi = False
    semi_start_epoch = 40

    ###################################################################
    # GPU Settings
    ## The index of GPU to use
    use_gpu_index = 0

    ###################################################################
    # Data
    ## DataLoader
    num_workers = 6

    ## Dataset Location
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    ## Train/ valid split ratio
    train_valid_split = 0.2  # ratio of valid set

    ###################################################################


# **Date processing 資料處理**

# **Dataset Declaration 資料集宣告**

In [5]:
class OrchidDataSet(ImageFolder):

    def __init__(self, root, transform_set):

        #
        super(OrchidDataSet, self).__init__(
            root=root, transform=transform_set)

    # help to get images for visualizing
    def getbatch(self, indices):
        '''
            @ Params : 
                1. indices (python.list)
            @ Returns : 
                1. images (torch.tensor with shape (1, ))
                2. labels (torch.tensor with shape (1, ))
        '''
        images = []
        labels = []
        for index in indices:
            image, label = self.__getitem__(index)
            # transform_ToTensor =  transforms.Compose([
            #                         transforms.Resize((224, 224)),
            #                         transforms.ToTensor()])
            # image = transform_ToTensor(image)

            images.append(image)
            labels.append(label)
        return torch.stack(images), torch.tensor(labels)


In [6]:
from torch.utils.data import Dataset


class TensorIntDataset(Dataset):
    ''' Dataset for loading and preprocessing the COVID19 dataset '''

    def __init__(self, x, y):
        # [x: numpy array, y: list of int]

        # convert into Pytorch.torch.tensor
        self.data = x

        # should be list of int
        self.target = y

        self.dim = self.data.shape[0]

        print('Finished reading TensorInt Dataset ({} samples found)'
              .format(len(self.data)))

    def __getitem__(self, index):
        # Returns one sample at a time
        return self.data[index], self.target[index]

    def __len__(self):
        # Returns the size of the dataset
        return len(self.data)


# **Model Declaration 模型宣告**

## **Swin Vit**

In [7]:
class Swin_ViT(nn.Module):

    def __init__(self, config: DefualtConfig):

        super(Swin_ViT, self).__init__()

        self.config = config
        self.num_labels = config.num_classes

        ###############################################
        # ViT
        # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
        pretrained_model = config.pretrained_model
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model)
        self.model = SwinForImageClassification.from_pretrained(pretrained_model)
        # print(self.model.config)

        # Classifier
        self.dropout = nn.Dropout(0.1)
        # self.classifier = nn.Linear(self.model.config.hidden_size, self.num_labels)
        self.classifier = nn.Linear(1000, self.num_labels)

    def feature_extract(self, imgs):
        '''
        '''
        # Change input array into list with each batch being one element

        # convert it to numpy array first
        device = torch.device('cpu')
        if imgs.device != torch.device('cpu'):
            device = torch.device(f'cuda:{self.config.use_gpu_index}')

        imgs = imgs.cpu().numpy()
        imgs = np.split(np.squeeze(np.array(imgs)), imgs.shape[0])

        # Remove unecessary dimension
        for index, array in enumerate(imgs):
            imgs[index] = np.squeeze(array)

        # Apply feature extractor, stack back into 1 tensor and then convert to tensor
        # imgs = (batch_size, 3, 224, 224)
        imgs = torch.tensor(
            np.stack(self.feature_extractor(imgs)['pixel_values'], axis=0))
        imgs = imgs.to(device)

        return imgs

    def forward(self, x, labels=None):
        '''
        Model forward function
        '''

        # Feature extraction
        # x = self.feature_extractor(x, return_tensors="pt")
        x = self.feature_extract(x)

        # Swin-ViT
        x = self.model(pixel_values=x)
        logits = self.classifier(x.logits)

        return logits


## **ConvNext**

In [8]:
class ConvNeXt(nn.Module):

    def __init__(self, config: DefualtConfig):

        super(ConvNeXt, self).__init__()

        self.config = config
        self.num_labels = config.num_classes

        ###############################################
        # Model
        pretrained_model = config.pretrained_model
        self.feature_extractor = ConvNextFeatureExtractor.from_pretrained(pretrained_model)
        self.model = ConvNextForImageClassification.from_pretrained(pretrained_model)

        # Classifier
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(1000, self.num_labels)

    def feature_extract(self, imgs):
        '''
        '''
        # Change input array into list with each batch being one element

        # convert it to numpy array first
        device = torch.device('cpu')
        if imgs.device != torch.device('cpu'):
            device = torch.device(f'cuda:{self.config.use_gpu_index}')

        imgs = imgs.cpu().numpy()
        imgs = np.split(np.squeeze(np.array(imgs)), imgs.shape[0])

        # Remove unecessary dimension
        for index, array in enumerate(imgs):
            imgs[index] = np.squeeze(array)

        # Apply feature extractor, stack back into 1 tensor and then convert to tensor
        # imgs = (batch_size, 3, 224, 224)
        imgs = torch.tensor(np.stack(self.feature_extractor(imgs)['pixel_values'], axis=0))
        imgs = imgs.to(device)

        return imgs

    def forward(self, x, labels=None):
        '''
        Model forward function
        '''

        # Feature extraction
        x = self.feature_extract(x)

        # Swin-ViT
        x = self.model(pixel_values=x)

        # x = self.dropout(x)
        logits = self.classifier(x.logits)

        return logits

## **STN-ViT**

In [9]:
class STN_ViT(nn.Module):

    def __init__(self, config: DefualtConfig):

        super(STN_ViT, self).__init__()

        self.config = config
        self.num_labels = config.num_classes

        ###############################################
        # input image with shape (batch_size, 3, 224, 224)
        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7),
            nn.Conv2d(32, 32, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),

            nn.Conv2d(32, 32, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),

            nn.Conv2d(32, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),

            nn.Conv2d(32, 32, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(32 * 9 * 9, 90),
            nn.ReLU(True),
            nn.Linear(90, 3 * 2)
        )
        # Initialize the weights/bias with identity transformation
        self.fc_loc[-1].weight.data.zero_()
        self.fc_loc[-1].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        ###############################################
        # ViT
        # pretrained_model = 'google/vit-base-patch16-224-in21k'
        pretrained_model = config.pretrained_model
        self.feature_extractor = ViTFeatureExtractor.from_pretrained(pretrained_model)
        self.vit = ViTModel.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, self.num_labels)

    def feature_extract(self, imgs):
        '''
        '''
        # Change input array into list with each batch being one element

        # convert it to numpy array first
        device = torch.device('cpu')
        if imgs.device != torch.device('cpu'):
            device = torch.device(f'cuda:{self.config.use_gpu_index}')

        imgs = imgs.cpu().numpy()
        imgs = np.split(np.squeeze(np.array(imgs)), imgs.shape[0])

        # Remove unecessary dimension
        for index, array in enumerate(imgs):
            imgs[index] = np.squeeze(array)

        # Apply feature extractor, stack back into 1 tensor and then convert to tensor
        # imgs = (batch_size, 3, 224, 224)
        imgs = torch.tensor(
            np.stack(self.feature_extractor(imgs)['pixel_values'], axis=0))
        imgs = imgs.to(device)

        return imgs

    def stn(self, x):
        '''
        Spatial transformer network forward function
        '''
        xs = self.localization(x)
        xs = torch.reshape(xs, (-1, 32 * 9 * 9))
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

    def forward(self, x, labels=None):
        '''
        Model forward function
        '''

        # Feature extraction
        x = self.feature_extract(x)

        # Spatial Transformer
        x = self.stn(x)

        # ViT
        x = self.vit(pixel_values=x)
        # x = torch.mean(x.last_hidden_state[:, ], 1)
        x = self.dropout(x.last_hidden_state[:, 0])
        logits = self.classifier(x)

        return logits


# **Learning rate scheduler 學習率調整**

In [10]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError(
                'multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [
                        base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        self.last_epoch = epoch if epoch != 0 else 1
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch /
                                    self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)


# **Utils**

## **Semi-supervised Learning 自監督式學習**

In [11]:
import gc

###################################################################################

config = DefualtConfig()


def get_pseudo_labels(model, *datasets, threshold=0.75):
    # This functions generates pseudo-labels of a dataset using given model.
    # It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    # You are NOT allowed to use any models trained on external data for pseudo-labeling.
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # # it costs too more time, uncomment here if you have enough time or have a string GPU
    # # combines unlabeled_set & test_sets
    # unlabeled, test = datasets
    # dataset = ConcatDataset([unlabeled, test])
    dataset = datasets[0]

    # Construct a data loader.
    data_loader = DataLoader(
        dataset, batch_size=config.batch_size, shuffle=False)

    # Make sure the model is in eval mode.
    model.eval()
    # Define softmax function.
    softmax = nn.Softmax(dim=-1)

    # temporary variables
    maxConfidence, pseudo_label = None, None
    masks, img = None, None

    # input to dataloader
    to_train_x, to_train_y = torch.tensor([]), []
    cnt = 0

    # Iterate over the dataset by batches.
    for batch in tqdm(data_loader):
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))

        # Obtain the probability distributions by applying softmax on logits.
        # dimension [128, 11]
        probs = softmax(logits)
        # probs = probs.cpu().detach().numpy()

        # ---------- TODO ----------
        # Filter the data and construct a new dataset.

        maxConfidence, pseudo_label = torch.max(probs, 1)

        # create a mask & delete non-fit datas
        masks = (maxConfidence >= threshold)
        img = img[masks]
        pseudo_label = pseudo_label[masks]

        # append result of each batch to all
        to_train_x = torch.cat((to_train_x, img), 0)
        for label in (pseudo_label):
            to_train_y.append(int(label.item()))

    cnt = len(to_train_y)

    # # Turn off the eval mode.
    model.train()

    #
    if cnt != 0:
        print(f"[ {cnt} Unlabeled Images append into train_set ]")

        # to_train_x = torch.tensor(to_train_x)
        # to_train_y = torch.tensor(to_train_y, dtype=torch.int)

        # # reshape
        to_train_x = torch.reshape(to_train_x, (-1, 3, 224, 224))
        # to_train_y = torch.reshape(to_train_y, (-1,))

        print(to_train_x.shape)
        print(len(to_train_y))

        # transfer list of Tensor into dataSet (TensorDataset)
        # res_dataset = TensorDataset(to_train_x, to_train_y)
        res_dataset = TensorIntDataset(to_train_x, to_train_y)

        # free the resources, or it will collapse eventually
        del maxConfidence, pseudo_label, masks, img
        del to_train_x, to_train_y
        gc.collect()

        return res_dataset

    return None


## **Mix-based Augmentation**

In [12]:
##########################################################################################
# mixup
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

##########################################################################################
# cutmix
def cutmix(batch, alpha):
    data, targets = batch

    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]

    lam = np.random.beta(alpha, alpha)

    image_h, image_w = data.shape[2:]
    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
    targets = (targets, shuffled_targets, lam)

    return data, targets

class CutMixCollator:
    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, batch):
        batch = torch.utils.data.dataloader.default_collate(batch)
        batch = cutmix(batch, self.alpha)
        return batch


class CutMixCriterion:
    def __init__(self, reduction, label_smoothing=0.1):
        self.criterion = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)

    def __call__(self, preds, targets):
        targets1, targets2, lam = targets
        return lam * self.criterion(
            preds, targets1) + (1 - lam) * self.criterion(preds, targets2)

##########################################################################################

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

    
##########################################################################################


## **Confidence 模型預測信心水準**

In [13]:
def get_confidence_score(model, loader, topN=5, use_gpu_index='-1', batch_size=32, outpu_file_path='./prediction-Confidence.csv'):
    '''
    '''

    # Set the model state to 'eval', we should not update any parameter here
    model.eval()

    device = torch.device(f'cuda:{use_gpu_index}' if torch.cuda.is_available(
    ) else'cpu') if use_gpu_index != -1 else torch.device('cpu')

    #
    with torch.no_grad():

        with open(outpu_file_path, 'w') as file:

            file.write(
                f'validset_index, topN, Ground_truth, 1, prob1, 2, prob2, 3, prob3, 4, prob4, 5, prob5\n')

            for batch_idx, batch in enumerate(loader):

                # A batch consists of image data and corresponding labels.
                imgs, labels = batch

                # We don't need gradient in validation.
                # Using torch.no_grad() accelerates the forward process.
                with torch.no_grad():
                    logits = model(imgs.to(device))

                y_label = logits.argmax(dim=-1)

                # List of
                y_probs = [F.softmax(el, dim=0) for i, el in zip(y_label, logits)]

                # record the prediction & ground truth for later review
                for i, _ in enumerate(y_label):

                    img_idx = (batch_idx) * batch_size + i
                    # img = ds_valid.__getitem__((batch_idx- 1) * config['batch_size'] + i)
                    # # img = np.array(img)

                    topN = 5
                    # topN_labels = y_probs[i].argsort()[-topN:].tolist()[::-1]
                    topN_values, topN_labels = y_probs[i].topk(topN)

                    #
                    file.write(f'{img_idx}, {topN}, {labels[i].item()}')
                    for i in range(topN):
                        file.write(
                            f', {topN_labels[i].item()}, {topN_values[i].item()}')
                    else:
                        file.write("\n")


# **訓練**

## **建立 Dataset, Dataloader**

In [14]:
def get_train_valid_ds(ds):

    # Split the train/test with each class should appear on both train/test dataset
    valid_split = config.train_valid_split

    indices = list(range(len(ds)))  # indices of the dataset
    train_indices, valid_indices = train_test_split(indices, test_size=valid_split, stratify=ds.targets, random_state=42)
    
    # Creating sub dataset from valid indices
    # Do not shuffle valid dataset, let the image in order
    valid_indices.sort()
    ds_valid = torch.utils.data.Subset(ds, valid_indices)

    ds_train = torch.utils.data.Subset(ds, train_indices)

    return ds_train, ds_valid

def get_loader(ds):

    # Split the train/test with each class should appear on both train/test dataset
    valid_split = config.train_valid_split

    indices = list(range(len(ds)))  # indices of the dataset
    train_indices, valid_indices = train_test_split(
        indices, test_size=valid_split, stratify=ds.targets)
    
    # Creating sub dataset from valid indices
    # Do not shuffle valid dataset, let the image in order
    valid_indices.sort()
    ds_valid = torch.utils.data.Subset(ds, valid_indices)

    # Creating PT data samplers and loaders:
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)

    # Construct data loaders.
    train_loader = DataLoader(
        ds, batch_size=config.batch_size, sampler=train_sampler, num_workers=config.num_workers, pin_memory=True)
    valid_loader = DataLoader(
        ds_valid, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

    return train_loader, valid_loader

## **宣告訓練/驗證流程**

In [15]:
def train(model, train_loader, criterion, optimizer, ema):
    
    # ---------- Training ----------
    # Make sure the model is in train mode before training.
    model.train()

    losses, accs = [], []

    # Iterate the training set by batches.
    for batch in tqdm(train_loader):

        # A batch consists of image data and corresponding labels.
        # imgs = (batch_size, 3, 224, 224)
        # labels = (batch_size)
        imgs, labels = batch
        imgs = imgs.to(device)

        target_a, target_b, lam = None, None, None
        # do_mix = False
        
        r = np.random.rand(1)
        do_mix = True if r < config.mix_prob else False

        r_mix_method = np.random.rand(1)
        mix_method = 'cutmix'
        if config.do_cutMix and config.do_MixUp:
            mix_method = 'cutmix' if r_mix_method < 0.5 else 'mixup'
        elif config.do_cutMix:
            mix_method = 'cutmix'
        elif config.do_MixUp:
            mix_method = 'mixup'

        if config.do_cutMix and mix_method == 'cutmix':
            if config.beta > 0 and do_mix:
                # generate mixed sample
                do_mix = True
                lam = np.random.beta(config.beta, config.beta)
                rand_index = torch.randperm(imgs.size()[0]).cuda()
                target_a = labels
                target_b = labels[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(imgs.size(), lam)
                imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2]
                # adjust lambda to exactly match pixel ratio
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (imgs.size()[-1] * imgs.size()[-2]))
        else:
            labels = labels.to(device)

        if config.do_MixUp and do_mix and mix_method == 'mixup':
            labels = labels.to(device)
            imgs, targets_a, targets_b, lam = mixup_data(imgs, labels, alpha=0.2, use_cuda=torch.cuda.is_available())
            imgs, targets_a, targets_b = map(Variable, (imgs, targets_a, targets_b))

        # Forward the data. (Make sure data and model are on the same device.)
        logits = model(imgs.to(device))

        # Calculate the cross-entropy loss.
        # We don't need to apply softmax before computing cross-entropy as it is done automatically.
        # loss = criterion(logits, labels.to(device))

        loss = None
        if (config.do_cutMix or config.do_MixUp) and do_mix:
            if config.do_MixUp and mix_method == 'mixup':
                loss = mixup_criterion(criterion, logits, targets_a, targets_b, lam)
            elif config.do_cutMix and mix_method == 'cutmix':
                # loss = criterion(logits, labels)
                target_a = target_a.to(device)
                target_b = target_b.to(device)
                # lam = lam.to(device)
                loss = criterion(logits, target_a) * lam + criterion(logits, target_b) * (1. - lam)
        else:
            loss = criterion(logits, labels.to(device))

        # Gradients stored in the parameters in the previous step should be cleared out first.
        optimizer.zero_grad()

        # Compute the gradients for parameters.
        loss.backward()

        # # STN : Allow transformes to do like translation, cropping, isotropic scaling but rotation
        # #           , with a intention to let STN learns where to focus on instead of how to transform the image.
        # # Below is what matrix should look like : 
        # #    [ x_ratio, 0 ] [offset_X]
        # #    [ 0, y_ratio ] [offset_y]
        # if config.model_name == "STN_ViT":
        #     model.fc_loc[-1].weight.grad[1].zero_()
        #     model.fc_loc[-1].weight.grad[3].zero_()

        # Update the parameters with computed gradients.
        optimizer.step()
        ema.update()

        # # Clip the gradient norms for stable training.
        # grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # Compute the accuracy for current batch.
        # acc = torch.tensor([0])
        # if not config.do_cutMix:        
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
        accs.append(acc.item())
        losses.append(loss.item())

    return np.mean(accs), np.mean(losses)

def valid(model, valid_loader, criterion, ema=None):
    # ---------- Validation ----------
    # Make sure the model is in eval mode so that some modules like dropout are disabled and work normally.
    model.eval()

    accs, losses = [], []

    if ema is not None:

        with ema.average_parameters():

            # Iterate the validation set by batches.
            for batch in tqdm(valid_loader):

                # A batch consists of image data and corresponding labels.
                imgs, labels = batch

                # We don't need gradient in validation.
                # Using torch.no_grad() accelerates the forward process.
                with torch.no_grad():
                    logits = model(imgs.to(device))

                # We can still compute the loss (but not the gradient).
                losses.append(criterion(logits, labels.to(device)).item())

                # Compute the accuracy for current batch.
                accs.append((logits.argmax(dim=-1) == labels.to(device)).float().mean().item())
    
    else:
        
        # Iterate the validation set by batches.
        for batch in tqdm(valid_loader):

            # A batch consists of image data and corresponding labels.
            imgs, labels = batch

            # We don't need gradient in validation.
            # Using torch.no_grad() accelerates the forward process.
            with torch.no_grad():
                logits = model(imgs.to(device))

            # We can still compute the loss (but not the gradient).
            losses.append(criterion(logits, labels.to(device)).item())

            # Compute the accuracy for current batch.
            accs.append((logits.argmax(dim=-1) == labels.to(device)).float().mean().item())

    return np.mean(accs), np.mean(losses)

## **開始訓練**

In [16]:


###################################################################################
def main(logdir):

    config = DefualtConfig()
    device = torch.device(f'cuda:{config.use_gpu_index}' if torch.cuda.is_available() else'cpu') if config.use_gpu_index != -1 else torch.device('cpu')


    # Step 1 : prepare logging writer
    writer = SummaryWriter(log_dir=logdir)

    # Step 2 : 
    print(config.model_name)
    # model = getattr(models, config.model_name)(config)
    model = None
    if config.model_name == 'STN_ViT':
        model = STN_ViT(config)
    elif config.model_name == 'Swin_ViT':
        model = Swin_ViT(config)
    elif config.model_name == 'ConvNeXt':
        model = ConvNeXt(config)

    if config.load_model:
        model.load_state_dict(torch.load(config.model_path))
    model.to(device)

    # 
    # Metrics : FLOPs, Params
    # resize = (224, 224) if config.model_name != 'ConvNeXt' else (384, 384)
    resize = (config.resize, config.resize)
    macs, params = get_model_complexity_info(model, (3, resize[0], resize[1]), as_strings=True, print_per_layer_stat=False, verbose=False)
    print('pthflops : ')
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    # Step 3 : DataSets}
    # Data Augumentation
    transform_set = [
        transforms.RandomResizedCrop((resize[0])),
        # transforms.ColorJitter(brightness=0.5),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        # transforms.RandAugment()
        # transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET)
    ]

    transform_set = transforms.Compose([

        # # Reorder transform randomly
        transforms.RandomOrder(transform_set),

        # Resize the image into a fixed shape
        transforms.Resize(resize),

        # ToTensor() should be the last one of the transforms.
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

        # 
        transforms.RandomErasing()
    ])
    ds = OrchidDataSet(config.trainset_path, transform_set=transform_set)
    ds_unlabeled = None
    if config.do_semi:
        ds_unlabeled = OrchidDataSet(config.unlabeledset_path, transform_set=transform_set)

    # Step 3
    # Deal with imbalance dataset
    #   For the classification task, we use cross-entropy as the measurement of performance.
    #   Since the wafer dataset is serverly imbalance, we add class weight to make it classifier better
    class_weights = [1 - (ds.targets.count(c))/len(ds) for c in range(config.num_classes)]
    class_weights = torch.FloatTensor(class_weights).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    # criterion = LabelSmoothingCrossEntropy()

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-8)
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, momentum=0.9)

    ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
    if config.load_model:
        ema = ema.load_state_dict(torch.load(config.ema_path))

    # scheduler_warmup is chained with schduler_steplr
    # scheduler_steplr = StepLR(optimizer, step_size=10, gamma=0.1)
    scheduler_steplr = CosineAnnealingLR(optimizer, T_max=config.cosine_tmax)
    # scheduler_steplr = CosineAnnealingLR(optimizer, T_max=config.num_epochs - config.lr_warmup_epoch + 1)
    # scheduler_steplr = ExponentialLR(optimizer, gamma=0.9)
    # if config.lr_warmup_epoch > 0:
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=config.lr_warmup_epoch, after_scheduler=scheduler_steplr)

    # Step 4
    # train_loader, valid_loader = get_loader(ds)
    ds_train, ds_valid = get_train_valid_ds(ds)

    # Step 5
    history = {'train_acc' : [], 'train_loss' : [], 'valid_acc' : [], 'valid_loss' : []}
    best_epoch, best_epoch_ema, best_loss, best_acc_ema, best_acc = 0, 0, 1e100, 0, 0
    nonImprove_epochs = 0

    # this zero gradient update is needed to avoid a warning message, issue #8.
    optimizer.zero_grad()
    optimizer.step()

    #
    # assert not(config.do_cutMix and config.do_MixUp), "Only support one of the mix-based augmentation"

    for epoch in range(config.start_epoch, config.start_epoch + config.num_epochs):

        print('=' * 150)

        # 
        # if config.lr_warmup_epoch > 0:
        scheduler_warmup.step(epoch + 1)

        writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)
        print(f'Epoch {epoch}, LR = {optimizer.param_groups[0]["lr"]}')

        # 
        collator = torch.utils.data.dataloader.default_collate
        # if config.do_cutMix:
        #     collator = CutMixCollator(config.cutMix_alpha)
        
        train_loader = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True, collate_fn=collator, num_workers=config.num_workers, pin_memory=True)
        # if epoch == 35:
        #     torch.save(model.state_dict(), f'{config.model_path[:-4]}_normal.pth')
        #     torch.save(ema.state_dict(), f'{config.ema_path[:-4]}_normal.pth')
            
        if epoch >= config.semi_start_epoch and config.do_semi:
            # Obtain pseudo-labels for unlabeled data using trained model.
            print(f"[ Train | Start pseudo labeling]")
            pseudo_set = get_pseudo_labels(model, ds_unlabeled)

            if pseudo_set != None:
                # Construct a new dataset and a data loader for training.
                # This is used in semi-supervised learning only.
                concat_dataset = ConcatDataset([ds_train, pseudo_set])
                train_loader = DataLoader(concat_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collator, num_workers=config.num_workers, pin_memory=True)
                
        valid_loader = DataLoader(ds_valid, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)

        # 
        train_criterion = criterion
        # if config.do_cutMix:
        #     train_criterion = CutMixCriterion(reduction='mean', label_smoothing=0.1)

        train_acc, train_loss = train(model, train_loader, train_criterion, optimizer, ema)
        print(f"[ Train | {epoch + 1:03d}/{config.num_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
        
        # 
        valid_acc, valid_loss = valid(model, valid_loader, criterion, None)
        print(f"[ Valid | {epoch + 1:03d}/{config.num_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
        
        # 
        valid_acc_ema, valid_loss_ema = valid(model, valid_loader, criterion, ema)
        print(f"[ Valid | {epoch + 1:03d}/{config.num_epochs:03d} ] loss = {valid_loss_ema:.5f}, acc = {valid_acc_ema:.5f} (EMA)")
        

        # Append the training statstics into history
        history['train_acc'].append(train_acc)
        history['valid_acc'].append(valid_acc)
        history['train_loss'].append(train_loss)
        history['valid_loss'].append(valid_loss)

        # Tensorboard Visualization
        writer.add_scalar("Train/train_acc", train_acc, epoch)
        writer.add_scalar("Valid/valid_acc", valid_acc, epoch)
        writer.add_scalar("Valid/valid_acc_ema", valid_acc_ema, epoch)
        writer.add_scalar("Train/train_loss", train_loss, epoch)
        writer.add_scalar("Valid/valid_loss", valid_loss, epoch)
        writer.add_scalar("Valid/valid_loss_ema", valid_loss_ema, epoch)

        #
        if valid_acc > best_acc:
            best_acc = valid_acc
            torch.save(model.state_dict(), f'{logdir}/{config.model_path}')
            torch.save(ema.state_dict(), f'{logdir}/{config.ema_path}')
            get_confidence_score(model, loader=valid_loader, use_gpu_index=config.use_gpu_index, batch_size=config.batch_size, outpu_file_path=f'{logdir}/prediction-Confidence-best.csv')
            print(f'Saving model with acc {valid_acc:.4f} and loss {valid_loss:.4f}')

        # EarlyStop
        # if the model improves, save a checkpoint at this epoch
        if valid_acc_ema > best_acc_ema:
            best_loss = valid_loss_ema
            best_acc_ema = valid_acc_ema
            best_epoch = epoch
            torch.save(model.state_dict(), f'{logdir}/ema_{config.model_path}')
            torch.save(ema.state_dict(), f'{logdir}/ema_{config.ema_path}')
            get_confidence_score(model, loader=valid_loader, use_gpu_index=config.use_gpu_index, batch_size=config.batch_size, outpu_file_path=f'{logdir}/prediction-Confidence-best-ema.csv')
            print(f'Saving model with acc {valid_acc_ema:.4f} and loss {valid_loss_ema:.4f} (EMA)')
            nonImprove_epochs = 0
        else:
            nonImprove_epochs += 1

        # Stop training if your model stops improving for "config['early_stop']" epochs.    
        if nonImprove_epochs >= config.earlyStop_interval:
            break
    
    torch.save(model.state_dict(), f'{logdir}/last_{config.model_path}')
    torch.save(ema.state_dict(), f'{logdir}/last_{config.ema_path}')
    print(f'Best epoch: {best_epoch} with acc {best_acc:.4f}')
    print(f'Best epoch: {best_epoch_ema} with acc {best_acc_ema:.4f} (EMA)')

    writer.flush()
    writer.close()

    # Step 6 : Explanation & Visualization
    get_confidence_score(model, loader=valid_loader, use_gpu_index=config.use_gpu_index, batch_size=config.batch_size, outpu_file_path=f'{logdir}/last-prediction-Confidence.csv')

###################################################################################
###################################################################################

In [17]:
###################################################################################

def start_training(logdir):

    # parser = argparse.ArgumentParser(description='AICUP - Orchid Classifier')

    # # parser.add_argument('--lr', default=2e-5, type=float,
    # #                     help='Base learning rate')
    # # parser.add_argument('--bs', default=32, type=int, help='Batch size')
    # # parser.add_argument('--e', default=50, type=int, help='Numbers of epoch')
    # # parser.add_argument('--v', default=50, type=int, help='Experiment version')
    # # parser.add_argument('--device', default=-1, type=int,
    # #                     help='GPU index, -1 for cpu')
    # parser.add_argument('--logdir', default='model', type=str, required=True, 
    #                             help='The folder to store the training stats of current model')

    # args = parser.parse_args()

    # 
    assert not os.path.isdir(os.path.join(os.getcwd(), logdir)), "Already has a folder with the same name"
    os.mkdir(logdir)

    # shutil.copy('./config.py', f'{logdir}/config.py')
    main(logdir)

## **Swin Transformer**

### **model_swin**

In [None]:
start_training('model_swin')

### **model_swin_mixs_tmax101**

In [None]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    # model_name = 'ConvNeXt'
    # # pretrained_model = 'facebook/convnext-base-224'
    # pretrained_model = 'facebook/convnext-base-384'

    model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    pretrained_model = 'microsoft/swin-base-patch4-window12-384'

    # model_name = 'CVT'
    # pretrained_model = 'microsoft/cvt-w24-384-22k'
    
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    # only one of them can be true
    do_MixUp = True

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 101

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_training('model_swin_mixs_tmax101')

## **ConvNeXt**

### **model_convnext_384_tmax101**

In [35]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    model_name = 'ConvNeXt'
    # pretrained_model = 'facebook/convnext-base-224'
    pretrained_model = 'facebook/convnext-base-384'

    # model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    # only one of them can be true
    do_MixUp = False

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 101

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_training('model_convnext_384_tmax101')

## **STN-ViT**

### **model_stnvit_mixs**

In [None]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    # model_name = 'ConvNeXt'
    # # pretrained_model = 'facebook/convnext-base-224'
    # pretrained_model = 'facebook/convnext-base-384'

    model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    pretrained_model = 'microsoft/swin-base-patch4-window12-384'

    # model_name = 'CVT'
    # pretrained_model = 'microsoft/cvt-w24-384-22k'
    
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    # only one of them can be true
    do_MixUp = True

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 101

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_training('model_stnvit_mixs')

# **資料預測**

In [18]:
###################################################################################

# config = DefualtConfig()
# device = torch.device(f'cuda:{config.use_gpu_index}' if torch.cuda.is_available() else'cpu') if config.use_gpu_index != -1 else torch.device('cpu')
config = None
device = torch.device('cpu')

BATCH_SIZE = 16

def get_fileName(ds):

    fileNames = []

    for i in range(len(ds.imgs)):
        fileNames.append(ds.imgs[i][0])

    return fileNames


def test(output_file_path='predictions.csv', model_name='model'):
    '''
    @ Params:

    '''

    # Mapping
    config = DefualtConfig()
    ds = OrchidDataSet(config.trainset_path, transform_set=None)
    idx_to_class = {}
    for k in ds.class_to_idx:
        idx_to_class[ds.class_to_idx[k]] = k

    # Step 1 : Model Define & Load
    # model = getattr(models, config.model_name)(config)
    model = None
    print(config.model_name)
    if config.model_name == 'STN_ViT':
        model = STN_ViT(config)
    elif config.model_name == 'Swin_ViT':
        model = Swin_ViT(config)
    elif config.model_name == 'ConvNeXt':
        model = ConvNeXt(config)

    device = torch.device(f'cuda:{config.use_gpu_index}' if torch.cuda.is_available() else'cpu') if config.use_gpu_index != -1 else torch.device('cpu')
    model = model.to(device)
    if torch.cuda.is_available() is True:
        model = model.cuda()
    model.load_state_dict(torch.load(f'./saved/{model_name}/{config.model_path}', map_location=device))

    ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
    ema.load_state_dict(torch.load(f'./saved/{model_name}/{config.ema_path}', map_location=device))

    # Step 2 : DataSet & DataLoader
    resize = resize = (config.resize, config.resize)
    transform_set = transforms.Compose([

        # Resize the image into a fixed shape
        transforms.Resize(resize),

        # ToTensor() should be the last one of the transforms.
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    ds_test = OrchidDataSet(config.testset_path, transform_set=transform_set)
    # test_loader = DataLoader(ds_test, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers)
    test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=config.num_workers)

    # Step 3 : Make prediction via trained model
    # Make sure the model is in eval mode.
    # Some modules like Dropout or BatchNorm affect if the model is in training mode.
    model.eval()

    # Initialize a list to store the predictions.
    predictions = []

    with ema.average_parameters():

        # Iterate the validation set by batches.
        for batch in tqdm(test_loader):

            # A batch consists of image data and corresponding labels.
            imgs, _ = batch

            # We don't need gradient in validation.
            # Using torch.no_grad() accelerates the forward process.
            with torch.no_grad():
                logits = model(imgs.to(device))

            predictions += logits.argmax(dim=-1)

    # imgs_file_names = os.listdir(config.testset_path)
    imgs_file_names = get_fileName(ds_test)

    # Step 4 : Save predictions into the file.
    with open(output_file_path, "w") as f:

        # The first row must be "Id, Category"
        f.write("filename,category\n")

        # For the rest of the rows, each image id corresponds to a predicted class.
        for i, pred in  enumerate(predictions):
            ans = idx_to_class[pred.item()]
            imgs_file_name = imgs_file_names[i].split("\\")
            # f.write(f"{imgs_file_names[i][-16:]},{ans}\n")
            f.write(f"{imgs_file_name[-1]},{ans}\n")

    # Step 5 : Rearrange
    # Rearrange the predictions into a dataframe.
    df_predict = pd.read_csv(output_file_path)
    df_sample = pd.read_csv('./submission_template.csv')
    df_arrange = pd.merge(df_sample, df_predict, how='inner', on=['filename'])
    df_arrange = df_arrange.drop('category_sample', axis=1)
    df_arrange.to_csv(f'rearrange_{output_file_path}', index=False)

    # 
    # # Step 6 : Explanation & Visualization
    # # get_confidence_score(model, loader=test_loader, use_gpu_index=config.use_gpu_index, batch_size=config.batch_size, outpu_file_path=f'{output_file_path[:-3]}_Confidence.csv')
    # get_confidence_score(model, loader=test_loader, use_gpu_index=config.use_gpu_index, batch_size=BATCH_SIZE, outpu_file_path=f'{output_file_path[:-3]}_Confidence.csv')


###################################################################################
###################################################################################

In [19]:
def start_testing(model):

    # parser = argparse.ArgumentParser(description='AICUP - Orchid Classifier')

    # # parser.add_argument('--lr', default=2e-5, type=float,
    # #                     help='Base learning rate')
    # # parser.add_argument('--bs', default=32, type=int, help='Batch size')
    # # parser.add_argument('--e', default=50, type=int, help='Numbers of epoch')
    # # parser.add_argument('--v', default=50, type=int, help='Experiment version')
    # # parser.add_argument('--device', default=-1, type=int,
    # #                     help='GPU index, -1 for cpu')
    # # parser.add_argument('--logdir', default='model', type=str, required=True, 
    # #                             help='The folder to store the training stats of current model')

    # parser.add_argument('--model', default='model_swin', type=str, required=True,
    #                             help='The name of the model')

    # parser.add_argument('--output', default=f'predictions.csv', type=str,
    #                             help='The file to store the predictions')

    # args = parser.parse_args()

    # # First, select the chosen model's config file, and replace it with the current one in the main folder
    # # Remove the old config file
    # if os.path.exists('config.py'):
    #     os.remove('config.py')
    
    # # Then copy the new config file from the model folder, and rename it to config.py
    # shutil.copyfile(f'./saved/{model}/config.py', 'config.py')

    # # 
    # config = DefualtConfig()

    # Second
    test(output_file_path=f'prediction_{model}.csv', model_name=model)

### **model_stnvit_mixs**

In [20]:
class DefualtConfig(object):

    ###################################################################
    # Model
    model_name = 'STN_ViT'
    pretrained_model = 'google/vit-base-patch16-224-in21k'
    # pretrained_model = 'google/vit-base-patch32-384'
    
    # model_name = 'ConvNeXt'
    # # pretrained_model = 'facebook/convnext-base-224'
    # pretrained_model = 'facebook/convnext-base-384'

    # model_name = 'Swin_ViT'
    # # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    # pretrained_model = 'microsoft/swin-base-patch4-window12-384'

    # model_name = 'CVT'
    # pretrained_model = 'microsoft/cvt-w24-384-22k'
    
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    do_MixUp = True

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 101

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_testing('model_stnvit_mixs')

### **model_convnext_384_tmax101**

In [None]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    model_name = 'ConvNeXt'
    # pretrained_model = 'facebook/convnext-base-224'
    pretrained_model = 'facebook/convnext-base-384'

    # model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    # only one of them can be true
    do_MixUp = False

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 101

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_testing('model_convnext_384_tmax101')

### **model_swin_mixs_tmax101**

In [None]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    # model_name = 'ConvNeXt'
    # # pretrained_model = 'facebook/convnext-base-224'
    # pretrained_model = 'facebook/convnext-base-384'

    model_name = 'Swin_ViT'
    # pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    pretrained_model = 'microsoft/swin-base-patch4-window12-384'

    # 
    resize = 384

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219
    
    ###################################################################
    # Mix-based Augmentation
    ## Mixup
    do_MixUp = True

    ## CutMix
    do_cutMix = True
    beta = 1.0
    
    # the probability to use mix-based augmentation
    mix_prob = 0.2

    ###################################################################
    # Training
    ## Epochs
    start_epoch = 0
    num_epochs = 105            # Total epoch
    earlyStop_interval = 600    # 
    batch_size = 8
    lr = 5e-5
    lr_warmup_epoch = 5         # warmup epoch
    cosine_tmax = 101           # The Tmax for cosine annealing learning rate scheduler
    
    # Do semi-supervised training (if there is additional data for training)
    do_semi = False
    semi_start_epoch = 40

    ###################################################################
    # GPU Settings
    ## The index of GPU to use
    use_gpu_index = 0

    ###################################################################
    # Data
    ## DataLoader
    num_workers = 6

    ## Dataset Location
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    ## Train/ valid split ratio
    train_valid_split = 0.2  # ratio of valid set

    ###################################################################


In [None]:
start_testing('model_swin_mixs_tmax101')

### **model_swin**

In [None]:
class DefualtConfig(object):

    ###################################################################
    # Model
    # model_name = 'STN_ViT'
    # pretrained_model = 'google/vit-base-patch16-224-in21k'
    
    # model_name = 'ConvNeXt'
    # pretrained_model = 'facebook/convnext-base-224'

    model_name = 'Swin_ViT'
    pretrained_model = 'microsoft/swin-base-patch4-window7-224'
    
    resize = 224

    ###################################################################

    model_path = 'model.pth'
    ema_path = 'ema.pth'
    load_model = False
    num_classes = 219

    # only one of them can be true
    do_MixUp = False

    do_cutMix = True
    beta = 1.0
    
    mix_prob = 0.2

    ###################################################################
    # Training
    start_epoch = 0
    num_epochs = 105
    earlyStop_interval = 600

    do_semi = False
    semi_start_epoch = 40

    batch_size = 16
    lr = 5e-5
    lr_warmup_epoch = 5
    cosine_tmax = 20

    ###################################################################
    # GPU Settings
    # use_gpu = True
    use_gpu_index = 0

    ###################################################################
    # DataLoader
    num_workers = 6

    # Dataset
    trainset_path = './data/dataset/train'
    unlabeledset_path = './data/dataset/unlabeled'
    testset_path = './data/dataset/test'

    train_valid_split = 0.2  # ratio of valid set

    ###################################################################

    def __init__(self) -> None:
        '''
        '''
        pass

    def parse(self, kwargs):
        '''
        '''
        print('User config : ')
        pass


In [None]:
start_testing('model_swin')

# **Ensemble 模型集成**

In [22]:
#
import pandas as pd
import numpy as np
from tqdm import tqdm

# 
# register_models = ['model_stnvit_mixs', 'model_swin_mixs_tmax101', 'model_convnext_384_tmax101']

# # Register confidence score dataframe from register models
# confidence_scores = []
# for model_name in register_models:
#     # Read condifence score table for each model
#     path = f'./saved/{model_name}/prediction-Confidence-best.csv'
#     confidence_scores.append(pd.read_csv(path))

register_models = ['model_stnvit_mixs', 'model_swin', 'model_swin_mixs_tmax101', 'model_convnext_384_tmax101']

# Register confidence score dataframe from register models
confidence_scores = []
for model_name in register_models:
    # Read condifence score table for each model
    path = f'./rearrange_prediction_{model_name}.csv'
    confidence_scores.append(pd.read_csv(path))

'''
validset_index, topN, Ground_truth, 1, prob1, 2, prob2, 3, prob3, 4, prob4, 5, prob5
0, 5, 0, 0, 0.7134590148925781, 49, 0.01988092251121998, 178, 0.009589494206011295, 200, 0.00530626904219389, 174, 0.005254499148577452
'''

df_sample = pd.read_csv('./submission_template.csv')

#
with open('./ensemble_predict.csv', 'w') as f:
    # 

    n_correct = 0

    N = len(confidence_scores[0])
    for idx in tqdm(range(N)):   
        # the ensemble result from each model
        # the format is [label1 : total confidence score, label2 : total confidence score, ...]
        ensemble = {}

        priority = [2, 1, 3, 0]

        # Add top 5 confidence score of each predicted label into ensemble
        for confidence in confidence_scores:
            if confidence.iloc[idx, 1] in ensemble:
                ensemble[confidence.iloc[idx, 1]] += 1
            else:
                ensemble[confidence.iloc[idx, 1]] = 1
        
        # Sort the ensemble result
        ensemble = sorted(ensemble.items(), key=lambda x: x[1], reverse=True)
        # print(ensemble)
        # break

        ensemble_label = ensemble[0][0]

        if len(ensemble) >= 2 and ensemble[0][1] == ensemble[1][1]:
            for i in priority:
                if confidence_scores[i].iloc[idx, 1] == ensemble[0][0]:
                    ensemble_label = ensemble[0][0]
                else:
                    ensemble_label = ensemble[1][0]

        # And then write the ensemble result into csv file
        # f.write(f'{idx},{ensemble_label}\n')
        df_sample.iloc[idx, 1] = ensemble_label

df_sample.to_csv('./ensemble_predict.csv', index=False)

# df_predict = pd.read_csv('ensemble_predict.csv')

100%|██████████| 81710/81710 [00:24<00:00, 3386.85it/s]
