In [2]:
import torch
import torch.nn as nn
import numpy as np
import math
import copy

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair

In [2]:
# !wandb login b248f05b86545578e213a3d77725b1793b6c237a

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
# pip install ml-collections

In [2]:
!nvidia-smi

Sun Jul 14 21:06:53 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-16GB           On  |   00000000:06:00.0 Off |                    0 |
| N/A   46C    P0             62W /  300W |   16017MiB /  16384MiB |      0%   E. Process |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-16GB           On  |   00

In [6]:
# !pip uninstall numpy -y
# !pip uninstall matplotlib -y

In [5]:
# !pip install numpy
# !pip install matplotlib

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple
from typing import List
import random

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# import torchvision.transforms as T
import torchvision.transforms.v2 as T
from torchvision.transforms import functional as F

from datasets import Dataset
from datasets import load_dataset

from transformers import ViTImageProcessor
from transformers import AutoImageProcessor
from transformers import TrainingArguments
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from transformers import Trainer

# import of custom functions
from validation_utils import get_compute_metrics
from data_utils import resample

torch.cuda.empty_cache()

In [4]:
print('Number CUDA Devices:', torch.cuda.device_count())
print ('Current cuda device: ', torch.cuda.current_device(), ' **May not correspond to nvidia-smi ID above, check visibility parameter')

Number CUDA Devices: 1
Current cuda device:  0  **May not correspond to nvidia-smi ID above, check visibility parameter


In [5]:
import wandb
# wandb.login('897cda038ea791f5f031be1adc101e476e229b31')

In [14]:
from torch.nn.modules.utils import _pair
# from torchmetrics.functional import pairwise_cosine_similarity

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} #, "swish": swish}

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        # EXPERIMENTAL. Overlapping patches:
        overlap = False
        if overlap: slide = 12 # 14

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])

            if overlap:
                n_patches = ((img_size[0] - patch_size[0]) // slide + 1) * ((img_size[1] - patch_size[1]) // slide + 1)
            else:
                n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

            self.hybrid = False

        if overlap:
            self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
                                        out_channels=config.hidden_size,
                                        kernel_size=patch_size,
                                        stride=(slide, slide) )                 
        else:
            self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
                                        out_channels=config.hidden_size,
                                        kernel_size=patch_size,
                                        stride=patch_size )

        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
        # self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = nn.Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        # B = x.shape[0]
        # cls_tokens = self.cls_token.expand(B, -1, -1)

        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        # x = torch.cat((cls_tokens, x), dim=1)


        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class ScaleDotProductAttention(nn.Module):
    """
    compute scale dot product attention

    Query : given sentence that we focused on (decoder)
    Key : every sentence to check relationship with Qeury(encoder)
    Value : every sentence same with Key (encoder)
    """

    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        # input is 4 dimension tensor
        # [batch_size, head, length, d_tensor]
        batch_size, head, length, d_tensor = k.size()

        # 1. dot product Query with Key^T to compute similarity
        k_t = k.transpose(2, 3)  # transpose
        score = (q @ k_t) / math.sqrt(d_tensor)  # scaled dot product

        # 2. apply masking (opt)
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)

        # 3. pass them softmax to make [0, 1] range
        score = self.softmax(score)

        # 4. multiply with Value
        v = score @ v

        return v, score

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        # 1. dot product with weight matrices
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        # 2. split tensor by number of heads
        q, k, v = self.split(q), self.split(k), self.split(v)

        # 3. do scale dot product to compute similarity
        out, attention = self.attention(q, k, v, mask=mask)
        
        # 4. concat and pass to linear layer
        out = self.concat(out)
        out = self.w_concat(out)

        # 5. visualize attention map
        # TODO : we should implement visualization

        return out

    def split(self, tensor):
        """
        split tensor by number of head

        :param tensor: [batch_size, length, d_model]
        :return: [batch_size, head, length, d_tensor]
        """
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)

        return tensor

    def concat(self, tensor):
        """
        inverse function of self.split(tensor : torch.Tensor)

        :param tensor: [batch_size, head, length, d_tensor]
        :return: [batch_size, length, d_model]
        """
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor

        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # '-1' means last dimension. 

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out


