<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/VisualBertResMLP_EndoVis18_VQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Download Dataset

In [None]:
# Downloading the VQA EndoVis18 Dataset https://drive.google.com/file/d/1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN/view?usp=sharing
!gdown --id 1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN

# Unzipping the VQA EndoVis18 Dataset\
!unzip -q EndoVis-18-VQA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN
From (redirected): https://drive.google.com/uc?id=1WGdztykX3nW6pi_BKp4rO8nA7ESNRfVN&confirm=t&uuid=e8e9f5f0-1281-45cb-a3e4-edb3560f3c15
To: /content/EndoVis-18-VQA.zip
100% 2.70G/2.70G [00:40<00:00, 67.2MB/s]


#VisualBERT VQA Classification

Different Versions of VisualBert Adapatation:

Utils

In [None]:
import torch
import os
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_fscore_support

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(optimizer, shrink_factor):
    """
    Shrinks learning rate by a specified factor.

    :param optimizer: optimizer whose learning rate must be shrunk.
    :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
    """

    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))


def save_clf_checkpoint(checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, Acc, final_args):
    """
    Saves model checkpoint.
    """
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'Acc': Acc,
             'model': model,
             'optimizer': optimizer,
             'final_args': final_args}
    filename = checkpoint_dir + 'Best.pth.tar'
    torch.save(state, filename)

def calc_acc(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    return acc

def calc_classwise_acc(y_true, y_pred):
    matrix = confusion_matrix(y_true, y_pred)
    classwise_acc = matrix.diagonal()/matrix.sum(axis=1)
    return classwise_acc

def calc_map(y_true, y_scores):
    mAP = average_precision_score(y_true, y_scores,average=None)
    return mAP

def calc_precision_recall_fscore(y_true, y_pred):
    precision, recall, fscore, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division = 1)
    return(precision, recall, fscore)


def seed_everything(seed=27):
    '''
    Set random seed for reproducible experiments
    Inputs: seed number
    '''
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Training Model: VisualBERTResMLP

In [None]:
#############Dataloader###############

from torch.utils.data import Dataset
from PIL import Image
import os
import glob
import torchvision.transforms as transforms
from torchvision import models
from torch import nn
import os
import glob
import h5py
from PIL import Image

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor, AutoFeatureExtractor


class EndoVis18VQAGPTClassification(Dataset):
    '''
    	seq: train_seq  = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    	     val_seq    = [1, 5, 16]
    	folder_head     = 'dataset/EndoVis-18-VQA/seq_'
    	folder_tail     = '/vqa/Classification/*.txt'
    '''
    def __init__(self, seq, folder_head, folder_tail, transform=None):

        self.transform = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                    ])

        self.transform = None
        self.image_processor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
        # files, question and answers
        filenames = []
        for curr_seq in seq: filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines: self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' %(len(filenames), len(self.vqas)))

        # Labels
        self.labels = ['kidney', 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
                        'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction',
                        'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
                        'left-top', 'right-top', 'left-bottom', 'right-bottom']

    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        loc = self.vqas[idx][0].split('/')

        # img
        img_loc = os.path.join(loc[0],loc[1],'left_frames',loc[-1].split('_')[0]+'.png')

        # if self.transform:
        #     img = Image.open(img_loc)
        #     img = self.transform(img)
        # else:
        img = self.image_processor(Image.open(img_loc), return_tensors="pt")['pixel_values'][0]

        # question and answer
        question = self.vqas[idx][1].split('|')[0]
        label = self.labels.index(str(self.vqas[idx][1].split('|')[1]))

        return img, question, label


#############Model Architecture###############
import torch
from torch import nn
from transformers import VisualBertModel, VisualBertConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


import math
import torch
import torch.nn as nn

from transformers.utils import logging
from transformers import VisualBertPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer

