In [1]:
import torch
import torch.nn as nn
import numpy as np
import copy

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair

In [2]:
# !wandb login b248f05b86545578e213a3d77725b1793b6c237a

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
# pip install ml-collections

In [2]:
!nvidia-smi

Tue Jul 23 18:04:54 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-16GB           On  |   00000000:06:00.0 Off |                    0 |
| N/A   44C    P0             43W /  300W |       3MiB /  16384MiB |      0%   E. Process |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-16GB           On  |   00

In [6]:
# !pip uninstall numpy -y
# !pip uninstall matplotlib -y

In [5]:
# !pip install numpy
# !pip install matplotlib

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1" # is need to train on 'hachiko'

from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")
from typing import Tuple
from typing import List
import random

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
# import torchvision.transforms as T
import torchvision.transforms.v2 as T
from torchvision.transforms import functional as F

from datasets import Dataset
from datasets import load_dataset

from transformers import ViTImageProcessor
from transformers import AutoImageProcessor
from transformers import TrainingArguments
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from transformers import Trainer

# import of custom functions
from validation_utils import get_compute_metrics
from data_utils import resample

torch.cuda.empty_cache()

In [5]:
print('Number CUDA Devices:', torch.cuda.device_count())
print ('Current cuda device: ', torch.cuda.current_device(), ' **May not correspond to nvidia-smi ID above, check visibility parameter')

Number CUDA Devices: 1
Current cuda device:  0  **May not correspond to nvidia-smi ID above, check visibility parameter


In [6]:
import wandb
# wandb.login('897cda038ea791f5f031be1adc101e476e229b31')

In [5]:
import math
from scipy import ndimage
from os.path import join as pjoin


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
# logger = logging.getLogger(__name__)

ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)
    
def swish(x):
    return x * torch.sigmoid(x)

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

class Attention(nn.Module):
    
    def __init__(self, config, vis, coeff_max=0.25):
        super(Attention, self).__init__()

        self.coeff_max = coeff_max

        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)
        self.softmax2 = Softmax(dim=-2)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
    
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / self.sqrt_att_head_size

        debug_mode = False
        print_info = False


        if mask is not None:

            if debug_mode:
                print_info = True if (random.random() < 0.000001) else False
                x = random.random()
                if (x > 0.00005) and (x < 0.00007):
                    print_info = True
                else:
                    print_info = False
            else:
                print_info = False

            max_as = torch.max(attention_scores[:, :, 0, :], dim=2, keepdim=False)[0]
            # max_as = max_as.to(device='cuda')

            if print_info:
                print("mask before:", mask)
                print("attn scores before:", attention_scores[:, :, 0, :])

                print("attn scores max_min before:")
                print(max_as, torch.min(attention_scores[:, :, 0, :], dim=2, keepdim=False)[0])

                print(torch.topk(attention_scores[:, :, 0, :], 5, largest=True), torch.topk(attention_scores[:, :, 0, :], 5, largest=False))

            mask_626 = torch.zeros(mask.size(0), (mask.size(1) + 1)) #, dtype=torch.float64) # dtype=torch.double)
            mask_626 = mask_626.to(device='cuda')
            mask_626[:, 1:] = mask[:, :]
            mask_626[:, 0] = 0

            if print_info: print("mask626:", mask_626)
            
            # positive only, obj + (max * coeff):
            attention_scores[:, :, 0, :] = \
                torch.where( mask_626[:, None, :] < 0.5, \
                        torch.add( attention_scores[:, :, 0, :], \
                            torch.mul( max_as[:, :, None] , torch.tensor(self.coeff_max).cuda()) ), \
                        attention_scores[:, :, 0, :] #.float()
                            )

            if print_info:
                print("attn scores after:", attention_scores[:, :, 0, :])

                print("attn scores max_min after:")
                print(torch.max(attention_scores[:, :, 0, :]), torch.min(attention_scores[:, :, 0, :]))
          
                print(torch.topk(attention_scores[:, :, 0, :], 5, largest=True), torch.topk(attention_scores[:, :, 0, :], 5, largest=False))


        attention_probs = self.softmax(attention_scores)
        
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)

        return attention_output, weights, self.softmax2(attention_scores)[:,:,:,0]