class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class EncoderLayer(nn.Module):

    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=config.d_model, n_head=config.n_head)
        self.norm1 = LayerNorm(d_model=config.d_model)
        self.dropout1 = nn.Dropout(p=config.drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=config.d_model, hidden=config.ffn_hidden, drop_prob=config.drop_prob)
        self.norm2 = LayerNorm(d_model=config.d_model)
        self.dropout2 = nn.Dropout(p=config.drop_prob)

    def forward(self, x, src_mask=None):
        # 1. compute self attention
        _x = x
        x = self.attention(q=x, k=x, v=x, mask=src_mask)
        
        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        # 3. positionwise feed forward network
        _x = x
        x = self.ffn(x)
      
        # 4. add and norm
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

class Encoder(nn.Module):

    def __init__(self, img_size, config):
        super().__init__()

        ## Rewrite here
        # self.emb = TransformerEmbedding(d_model=d_model,
        #                                 max_len=max_len,
        #                                 vocab_size=enc_voc_size,
        #                                 drop_prob=drop_prob,
        #                                 device=device)

        self.embd = Embeddings(config, img_size, in_channels=3)

        self.layers = nn.ModuleList([EncoderLayer(config)
                                     for _ in range(config.n_layers)])

    def forward(self, x, src_mask=None):
        x = self.emb(x)

        for layer in self.layers:
            x = layer(x, src_mask)

        return x

class DecoderLayer(nn.Module):

    def __init__(self, config):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=config.d_model, n_head=config.n_head)
        self.norm1 = LayerNorm(d_model=config.d_model)
        self.dropout1 = nn.Dropout(p=config.drop_prob)

        self.enc_dec_attention = MultiHeadAttention(d_model=config.d_model, n_head=config.n_head)
        self.norm2 = LayerNorm(d_model=config.d_model)
        self.dropout2 = nn.Dropout(p=config.drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=config.d_model, hidden=config.ffn_hidden, drop_prob=config.drop_prob)
        self.norm3 = LayerNorm(d_model=config.d_model)
        self.dropout3 = nn.Dropout(p=config.drop_prob)

    def forward(self, dec, enc, trg_mask=None, src_mask=None):    
        # 1. compute self attention
        _x = dec
        x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)
        
        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            # 3. compute encoder - decoder attention
            _x = x
            x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
            
            # 4. add and norm
            x = self.dropout2(x)
            x = self.norm2(x + _x)

        # 5. positionwise feed forward network
        _x = x
        x = self.ffn(x)
        
        # 6. add and norm
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

class Decoder(nn.Module):
    def __init__(self, img_size, config):
        super().__init__()

        # rewrite here
        self.embd = Embeddings(config, img_size, in_channels=3)

        self.layers = nn.ModuleList([DecoderLayer(config)
                                     for _ in range(config.n_layers)])

        self.linear = nn.Linear(config.d_model, config.dec_voc_size) # it should be equivalent to n_classes

    def forward(self, trg, src, trg_mask=None, src_mask=None):
        trg = self.emb(trg)

        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)

        # pass to LM head
        output = self.linear(trg)
        return output

class LAVIT2(nn.Module):
    def __init__(self, config, img_size=512, num_classes=5):
        super(LAVIT2, self).__init__()
        self.encoder = Encoder(img_size, config)
        self.decoder = Decoder(img_size, config)

    def forward(self, src, trg, src_mask=None, trg_mask=None):
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        return 0, output



In [15]:
from transformers import PretrainedConfig
import ml_collections

def get_base_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.d_model = 512,
    config.n_heads = 8,
    config.drop_prob = 0.1,
    config.ffn_hidden = 2048,
    config.n_layers = 6,
    config.dec_voc_size = 64
    
    return config

cfgs = {
    'b16': get_base_config
}

class LAViTConfig(PretrainedConfig):
    model_type = "la-vit-2"

    def __init__(
        self,
        img_size: int = 512,
        patches: dict = {'size': (16, 16)}
        **kwargs
    ):
        self.img_size = img_size
        self.config = 'b16'
        super().__init__(**kwargs)

class LAViTClassification(PreTrainedModel):
    config_class = LAViTConfig

    def __init__(self, config, pretrained=False):
        super().__init__(config)

        cfg = cfgs[config.config]()

        if pretrained is False: # without pretrained weights
          print('Initialized with random weights:')
          self.model = LAVIT2(
          img_size = config.img_size,
          num_classes = config.num_classes,
          config = cfg
          )

    def forward(self, pixel_values, labels=None, masks = None):
        # define function in transformers library maner
        # logits = self.model(pixel_values, mask = masks)
        if labels is not None:
            loss, logits = self.model(pixel_values, labels=labels, mask=masks)
            # loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            logits, attn_weights = self.model(pixel_values, labels=labels, mask=masks)
            return {"logits": logits, "attn_weights": attn_weights}

