In [1]:
!nvidia-smi

Sun Aug  4 09:02:11 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-16GB           On  |   00000000:06:00.0 Off |                    0 |
| N/A   41C    P0             42W /  300W |       3MiB /  16384MiB |      0%   E. Process |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-16GB           On  |   00

In [5]:
# !pip uninstall numpy -y
# !pip uninstall torchvision -y

Found existing installation: torchvision 0.12.0
Uninstalling torchvision-0.12.0:
  Successfully uninstalled torchvision-0.12.0
[0m

In [6]:
# !pip install timm==0.5.4
# !pip install torchvision==0.12.0
# !pip install torchvision
# !pip install numpy

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

import math
import time
import pandas as pd

import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.distributed import all_reduce, ReduceOp

import numpy as np
from datasets import Dataset
# from torch.utils.data import Dataset
from PIL import Image, ImageFilter, ImageOps

# from funcs import is_main, to_devices, print_msg

In [2]:
import wandb

In [3]:
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from functools import partial
from vits import archs
from ssit import SSiT

class SSITConfig(PretrainedConfig):
    model_type = "ssit"

    def __init__(
        self,
        **kwargs
    ):
        self.temperature = 0.2
        self.pool_mode = 'max'
        self.saliency_threshold = 0.5
        self.arch = 'ViT-S-p16'
        self.pretrained = True
        self.input_size = 384
        self.mask_ratio = 0.25
        self.epochs = 300
        self.moco_m = 0.99
        self.ss = 10
        self.cl = 1
        
class SSITSegmentation(PreTrainedModel):
    config_class = SSITConfig

    def __init__(self, config):
        super().__init__(config)

        self.moco_m = config.moco_m
        self.ss = config.ss
        self.cl = config.cl
        self.ss_decay = True
        self.epoch = 0
        self.epochs = 0
        
        encoder = partial(
            archs[config.arch],
            pretrained=config.pretrained,
            img_size=config.input_size,
            mask_ratio=config.mask_ratio,
            )
    
        self.model = SSiT(
            encoder,
            dim=256,
            mlp_dim=4096,
            T=config.temperature,
            pool_mode=config.pool_mode,
            saliency_threshold=config.saliency_threshold,
        )

    def adjust_moco_momentum(self, epoch, epochs):
        """Adjust moco momentum based on current epoch"""
        m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / epochs)) * (1. - self.moco_m)
        return m
    
    def adjust_lambda_ss(self, epoch, epochs):
        """Adjust moco momentum based on current epoch"""
        ss = self.ss * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
        return ss

    def save_weights(self, save_path):
        checkpoint = {
            'state_dict': self.model.state_dict()
        }
        # model = self.model.module
    
        torch.save(checkpoint, os.path.join(save_path, 'checkpoint.pt'))
        # torch.save(model, os.path.join(save_path, 'epoch_.pt'))
        print('Saved checkpoint to {}'.format(save_path))

    def forward(self, X1, X2, M1, M2, epoch=None, epochs=None, return_loss=True):

        if epoch == None:
            epoch = self.epoch
            epochs = self.epochs
        else:
            self.epoch = epoch
            self.epochs = epochs

        moco_m = self.adjust_moco_momentum(epoch, epochs)
        ss = self.adjust_lambda_ss(epoch, epochs) if self.ss_decay else self.ss

        cl_loss, ss_loss = self.model(X1, X2, M1, M2, moco_m)
        loss = self.cl * cl_loss + ss * ss_loss
        
        return {"loss": loss, "cl_loss": cl_loss, "ss_loss": ss_loss}
        

In [4]:
SSITSegConfig = SSITConfig()
model = SSITSegmentation(SSITSegConfig)

In [5]:
#FIXME: rewrite path and add mask path

# load dataset via csv table