class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        # EXPERIMENTAL. Overlapping patches:
        overlap = False
        if overlap: slide = 12 # 14

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])

            if overlap:
                n_patches = ((img_size[0] - patch_size[0]) // slide + 1) * ((img_size[1] - patch_size[1]) // slide + 1)
            else:
                n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

            self.hybrid = False

        if overlap:
            self.patch_embeddings = Conv2d(in_channels=in_channels,
                                        out_channels=config.hidden_size,
                                        kernel_size=patch_size,
                                        stride=(slide, slide) )                 
        else:
            self.patch_embeddings = Conv2d(in_channels=in_channels,
                                        out_channels=config.hidden_size,
                                        kernel_size=patch_size,
                                        stride=patch_size )

        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)


        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class Block(nn.Module):
    def __init__(self, config, vis, coeff_max):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis, coeff_max)

    def forward(self, x, mask=None):
        h = x
        x = self.attention_norm(x)
        x, weights, contribution = self.attn(x, mask)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights, contribution

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

class Encoder(nn.Module):
    def __init__(self, config, vis, coeff_max):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        num_layers = config.transformer["num_layers"]

        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(num_layers):
            layer = Block(config, vis, coeff_max)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states, mask=None):
        attn_weights = []
        contributions = []
        tokens = [[] for i in range(hidden_states.shape[0])]

        for layer_block in self.layer:
            hidden_states, weights, contribution = layer_block(hidden_states, mask)

            if self.vis:
                attn_weights.append(weights)
                contributions.append(contribution)

        encoded = self.encoder_norm(hidden_states)

        return encoded, attn_weights

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis, coeff_max):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis, coeff_max)

    def forward(self, input_ids, mask=None):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output, mask)

        return encoded, attn_weights

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=400, num_classes=200, smoothing_value=0, zero_head=False, vis=False, dataset='CUB', coeff_max=0.25, contr_loss=False, focal_loss=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.smoothing_value = smoothing_value
        self.classifier = config.classifier
        self.dataset=dataset

        self.contr_loss = contr_loss
        self.focal_loss = focal_loss

        self.transformer = Transformer(config, img_size, vis, coeff_max)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None, mask=None):
        x, attn_weights = self.transformer(x, mask)
        print("X shape", x.shape)
        logits = self.head(x[:, 0])
        print("Logits shape", logits.shape)

        if labels is not None:
            if self.smoothing_value == 0:
                loss_fct = CrossEntropyLoss()
            else:
                loss_fct = LabelSmoothing(self.smoothing_value)

            if self.focal_loss: # enforce another type of loss
                loss_fct = FocalLoss()

            ce_loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))

            if self.contr_loss:
                contrast_loss = con_loss(x[:, 0], labels.view(-1))
                loss = ce_loss + contrast_loss
            else:
                loss = ce_loss # FFVT

            return loss, logits
        else:
            return logits, attn_weights

    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.weight)
                nn.init.zeros_(self.head.bias)
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())

            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))

            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                if bname.startswith('ff') == False:
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname)

import ml_collections

