# Implementation Guidelines of Sample Code (Pytorch)

    See the annotations at every markdown blocks correspoding to each code blocks, and also # TODO annotations. :D

# Usage guideline of Jupyter Notebook (If needed)

    Installation   : https://jupyter.org/install  
    User Document  : https://jupyter-notebook.readthedocs.io/en/latest/user-documentation.html

# Test Environment (Recommended)

    In test time, we will evaluate the given codes from you with the following version of libraries.  
    So, it is highly recommended to use those packages with specific version below.

    test environment : pytorch

### Packages
    python   : 3.8.17  
    torch    : 2.0.1   
    skimage  : 0.21.0  
    cv2      : 4.8.0

# Import libraries (Do not change!)

In [1]:
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import cv2
from torch.utils.data import DataLoader
from skimage import io
import pandas as pd
import matplotlib.pyplot as plt
import math
import copy
import time
import PIL
import pickle

# Split dataset (Do not change!)

### Notice 1
    This function do split your dataset of 1000 classes into 10 groups of 100 each.    
    So, it is needed to be implemented just once at first to split your dataset for continual learning.   
    *Again, you dont need to use this function in every tranining time if you already split your dataset into 10 groups.

    Notice the annotation codes below. (You can see this codes in 'main' block.)

```python
        parser = argparse.ArgumentParser()   
        # Change this as 'False' after dividing your datsaet into 10 groups.
        parser.add_argument('--div_data',   default = True)  
        args = parser.parse_args(args=[])  
```

### Notice 2
    We reshapes all the input data size into constant 128x128.   
    Until further notification, use this constant size. 

```python
        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 130):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))  # resize image into 128 x 128 
                x_train.append(img)
```


In [2]:
def train_split(validation_num):
    # TODO : set dataset path
    # TODO : We recommends you to place your code and tranining dataset in the same location.
    
    dir = './Koh_Young_AI_data/'
    

    for div_idx in range(0, 10): # Div into 10 groups
        # Divide data 0-129 for training, 130-150 for validation.
        x_train = []
        x_valid = []
        y_train = []
        y_valid = []
        start   = 100*div_idx + 1
        end     = 100*div_idx + 100 + 1

        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 150-validation_num):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_train.append(img)

            for img_idx in range(150-validation_num, 150):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_valid.append(img)

        # Split corresponding output label data.
        for folder_idx in range(start, end):
            for img_idx in range(0, 150-validation_num):
                y_train.append(np.array([folder_idx]))
            for img_idx in range(150-validation_num, 150):
                y_valid.append(np.array([folder_idx]))

        # Convert list to numpy 
        x_train = np.array(x_train)
        y_train = np.array(y_train)
        x_valid = np.array(x_valid)
        y_valid = np.array(y_valid)

        # TODO : Define train data and valid data directory path.
        # TODO : Recommends not to change these directory paths. 
        train_save_dir = 'train_data'
        valid_save_dir = 'valid_data'
        if not os.path.exists(train_save_dir):
            os.makedirs(train_save_dir)

        if not os.path.exists(valid_save_dir):
            os.makedirs(valid_save_dir)

        # TODO : Save train/valid data
        np.save(f'./train_data/x_data_{div_idx+1}', x_train)
        np.save(f'./train_data/y_data_{div_idx+1}', y_train)
        np.save(f'./valid_data/x_data_{div_idx+1}', x_valid)
        np.save(f'./valid_data/y_data_{div_idx+1}', y_valid)

        print(f" ===================== Done in {div_idx} ===================== ")

# Define Dataloader (Do not change!)

    You can define your own dataloader with API of torch.utils.data.Dataset.  
    This can usually help you to reduce computational burden when dealing with high dimensional data, such as images.  

    reference url : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_data, device):
        self.x_data = x_data
        self.y_data = y_data
        self.device = device

    def __getitem__(self, idx):
        # .transpose(0, 2) : width x height x channel (0, 1, 2) ---> channel x width x height (2, 0, 1).
        # .squeeze(0) : add extra dimension at axis 0.
        x = torch.FloatTensor(self.x_data[idx]).transpose(0, 2)
        y = torch.LongTensor(self.y_data[idx]).squeeze(0)
        return x, y
        
    def __len__(self):
        return len(self.x_data)

def load_train_data(class_num):
    # TODO : set 'class_path' with your train_data path.
    class_path  = f'./train_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)
    return x_data, y_data

