In [1]:
!nvidia-smi

Failed to initialize NVML: Unknown Error


In [12]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

import math
import time
import pandas as pd

import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.distributed import all_reduce, ReduceOp

import numpy as np
from datasets import Dataset
# from torch.utils.data import Dataset
from PIL import Image, ImageFilter, ImageOps

import wandb

from funcs import is_main, to_devices, print_msg

In [13]:
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from functools import partial
from funcs import load_checkpoint
from vits import archs
from vits import resize_pos_embed
from ssit import SSiT

def load_checkpoint(model, checkpoint_path, checkpoint_key, linear_key):
    checkpoint = torch.load(checkpoint_path)
    # print(checkpoint)
    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith(checkpoint_key) and not k.startswith('%s.%s' % (checkpoint_key, linear_key)):
            # remove prefix
            state_dict[k[len("%s." % checkpoint_key):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    # position embedding
    pos_embed_w = state_dict['pos_embed']
    pos_embed_w = resize_pos_embed(pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
    state_dict['pos_embed'] = pos_embed_w

    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"%s.weight" % linear_key, "%s.bias" % linear_key}
    print_msg('Load weights form {}'.format(checkpoint_path))

class SSITClfConfig(PretrainedConfig):
    model_type = "ssit"

    def __init__(
        self,
        **kwargs
    ):
        self.arch = 'ViT-S-p16'
        self.num_classes = 5
        self.input_size = 224
        self.checkpoint_key = 'base_encoder'
        self.checkpoint = "saved_models/SSIT_unlabled_bs64_100ep/checkpoint.pt"
        
        # self.checkpoint = False,
        self.pretrained = True

        
        
class SSITClassification(PreTrainedModel):
    config_class = SSITClfConfig

    def __init__(self, config):
        super().__init__(config)

        # encoder = partial(
        #     archs[config.arch],
        #     pretrained=config.pretrained,
        #     img_size=config.input_size,
        #     # mask_ratio=config.mask_ratio,
        #     )

        # self.model = encoder(num_classes=config.num_classes)
    
        self.model = archs[config.arch](
            num_classes=config.num_classes,
            pretrained=config.pretrained,
            img_size=config.input_size,
            feat_concat=True
        )

        linear_key = 'head'
        checkpoint_key = config.checkpoint_key
        
        if config.checkpoint:
            load_checkpoint(self.model, config.checkpoint, checkpoint_key, linear_key)
        else:
            print('No checkpoint provided. Training from scratch.')

    def forward(self, pixel_values, labels=None):
        # define function in transformers library maner
        logits, f = self.model(pixel_values)
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(logits, labels)
            # loss = torch.nn.functional.mse_loss(logits, labels)
            return {"loss": loss, "logits": logits}
            
        return {"logits": logits}
        

In [14]:
SSITClConfig = SSITClfConfig()
model = SSITClassification(SSITClConfig)
# model = 

  checkpoint = torch.load(checkpoint_path)


Load weights form saved_models/SSIT_unlabled_bs64_100ep/checkpoint.pt


In [15]:
#FIXME: rewrite path and add mask path

# load dataset via csv table
labelsTable = pd.read_csv('../../mnt/local/data/kalexu97/trainLabels.csv') # initial table

error_images = ['15337_left.jpeg', '40764_right.jpeg']

for error_image in error_images:
    error_image = error_image[:-5]
    labelsTable = labelsTable[labelsTable.image != error_image]

# add folder path 'mask_image'
root_dir = '../../mnt/local/data/kalexu97/processed_train'
mask_dir = '../../mnt/local/data/kalexu97/saliency_mask/'

labelsTable['image_path'] = labelsTable['image'].apply(lambda x: os.path.join(root_dir, x+'.jpeg'))
labelsTable['mask_image'] = labelsTable['image'].apply(lambda x: os.path.join(mask_dir, x+'.npy'))
labelsTable['label'] = labelsTable['level']
labelsTable = labelsTable.drop(columns=['image', 'level'], axis=1)

# dataset is spliated to trian and test previously, and is constant for every training process
test_dataset = pd.read_csv('../test_dataset.csv')
test_dataset['image'] = test_dataset['image_path'].apply(lambda x: x[33:])

for error_image in error_images:
    error_image = error_image
    test_dataset = test_dataset[test_dataset.image != error_image]
    
test_dataset['image_path'] = test_dataset['image'].apply(lambda x: os.path.join(root_dir, x))
test_dataset['mask_image'] = test_dataset['image'].apply(lambda x: os.path.join(mask_dir, x[:-5]+'.npy'))

# subtract the test_dataset from the full dataset to get the train_dataset
df = pd.concat([test_dataset, labelsTable])
df = df.reset_index(drop=True)
df_gpby = df.groupby(list(['image_path', 'label']))
idx = [x[0] for x in df_gpby.groups.values() if len(x) == 1]

train_dataset = df.reindex(idx).drop(columns=['Unnamed: 0'], axis=1)

In [16]:
from data_utils import resample
train_dataset = resample(train_dataset, ratio = 35)

0: length: 19460
1: length: 19460
2: length: 19460
3: length: 19460
4: length: 19460
N_added_rows:  26953
N_all_rows:  28099
Ratio of used rows:  0.9592156304494822


In [17]:
from torchvision import transforms

def data_transforms(input_size):
    mean = [0.425753653049469, 0.29737451672554016, 0.21293757855892181]  # eyepacs mean
    std = [0.27670302987098694, 0.20240527391433716, 0.1686241775751114]  # eyepacs std
    
    augmentations = [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomResizedCrop(
            size=(input_size, input_size),
            scale=(0.87, 1.15),
            ratio=(0.7, 1.3)
        ),
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.1,
            hue=0.1
        ),
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        # transforms.TrivialAugmentWide()
    ]

    normalization = [
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]

    train_preprocess = transforms.Compose([
        *augmentations,
        *normalization
    ])

    test_preprocess = transforms.Compose(normalization)

    return train_preprocess, test_preprocess

train_transforms, test_transforms = data_transforms(224)

def load_image(path_image, label, mode):
    """
    The function loads image from path and make Pre-Augmentation.
    """
    image = Image.open(path_image)

    if mode == 'train':
        image = train_transforms(image)
        return image
        
    else:
        image = test_transforms(image)
        return image
        
def func_transform(examples):
    """
    The function is used to preprocess train dataset.
    """
    # pre-augmentation and preprocessing
    imgs = [load_image(path, lb, 'train') for path, lb in zip(examples['image_path'], examples['label'])]
    
    # post-augmentation
    # inputs_post = [post_transforms_train(img_tensor) for img_tensor in inputs['pixel_values']]
    inputs = {}
    inputs['pixel_values'] = imgs
    inputs['label'] = examples['label']

    return inputs

def func_transform_test(examples):
    """
    The function is used to preprocess test dataset.
    """
    # pre-augmentation and preprocessing
    imgs = [load_image(path, lb, 'test') for path, lb in zip(examples['image_path'], examples['label'])]
    
    # post-augmentation
    # inputs_post = [post_transforms_train(img_tensor) for img_tensor in inputs['pixel_values']]
    inputs = {}
    inputs['pixel_values'] = imgs
    inputs['label'] = examples['label']

    return inputs
    
    return inputs

# to dataset
train_ds = Dataset.from_pandas(train_dataset, preserve_index=False)
test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)