In [16]:
config = LAViTConfig()
model = LAViTClassification(config, pretrained=False)

# from transformers import ViTForImageClassification

# model = ViTForImageClassification.from_pretrained(
#     # 'google/vit-hybrid-base-bit-384',
#     'google/vit-base-patch16-384',
#     ignore_mismatched_sizes=True,
#     num_labels=5
# )

Initialized with random weights:


In [None]:
cfg = cfgs[config.config]()

model = LAVIT2(img_size = 512,
          num_classes = 5,
          config = cfg)

device = torch.device('cuda') 
model = model.to(device)

test_input_src = torch.randn(6, 3, 512, 512).to(device)
test_input_trg = torch.randn(6, 3, 512, 512).to(device)
test_output = model(test_input_src, test_input_trg)
print("test output shape: ", test_output.shape)

In [6]:
# from transformers import AutoImageProcessor, ViTForMaskedImageModeling

# model = ViTForMaskedImageModeling.from_pretrained('google/vit-base-patch16-384')
# # num_patches = (model.config.image_size // model.config.patch_size) ** 2
# print(model.config.image_size)
# print(model.config.patch_size)

Some weights of ViTForMaskedImageModeling were not initialized from the model checkpoint at google/vit-base-patch16-384 and are newly initialized: ['decoder.0.bias', 'decoder.0.weight', 'vit.embeddings.mask_token']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


384
16


In [8]:
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/trainLabels.csv') # initial table
print(labelsTable.shape)
print(labelsTable.shape)

(35126, 2)
(35126, 2)


In [17]:
#FIXME: rewrite path and add mask path

# load dataset via csv table
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/trainLabels.csv') # initial table

error_images = ['15337_left.jpeg', '40764_right.jpeg']
    # '15337_left.jpeg',
                # '40551_left.jpeg',
                # '20289_right.jpeg',
                # '27991_right.jpeg',
                # '39477_right.jpeg',
                # '40758_left.jpeg',
                # '17768_left.jpeg']

for error_image in error_images:
    error_image = error_image[:-5]
    labelsTable = labelsTable[labelsTable.image != error_image]

# add folder path 'mask_image'
root_dir = '../mnt/local/data/kalexu97/processed_train'
mask_dir = '../mnt/local/data/kalexu97/saliency_mask/'

labelsTable['image_path'] = labelsTable['image'].apply(lambda x: os.path.join(root_dir, x+'.jpeg'))
labelsTable['mask_image'] = labelsTable['image'].apply(lambda x: os.path.join(mask_dir, x+'.npy'))
labelsTable['label'] = labelsTable['level']
labelsTable = labelsTable.drop(columns=['image', 'level'], axis=1)

# dataset is spliated to trian and test previously, and is constant for every training process
test_dataset = pd.read_csv('test_dataset.csv')
test_dataset['image'] = test_dataset['image_path'].apply(lambda x: x[33:])

for error_image in error_images:
    error_image = error_image
    test_dataset = test_dataset[test_dataset.image != error_image]
    
test_dataset['image_path'] = test_dataset['image'].apply(lambda x: os.path.join(root_dir, x))
test_dataset['mask_image'] = test_dataset['image'].apply(lambda x: os.path.join(mask_dir, x[:-5]+'.npy'))

# subtract the test_dataset from the full dataset to get the train_dataset
df = pd.concat([test_dataset, labelsTable])
df = df.reset_index(drop=True)
df_gpby = df.groupby(list(['image_path', 'label']))
idx = [x[0] for x in df_gpby.groups.values() if len(x) == 1]

train_dataset = df.reindex(idx).drop(columns=['Unnamed: 0'], axis=1)

In [67]:
# test_dataset.image_path.to_list()

In [68]:
# test_dataset.mask_image.to_list()

In [69]:
# test_dataset.image_path.to_list()

In [18]:
# RUS for major classes, ROS for minor classes
# number of items in each class is equal to 
#           ratio * len(most_minor_dataset) 

# oversampling just repeating minority class items
# enought times to be equal to major dataset in size
train_dataset = resample(train_dataset, ratio = 35)

0: length: 19460
1: length: 19460
2: length: 19460
3: length: 19460
4: length: 19460
N_added_rows:  26953
N_all_rows:  28099
Ratio of used rows:  0.9592156304494822


In [19]:
# define preprocessor
model_name_or_path = "./saved_models/MedViT320_tr35_stg1_8bs_lr2e-5_30ep"
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)

size = 512