def load_valid_data(class_num):
    # TODO : set 'class_path' with your valid_data path.
    class_path  = f'./valid_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)

    # return processed data. 
    return x_data, y_data

# Define tranining function (You can modify this part!)

    Set your model with train mode as 'model.train()'.   

    useful reference : https://wikidocs.net/195118

In [4]:
# Minimizer
class ASAM:
    def __init__(self, optimizer, model, rho=0.5, eta=0.01):
        self.optimizer = optimizer
        self.model = model
        self.rho = rho
        self.eta = eta
        self.state = dict()

    @torch.no_grad()
    def ascent_step(self):
        wgrads = []
        for n, p in self.model.named_parameters():
            if p.grad is None:
                continue
            if p not in self.state:
                self.state[p] = {}
            t_w = self.state[p].get("eps")
            if t_w is None:
                t_w = torch.clone(p).detach()
                self.state[p]["eps"] = t_w
            if 'weight' in n:
                t_w[...] = p[...]
                t_w.abs_().add_(self.eta)
                p.grad.mul_(t_w)
            wgrads.append(torch.norm(p.grad, p=2))
        wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
        for n, p in self.model.named_parameters():
            if p.grad is None:
                continue
            if p in self.state and "eps" in self.state[p]:
                t_w = self.state[p].get("eps")
            else:
                continue
            if 'weight' in n:
                p.grad.mul_(t_w)
            eps = t_w
            eps[...] = p.grad[...]
            eps.mul_(self.rho / wgrad_norm)
            p.add_(eps)
        self.optimizer.zero_grad()

    @torch.no_grad()
    def descent_step(self):
        for n, p in self.model.named_parameters():
            if p.grad is None:
                continue
            p.sub_(self.state[p]["eps"])
        self.optimizer.step()
        self.optimizer.zero_grad()


# feature kd loss function
def hcl(fstudent, fteacher):
    loss_all = 0.0
    fs = fstudent
    
    for i in range(len(fteacher)):
        ft = fteacher[i]
        n,c,h,w = fs.shape
        if n == 0:
            break
        loss = nn.functional.mse_loss(fs, ft, reduction='mean')
        cnt = 1.0
        tot = 1.0
        for l in [4, 2, 1]:
            if l >= h:
                continue
            tmpfs = nn.functional.adaptive_avg_pool2d(fs, (l, l))
            tmpft = nn.functional.adaptive_avg_pool2d(ft, (l, l))
            cnt /= 2.0
            loss += nn.functional.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
            tot += cnt
        loss = loss / tot
        loss_all = loss_all + loss
    return loss_all


def clone_model(model, device): # convnets, old_weight, new_weight, n_classes, ntask
    new_model = type(model)(cfg["convnet"],
                            cfg=cfg,
                            nf=cfg["channel"],
                            device=device,
                            depth=cfg["depth"],
                            widen_factor=cfg["widen_factor"],
                            dropRate=cfg["dropRate"])  # model의 같은 타입의 새 인스턴스 생성
    new_model.convnets = copy.deepcopy(model.convnets)
    new_model.old_weight = nn.Parameter(model.old_weight.data.clone().detach(), requires_grad=False) if model.old_weight is not None else None
    new_model.new_weight = nn.Parameter(model.new_weight.data.clone().detach(), requires_grad=False)
    new_model.n_classes = model.n_classes
    new_model.ntask = model.ntask
    new_model.eval()
    new_model.to(device)  # 새 모델을 원래 모델과 같은 device로 옮김
    return new_model