'''
visualBertModule Embedding module for text and visual
'''
class VisualBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

        # For Visual Features
        # Token type and position embedding for image features
        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        if config.special_visual_initialize:
            self.visual_token_type_embeddings.weight.data = nn.Parameter(self.token_type_embeddings.weight.data.clone(), requires_grad=True)
            self.visual_position_embeddings.weight.data = nn.Parameter(self.position_embeddings.weight.data.clone(), requires_grad=True)

        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)


    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, visual_embeds=None, visual_token_type_ids=None, image_text_alignment=None):

        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + token_type_embeddings

        # Absolute Position Embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings

        if visual_embeds is not None:
            if visual_token_type_ids is None:
                visual_token_type_ids = torch.ones(visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device)

            visual_embeds = self.visual_projection(visual_embeds)
            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)

            if image_text_alignment is not None:
                # image_text_alignment = Batch x image_length x alignment_number.
                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.

                dtype = token_type_embeddings.dtype
                image_text_alignment_mask = (image_text_alignment != -1).long()
                # Get rid of the -1.
                image_text_alignment = image_text_alignment_mask * image_text_alignment

                # Batch x image_length x alignment length x dim
                visual_position_embeddings = self.position_embeddings(image_text_alignment)
                visual_position_embeddings *= image_text_alignment_mask.to(dtype=dtype).unsqueeze(-1)
                visual_position_embeddings = visual_position_embeddings.sum(2)

                # We want to averge along the alignment_number dimension.
                image_text_alignment_mask = image_text_alignment_mask.to(dtype=dtype).sum(2)

                if (image_text_alignment_mask == 0).sum() != 0:
                    image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid divide by zero error
                    logger.warning("Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero error.")

                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)

                visual_position_ids = torch.zeros(*visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device)

                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.
                if visual_position_embeddings.size(1) != visual_embeds.size(1):
                    if visual_position_embeddings.size(1) < visual_embeds.size(1):
                        raise ValueError(
                            f"Visual position embeddings length: {visual_position_embeddings.size(1)} "
                            f"should be the same as `visual_embeds` length: {visual_embeds.size(1)}"
                        )
                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]

                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(visual_position_ids)
            else:
                visual_position_ids = torch.zeros(*visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device)
                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)

            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings
            print(embeddings.shape, visual_embeddings.shape)
            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


'''
VisualBertEncoder SelfAttention sub-sub-module
'''
class VisualBertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):

        mixed_query_layer = self.query(hidden_states)

        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in VisualBertSelfAttentionModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs


'''
VisualBertEncoder SelfAttention output sub-sub-module
'''
class VisualBertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


'''
VisualBertEncoder Attention sub-module
'''
class VisualBertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = VisualBertSelfAttention(config)
        self.output = VisualBertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads)

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


'''
Cross-channel sub-layer module
'''
class Mlp(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


'''
cross-token sub-layer module
'''
class ResMLP_BLocks(nn.Module):
    def __init__(self, config, token_size):
        super().__init__()
        self.linear_patches = nn.Linear(token_size, token_size)  #Linear layer on patches
        self.layerNorm1 = nn.LayerNorm(config.hidden_size, eps= config.layer_norm_eps)
        self.mlp_channels = Mlp(config)            #MLP on channels
        self.layerNorm2 = nn.LayerNorm(config.hidden_size, eps= config.layer_norm_eps)

    def forward(self, x):
        res_1 = self.linear_patches(x.transpose(1,2)).transpose(1,2)
        x = self.layerNorm1(x + res_1)
        res_2 = self.mlp_channels(x)
        x = self.layerNorm2(x + res_2)
        return x


'''
VisualBertModel layer module
'''
class VisualBertResMLPLayer(nn.Module):
    def __init__(self, config, token_size):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = VisualBertAttention(config)
        self.output = ResMLP_BLocks(config, token_size)

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):

        self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions=output_attentions)
        attention_output = self_attention_outputs[0]

        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output)
        outputs = (layer_output,) + outputs

        return outputs

    def feed_forward_chunk(self, attention_output):
        layer_output = self.output(attention_output)
        return layer_output


'''
VisualBertModel VisualBertEncoder module
'''
class VisualBertResMLPEncoder(nn.Module):
    def __init__(self, config, token_size):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([VisualBertResMLPLayer(config, token_size) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)
                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask)
            else:
                layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)

            hidden_states = layer_outputs[0]
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions)