# Pre-Augmetations
_transforms_train = T.Compose([
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomVerticalFlip(p = 0.5),
    T.RandomCrop(460, padding_mode='symmetric', pad_if_needed=True),
    T.Resize((512, 512), interpolation=T.InterpolationMode.BICUBIC),
    # T.TrivialAugmentWide(),
    # Sharpness(),
    # Blur()
])

tens2img = T.ToPILImage()
img2tens = T.ToTensor()

# for some models it is possible to change input size between training stage
image_processor.size['height'] = size
image_processor.size['width'] = size

def load_image(path_image, mask_path, mode):
    """
    The function loads image from path and make Pre-Augmentation.
    """
    # print(path_image)
    top_per = 0.4
    image = Image.open(path_image)
    orig_mask = np.load(mask_path, mmap_mode='r')
    orig_mask = torch.from_numpy(orig_mask)

    image, orig_mask = _transforms_train(image, tens2img(orig_mask))
    orig_mask = img2tens(orig_mask)[0]

    # mask_size = int(orig_mask.shape[0] // 16)
    mask_size = int(size//16)

    transform = T.Resize(mask_size, interpolation=Image.NEAREST)
    resized_mask = transform(orig_mask[None, :, :])
    # bool_resized_mask = (resized_mask > 0.1)*1 #### CHECK: that it should not be inverse

    low_val_in_topl_p1 = torch.topk(resized_mask.flatten(), int(0.4*resized_mask.shape[1]**2)).values[-1]
    low_val_in_topl_p2 = torch.topk(resized_mask.flatten(), int(0.55*resized_mask.shape[1]**2)).values[-1]
    
    rand_region_bids = torch.logical_and(resized_mask[0]>low_val_in_topl_p2, resized_mask[0]<low_val_in_topl_p1)
    bool_masked_pos = torch.randint(low=0, high=2, size=(rand_region_bids.shape)).bool()
    rand_region_bids = torch.logical_and(rand_region_bids, bool_masked_pos)
    
    final_mask = torch.logical_or(resized_mask[0]>low_val_in_topl_p1, rand_region_bids)
    #bool_resized_mask = resized_mask[0]>low_val_in_topl

    mask = torch.flatten(final_mask) 


    if mode == 'train':
        # image = _transforms_train(image)
        # FIXME: add trainsforms !!!
        return [image, mask]
        
    else:
        # image = _transforms_test(image)
        return [image, mask]


def func_transform(examples):
    """
    The function is used to preprocess train dataset.
    """
    # pre-augmentation and preprocessing
    transformed_inputs = [load_image(path_img, path_msk, 'train') for path_img, path_msk in zip(examples['image_path'], examples['mask_image'])]
    images = [item[0] for item in transformed_inputs]
    masks = [item[1] for item in transformed_inputs]

    # print(masks)
    # print(images)
    # inputs = image_processor([load_image(path, lb, 'train')
                                # for path, lb in zip(examples['image_path'], examples['label'])], return_tensors='pt')
    inputs = image_processor(images, return_tensors='pt')
    # print(inputs)
    # print(masks)
    inputs['mask'] = masks
    inputs['label'] = examples['label']

    return inputs

def func_transform_test(examples):
    """
    The function is used to preprocess test dataset.
    """
    # pre-augmentation and preprocessing
    transformed_inputs = [load_image(path_img, path_msk, 'test') for path_img, path_msk in zip(examples['image_path'], examples['mask_image'])]
    images = [item[0] for item in transformed_inputs]
    masks = [item[1] for item in transformed_inputs]
    
    inputs = image_processor(images, return_tensors='pt')
    inputs['mask'] = masks
    inputs['label'] = examples['label']

    return inputs

# to dataset
train_ds = Dataset.from_pandas(train_dataset, preserve_index=False)
test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)

# apply preprocessing
prepared_ds_train = train_ds.with_transform(func_transform)
prepared_ds_test = test_ds.with_transform(func_transform_test)

# for sorted datasets shuffling can be usefull
prepared_ds_train = prepared_ds_train.shuffle(seed=42)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)

In [29]:
# prepared_ds_train[[0, 1, 2]]

In [20]:
# Define function to define collate function
def collate_fn(batch):
    # print([x['mask'] for x in batch])
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch]),
        # 'masks': torch.stack([x['mask'] for x in batch]),
        'masks': None
    }

In [21]:
# val_dataset is alse defined previously, so we just need to load its indexes
# with open('test_indeces.npy', 'rb') as f:
#     sample_ids = np.load(f)
#     inv_sample_ids = np.load(f)