def get_l16_config():
    """Returns the ViT-L/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 1024
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 4096
    config.transformer.num_heads = 16
    config.transformer.num_layers = 24
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config
    
def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config

cfgs = {
    'l16': get_l16_config,
    'b16': get_b16_config
}


In [6]:
# need to define config file to have an ability to 
# load pretrained weights received from previos train stage
class SMViTConfig(PretrainedConfig):
    model_type = "sm-vit"

    def __init__(
        self,
        img_size: int = 512,
        num_classes: int = 5,
        smoothing_value: float = 0,
        zero_head: bool = True,
        vis: bool = False,
        coeff_max: float = False,
        focal_loss: bool = False,

        # config = ml_collections.ConfigDict()
        patches: dict = {'size': (16, 16)},
        hidden_size: int = 1024,
        # transformer = ml_collections.ConfigDict()
        mlp_dim: int = 4096,
        num_heads: int = 16,
        num_layers: int = 24,
        attention_dropout_rate: float = 0.0,
        dropout_rate: float = 0.1,
        classifier: str = 'token',
        # representation_size = None
        **kwargs
    ):
        self.img_size = img_size
        self.num_classes = num_classes
        self.smoothing_value = smoothing_value
        self.zero_head = zero_head
        self.vis = vis
        self.coeff_max = coeff_max
        self.focal_loss = focal_loss
        self.config = 'b16'
        super().__init__(**kwargs)

class SMViTClassification(PreTrainedModel):
    config_class = SMViTConfig

    def __init__(self, config, pretrained=False):
        super().__init__(config)

        cfg = cfgs[config.config]()

        if pretrained is False: # without pretrained weights
          print('Initialized with random weights:')
          self.model = VisionTransformer(
          img_size = config.img_size,
          num_classes = config.num_classes,
          smoothing_value = config.smoothing_value,
          zero_head = config.zero_head,
          vis = config.vis,
          coeff_max = config.coeff_max,
          focal_loss = config.focal_loss,
          config = cfg
          )

        else:
            
            self.model = VisionTransformer(
                img_size = config.img_size,
                num_classes = config.num_classes,
                smoothing_value = config.smoothing_value,
                zero_head = config.zero_head,
                vis = config.vis,
                coeff_max = config.coeff_max,
                focal_loss = config.focal_loss,
                config = cfg)
            print("Load weights")
            self.model.load_from(np.load("imagenet21k_ViT-B_16.npz"))

    # FIXME: need to add extraicting mask to original trainer
    def forward(self, pixel_values, labels=None, masks = None):
        # define function in transformers library maner
        # logits = self.model(pixel_values, mask = masks)
        if labels is not None:
            loss, logits = self.model(pixel_values, labels=labels, mask=masks)
            # loss = torch.nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            logits, attn_weights = self.model(pixel_values, labels=labels, mask=masks)
            return {"logits": logits, "attn_weights": attn_weights}

In [7]:
smvit_pretrained_config = SMViTConfig()
model = SMViTClassification(smvit_pretrained_config, pretrained=True)

# from transformers import ViTForImageClassification

# model = ViTForImageClassification.from_pretrained(
#     # 'google/vit-hybrid-base-bit-384',
#     'google/vit-base-patch16-384',
#     ignore_mismatched_sizes=True,
#     num_labels=5
# )

Load weights
load_pretrained: resized variant: torch.Size([1, 197, 768]) to torch.Size([1, 1025, 768])
load_pretrained: grid-size from 14 to 32


In [8]:
test_batch = torch.rand(4, 3, 512, 512)
output = model(test_batch)
print(output['logits'])

X shape torch.Size([4, 1025, 768])
Logits shape torch.Size([4, 5])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<AddmmBackward0>)


In [6]:
# from transformers import AutoImageProcessor, ViTForMaskedImageModeling

# model = ViTForMaskedImageModeling.from_pretrained('google/vit-base-patch16-384')
# # num_patches = (model.config.image_size // model.config.patch_size) ** 2
# print(model.config.image_size)
# print(model.config.patch_size)

Some weights of ViTForMaskedImageModeling were not initialized from the model checkpoint at google/vit-base-patch16-384 and are newly initialized: ['decoder.0.bias', 'decoder.0.weight', 'vit.embeddings.mask_token']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


384
16


In [40]:
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/trainLabels.csv') # initial table
print(labelsTable.shape)
print(labelsTable.shape)

(35126, 2)
(35126, 2)


In [10]:
#FIXME: rewrite path and add mask path

# load dataset via csv table
labelsTable = pd.read_csv('../mnt/local/data/kalexu97/trainLabels.csv') # initial table

error_images = ['15337_left.jpeg', '40764_right.jpeg']
    # '15337_left.jpeg',
                # '40551_left.jpeg',
                # '20289_right.jpeg',
                # '27991_right.jpeg',
                # '39477_right.jpeg',
                # '40758_left.jpeg',
                # '17768_left.jpeg']

for error_image in error_images:
    error_image = error_image[:-5]
    labelsTable = labelsTable[labelsTable.image != error_image]

# add folder path 'mask_image'
root_dir = '../mnt/local/data/kalexu97/processed_train'
mask_dir = '../mnt/local/data/kalexu97/saliency_mask/'

labelsTable['image_path'] = labelsTable['image'].apply(lambda x: os.path.join(root_dir, x+'.jpeg'))
labelsTable['mask_image'] = labelsTable['image'].apply(lambda x: os.path.join(mask_dir, x+'.npy'))
labelsTable['label'] = labelsTable['level']
labelsTable = labelsTable.drop(columns=['image', 'level'], axis=1)

# dataset is spliated to trian and test previously, and is constant for every training process
test_dataset = pd.read_csv('test_dataset.csv')
test_dataset['image'] = test_dataset['image_path'].apply(lambda x: x[33:])

for error_image in error_images:
    error_image = error_image
    test_dataset = test_dataset[test_dataset.image != error_image]
    
test_dataset['image_path'] = test_dataset['image'].apply(lambda x: os.path.join(root_dir, x))
test_dataset['mask_image'] = test_dataset['image'].apply(lambda x: os.path.join(mask_dir, x[:-5]+'.npy'))

# subtract the test_dataset from the full dataset to get the train_dataset
df = pd.concat([test_dataset, labelsTable])
df = df.reset_index(drop=True)
df_gpby = df.groupby(list(['image_path', 'label']))
idx = [x[0] for x in df_gpby.groups.values() if len(x) == 1]

train_dataset = df.reindex(idx).drop(columns=['Unnamed: 0'], axis=1)

In [67]:
# test_dataset.image_path.to_list()

In [68]:
# test_dataset.mask_image.to_list()

In [69]:
# test_dataset.image_path.to_list()

In [11]:
# RUS for major classes, ROS for minor classes
# number of items in each class is equal to 
#           ratio * len(most_minor_dataset) 

# oversampling just repeating minority class items
# enought times to be equal to major dataset in size
train_dataset = resample(train_dataset, ratio = 35)

0: length: 19460
1: length: 19460
2: length: 19460
3: length: 19460
4: length: 19460
N_added_rows:  26953
N_all_rows:  28099
Ratio of used rows:  0.9592156304494822


In [12]:
# define preprocessor
model_name_or_path = "./saved_models/MedViT320_tr35_stg1_8bs_lr2e-5_30ep"
image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)

size = 512

# Pre-Augmetations
_transforms_train = T.Compose([
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomVerticalFlip(p = 0.5),
    T.RandomCrop(460, padding_mode='symmetric', pad_if_needed=True),
    T.Resize((512, 512), interpolation=T.InterpolationMode.BICUBIC),
    # T.TrivialAugmentWide(),
    # Sharpness(),
    # Blur()
])

tens2img = T.ToPILImage()
img2tens = T.ToTensor()

# for some models it is possible to change input size between training stage
image_processor.size['height'] = size
image_processor.size['width'] = size

def load_image(path_image, mask_path, mode):
    """
    The function loads image from path and make Pre-Augmentation.
    """
    # print(path_image)
    top_per = 0.4
    image = Image.open(path_image)
    orig_mask = np.load(mask_path, mmap_mode='r')
    orig_mask = torch.from_numpy(orig_mask)

    image, orig_mask = _transforms_train(image, tens2img(orig_mask))
    orig_mask = img2tens(orig_mask)[0]

    # mask_size = int(orig_mask.shape[0] // 16)
    mask_size = int(size//16)

    transform = T.Resize(mask_size, interpolation=Image.NEAREST)
    resized_mask = transform(orig_mask[None, :, :])
    # bool_resized_mask = (resized_mask > 0.1)*1 #### CHECK: that it should not be inverse

    low_val_in_topl_p1 = torch.topk(resized_mask.flatten(), int(0.4*resized_mask.shape[1]**2)).values[-1]
    # low_val_in_topl_p2 = torch.topk(resized_mask.flatten(), int(0.55*resized_mask.shape[1]**2)).values[-1]
    
    # rand_region_bids = torch.logical_and(resized_mask[0]>low_val_in_topl_p2, resized_mask[0]<low_val_in_topl_p1)
    # bool_masked_pos = torch.randint(low=0, high=2, size=(rand_region_bids.shape)).bool()
    # rand_region_bids = torch.logical_and(rand_region_bids, bool_masked_pos)
    
    # final_mask = torch.logical_or(resized_mask[0]>low_val_in_topl_p1, rand_region_bids)
    final_mask = resized_mask[0]>low_val_in_topl_p1

    mask = torch.flatten(final_mask) 


    if mode == 'train':
        # image = _transforms_train(image)
        # FIXME: add trainsforms !!!
        return [image, mask]
        
    else:
        # image = _transforms_test(image)
        return [image, mask]


def func_transform(examples):
    """
    The function is used to preprocess train dataset.
    """
    # pre-augmentation and preprocessing
    transformed_inputs = [load_image(path_img, path_msk, 'train') for path_img, path_msk in zip(examples['image_path'], examples['mask_image'])]
    images = [item[0] for item in transformed_inputs]
    masks = [item[1] for item in transformed_inputs]

    # print(masks)
    # print(images)
    # inputs = image_processor([load_image(path, lb, 'train')
                                # for path, lb in zip(examples['image_path'], examples['label'])], return_tensors='pt')
    inputs = image_processor(images, return_tensors='pt')
    # print(inputs)
    # print(masks)
    inputs['mask'] = masks
    inputs['label'] = examples['label']

    return inputs

def func_transform_test(examples):
    """
    The function is used to preprocess test dataset.
    """
    # pre-augmentation and preprocessing
    transformed_inputs = [load_image(path_img, path_msk, 'test') for path_img, path_msk in zip(examples['image_path'], examples['mask_image'])]
    images = [item[0] for item in transformed_inputs]
    masks = [item[1] for item in transformed_inputs]
    
    inputs = image_processor(images, return_tensors='pt')
    inputs['mask'] = masks
    inputs['label'] = examples['label']

    return inputs

# to dataset
train_ds = Dataset.from_pandas(train_dataset, preserve_index=False)
test_ds = Dataset.from_pandas(test_dataset, preserve_index=False)

# apply preprocessing
prepared_ds_train = train_ds.with_transform(func_transform)
prepared_ds_test = test_ds.with_transform(func_transform_test)

# for sorted datasets shuffling can be usefull
prepared_ds_train = prepared_ds_train.shuffle(seed=42)
prepared_ds_test = prepared_ds_test.shuffle(seed=42)

In [29]:
# prepared_ds_train[[0, 1, 2]]

In [13]:
# Define function to define collate function
def collate_fn(batch):
    # print([x['mask'] for x in batch])
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch]),
        'masks': torch.stack([x['mask'] for x in batch]),
        # 'masks': None
    }

In [14]:
# val_dataset is alse defined previously, so we just need to load its indexes
# with open('test_indeces.npy', 'rb') as f:
#     sample_ids = np.load(f)
#     inv_sample_ids = np.load(f)

sample_ids = np.random.choice(len(prepared_ds_test), size=1000, replace=False)
inv_sample_ids = np.setdiff1d(np.arange(len(prepared_ds_test)), sample_ids)

val_ds = prepared_ds_test.select(sample_ids)
test_ds = prepared_ds_test.select(inv_sample_ids)

In [15]:
# run_name is used to log metadata in wandb for tracking
r_name = "SM512Pr_Mask04_bs8_10ep"

# define the function to compute metrics
compute_metrics = get_compute_metrics(r_name, 'EyE', save_cm=False)

# arguments for training
training_args = TrainingArguments(
    output_dir="./SMViT-withMask-04rnd055",
    evaluation_strategy="steps",
    logging_steps=50,

    save_steps=50,
    eval_steps=50,
    save_total_limit=3,
    
    report_to="wandb",  # enable logging to W&B
    run_name=r_name,  # name of the W&B run (optional)
    
    remove_unused_columns=False,
    dataloader_num_workers = 16,
    # lr_scheduler_type = 'constant_with_warmup', # 'constant', 'cosine'
    
    learning_rate=2e-5,
    # label_smoothing_factor = 0.6,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    warmup_ratio=0.02,
    
    metric_for_best_model="kappa", # select the best model via metric kappa
    greater_is_better = True,
    load_best_model_at_end=True,
    
    push_to_hub=False
)

# define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds_train,
    eval_dataset=val_ds,
)

In [14]:
# !ls ../mnt/local/data/kalexu97/train

In [15]:
# !ls ../mnt/local/data/kalexu97/processed_train

In [18]:
!ls ../mnt/local/data/kalexu97/processed_train -1 | wc -l

35124


In [19]:
!ls ../mnt/local/data/kalexu97/train -1 | wc -l

35126


In [None]:
# trainer.train("./MedViT-base/checkpoint-22800")
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

[34m[1mwandb[0m: Currently logged in as: [33malexu97[0m ([33malexu97-skoltech[0m). Use [1m`wandb login --relogin`[0m to force relogin


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy,Kappa,F1,Roc Auc,Class 0,Class 1,Class 2,Class 3,Class 4
50,1.6094,1.609355,0.158,0.024705,0.074782,0.634044,0.281,0.933,0.25,0.968,0.884
100,1.6084,1.605865,0.11,0.188696,0.06731,0.669936,0.276,0.253,0.81,0.969,0.912
150,1.6025,1.596006,0.086,0.292862,0.019616,0.698309,0.269,0.215,0.837,0.955,0.896
200,1.5891,1.575572,0.089,0.29325,0.028336,0.6993,0.27,0.234,0.845,0.957,0.872
250,1.5648,1.545631,0.086,0.311681,0.020143,0.718975,0.269,0.226,0.849,0.945,0.883
300,1.5347,1.534902,0.075,0.265629,0.013247,0.691915,0.269,0.327,0.845,0.935,0.774
350,1.497,1.482901,0.083,0.293065,0.016013,0.709456,0.269,0.257,0.85,0.957,0.833
400,1.4732,1.46875,0.075,0.291671,0.013485,0.69013,0.269,0.283,0.847,0.959,0.792
450,1.4407,1.482553,0.071,0.229803,0.012647,0.697875,0.269,0.34,0.843,0.941,0.749
500,1.4183,1.360159,0.081,0.395734,0.015134,0.71461,0.269,0.207,0.85,0.946,0.89


[[ 17   0 633   0  81]
 [  0   0  60   0   7]
 [  5   0 131   0  14]
 [  0   0  28   0   4]
 [  0   0  10   0  10]]
[[  8 616  52   0  55]
 [  0  59   6   0   2]
 [  1 104  28   0  17]
 [  0  15   8   2   7]
 [  0   4   2   1  13]]
[[  0 663  11   4  53]
 [  0  65   1   0   1]
 [  0 107   2   9  32]
 [  0  12   3   1  16]
 [  0   1   0   1  18]]
[[  1 651   9   8  62]
 [  0  62   1   1   3]
 [  0  99   7   2  42]
 [  0  10   2   0  20]
 [  0   1   0   0  19]]
[[  0 662   3  13  53]
 [  0  64   0   1   2]
 [  0  96   3   9  42]
 [  0  12   1   0  19]
 [  0   1   0   0  19]]
[[  0 593   3  21 114]
 [  0  55   0   1  11]
 [  0  65   0  11  74]
 [  0   3   2   1  26]
 [  0   0   0   1  19]]
[[  0 642   0   9  80]
 [  0  62   0   0   5]
 [  0  87   0   4  59]
 [  0   8   0   2  22]
 [  0   1   0   0  19]]
[[  0 624   2   7  98]
 [  0  54   1   1  11]
 [  0  74   0   2  74]
 [  0   6   0   1  25]
 [  0   0   0   0  20]]
[[  0 568   7  21 135]
 [  0  51   0   2  14]
 [  0  73   0   4  73]
 [ 

In [20]:
!ls ../mnt/local/data/kalexu97/saliency_mask -1 | wc -l

35124