def train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion, x_valid, valid_data_loader):
    n_task = model.ntask
    model.add_classes()
    model.task_size = 100
    if n_task == 0:     # task 0만 batchnorm2d training. 그 이후 task는 batchnorm2d frozen.
        model.train()
    else:
        model.eval()
    model.convnets[-1].to(device).train()
    if n_task >= 1:
        for i in range(n_task):
            model.convnets[i].eval()
            
    scheduling = cfg['scheduling'] if n_task == 0 else list(np.array(cfg['scheduling']) + n_task*2)
    num_epochs = num_epochs + n_task*2
    temperature = (cfg['max_temperature'] - cfg['temperature_step'] * n_task) if (cfg['max_temperature'] - cfg['temperature_step'] * n_task) >= cfg['min_temperature'] else cfg['min_temperature']
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg['lr']) if optimizer == "adamw" else optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg['lr'])
    minimizer = ASAM(optimizer, model, rho=cfg['rho'], eta=cfg['eta'])
    scheduler = optim.lr_scheduler.MultiStepLR(minimizer.optimizer, scheduling, gamma=cfg['lr_decay']) if cfg['scheduler'] == "multistep" else None

    print("Train on {}->{}.".format(n_task*100, (n_task+1)*100 - 1))
    
    best_model_wts = None
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        loss = 0.0
        acc = 0
        for i, (inputs, targets) in enumerate(train_data_loader, start=1):
            if n_task == 0: # task 0만 batchnorm2d training.
                model.train()
            inputs, targets = inputs.to(device), targets.to(device)

            # ascent step
            output = model(inputs, test=False)
            _, preds = torch.max(output['logit'], 1)
            _loss = criterion(output['logit'], targets)
            
            if cfg['old_model']:
                with torch.no_grad():
                    old_output = cfg['old_model'](inputs, test=False)
                    old_output_feature = [feature.detach() for feature in old_output['feature']]
                distill_loss = nn.functional.kl_div(nn.functional.log_softmax(output['logit'] / temperature, dim=1), nn.functional.softmax(old_output['logit'].detach() / temperature, dim=1), reduction='batchmean', log_target=False)
                feature_kd_loss = hcl(output['feature'][-1], old_output_feature)
                _loss = cfg['lambda_ent'] * _loss + cfg['lambda_distill'] * distill_loss + cfg['lambda_feat'] * feature_kd_loss
                
            _loss.backward()
            minimizer.ascent_step()
            
            # descent step
            output = model(inputs, test=False)
            _, preds = torch.max(output['logit'], 1)
            _loss = criterion(output['logit'], targets)
            
            if cfg['old_model']:
                with torch.no_grad():
                    old_output = cfg['old_model'](inputs, test=False)
                    old_output_feature = [feature.detach() for feature in old_output['feature']]
                distill_loss = nn.functional.kl_div(nn.functional.log_softmax(output['logit'] / temperature, dim=1), nn.functional.softmax(old_output['logit'].detach() / temperature, dim=1), reduction='batchmean', log_target=False)
                feature_kd_loss = hcl(output['feature'][-1], old_output_feature)
                _loss = cfg['lambda_ent'] * _loss + cfg['lambda_distill'] * distill_loss + cfg['lambda_feat'] * feature_kd_loss
            
            _loss.backward()
            minimizer.descent_step()

            loss += _loss.item()
            acc += torch.sum(preds.detach().cpu() == (targets.data).detach().cpu())

        if scheduler:
            scheduler.step()
        print(f" # - EPOCHS {epoch + 1} / {num_epochs} | Train Loss {loss/len(train_data_loader)} | Accuracy : {acc/len(x_train)} - #")
        
        # validation
        model.eval()
        
        with torch.no_grad():
            acc_val = 0
            acc_old = 0
            old_nums = 0
            acc_new = 0
            new_nums = 0
            val_loss = 0.0
            for x, y in valid_data_loader:
                old_y_idx = y < (100*n_task)
                new_y_idx = y >= (100*n_task)
                output = model(x.data.to(device), test=False)
                _, preds = torch.max(output['logit'], 1)
                preds_old = preds[old_y_idx]
                preds_new = preds[new_y_idx]
                
                cost  = criterion(output['logit'], y.to(device))
                val_loss += cost.item()
                acc_val += torch.sum(preds.detach().cpu() == (y.data).detach().cpu())
                acc_old += torch.sum(preds_old.detach().cpu() == (y[old_y_idx].data).detach().cpu())
                acc_new += torch.sum(preds_new.detach().cpu() == (y[new_y_idx].data).detach().cpu())
                old_nums += torch.sum(old_y_idx)
                new_nums += torch.sum(new_y_idx)
            print(f" # - Val Loss {val_loss/len(valid_data_loader)} | Accuracy : {acc_val / len(x_valid)} | Old Accuracy : {acc_old / old_nums} | New Accuracy : {acc_new / new_nums} - #")
        
        if (val_loss/len(valid_data_loader)) < best_val_loss:
            best_val_loss = val_loss / len(valid_data_loader)
            best_model_wts = copy.deepcopy(model.state_dict())
    
    model.load_state_dict(best_model_wts)

    with torch.no_grad():  # Autograd 연산을 추적하지 않도록 설정
        cfg['old_model'] = clone_model(model.to('cpu'), device)
        model.to(device)
    
    return model, loss/len(train_data_loader), acc/len(x_train)