# apply preprocessing
prepared_ds_train = train_ds.with_transform(func_transform)
prepared_ds_test = test_ds.with_transform(func_transform_test)

# for sorted datasets shuffling can be usefull
prepared_ds_train = prepared_ds_train.shuffle(seed=42)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)

In [18]:
# val_dataset is alse defined previously, so we just need to load its indexes
with open('test_indeces.npy', 'rb') as f:
    sample_ids = np.load(f)
    inv_sample_ids = np.load(f)

# sample_ids = np.random.choice(len(prepared_ds_test), size=1000, replace=False)
# inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)

# with open('test_indeces.npy', 'wb') as f:
#     np.save(f, sample_ids)
#     np.save(f, inv_sample_ids)

val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

In [19]:
from validation_utils import collate_fn, get_compute_metrics
from transformers import TrainingArguments
from transformers import Trainer

# run_name is used to log metadata in wandb for tracking
r_name = "SSIT224_trainedOnUnlabled"

# define the function to compute metrics
compute_metrics = get_compute_metrics(r_name, 'EyE', save_cm=False)

# arguments for training
training_args = TrainingArguments(
    output_dir="./SSiT-base",
    evaluation_strategy="steps",
    logging_steps=50,

    save_steps=50,
    eval_steps=50,
    save_total_limit=3,
    
    report_to="wandb",  # enable logging to W&B
    run_name=r_name,  # name of the W&B run (optional)
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    # lr_scheduler_type = 'constant_with_warmup', # 'constant', 'cosine'
    
    learning_rate=2e-5,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=64,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=64,
    num_train_epochs=15,
    warmup_ratio=0.02,
    
    metric_for_best_model="kappa", # select the best model via metric kappa
    greater_is_better = True,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=val_ds,
)



