### *Module Loading*

In [None]:
import sys
import shutil
import cv2
from IPython import display

### *External Module Loading*

In [None]:
external_modules_path = '..\\nn_likelihood_modules'
sys.path.append(external_modules_path)

In [None]:
from basic_network_structure import *
from common_imports import *
from common_use_functions import *
from constant import *
from defined_data_structure import *
from defined_network_structure import *
from experim_neural_network import *
from experim_preparation import *
from experim_ResNet import *
from ResNet import *
from experim_swintransformer import *
from tiny_imagenet_data_prep import *
from pytorch_swintransformer import *
from pytorch_swintransformer_modified import *

### *GPU verification*

In [None]:
# Get the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nb_gpu = torch.cuda.device_count()
if nb_gpu > 0:
    print(torch.cuda.get_device_name(0))
else:
    print("CPU")

### *Working directory*

In [None]:
# Current path
current_path = os.path.abspath(os.getcwd())

### *Load configurations and data*

In [None]:
"""
All the parameters in this part should be configured
"""
# Experience path
experim_path = current_path

# File extensions
json_ext = '.json'
np_ext = '.npy'
csv_ext = '.csv'

# The tiny imagenet folder
tiny_imagenet_path = 'D:\\Doctorat\\research\\Tiny_imagenet_experim\\tiny-imagenet-200-reform\\'
tiny_imagenet_val_path = path_join(tiny_imagenet_path, 'reform_val')
tiny_imagenet_train_path = path_join(tiny_imagenet_path, 'reform_train')

# Save paths
model_save_path = path_join(experim_path, 'experim_models_swin')

# ResNet related params
model_name_prefix = 'tiny_imagenet'
model_name = 'swin'

# Tested sets name
train_set_name = 'train'
test_set_name = 'test'
valid_set_name = 'valid'
input_extension = 'X'
label_extension = 'Y'

# Resized image size
resized_image_size = 256

# Dataset general informations
data_set_infos = {
    'nb_classes' : 10
}

# Batch size for the dataloader creation
torch_batch_size = 2

In [None]:
# Create the folder to save models and data if not existed
create_directory(model_save_path)

### *Tiny image net preparation*

In [None]:
# Load the Tiny imagenet dataset (Since the test set don't have labels we will not use it, we use valid set as test set)
tiny_imagenet_train_dataset, tiny_imagenet_test_dataset = get_tiny_imagenet_dataset(tiny_imagenet_train_path, tiny_imagenet_val_path)

### *Swintrasnformer initialization*

In [None]:
# Get the pretrained dict
pretrained_swin_model = get_swin_v2_b(pretrained=True, num_classes=1000)
pretrained_swin_state_dict = pretrained_swin_model.state_dict()

In [None]:
# Create the swintransformer
swin = get_swin_v2_b_modified(num_classes=200, hidden_size=1024)

In [None]:
## Update the weights with the pretrained weights
# Get the model state dict
swin_state_dict = swin.state_dict()
# Get the weight without head
pretrained_weight_dict = {k: v for k, v in pretrained_swin_state_dict.items() if 'head' not in k}
# Update the model dict with the pretrained weights
swin_state_dict.update(pretrained_weight_dict) 
# Load the pretrained weights
swin.load_state_dict(swin_state_dict)

In [None]:
swin

### *SwinTransformer training*

In [None]:
# Training hyperparameters
lr = 0.000005
optim_type = 'adamw'
nb_epochs = 20
criterion_type = 'cross_entropy'

In [None]:
# Training preparation
train_criterion = nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1) 
eval_criterion = get_criterion(criterion_type, mean_reduction=False)
optimizer = optim.AdamW(swin.parameters(), lr=lr, weight_decay=0.05)

In [None]:
# Dataloader building
train_loader = create_loader_from_torch_dataset(tiny_imagenet_train_dataset, batch_size=torch_batch_size, shuffle=True, num_workers=0)
test_loader = create_loader_from_torch_dataset(tiny_imagenet_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)

In [None]:
# Training
train_hist = train_network_without_valid_cosine_annealing(swin, nb_epochs, train_loader, test_loader, 32 // torch_batch_size, optimizer, train_criterion, eval_criterion,
                                    pth=model_save_path, net_name=join_string([model_name_prefix, model_name]), lr_scheduler=True)

In [None]:
# Evaluation on the test set
accuracy_eval(swin, test_loader, set_name='test')