'''
VisualBertModel VisualBertPooler module
'''
class VisualBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class VisualBertResMLPModel(VisualBertPreTrainedModel):
    """
    VisualBert ResMLP model
    Incorporates
        (a) VisualBert for self-attention between visual and text token
        (b) ResMLP to enforce interaction among all visual and text tokens.
    """

    def __init__(self, config, token_size = 26, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = VisualBertEmbeddings(config)
        self.encoder = VisualBertResMLPEncoder(config, token_size)

        self.pooler = VisualBertPooler(config) if add_pooling_layer else None

        self.bypass_transformer = config.bypass_transformer

        if self.bypass_transformer:
            self.additional_layer = VisualBertLayer(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    #@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
                visual_embeds=None, visual_attention_mask=None, visual_token_type_ids=None, image_text_alignment=None, output_attentions=None, output_hidden_states=None, return_dict=None):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if visual_embeds is not None:
            visual_input_shape = visual_embeds.size()[:-1]

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)

        if visual_embeds is not None and visual_attention_mask is None:
            visual_attention_mask = torch.ones(visual_input_shape, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if visual_embeds is not None:
            # print('attention_mask:',attention_mask.shape, visual_attention_mask.shape)
            combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(combined_attention_mask, [batch_size, input_shape + visual_input_shape], device)

        else:
            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, [batch_size, input_shape], device)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds,
                                            visual_embeds=visual_embeds, visual_token_type_ids=visual_token_type_ids, image_text_alignment=image_text_alignment)
        print('embedding_output:', embedding_output.shape)

        if self.bypass_transformer and visual_embeds is not None:
            print('enter bypass_transformer')
            text_length = input_ids.size(1)
            text_embedding_output = embedding_output[:, :text_length, :]
            visual_embedding_output = embedding_output[:, text_length:, :]

            text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]
            print('text_embedding_output:', text_embedding_output.shape, 'visual_embedding_output:', visual_embedding_output.shape)
            encoded_outputs = self.encoder(text_embedding_output, attention_mask=text_extended_attention_mask, output_attentions=output_attentions,
                                            output_hidden_states=output_hidden_states, return_dict=return_dict)
            sequence_output = encoded_outputs[0]
            print('ResMLP out:', sequence_output.shape)

            concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)
            sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)
            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        else:
            print('else bypass_transformer')
            encoder_outputs = self.encoder(embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask,
                                            output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
            sequence_output = encoder_outputs[0]
            print('ResMLP out:', sequence_output.shape)

            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions)

'''
VisualBertResMLP Classification Model
'''
class VisualBertResMLPClassification(nn.Module):
    def __init__(self, vocab_size, layers, n_heads, num_class = 10, token_size = 26):
        super(VisualBertResMLPClassification, self).__init__()
        VBconfig = VisualBertConfig(vocab_size= vocab_size, visual_embedding_dim = 512, num_hidden_layers = layers, num_attention_heads = n_heads, hidden_size = 2048)
        self.VisualBertResMLPEncoder = VisualBertResMLPModel(VBconfig, token_size = token_size)
        self.classifier = nn.Linear(VBconfig.hidden_size, num_class)

        self.img_feature_extractor = models.resnet18(weights=True)
        self.img_feature_extractor.fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])
        # self.visual_projection = nn.Linear(512, 768)


    def forward(self, inputs, visual_embeds):
        # prepare visual embedding
        # print('visual_embeds', visual_embeds)
        visual_embeds = self.img_feature_extractor(visual_embeds).unsqueeze(1)
        # visual_embeds = self.visual_projection(visual_embeds)
        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(device)
        print('visual_embeds:', visual_embeds.shape, visual_attention_mask.shape)
        # append visual features to text
        inputs.update({
                        "visual_embeds": visual_embeds,
                        "visual_token_type_ids": visual_token_type_ids,
                        "visual_attention_mask": visual_attention_mask,
                        "output_attentions": True
                        })

        inputs['input_ids'] = inputs['input_ids'].to(device)
        inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
        inputs['attention_mask'] = inputs['attention_mask'].to(device)
        inputs['visual_token_type_ids'] = inputs['visual_token_type_ids'].to(device)
        inputs['visual_attention_mask'] = inputs['visual_attention_mask'].to(device)

        print('visual_embeds:', visual_embeds.shape, 'input_ids:', inputs['input_ids'].shape)

        # Encoder output
        outputs = self.VisualBertResMLPEncoder(**inputs)

        # Classification layer
        outputs = self.classifier(outputs['pooler_output'])
        return outputs