In [20]:
# trainer.train("./MedViT-base/checkpoint-22800")
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

[34m[1mwandb[0m: Currently logged in as: [33malexu97[0m ([33malexu97-skoltech[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy,Kappa,F1,Roc Auc,Class 0,Class 1,Class 2,Class 3,Class 4
50,1.6516,1.645779,0.207,0.066853,0.24891,0.600711,0.346,0.655,0.689,0.812,0.912
100,1.4585,1.347893,0.403,0.26456,0.468733,0.713188,0.515,0.686,0.778,0.927,0.9
150,1.2056,1.124268,0.46,0.446735,0.525947,0.787955,0.559,0.658,0.793,0.949,0.961
200,1.1085,1.277099,0.269,0.400774,0.313577,0.792989,0.394,0.449,0.798,0.938,0.959
250,1.0581,1.101521,0.457,0.492722,0.527496,0.808773,0.564,0.613,0.822,0.952,0.963
300,1.0269,1.110859,0.513,0.496247,0.573356,0.808596,0.603,0.71,0.795,0.946,0.972
350,1.0094,1.184819,0.449,0.456598,0.511453,0.806842,0.545,0.676,0.762,0.952,0.963
400,0.974,1.053655,0.508,0.53765,0.57565,0.825481,0.6,0.656,0.817,0.965,0.978
450,0.9434,1.024454,0.548,0.550481,0.610301,0.817446,0.641,0.699,0.817,0.963,0.976
500,0.9167,1.034855,0.564,0.556581,0.621581,0.818083,0.66,0.729,0.82,0.944,0.975


[[129 247 184 135  42]
 [ 16  32  12  11   7]
 [ 22  43  31  25  21]
 [  3   7   0   8   4]
 [  5   2   4   3   7]]
[[315 233  98  35  56]
 [ 24  34   9   5   6]
 [ 33  32  28  18  31]
 [  6   3   1  10   2]
 [  0   2   0   3  16]]
[[354 267  92  10  14]
 [ 25  35  12   4   2]
 [ 31  30  44  24  13]
 [  2   2   2  13   3]
 [  0   0   3   4  14]]
[[153 473  80  15  16]
 [ 10  49  13   4   2]
 [ 12  47  38  32  13]
 [  0   2   2  15   3]
 [  0   0   3   4  14]]
[[352 305  58   7  15]
 [ 25  38  10   3   2]
 [ 26  40  37  28  11]
 [  0   2   2  15   3]
 [  0   0   3   3  15]]
[[388 222 104  13  10]
 [ 25  34  12   5   2]
 [ 23  23  58  29   9]
 [  0   1   2  18   1]
 [  0   0   3   3  15]]
[[317 255 144   8  13]
 [ 21  36  12   7   2]
 [ 14  26  64  24  14]
 [  0   1   2  16   3]
 [  0   0   2   3  16]]
[[379 267  82   5   4]
 [ 23  40  11   2   2]
 [ 19  37  63  14   9]
 [  0   2   8  11   1]
 [  0   0   3   3  15]]
[[428 217  81   5   6]
 [ 30  30  14   2   2]
 [ 19  34  62  18   9]
 [ 

In [21]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

[[3588  514  337   11   15]
 [ 275   89   51    4    2]
 [ 277   96  437   46   19]
 [  11    9   45   57   11]
 [   7    0   35   15   74]]
***** eval metrics *****
  epoch                   =    14.9901
  eval_accuracy           =     0.7046
  eval_class_0            =     0.7598
  eval_class_1            =     0.8422
  eval_class_2            =     0.8496
  eval_class_3            =     0.9748
  eval_class_4            =     0.9827
  eval_f1                 =     0.7213
  eval_kappa              =     0.6114
  eval_loss               =     0.8388
  eval_roc_auc            =     0.8115
  eval_runtime            = 0:01:08.11
  eval_samples_per_second =      88.45
  eval_steps_per_second   =      1.395


In [23]:
model.save_pretrained(f"./saved_models/{r_name}", from_pt=True)
# image_processor.save_pretrained(f"./saved_models/{r_name}")