"""
labelsTable = pd.read_csv('../../mnt/local/data/kalexu97/trainLabels.csv') # initial table

error_images = ['15337_left.jpeg', '40764_right.jpeg']

for error_image in error_images:
    error_image = error_image[:-5]
    labelsTable = labelsTable[labelsTable.image != error_image]

# add folder path 'mask_image'
root_dir = '../../mnt/local/data/kalexu97/processed_train'
mask_dir = '../../mnt/local/data/kalexu97/saliency_mask/'

labelsTable['image_path'] = labelsTable['image'].apply(lambda x: os.path.join(root_dir, x+'.jpeg'))
labelsTable['mask_image'] = labelsTable['image'].apply(lambda x: os.path.join(mask_dir, x+'.npy'))
labelsTable['label'] = labelsTable['level']
labelsTable = labelsTable.drop(columns=['image', 'level'], axis=1)

# dataset is spliated to trian and test previously, and is constant for every training process
test_dataset = pd.read_csv('../test_dataset.csv')
test_dataset['image'] = test_dataset['image_path'].apply(lambda x: x[33:])

for error_image in error_images:
    error_image = error_image
    test_dataset = test_dataset[test_dataset.image != error_image]
    
test_dataset['image_path'] = test_dataset['image'].apply(lambda x: os.path.join(root_dir, x))
test_dataset['mask_image'] = test_dataset['image'].apply(lambda x: os.path.join(mask_dir, x[:-5]+'.npy'))

# subtract the test_dataset from the full dataset to get the train_dataset
df = pd.concat([test_dataset, labelsTable])
df = df.reset_index(drop=True)
df_gpby = df.groupby(list(['image_path', 'label']))
idx = [x[0] for x in df_gpby.groups.values() if len(x) == 1]

train_dataset = df.reindex(idx).drop(columns=['Unnamed: 0'], axis=1)
"""

In [5]:
from os import listdir
from os.path import isfile, join
from sklearn.model_selection import train_test_split

path = "../../mnt/local/data/kalexu97/test"
onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]
# print(len(onlyfiles))
error_images = ['5371_left.jpeg', '26250_left.jpeg', '13227_left.jpeg']
image_names = [f for f in onlyfiles if f not in error_images]
mask_names = [x[:-5]+'.npy' for x in image_names]
# print(len(images_names))

dataset = {'image_path': image_names, 'mask_image': mask_names}

dataset = pd.DataFrame.from_dict(dataset)
train_dataset, test_dataset = train_test_split(dataset, test_size=0.1) 

In [6]:
from data import TransformWithMask

image_folder = '../../mnt/local/data/kalexu97/processed_test/'
mask_folder = '../../mnt/local/data/kalexu97/processed_mask_test/'

input_size = 384
mean = [0.425753653049469, 0.29737451672554016, 0.21293757855892181]  # eyepacs mean
std = [0.27670302987098694, 0.20240527391433716, 0.1686241775751114]  # eyepacs std
data_aug = {
    'brightness': 0.4,
    'contrast': 0.4,
    'saturation': 0.2,
    'hue': 0.1,
    'scale_stu': (0.08, 0.8),
    'scale_tea': (0.8, 1.0),
    'degrees': (-180, 180),
}

transform = TransformWithMask(input_size, mean, std, data_aug)

def pil_loader(img_path):
    path = image_folder+img_path
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def npy_loader(mask_path):
    path = mask_folder+mask_path
    with open(path, 'rb') as f:
        # import numpy as np
        img = np.load(f)
        return img

def func_transform(examples):
    """
    The function is used to preprocess train dataset.
    """
    # pre-augmentation and preprocessing

    X1 = []
    X2 = []
    M1 = []
    M2 = []
    
    for img_path, mask_path in zip(examples['image_path'], examples['mask_image']):
        img = pil_loader(img_path)
        mask = npy_loader(mask_path)
        # print(mask.min(), mask.max())
        mask = Image.fromarray(np.uint8(mask*255))
        if transform is not None:
            img_stu, img_tea, mask_stu, mask_tea = transform(img, mask)

        X1.append(img_stu)
        X2.append(img_tea)
        M1.append(mask_stu)
        M2.append(mask_tea)

    inputs = {}
    inputs['X1'] = X1
    inputs['X2'] = X2
    inputs['M1'] = M1
    inputs['M2'] = M2
    
    return inputs