#############Training Script###############
import os
import sys
import argparse
from torch import nn
from torch import optim
import torch.utils.data
import torch.nn.functional as F
from torch.utils.data  import DataLoader
from transformers import BertTokenizer
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

def get_arg():
    parser = argparse.ArgumentParser(description='VisualQuestionAnswerClassification')

    # VB Model parameters
    parser.add_argument('--n_heads',        type=int,   default=8,                                  help='Multi-head attention.')
    parser.add_argument('--encoder_layers', type=int,   default=6,                                  help='the number of layers of encoder in Transformer.')

    # Training parameters
    parser.add_argument('--epochs',         type=int,   default=2,                                 help='number of epochs to train for (if early stopping is not triggered).') #80, 26
    parser.add_argument('--batch_size',     type=int,   default=64,                                 help='batch_size')
    parser.add_argument('--workers',        type=int,   default=1,                                  help='for data-loading; right now, only 1 works with h5pys.')
    parser.add_argument('--lr',             type=float, default=0.00001,                           help='0.000005, 0.00001, 0.000005')
    parser.add_argument('--checkpoint_dir', default= 'checkpoints/VB_RN18',            help='med_vqa_c/m18/c80/m18_vid/c80_vid') #clf_v1_2_1x1/med_vqa_c3
    parser.add_argument('--question_len',   default= 25,                                            help='25')
    parser.add_argument('--num_class',      default= 2,                                             help='25')
    parser.add_argument('--validate',       default=False,                                          help='When only validation required False/True')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()
    return args

def train(args, train_dataloader, model, criterion, optimizer, epoch, tokenizer, device):

    model.train()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None

    for i, (imgs, q, labels) in enumerate(train_dataloader,0):
        questions = []
        for question in q: questions.append(question)
        inputs = tokenizer(questions, padding="max_length",max_length= args.question_len, return_tensors="pt")
        imgs, labels = imgs.to(device), labels.to(device)
        # print(imgs.shape)
        # imgs = torch.squeeze(imgs,1)
        outputs = model(inputs, imgs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
        label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
        label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
        label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    # loss and acc
    acc, c_acc = calc_acc(label_true, label_pred), calc_classwise_acc(label_true, label_pred)
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Train: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))
    return acc


def validate(args, val_loader, model, criterion, epoch, tokenizer, device, save_output = False):

    model.eval()
    total_loss = 0.0
    label_true = None
    label_pred = None
    label_score = None

    with torch.no_grad():
        for i, ( imgs, q, labels) in enumerate(val_loader,0):
            questions = []
            for question in q: questions.append(question)
            inputs = tokenizer(questions, padding="max_length",max_length=args.question_len, return_tensors="pt")
            imgs, labels = imgs.to(device), labels.to(device)

            # model forward pass
            outputs = model(inputs, imgs)

            # loss
            loss = criterion(outputs,labels)
            total_loss += loss.item()
            scores, predicted = torch.max(F.softmax(outputs, dim=1).data, 1)
            label_true = labels.data.cpu() if label_true == None else torch.cat((label_true, labels.data.cpu()), 0)
            label_pred = predicted.data.cpu() if label_pred == None else torch.cat((label_pred, predicted.data.cpu()), 0)
            label_score = scores.data.cpu() if label_score == None else torch.cat((label_score, scores.data.cpu()), 0)

    acc = calc_acc(label_true, label_pred)
    c_acc = 0.0
    precision, recall, fscore = calc_precision_recall_fscore(label_true, label_pred)
    print('Test: epoch: %d loss: %.6f | Acc: %.6f | Precision: %.6f | Recall: %.6f | FScore: %.6f' %(epoch, total_loss, acc, precision, recall, fscore))

    return (acc, c_acc, precision, recall, fscore)