# Define validation function (Do not change!)

    And eval mode as 'model.eval()' or 'model.train(False)'.

In [5]:
def validation(model, x_valid, valid_data_loader, criterion):
    """
    model             : your customized model 
    x_vallid          : input data for validation
    valid_data_loader : dataloder of valid dataset 
    """
    
    model.eval() # Set eval mode
    
    acc = 0
    for x, y in valid_data_loader:
        out = model(x.data.to(device))
        _, preds = torch.max(out, 1)
        cost  = criterion(out, y.to(device))
        acc += torch.sum(preds.detach().cpu() == (y.data).detach().cpu())
    print(f" # - ValidCost {cost} | Accuracy : {acc / len(x_valid)} - #")

    # Return Accuracy 
    return acc/len(x_valid)

# Define your model and hyperparameter (You can modify this part!)

    Here is the pivotal part of your competition.
    We gives a simple CNN model, for example. 
    Go make your own model!          

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

# Constants
EPSILON = 1e-8

# CFG
cfg = {
    "device": "0",
    "optimizer": "adamw", # adam, adamw
    "minimizer": "asam",
    "rho": 1.0,
    "eta": 0.01,
    "batch_size": 128,
    "task_max": 10,
    "lr_min": 1e-4,
    "lr": 1e-2,
    "weight_decay": 5e-4,
    "epochs": 40,
    "scheduling": [20, 25, 30, 35],
    "scheduler": "multistep",  # "multistep", "cosine"
    "eta_min": 5e-5,
    "lr_decay": 0.5,
    # Network
    "convnet": "resnet34", # resnet18, resnet34
    "channel": 64,
    "use_bias": False,
    "last_relu": False,
    "dropRate": 0.0,
    # wrn
    "depth": 28,
    "widen_factor": 10,
    
    "train_head": "softmax",
    "infer_head": "softmax",
    "min_temperature": 2.5,
    "max_temperature": 4,
    "temperature_step": 0.2,
    "distillation": True,
    "start_class": 0,
    "start_task": 0,

    "trial": 1,
    "seed": 1993,
    "old_model": None,
    "lambda_ent": 1.0,
    "lambda_distill": 1.0,
    "lambda_feat": 1.0
}

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, remove_last_relu=False, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.remove_last_relu = remove_last_relu
        self.dropout = nn.Dropout(dropRate)
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        if not self.remove_last_relu:
            out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self,
                 block,
                 layers,
                 nf=64,
                 zero_init_residual=True,
                 remove_last_relu=False,
                 dropRate=0.0):
        super(ResNet, self).__init__()
        self.remove_last_relu = remove_last_relu
        self.inplanes = nf

        self.conv1 = nn.Sequential(
                    nn.Conv2d(3, nf, kernel_size=7, stride=2, padding=3, bias=False),   # (nf, 64, 64)
                    nn.BatchNorm2d(nf),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)                    # (nf, 32, 32)
        )
        
        self.layer1 = self._make_layer(block, 1 * nf, layers[0], dropRate=dropRate)
        self.layer2 = self._make_layer(block, 2 * nf, layers[1], stride=2, dropRate=dropRate)
        self.layer3 = self._make_layer(block, 4 * nf, layers[2], stride=2, dropRate=dropRate)
        self.layer4 = self._make_layer(block, 8 * nf, layers[3], stride=2, remove_last_relu=remove_last_relu, dropRate=dropRate)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.out_dim = 8 * nf * block.expansion
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
    

    def _make_layer(self, block, planes, blocks, remove_last_relu=False, stride=1, dropRate=0.0):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                        conv1x1(self.inplanes, planes * block.expansion, stride),
                        nn.BatchNorm2d(planes * block.expansion)
            )
        
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, dropRate=dropRate))
        self.inplanes = planes * block.expansion
        if remove_last_relu:
            for i in range(1, blocks - 1):
                layers.append(block(self.inplanes, planes, dropRate=dropRate))
            layers.append(block(self.inplanes, planes, remove_last_relu=True, dropRate=dropRate))
        else:
            for _ in range(1, blocks):
                layers.append(block(self.inplanes, planes, dropRate=dropRate))
        
        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x = self.avgpool(x4)
        x = x.view(x.size(0), -1)
        return x, x4 # feature