# to dataset
train_ds = Dataset.from_pandas(train_dataset, preserve_index=False)
test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)

# apply preprocessing
prepared_ds_train = train_ds.with_transform(func_transform)
prepared_ds_test = test_ds.with_transform(func_transform)

# for sorted datasets shuffling can be usefull
prepared_ds_train = prepared_ds_train.shuffle(seed=42)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)


In [7]:
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

# subclass trainer
from transformers import Trainer
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        
        outputs = model(**inputs, epoch=self.state.epoch, epochs=self.state.num_train_epochs)
        # loss = outputs.loss
        loss = outputs["loss"]

        # if return_outputs:
            # print(loss)

        return (loss, outputs) if return_outputs else loss

In [8]:
# Define function to define collate function
def collate_fn(batch):
    # print([x['mask'] for x in batch])
    return {
        'X1': torch.stack([x['X1'] for x in batch]),
        'X2': torch.stack([x['X2'] for x in batch]),
        'M1': torch.stack([x['M1'] for x in batch]),
        'M2': torch.stack([x['M2'] for x in batch])
    }

In [14]:
# val_dataset is alse defined previously, so we just need to load its indexes
# with open('test_indeces.npy', 'rb') as f:
#     sample_ids = np.load(f)
#     inv_sample_ids = np.load(f)

# sample_ids = np.random.choice(len(prepared_ds_test), size=1000, replace=False)
# inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)

# with open('test_indeces.npy', 'wb') as f:
#     np.save(f, sample_ids)
#     np.save(f, inv_sample_ids)

# val_ds = prepared_ds_test.select(sample_ids)
# test_ds = prepared_ds_test.select(inv_sample_ids)

In [10]:
# def compute_metrics(eval_pred):
#     print(eval_pred)
#     loss, cl_loss, ss_loss = eval_pred
#     result = {
#             # 'loss': loss,
#             'cl_loss': cl_loss,
#             'ss_loss': ss_loss
#             }

#     # print(result)
    
#     return result

In [9]:
from transformers import TrainingArguments
# run_name is used to log metadata in wandb for tracking
r_name = "SSIT384_unlabled_bs16_100ep_2"

# arguments for training
training_args = TrainingArguments(
    output_dir="./SSIT",
    evaluation_strategy="steps",
    logging_steps=40,

    save_steps=40,
    eval_steps=40,
    save_total_limit=3,
    
    report_to="wandb",  # enable logging to W&B
    run_name=r_name,  # name of the W&B run (optional)
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    # lr_scheduler_type = 'constant_with_warmup', # 'constant', 'cosine'
    
    learning_rate=1e-3,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=100,
    warmup_ratio=0.02,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

# define trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    # compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=prepared_ds_test,
)



In [None]:
# trainer.train("./MedViT-base/checkpoint-22800")
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

[34m[1mwandb[0m: Currently logged in as: [33malexu97[0m ([33malexu97-skoltech[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss
40,16.4589,10.359242
80,6.8352,5.107365
120,4.0456,3.380331
160,2.6947,2.550945
200,2.2718,2.361122
240,2.0971,2.173501
280,2.0408,2.11973
320,1.9361,2.051498
360,1.8736,1.956703
400,1.7766,1.887273


Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature': 0.2}
Non-default generation parameters: {'temperature

In [12]:
model.save_pretrained(f"saved_models/{r_name}", from_pt=True)
# image_processor.save_pretrained(f"saved_models/{r_name}")

Non-default generation parameters: {'temperature': 0.2}