sample_ids = np.random.choice(len(prepared_ds_test), size=1000, replace=False)
inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)

val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

In [22]:
# run_name is used to log metadata in wandb for tracking
r_name = "LAVIT_test_4_addSelfAtt"

# define the function to compute metrics
compute_metrics = get_compute_metrics(r_name, 'EyE', save_cm=False)

# arguments for training
training_args = TrainingArguments(
    output_dir="./LAVIT",
    evaluation_strategy="steps",
    logging_steps=50,

    save_steps=50,
    eval_steps=50,
    save_total_limit=3,
    
    report_to="wandb",  # enable logging to W&B
    run_name=r_name,  # name of the W&B run (optional)
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    # lr_scheduler_type = 'constant_with_warmup', # 'constant', 'cosine'
    
    learning_rate=2e-5,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    warmup_ratio=0.02,
    
    metric_for_best_model="kappa", # select the best model via metric kappa
    greater_is_better = True,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=val_ds,
)

In [14]:
# !ls ../mnt/local/data/kalexu97/train

In [15]:
# !ls ../mnt/local/data/kalexu97/processed_train

In [18]:
!ls ../mnt/local/data/kalexu97/processed_train -1 | wc -l

35124


In [19]:
!ls ../mnt/local/data/kalexu97/train -1 | wc -l

35126


In [None]:
# trainer.train("./MedViT-base/checkpoint-22800")
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

[34m[1mwandb[0m: Currently logged in as: [33malexu97[0m ([33malexu97-skoltech[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy,Kappa,F1,Roc Auc,Class 0,Class 1,Class 2,Class 3,Class 4
50,1.6099,1.618397,0.173,0.034482,0.201751,0.540412,0.324,0.451,0.68,0.973,0.918
100,1.609,1.591968,0.175,0.019548,0.192675,0.547696,0.321,0.927,0.489,0.952,0.661
150,1.6064,1.621138,0.127,0.039105,0.140831,0.558053,0.307,0.401,0.826,0.988,0.732
200,1.5997,1.660633,0.159,0.027889,0.226905,0.566659,0.35,0.92,0.775,0.493,0.78
250,1.6047,1.624267,0.211,0.0326,0.295826,0.576867,0.398,0.929,0.833,0.551,0.711
300,1.6073,1.64022,0.196,0.036465,0.250917,0.585922,0.358,0.787,0.646,0.915,0.686
350,1.615,1.581791,0.243,0.02739,0.310391,0.564061,0.409,0.552,0.841,0.914,0.77
400,1.6083,1.652287,0.188,0.034158,0.239979,0.549174,0.348,0.504,0.783,0.988,0.753
450,1.6153,1.764337,0.042,0.036879,0.011489,0.557529,0.256,0.633,0.848,0.677,0.67
500,1.6106,1.542696,0.326,0.043483,0.412015,0.557143,0.474,0.929,0.798,0.472,0.979


[[ 93 412 180  12  48]
 [  7  35  22   1   6]
 [ 14  85  42   1  10]
 [  0   8   3   0   1]
 [  3   8   5   1   3]]
[[ 91   1 376  33 244]
 [  4   0  41   1  25]
 [ 19   1  71   4  57]
 [  0   0   6   2   4]
 [  2   0   7   0  11]]
[[ 69 459  23   0 194]
 [  6  43   2   0  20]
 [ 11  95   4   0  42]
 [  0   9   0   0   3]
 [  0   8   1   0  11]]
[[125   8  77 380 155]
 [  8   0  11  37  15]
 [ 19   1  15  80  37]
 [  0   0   0   9   3]
 [  3   0   0   7  10]]
[[187   0  17 331 210]
 [ 15   0   3  32  21]
 [ 26   0   5  76  45]
 [  0   0   0   9   3]
 [  3   0   0   7  10]]
[[127 121 215  56 226]
 [  6  13  22   4  26]
 [ 18  30  44  12  48]
 [  0   3   3   2   4]
 [  0   1   6   3  10]]
[[198 320   5  56 166]
 [ 10  37   1   4  19]
 [ 32  78   0  12  30]
 [  1   7   1   1   2]
 [  1   9   0   3   7]]
[[130 372  69   0 174]
 [ 13  36   5   0  17]
 [ 17  79  14   0  42]
 [  1   7   2   0   2]
 [  6   3   3   0   8]]
[[  1 272   0 238 234]
 [  0  22   0  23  26]
 [  0  40   0  52  60]
 [ 

In [20]:
!ls ../mnt/local/data/kalexu97/saliency_mask -1 | wc -l

35124