# ResNet18
def resnet18(**kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


# ResNet34
def resnet34(**kwargs):
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model


def get_convnet(convnet_type, nf, remove_last_relu, depth, widen_factor, dropRate):
    if convnet_type == 'resnet18':
        return resnet18(nf=nf, remove_last_relu=remove_last_relu, dropRate=dropRate)
    elif convnet_type == 'resnet34':
        return resnet34(nf=nf, remove_last_relu=remove_last_relu, dropRate=dropRate)

class BasicNet(nn.Module):
    def __init__(self,
                 convnet_type,
                 cfg,
                 nf=64,
                 init="kaiming",
                 device=None,
                 depth=28,
                 widen_factor=10,
                 dropRate=0.3
                 ):
        super(BasicNet, self).__init__()
        self.nf = nf
        self.depth = depth
        self.widen_factor = widen_factor
        self.dropRate = dropRate
        self.convnet_type = convnet_type
        self.start_class = cfg['start_class']
        self.init = init
        self.remove_last_relu = False
        
        self.convnets = nn.ModuleList()
        self.convnets.append(
            get_convnet(convnet_type,
                        nf=nf,
                        remove_last_relu=self.remove_last_relu,
                        depth=self.depth,
                        widen_factor=self.widen_factor,
                        dropRate=self.dropRate
                        )
        )
        self.out_dim = self.convnets[0].out_dim
        
        self.old_weight = None
        self.new_weight = None
        
        self.n_classes = 1000
        self.ntask = 0
        self.device = device
        self.task_size = 100
        
        self.post_processor = None
        
        self.to(self.device)
    
    def forward(self, x, test=True):
        # rescaling
        x = x / 255.0
        
        if self.new_weight is None:
            raise Exception("Add some classes before training.")
        
        logits = []
        features = []
        for convnet in self.convnets:
            logit, feature = convnet(x)
            logits.append(logit)
            features.append(feature)
        logits = torch.cat(logits, dim=1)
        
        combined_weight = torch.cat([self.old_weight, self.new_weight], dim=1) if self.old_weight is not None else self.new_weight
        logits = torch.nn.functional.linear(logits, combined_weight)
        if test:
            return logits
        else:
            return {'feature': features, 'logit': logits}
    
    @property
    def features_dim(self):
        if self.der:
            return self.out_dim * len(self.convnets)
        else:
            return self.out_dim
    
    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.eval()
        return self
    
    def copy(self):
        return copy.deepcopy(self)
    
    def add_classes(self):
        self.ntask += 1
        self._add_fc()
        # freezing
        for i in range(self.ntask-1):
            for p in self.convnets[i].parameters():
                p.requires_grad = False

    def _add_fc(self):
        if self.ntask > 1:
            new_clf = get_convnet(self.convnet_type,
                                nf=self.nf,
                                remove_last_relu=self.remove_last_relu,
                                depth=self.depth,
                                widen_factor=self.widen_factor,
                                dropRate=self.dropRate).to(self.device)
            new_clf.load_state_dict(self.convnets[-1].state_dict())
            self.convnets.append(new_clf)

            if self.ntask == 1:
                self.old_weight = None
            elif self.old_weight is None:
                self.old_weight = nn.Parameter(self.new_weight.detach().clone(), requires_grad=False)
            elif self.ntask > 2:
                self.old_weight = nn.Parameter(torch.cat([self.old_weight.detach().clone(), self.new_weight.detach().clone()], dim=1), requires_grad=False)
            
        self.new_weight = nn.Parameter(torch.Tensor(self.n_classes, self.out_dim).normal_(0, 0.01).to(self.device), requires_grad=True)
        


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# TODO : Define your model
model = BasicNet(cfg["convnet"],
            cfg=cfg,
            nf=cfg["channel"],
            device=device,
            depth=cfg["depth"],
            widen_factor=cfg["widen_factor"],
            dropRate=cfg["dropRate"]).to(device)

# TODO : Set your hyperparameters
batch_size        = cfg['batch_size']
learning_rate     = cfg['lr']
num_epochs        = cfg['epochs']
random_seed       = cfg['seed']
validation_num    = 20 # for 150 images for class, the number for validation data
criterion = nn.CrossEntropyLoss() # Define criterion.
optimizer = cfg['optimizer']

# Incremental Learning. (Do not change!)

### WARNING:
    The training and validation datasets each SHOULD BE prepared properly beforehand.  
    If not, the submitted code from you will be immediately rejected.

In [7]:

"""  
--div_data  : split your data or not.   
"""
parser = argparse.ArgumentParser()  
parser.add_argument('--div_data',   default = False)  # Change this with 'False' after dividing your datsaet into 10 groups.
args = parser.parse_args(args=[])  

# TODO : Saving tranined model in this location. Don't change this path. 
save_dir = './result'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# TODO : Seed
random.seed(random_seed)
torch.manual_seed(random_seed)

# TODO : Split dataset according to argument '--div_data'
if args.div_data == True:
    train_split(validation_num)
else:
    pass


""" 
1. training      : train each 100 classes sequentailly with respect to 1000 output class. 
    trainining class === 1-100 -> 101-200 -> 201-300 -> 301-400 -> ... -> 901-1000
    
2. validation    : validate each trained model.
    validation class === 1-100 -> 1-200 -> 1-300 -> ... -> 1-1000
    
3. model save    : saves each trained model.                
"""

for div_idx in range(10):

    # TODO : Load your train and validation data
    x_train, y_train = load_train_data(div_idx)
    x_valid, y_valid = load_valid_data(div_idx)

    """
        in case of tranining 1  -100 classes, validate on 1-100 classes
        in case of tranining 101-200 classes, validate on 1-200 classes
        in case of tranining 201-300 classes, validate on 1-300 classes
        and so on...            
    """
    
    if div_idx == 0:
        x_val_tmp = x_valid
        y_val_tmp = y_valid
    else:
        x_val_tmp = np.concatenate((x_val_tmp, x_valid), axis = 0)
        y_val_tmp = np.concatenate((y_val_tmp, y_valid), axis = 0)
        x_valid   = x_val_tmp
        y_valid   = y_val_tmp

    # TODO : let the label starts from 0 to match the output index of model prediction. (Currently the label starts from 1.)
    y_train = y_train - 1
    y_valid = y_valid - 1

    # TODO : Define dataset and dataloader
    train_dataset     = CustomDataset(x_train, y_train, device)
    valid_dataset     = CustomDataset(x_valid, y_valid, device)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    valid_data_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

    # TODO : train and validate
    trained_model, loss_train, acc_train = train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion, x_valid, valid_data_loader)
    acc_valid                    = validation(trained_model, x_valid, valid_data_loader, criterion)

    if div_idx == 9:
        MODEL_SAVE_FOLDER_PATH = './model_save/'
        if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
            os.mkdir(MODEL_SAVE_FOLDER_PATH)        
        model_path = MODEL_SAVE_FOLDER_PATH + 'continual_model.pt'
        # TODO : save trained model in 'save_model_path'
        torch.save(trained_model.state_dict(), model_path)

    print(f'{str(div_idx)} Iteration Done.')