if __name__ == '__main__':
    args = get_arg()
    args.checkpoint_dir = 'checkpoints/VB_RN18'
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    seed_everything()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0
    final_args = { "n_heads": args.n_heads, "encoder_layers": args.encoder_layers}
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]
    args.num_class = 18

    folder_head = 'EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Classification/*.txt'

    train_dataset = EndoVis18VQAGPTClassification(train_seq, folder_head, folder_tail)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size= args.batch_size, shuffle=True, num_workers=2)
    val_dataset = EndoVis18VQAGPTClassification(val_seq, folder_head, folder_tail)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size= args.batch_size, shuffle=False, num_workers=2)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    # model = VisualBertClassification(vocab_size=len(tokenizer), layers=args.encoder_layers, n_heads=args.n_heads, num_class = args.num_class)
    # model = VisualBertClassification_V3(vocab_size=len(tokenizer), layers=args.encoder_layers, n_heads=args.n_heads, num_class = args.num_class)
    model = VisualBertResMLPClassification(vocab_size=len(tokenizer), layers=args.encoder_layers, n_heads=args.n_heads, num_class = args.num_class)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(start_epoch, args.epochs):

            if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
                adjust_learning_rate(optimizer, 0.8)

            train_acc = train(args, train_dataloader=train_dataloader, model = model, criterion=criterion, optimizer=optimizer, epoch=epoch, tokenizer = tokenizer, device = device)
            test_acc, test_c_acc, test_precision, test_recall, test_fscore = validate(args, val_loader=val_dataloader, model = model, criterion=criterion, epoch=epoch, tokenizer = tokenizer, device = device)

            if test_acc >= best_results[0]:
                print('Best Epoch:', epoch)
                epochs_since_improvement = 0
                best_results[0] = test_acc
                best_epoch[0] = epoch
                save_clf_checkpoint(args.checkpoint_dir, epoch, epochs_since_improvement, model, optimizer, best_results[0], final_args)
            else:
                epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Total files: 1560 | Total question: 9014
Total files: 447 | Total question: 2769




visual_embeds: torch.Size([64, 1, 512]) torch.Size([64, 1])
visual_embeds: torch.Size([64, 1, 512]) input_ids: torch.Size([64, 25])
torch.Size([64, 25, 2048]) torch.Size([64, 1, 2048])
embedding_output: torch.Size([64, 26, 2048])
else bypass_transformer
ResMLP out: torch.Size([64, 26, 2048])
visual_embeds: torch.Size([64, 1, 512]) torch.Size([64, 1])
visual_embeds: torch.Size([64, 1, 512]) input_ids: torch.Size([64, 25])
torch.Size([64, 25, 2048]) torch.Size([64, 1, 2048])
embedding_output: torch.Size([64, 26, 2048])
else bypass_transformer
ResMLP out: torch.Size([64, 26, 2048])
visual_embeds: torch.Size([64, 1, 512]) torch.Size([64, 1])
visual_embeds: torch.Size([64, 1, 512]) input_ids: torch.Size([64, 25])
torch.Size([64, 25, 2048]) torch.Size([64, 1, 2048])
embedding_output: torch.Size([64, 26, 2048])
else bypass_transformer
ResMLP out: torch.Size([64, 26, 2048])
visual_embeds: torch.Size([64, 1, 512]) torch.Size([64, 1])
visual_embeds: torch.Size([64, 1, 512]) input_ids: torch.Size

KeyboardInterrupt: 

In [None]:
VBconfig = VisualBertConfig(vocab_size= len(tokenizer), visual_embedding_dim = 512, num_hidden_layers = args.encoder_layers, num_attention_heads = args.n_heads, hidden_size = 2048)
VBconfig

VisualBertConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "bypass_transformer": false,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "visual_bert",
  "num_attention_heads": 8,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "special_visual_initialize": true,
  "transformers_version": "4.44.2",
  "type_vocab_size": 2,
  "visual_embedding_dim": 512,
  "vocab_size": 30522
}

In [None]:
VBconfig.visual_embedding_dim

512

In [None]:
VBconfig.bypass_transformer

False

In [None]:
import requests
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
img = Image.open(requests.get(url, stream=True).raw)

In [None]:
out = image_processor(img)
out['pixel_values'][0].shape

(3, 224, 224)