## ViT training

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple
from typing import List
import random

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.transforms import functional as F

from datasets import Dataset
from datasets import load_dataset

from transformers import ViTImageProcessor
from transformers import AutoImageProcessor
from transformers import TrainingArguments
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from transformers import Trainer
from transformers import ViTForImageClassification
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification

# import of custom functions
from validation_utils import collate_fn, get_compute_metrics
from data_utils import resample

torch.cuda.empty_cache()

## Define a model

In [None]:
# define model to trian start from pretrained weights on Imagenet
model_name_or_path = 'google/vit-base-patch16-224-in21k'

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=True,
    num_labels=5
)

## Load dataset

In [6]:
# load dataset via csv table
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/trainLabels.csv') # initial table

# add folder path
root_dir = '../mnt/local/data/kalexu97/train'
labelsTable['image_path'] = labelsTable['image'].apply(lambda x: os.path.join(root_dir, x+'.jpeg'))
labelsTable['label'] = labelsTable['level']
labelsTable = labelsTable.drop(columns=['image', 'level'], axis=1)

# dataset is spliated to trian and test previously, and is constant for every training process
test_dataset = pd.read_csv('test_dataset.csv')

# subtract the test_dataset from the full dataset to get the train_dataset
df = pd.concat([test_dataset, labelsTable])
df = df.reset_index(drop=True)
df_gpby = df.groupby(list(['image_path', 'label']))
idx = [x[0] for x in df_gpby.groups.values() if len(x) == 1]

train_dataset = df.reindex(idx).drop(columns=['Unnamed: 0'], axis=1)

### Resampling

In [7]:
# RUS for major classes, ROS for minor classes
# number of items in each class is equal to 
#           ratio * len(most_minor_dataset) 

# oversampling just repeating minority class items
# enought times to be equal to major dataset in size
train_dataset = resample(train_dataset, ratio = 35)

0: length: 19460
1: length: 19460
2: length: 19460
3: length: 19460
4: length: 19460
N_added_rows:  26953
N_all_rows:  28100
Ratio of used rows:  0.9591814946619217


### Pre-Augmentaion, Preprocessing, Post-Augmentation

In [None]:
from data_utils import CenterCrop, Spot, RandomSharpen, Blur, Halo, Hole

# define preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)

# for some models it is possible to change input size between training stage
image_processor.size['height'] = 224
image_processor.size['width'] = 224

# Pre-Augmetations
_transforms_train = T.Compose([
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomVerticalFlip(p = 0.5),
    CenterCrop(),
    T.Resize((260, 260), interpolation=T.InterpolationMode.BICUBIC),
    T.RandomCrop(224, padding_mode='symmetric', pad_if_needed=True),
    T.TrivialAugmentWide(),
    # Sharpness(),
    # Blur()
])

# Post-Augmentations
post_transforms_train = T.Compose([
    Spot(size=224),
    Halo(size=224),
    Hole(size=224),
    
])

# Pre-Augmentaions for test_dataset
_transforms_test = T.Compose([
    CenterCrop(),
])


def load_image(path_image, label, mode):
    """
    The function loads image from path and make Pre-Augmentation.
    """
    image = Image.open(path_image)

    if mode == 'train':
        image = _transforms_train(image)
        return image
        
    else:
        image = _transforms_test(image)
        return image
        

def func_transform(examples):
    """
    The function is used to preprocess train dataset.
    """
    # pre-augmentation and preprocessing
    inputs = image_processor([load_image(path, lb, 'train')
                                for path, lb in zip(examples['image_path'], examples['label'])], return_tensors='pt')
    
    # post-augmentation
    inputs_post = [post_transforms_train(img_tensor) for img_tensor in inputs['pixel_values']]
    inputs['pixel_values'] = inputs_post
    inputs['label'] = examples['label']

    return inputs

def func_transform_test(examples):
    """
    The function is used to preprocess test dataset.
    """
    # pre-augmentation and preprocessing
    inputs = image_processor([load_image(path, lb, 'test')
                                for path, lb in zip(examples['image_path'], examples['label'])], return_tensors='pt')
    inputs['label'] = examples['label']
    
    return inputs

# to dataset
train_ds = Dataset.from_pandas(train_dataset, preserve_index=False)
test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)

# apply preprocessing
prepared_ds_train = train_ds.with_transform(func_transform)
prepared_ds_test = test_ds.with_transform(func_transform_test)

# for sorted datasets shuffling can be usefull
prepared_ds_train = prepared_ds_train.shuffle(seed=42)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)

### Split test_dataset to val_dataset and test_dataset

In [None]:
# val_dataset is alse defined previously, so we just need to load its indexes
with open('test_indeces.npy', 'rb') as f:
    sample_ids = np.load(f)
    inv_sample_ids = np.load(f)

val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

# Define trainer

In [13]:
# run_name is used to log metadata in wandb for tracking
r_name = "vit384(16)_rot_4"

# define the function to compute metrics
compute_metrics = get_compute_metrics(r_name, 'EyE', save_cm=False)

# arguments for training
training_args = TrainingArguments(
    output_dir="./ViT-base",
    evaluation_strategy="steps",
    logging_steps=50,

    save_steps=50,
    eval_steps=50,
    save_total_limit=3,
    
    report_to="wandb",  # enable logging to W&B
    run_name=r_name,  # name of the W&B run (optional)
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    # lr_scheduler_type = 'constant_with_warmup', # 'constant', 'cosine'
    
    learning_rate=1e-5,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    warmup_ratio=0.02,
    
    metric_for_best_model="kappa", # select the best model via metric kappa
    greater_is_better = True,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=val_ds,
)

## Train !

In [2]:
# in some cases we can continue training from some checkpoint

# trainer.num_train_epochs = trainer.num_train_epochs + 5
# trainer.learning_rate=1e-5

In [None]:
# trainer.train("./MedViT-base/checkpoint-22800")
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

## Validate on test_dataset

In [None]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

## Save model weights and preprocessor configs

In [None]:
model.save_pretrained(f"./saved_models/{r_name}", from_pt=True)
image_processor.save_pretrained(f"./saved_models/{r_name}")