Train on 0->99.
 # - EPOCHS 1 / 40 | Train Loss 3.2564272155948712 | Accuracy : 0.0835384652018547 - #
 # - Val Loss 2.886882320046425 | Accuracy : 0.14149999618530273 | Old Accuracy : nan | New Accuracy : 0.14149999618530273 - #
 # - EPOCHS 2 / 40 | Train Loss 2.331234077612559 | Accuracy : 0.19238461554050446 - #
 # - Val Loss 2.085022084414959 | Accuracy : 0.2644999921321869 | Old Accuracy : nan | New Accuracy : 0.2644999921321869 - #
 # - EPOCHS 3 / 40 | Train Loss 1.6366192838724922 | Accuracy : 0.3922307789325714 - #
 # - Val Loss 1.226191632449627 | Accuracy : 0.5335000157356262 | Old Accuracy : nan | New Accuracy : 0.5335000157356262 - #
 # - EPOCHS 4 / 40 | Train Loss 0.955980556268318 | Accuracy : 0.6359999775886536 - #
 # - Val Loss 1.0402108877897263 | Accuracy : 0.7149999737739563 | Old Accuracy : nan | New Accuracy : 0.7149999737739563 - #
 # - EPOCHS 5 / 40 | Train Loss 0.5055654455049365 | Accuracy : 0.8186923265457153 - #
 # - Val Loss 0.2967154085636139 | Accuracy : 0