## Reft + VLBart experiments

Does ReFT works with Vision-language? Let's find out with the VQA Task.

I replicated the following code mostly from src/multitask.py.

### 1.1 Reft Model Replica for VLBart

In [1]:
from trainer_base import TrainerBase
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import collections
from pathlib import Path
from packaging import version

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import logging
import shutil
from pprint import pprint, pformat
from copy import deepcopy

from param import parse_args


import vqa
import gqa
import nlvr
import vcr
import caption
import mmt
import refcoco

import multitask_data

from utils import LossMeter, set_global_logging_level
from dist_utils import reduce_dict
import wandb

from vis_encoder import get_vis_encoder
from transformers.models.t5.modeling_t5 import T5LayerNorm
import modeling_t5
import modeling_bart
from clip.model import VisualAdapter
from ddp_fix import ddp_forward

from adapters import AdapterController, MetaLayersAdapterController

proj_dir = os.path.dirname(os.path.dirname(os.getcwd()))


_use_native_amp = False
_use_apex = False

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transormers.file_utils import is_apex_available
    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast

In [2]:
import math
import random
from dataclasses import dataclass

import transformers

from transformers.models.bart.modeling_bart import (
    BartLearnedPositionalEmbedding,
    BartEncoderLayer,
    BartPretrainedModel,
    BartConfig,
    ACT2FN,
    shift_tokens_right, _make_causal_mask, _expand_mask
)

from my_transformers.modeling_bart import BartModel, BartForConditionalGeneration, BartDecoder, BartEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import copy

from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import logging
from transformers import BeamScorer, BeamSearchScorer

from adapters import (
    AdapterLayer, 
    AdapterController,
    OutputParallelAdapterLayer, 
    TaskEmbeddingController,
    AdapterLayersHyperNetController,
    AdapterLayersOneHyperNetController,
    MetaLayersAdapterController
)

from adapters.hypercomplex.layers import PHMLinear

from prompt import (
    PromptController,
)



logger = logging.get_logger(__name__)
from modeling_bart import VLBart

Here we replace the `VLBart` class in the vladapter/DoRA codebase with VLBartReft. Which wraps the models into Pyvene intervenables.

In [3]:
class VLBartReft(VLBart):
    def __init__(self, config: BartConfig):
        super().__init__(config)
        from pyreft import get_reft_model
        self.intervenable = get_reft_model(self.model, config.reft_config)
        # print("Reft parameters:", self.intervenable.interventions)
        # self.intervenable.unfreeze_intervention_parameters()
        self.intervenable.print_trainable_parameters()
        # print("INTERVENABLE:", self.intervenable.model)

        # Unfreeze the PyVene intervention parameters
        for k, v in self.intervenable.unfreeze_intervention_parameters().items():
            n = k.replace(".", "#")
            print(n)
            self.register_parameter(n, v)
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,

        vis_inputs=None,
        vis_attention_mask=None,

        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        task=None,

        reduce_loss=False,
        intervention_locations = None,


        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
        
        if intervention_locations is not None:
            # print("Intervention locs not None")
            # Pyvene forward pass
            intervention_locations = intervention_locations.clone().detach().permute(1, 0, 2)
            _, outputs = self.intervenable(
                {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "vis_inputs": vis_inputs,
                    "vis_attention_mask": vis_attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                    "decoder_attention_mask": decoder_attention_mask,
                    "encoder_outputs": encoder_outputs,
                    "past_key_values": past_key_values,
                    "inputs_embeds": inputs_embeds,
                    "decoder_inputs_embeds": decoder_inputs_embeds,
                    "output_attentions": output_attentions,
                    "output_hidden_states": output_hidden_states,
                    "task": task,
                    "return_dict": return_dict,
                },
                unit_locations={"sources->base": (
                    None,
                    intervention_locations
                )},
                labels=labels,
                return_dict=False,
                subspaces=None,
                use_cache=use_cache,
            )
        else:
            # print("Intervention locs None")
            outputs = self.model(
                input_ids,
                attention_mask=attention_mask,

                vis_inputs=vis_inputs,
                vis_attention_mask=vis_attention_mask,

                decoder_input_ids=decoder_input_ids,
                encoder_outputs=encoder_outputs,
                decoder_attention_mask=decoder_attention_mask,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                decoder_inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                task=task,
            )

        # print("Outputs:", outputs)
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias

        if self.output_adapter is not None:
            lm_logits = lm_logits + self.output_adapter(outputs[0])
        
        masked_lm_loss = None
        # print("LOGITS:", lm_logits)
        # print("LABELS", labels)
        if labels is not None:
            # loss_fct = CrossEntropyLoss()
            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
            
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # if masked_lm_loss is not None and len(masked_lm_loss) > 1:
        #     masked_lm_loss = masked_lm_loss[0]
        # print("LOSS 0:", masked_lm_loss)

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )



This class is the same as `VLBartVQA` in vqa_model.py, except that we inherit from `VLBartReft`, and we adapt to Pyvene generation.

In [4]:
class VLBartVQA(VLBartReft):
    def __init__(self, config, num_answers=None, label2ans=None):
        super().__init__(config)

        if config.classifier:
            self.answer_head = nn.Sequential(
                nn.Linear(config.d_model, config.d_model * 2),
                nn.GELU(),
                nn.LayerNorm(config.d_model * 2),
                nn.Linear(config.d_model * 2, num_answers)
            )

        self.num_answers = num_answers
        self.label2ans = label2ans
        self.bce_loss = nn.BCEWithLogitsLoss()

    def train_step(self, batch):

        device = next(self.parameters()).device

        batch = self.vis_forward(batch, device)
        task = batch["task"]

        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)
        intervention_locations = batch['intervention_locations'].to(device)

        if self.config.classifier:
            B = len(input_ids)

            decoder_input_ids = torch.tensor(
                [self.config.decoder_start_token_id, self.config.bos_token_id],
                dtype=torch.long, device=device).unsqueeze(0).expand(B, 2)

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True,
                return_dict=True,
                task=task,
                intervention_locations=intervention_locations
            )

            target = batch['targets'].to(device)

            last_layer_hidden_state = output.decoder_hidden_states[-1]
            last_hidden_state = last_layer_hidden_state.view(B, -1, self.config.d_model)[:, -1]

            # [B, num_answers]
            logit = self.answer_head(last_hidden_state)

            loss = self.bce_loss(logit, target)

        else:
            lm_labels = batch["target_ids"].to(device)

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                labels=lm_labels,
                return_dict=True,
                task=task,
                intervention_locations=intervention_locations
            )
            assert 'loss' in output

            lm_mask = (lm_labels != -100).float()
            B, L = lm_labels.size()

            loss = output['loss']

            loss = loss.view(B, L) * lm_mask

            loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)  # B

            loss = loss * batch['scores'].to(device=device)

            loss = loss.mean()
        
        # print("LOSS 1:", batch["scores"], loss.item())
        result = {
            'loss': loss
        }

        return result

    @torch.no_grad()
    def test_step(self, batch, **kwargs):
        self.eval()
        device = next(self.parameters()).device

        batch = self.vis_forward(batch, device)

        vis_feats = batch['vis_feats'].to(device)
        input_ids = batch['input_ids'].to(device)
        vis_pos = batch['boxes'].to(device)
        task = batch["task"]
        intervention_locations = batch['intervention_locations'].to(device)

        result = {}
        if self.config.classifier:
            B = len(input_ids)

            decoder_input_ids = torch.tensor(
                [self.config.decoder_start_token_id, self.config.bos_token_id],
                dtype=torch.long, device=device).unsqueeze(0).expand(B, 2)

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True,
                return_dict=True,
                task=task,
                intervention_locations=intervention_locations
            )

            last_layer_hidden_state = output.decoder_hidden_states[-1]
            last_hidden_state = last_layer_hidden_state.view(B, -1, self.config.d_model)[:, -1]

            # [B, num_answers]
            logit = self.answer_head(last_hidden_state)

            score, pred_ans_id = logit.max(1)
            pred_ans_id = pred_ans_id.cpu().numpy()
            pred_ans = [self.label2ans[ans_id] for ans_id in pred_ans_id]

            result['pred_ans'] = pred_ans

        else:
            generation_args = {
                "base": {
                    "input_ids":input_ids,
                    "vis_inputs":(vis_feats, vis_pos),
                    "task":task,
                    **kwargs
                },
                "unit_locations": {"sources->base": (None, 
                    intervention_locations.permute(1, 0, 2))},
                "intervene_on_prompt": True,
                "eos_token_id": self.tokenizer.eos_token_id,
                "early_stopping": True,
                "model": self,
            }
            # print("Generating...", input_ids.shape, intervention_locations)
            # TODO: temperature, top_p, top_k
            # print("GENERATE MODEL:", self.intervenable.model)
            _, output = self.intervenable.generate(**generation_args)
            generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
            result['token_ids'] = output
            result['pred_ans'] = generated_sents

        return result



### 1.2 Multitask VLBart Model

In [5]:
class VLBartMultiTask(VLBartReft):
    def __init__(self, config):
        super().__init__(config)

    def train_step(self, batch, **kwargs):
        task = batch['task']
        if task == 'vqa':
            return VLBartVQA.train_step(self, batch, **kwargs)

    def valid_step(self, batch, **kwargs):
        task = batch['task']
        if task == 'vqa':
            return VLBartVQA.valid_step(self, batch, **kwargs)

    def test_step(self, batch, **kwargs):
        task = batch['task']
        if task == 'vqa':
            return VLBartVQA.test_step(self, batch, **kwargs)

### 1.3 Multitask VLBart Trainer

### 1.3.1 Reft Specific Trainer

Here we create `ReftConfig` for ReFT to create its appropriate interventions. Also, we change the weight decay of ReFT parameters to 0.

In [6]:
from pyreft import ReftConfig, LoreftIntervention, TaskType

class ReftTrainer(TrainerBase):
    def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):
        super().__init__(
            args,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            train=train)

    def create_reft_config(self, config):
        args = self.args
        layers = args.layers
        # ReFT layers - right now only "all" works properly
        # TODO: properly process "layers" when it is not "all"
        if layers != "all":
            layers = [int(l) for l in layers.split(";")]
        else:
            # TODO: verify config's hidden layers field
            layers = [l for l in range(config.num_hidden_layers)]
        if '+' in self.args.positions and not args.share_weights:
            layers += layers
        
        image_rank = args.reft_image_rank
        text_rank = args.reft_rank
        embed_dim = args.mid_dim

        # print("REFT PARAMS:",embed_dim, rank, args.dropout)
        representations = []
        # Text interventions
        if text_rank != -1:
            representations += [{
                "layer": l, "component": "block_output",
                "low_rank_dimension": text_rank,
                "intervention": LoreftIntervention(
                    embed_dim=embed_dim, low_rank_dimension=text_rank,
                    dropout=args.reft_dropout, dtype=torch.float32, act_fn=None, device="cuda",
                    add_bias=True
                )
            } for l in layers]
        # Image interventions
        if image_rank != -1:
            representations += [{
                "layer": l, "component": "block_output",
                "low_rank_dimension": image_rank,
                "intervention": LoreftIntervention(
                    embed_dim=embed_dim, low_rank_dimension=image_rank,
                    dropout=args.reft_image_dropout, dtype=torch.float32, act_fn=None, device="cuda",
                    add_bias=True
                )
            } for l in layers]
        reft_config = ReftConfig(representations=representations)
        print(reft_config)
        return reft_config

    def create_config(self):
        config = super().create_config()
        setattr(config, "reft_config", self.create_reft_config(config))
        return config

    def create_optimizer_and_scheduler(self):
        if self.verbose:
            print('Building Optimizer')

        lr_scheduler = None

        from transformers.optimization import AdamW, get_linear_schedule_with_warmup

        # Added "#unit#pos" to `no_decay` to keep ReFT intervention's weight decay to 0
        # Bart's bias and layer norm's weight decay is 0, others are not zero 
        no_decay = ["bias", "LayerNorm.weight", "#unit#pos"]

        if 'adamw' in self.args.optim:

            if self.args.use_separate_optimizer_for_visual:
                
                # transformer's parameters
                optimizer_grouped_parameters = [
                    {
                        "params": [p for n, p in self.model.named_parameters() if ( (not any(nd in n for nd in no_decay)) and ("vis_encoder" not in n) ) ],
                        "weight_decay": self.args.weight_decay,
                        "lr": self.args.lr,
                    },
                    {
                        "params": [p for n, p in self.model.named_parameters() if ( (any(nd in n for nd in no_decay)) and ("vis_encoder" not in n ))],
                        "weight_decay": 0.0,
                        "lr": self.args.lr,
                    },
                ]
                
                visn_model = self.model.vis_encoder
                if self.args.use_adam_for_visual:

                    vis_optimizer_grouped_parameters = [
                        {
                            "params": [p for n, p in visn_model.named_parameters() if not any(nd in n for nd in no_decay)],
                            "weight_decay": self.args.vis_weight_decay,
                            "lr": self.args.vis_lr,
                        },
                        {
                            "params": [p for n, p in visn_model.named_parameters() if any(nd in n for nd in no_decay)],
                            "weight_decay": 0.0,
                            "lr": self.args.vis_lr,
                        },
                    ]
                    optim = AdamW(
                        optimizer_grouped_parameters + vis_optimizer_grouped_parameters,
                        lr=self.args.lr,
                        # betas=(0.9, 0.98),
                        eps=self.args.adam_eps
                    )
                else:
                    optim = AdamW(
                        optimizer_grouped_parameters, lr=self.args.lr, eps=self.args.adam_eps
                    )
                    vis_optim = torch.optim.SGD(
                        visn_model.parameters(), 
                        self.args.vis_lr,
                        momentum=0,
                        weight_decay=self.args.vis_weight_decay
                    )

                    optim = FusedOptimizer([optim, vis_optim])

            else:
                # for n, _ in self.model.named_parameters():
                #     print("Parameter: ", n)
                # print("=======")
                # for n, _ in self.model.intervenable.unfreeze_intervention_parameters().items():
                #     print("Parameter: ", n)
                optimizer_grouped_parameters = [
                    {
                        "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                        "weight_decay": 0.0,
                    },
                    # {
                    #     "params": [p for n, p in self.model.intervenable.unfreeze_intervention_parameters().items()],
                    #     "weight_decay": self.args.weight_decay,
                    # }
                ]
                optim = AdamW(optimizer_grouped_parameters,
                            lr=self.args.lr, eps=self.args.adam_eps)

        else:
            # print("Parameters:", self.model.named_parameters())
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
                # {
                #     "params": self.model.intervenable.unfreeze_intervention_parameters(),
                #     "weight_decay": self.args.weight_decay,
                # }
            ]

            # if self.include_vis_encoder:
            #     trainable_parameters = trainable_parameters + list(self.vis_encoder.parameters())

            optim = self.args.optimizer(optimizer_grouped_parameters, self.args.lr)

        batch_per_epoch = len(self.train_loader)
        t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
        warmup_ratio = self.args.warmup_ratio
        warmup_iters = int(t_total * warmup_ratio)
        if self.verbose:
            print("Batch per epoch: %d" % batch_per_epoch)
            print("Total Iters: %d" % t_total)
            print('Warmup ratio:', warmup_ratio)
            print("Warm up Iters: %d" % warmup_iters)

        lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters, t_total)

        return optim, lr_scheduler




### 1.3.2 Reft images trainer

Only thing we changed in this class compare to `trainer_base.py` is changing the intervenable model to Cuda, and unfreeze Reft parameters.

In [7]:
class Trainer(ReftTrainer):
    def __init__(self, args, train_loader=None, val_loader=None, test_loader=None, train=True):
        super().__init__(
            args,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            train=train)

        if not self.verbose:
            set_global_logging_level(logging.ERROR, ["transformers"])

        model_kwargs = {}
        if 'bart' in args.backbone:
            model_class = VLBartMultiTask

        config = self.create_config()
        self.tokenizer = self.create_tokenizer()

        if 'bart' in self.args.tokenizer:
            num_added_toks = 0
            if config.use_vis_order_embedding:
                additional_special_tokens = [f'<extra_id_{i}>' for i in range(100-1, -1, -1)] + \
                        [f'<vis_extra_id_{i}>' for i in range(100-1, -1, -1)]
                special_tokens_dict = {'additional_special_tokens': additional_special_tokens}
                num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)

                config.default_obj_order_ids = self.tokenizer.convert_tokens_to_ids([f'<vis_extra_id_{i}>' for i in range(100)])

        self.model = self.create_model(model_class, config, **model_kwargs)

        if 't5' in self.args.tokenizer:
            self.model.resize_token_embeddings(self.tokenizer.vocab_size)
        elif 'bart' in self.args.tokenizer:
            self.model.resize_token_embeddings(self.model.model.shared.num_embeddings + num_added_toks)

        self.model.tokenizer = self.tokenizer
        if 't5' in self.args.tokenizer or 'bart' in self.args.tokenizer:
            self.model.true_id = self.tokenizer('true', add_special_tokens=False).input_ids[0]
            self.model.false_id = self.tokenizer('false', add_special_tokens=False).input_ids[0]

        if self.include_vis_encoder:
            # train vision encoder end-to-end
            vis_encoder_type = self.args.feature_type.split("_")[-1]

            if self.args.use_vis_adapter:
                self.vis_encoder = get_vis_encoder(
                    backbone=vis_encoder_type, 
                    image_size=eval(self.args.image_size)[0],
                    adapter_type=self.args.vis_adapter_type,
                    reduction_factor=self.args.vis_reduction_factor,
                    use_bn=not self.args.remove_bn_vis_adapter,
                )
            else:
                self.vis_encoder = get_vis_encoder(
                    backbone=vis_encoder_type, 
                    image_size=eval(self.args.image_size)[0],
                    adapter_type=None,
                )

            print("include vision encoder")
            self.model.vis_encoder = self.vis_encoder
            print(self.model)
        # Load Checkpoint
        self.start_epoch = None
        if args.load is not None:
            ckpt_path = args.load
            self.load_checkpoint(ckpt_path)
        if self.args.from_scratch:
            self.init_weights()

        # GPU Options
        print(f'Model Launching at GPU {self.args.gpu}')
        if self.verbose:
            from time import time
            start = time()
        self.model = self.model.to(args.gpu)
        
        # Only thing changed: set device to cuda, and unfreeze ReFT params

        self.model.intervenable.set_device(self.model.model.device)

        self.freeze_whole_model() # freeze whole parameters first
        self.unfreeze_parameters() # unfreeze selected parameters
        self.model.intervenable.unfreeze_intervention_parameters()
        # print(self.model)
        self.percent_updated_parameters = self.print_trainable_params_percentage(self.model)

        # Optimizer
        if train:
            self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler()

            if self.args.fp16 and _use_native_amp:
                self.scaler = torch.cuda.amp.GradScaler()
            elif _use_apex:
                self.model, self.optim = amp.initialize(
                    self.model, self.optim, opt_level='O1', verbosity=self.verbose)

        if args.multiGPU:
            if args.distributed:
                self.model = DDP(self.model, device_ids=[args.gpu],
                                 find_unused_parameters=True
                                 )
        if self.verbose:
            print(f'It took {time() - start:.1f}s')

    def train(self):
        if self.verbose:
            vqa_loss_meter = LossMeter()
            refcoco_loss_meter = LossMeter()
            # best_eval_loss = 9595.
            quesid2ans = {}
            best_vqa_valid = 0.
            best_vqa_epoch = 0

            wandb.init(project=self.args.project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)
            wandb.log(
                {"percent of updated parameters (%)": self.percent_updated_parameters}
            )

            src_dir = os.path.dirname(os.getcwd())
            base_path = os.path.dirname(src_dir)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            self.partial_eval()

            if self.args.distributed:
                self.train_loader.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=250)

            epoch_results = {
                'loss': 0.,

            }

            task_counter = {
                'vqa': 0,
            }

            # vqa
            quesid2ans = {}
            train_acc = 0.
            # train_acc_steps = int(len(self.train_loader) * 0.05)
            # last_acc_step = 0

            for step_i, batch in enumerate(self.train_loader):

                # print(f'GPU{self.args.gpu} inside training loop')
                # print(batch)
                task = batch['task']
                # if self.verbose:
                #     print('task', task)
                task_counter[task] += 1

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                # self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = ddp_forward(self.model, batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = ddp_forward(self.model, batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.track_z:
                    reg_loss = 0
                    layer_num = 0
                    for name, sub_module in self.model.named_modules():
                        if isinstance(sub_module, (AdapterController)):
                            reg_loss += ((sub_module.adapters[task].z) ** 2).mean()
                            layer_num += 1

                        if isinstance(sub_module, (MetaLayersAdapterController)):
                            reg_loss += ((sub_module.z) ** 2).mean()
                            layer_num += 1

                    reg_loss = reg_loss / layer_num

                    loss = loss + self.args.lambda_z * reg_loss

                    # wandb.log(
                    #     {"Reg loss": reg_loss.item()},
                    #     step=global_step
                    # )

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(
                            self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr


                # self.train_step_post_hook(result)

                if self.args.log_train_accuracy and task == 'refcoco':
                    correct = results['correct']
                    n_correct += sum(correct)
                    n_total += len(correct)

                if self.verbose:
                    if task == 'vqa':
                        vqa_loss_meter.update(loss.item())

                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'

                    desc_str += f" |"
                    if 'vqa' in self.args.tasks:
                        desc_str += f" VQA {task_counter['vqa']}"
                    if len(vqa_loss_meter) > 0:
                        desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()
                # self.save("Epoch%02d" % (epoch + 1))

            if self.args.log_train_accuracy:
                train_score_dict = {
                    'n_correct': n_correct,
                    'n_total': n_total
                }
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            if self.verbose:
                # Validation
                log_str = ''
                wandb_log_dict = {}

                if 'vqa' in self.args.tasks:
                    # VQA
                    vqa_val_loader = self.val_loader['vqa']
                    score_dict = self.vqa_evaluate(vqa_val_loader)
                    valid_score = score_dict['topk_score'] * 100.
                    valid_score_raw = score_dict['overall']
                    if valid_score_raw > best_vqa_valid or epoch == 0:
                        best_vqa_valid = valid_score_raw
                        best_vqa_epoch = epoch
                        # self.save("VQA_BEST")
                    log_str += f"VQA"
                    log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (epoch, valid_score_raw, valid_score)
                    log_str += "\nEpoch %d: Best Raw %0.2f\n" % (best_vqa_epoch, best_vqa_valid)
                    wandb_log_dict['VQA/Valid/score'] = valid_score
                    wandb_log_dict['VQA/Valid/raw_score'] = score_dict['overall']
                
                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        # Test Set
        if self.verbose:
            self.save("LAST")

            log_str = ''
            wandb_log_dict = {}

            if 'vqa' in self.args.tasks:
                # VQA
                vqa_test_loader = self.test_loader['vqa']
                evaluator = vqa_test_loader.evaluator
                dump_path = os.path.join(self.args.output, 'karpathy_test_predict.json')
                quesid2ans = self.vqa_predict(vqa_test_loader, dump_path)
                wandb.save(dump_path, base_path=self.args.output)

                acc_dict_all = evaluator.evaluate_raw(quesid2ans)
                acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True)
                acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False)

                wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall']
                wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable['overall']
                wandb_log_dict['VQA/Test/topk_not_optimal'] = acc_dict_unanswerable['overall']

                if self.test_loader.get("vqa_submit", None):
                    vqa_submit_test_loader = self.test_loader['vqa_submit']
                    dump_path = os.path.join(self.args.output, 'vqa_submit.json')
                    self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path)
                    wandb.save(dump_path, base_path=self.args.output)

            print(log_str)
            wandb.log(wandb_log_dict, step=self.args.epochs)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()

    def vqa_predict(self, loader, dump_path=None):
        self.model.eval()
        with torch.no_grad():
            quesid2ans = {}

            gen_kwargs = {}
            gen_kwargs['num_beams'] = 1

            for i, batch in enumerate(tqdm(loader, ncols=150, desc="VQA Validation")):

                if self.args.distributed:
                    results = self.model.module.test_step(batch, **gen_kwargs)
                else:
                    results = self.model.test_step(batch, **gen_kwargs)

                pred_ans = results['pred_ans']
                ques_ids = batch['question_ids']

                for qid, ans in zip(ques_ids, pred_ans):
                    quesid2ans[qid] = ans

            if dump_path is not None:
                loader.evaluator.dump_result(quesid2ans, dump_path)
            return quesid2ans

    def vqa_evaluate(self, loader, dump_path=None):
        evaluator = loader.evaluator
        quesid2ans = self.vqa_predict(loader, dump_path)

        acc_dict = evaluator.evaluate_raw(quesid2ans)

        topk_score = evaluator.evaluate(quesid2ans)
        acc_dict['topk_score'] = topk_score

        return acc_dict

### 1.4 Reft Data Specifics (intervention locations)

This part is pretty complex, similar to the ReFT implementation on text. If we do not differentiate text vs image tokens, use `get_intervention_locations()`. If we have both text and image interventions, use `get_image_intervention_locations()`. Other two getters are not thoroughly tested. Also, the number of interventions passed into these functions should match the number of interventions defined in the `ReftConfig` in section 1.3.1. We only use `pad_mode = first`. `pad_mode = last` have not been tested.

In [8]:
IGNORE_INDEX = -100
from transformers import DataCollatorForSeq2Seq
import torch

def parse_positions(positions: str):
    # parse position
    first_n, last_n = 0, 0
    if "+" in positions:
        first_n = int(positions.split("+")[0].strip("f"))
        last_n = int(positions.split("+")[1].strip("l"))
    else:
        if "f" in positions:
            first_n = int(positions.strip("f"))
        elif "l" in positions:
            last_n = int(positions.strip("l"))
    return first_n, last_n


def get_intervention_locations(**kwargs):
    """
    This function generates the intervention locations.

    For your customized dataset, you want to create your own function.
    """
    # parse kwargs
    share_weights = kwargs["share_weights"] if "share_weights" in kwargs else False
    last_position = kwargs["last_position"]
    if "positions" in kwargs:
        _first_n, _last_n = parse_positions(kwargs["positions"])
    else:
        _first_n, _last_n = kwargs["first_n"], kwargs["last_n"]
    num_interventions = kwargs["num_interventions"]
    pad_mode = kwargs["pad_mode"] if "pad_mode" in kwargs else "first"
    last_offset = kwargs["last_offset"] if "last_offset" in kwargs else 0
    last_position += last_offset


    first_n = min(last_position // 2, _first_n)
    last_n = min(last_position // 2, _last_n)

    pad_amount = (_first_n - first_n) + (_last_n - last_n)
    pad_position = -1 if pad_mode == "first" else last_position
    if share_weights or (first_n == 0 or last_n == 0):
        position_list = [i for i in range(first_n)] + \
            [i for i in range(last_position - last_n, last_position)] + \
            [pad_position for _ in range(pad_amount)]
        intervention_locations = [position_list]*num_interventions
    else:
        left_pad_amount = (_first_n - first_n)
        right_pad_amount = (_last_n - last_n)
        left_intervention_locations = [i for i in range(first_n)] + [pad_position for _ in range(left_pad_amount)]
        right_intervention_locations = [i for i in range(last_position - last_n, last_position)] + \
            [pad_position for _ in range(right_pad_amount)]
        # after padding, there could be still length diff, we need to do another check
        left_len = len(left_intervention_locations)
        right_len = len(right_intervention_locations)
        if left_len > right_len:
            right_intervention_locations += [pad_position for _ in range(left_len-right_len)]
        else:
            left_intervention_locations += [pad_position for _ in range(right_len-left_len)]
        intervention_locations = [left_intervention_locations]*(num_interventions//2) + \
            [right_intervention_locations]*(num_interventions//2)
    
    return intervention_locations

def get_all_intervention_locations(**kwargs):
    positions = kwargs["positions"]
    amt = int(positions.strip("all"))
    pad_mode = kwargs["pad_mode"] if "pad_mode" in kwargs else "first"
    last_offset = kwargs["last_offset"] if "last_offset" in kwargs else 0
    last_position = kwargs["last_position"]
    last_position += last_offset
    pad_position = -1 if pad_mode == "first" else last_position
    intervention_locations = [i for i in range(last_position)] + [pad_position for _ in range(amt - last_position)]
    return [intervention_locations]*kwargs["num_interventions"]

def get_image_only_intervention_locations(**kwargs):
    """
    This function generates the intervention locations.
    For simplicity, this function does not implement padding.

    For your customized dataset, you want to create your own function.
    """
    # parse kwargs
    share_weights = kwargs["share_weights"] if "share_weights" in kwargs else False
    last_text_position = kwargs["last_position"]
    assert "image_positions" in kwargs, "Image positions must be provided"
    first_image_n, last_image_n = parse_positions(kwargs["image_positions"])

    num_interventions = kwargs["num_interventions"]
    image_offset = kwargs["last_offset"] if "last_offset" in kwargs else 0

    pad_mode = kwargs["pad_mode"] if "pad_mode" in kwargs else "first"
    pad_position = -1 if pad_mode == "first" else last_text_position + image_offset
    if pad_mode != "first" and "nlvr" in kwargs["tasks"]:
        pad_position = last_text_position + 2 * image_offset

    if share_weights or (first_image_n == 0 or last_image_n == 0):
        image_position_list = [i for i in range(last_text_position, last_text_position + first_image_n)] + \
            [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]
        if "nlvr" in kwargs["tasks"]:
            image_position_list += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)] + \
            [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]
        intervention_locations = [image_position_list]* num_interventions
    else:
        left_image_intervention_locations = [i for i in range(last_text_position, last_text_position + first_image_n)]
        right_image_intervention_locations = [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]
        if "nlvr" in kwargs["tasks"]:
            left_image_intervention_locations += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)]
            right_image_intervention_locations += [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]
        intervention_locations = \
            [left_image_intervention_locations]*(num_interventions//2) + \
            [right_image_intervention_locations]*(num_interventions//2)
    return intervention_locations



def get_image_intervention_locations(**kwargs):
    """
    This function generates the intervention locations.
    For simplicity, this function does not implement padding.

    For your customized dataset, you want to create your own function.
    """
    # parse kwargs
    share_weights = kwargs["share_weights"] if "share_weights" in kwargs else False
    last_text_position = kwargs["last_position"]
    assert "image_positions" in kwargs, "Image positions must be provided"
    assert "positions" in kwargs, "Text positions must be provided"
    first_n, last_n = parse_positions(kwargs["positions"])
    first_image_n, last_image_n = parse_positions(kwargs["image_positions"])

    num_interventions = kwargs["num_interventions"]
    # `last_offset` is the length of the images (n_boxes).
    # Image tokens are concatenated to the end of the text tokens, i.e. after `last_position` tokens.
    # The true last position of the input is `last_position + last_offset`
    image_offset = kwargs["last_offset"] if "last_offset" in kwargs else 0

    pad_mode = kwargs["pad_mode"] if "pad_mode" in kwargs else "first"
    pad_position = -1 if pad_mode == "first" else last_text_position + image_offset
    if pad_mode != "first" and "nlvr" in kwargs["tasks"]:
        pad_position = last_text_position + 2 * image_offset

    if share_weights or ((first_n == 0 or last_n == 0) and (first_image_n == 0 or last_image_n == 0)):
        position_list = [i for i in range(first_n)] + \
            [i for i in range(last_text_position - last_n, last_text_position)]
        image_position_list = [i for i in range(last_text_position, last_text_position + first_image_n)] + \
            [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]
        # There are 2 images in nlvr, so performing special treatment
        # For this notebook however, we only use vqa
        if "nlvr" in kwargs["tasks"]:
            image_position_list += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)] + \
            [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]
        text_len = len(position_list)
        image_len = len(image_position_list)
        if text_len > image_len:
            image_position_list += [pad_position for _ in range(text_len-image_len)]
        else:
            position_list += [pad_position for _ in range(image_len-text_len)]
        intervention_locations = [position_list]*(num_interventions//2) + \
            [image_position_list]*(num_interventions//2)
    else:
        assert first_n == last_n, "For now, we only support same first and last positions"
        left_intervention_locations = [i for i in range(first_n)]
        right_intervention_locations = [i for i in range(last_text_position - last_n, last_text_position)]
        left_image_intervention_locations = [i for i in range(last_text_position, last_text_position + first_image_n)]
        right_image_intervention_locations = [i for i in range(last_text_position + image_offset - last_image_n, last_text_position + image_offset)]
        if "nlvr" in kwargs["tasks"]:
            left_image_intervention_locations += [i for i in range(last_text_position + image_offset, last_text_position + image_offset + first_image_n)]
            right_image_intervention_locations += [i for i in range(last_text_position + 2 * image_offset - last_image_n, last_text_position + 2 * image_offset)]
        text_len = len(left_intervention_locations)
        image_len = len(left_image_intervention_locations)
        if text_len > image_len:
            left_image_intervention_locations += [pad_position for _ in range(text_len-image_len)]
            right_image_intervention_locations += [pad_position for _ in range(text_len-image_len)]
        else:
            left_intervention_locations += [pad_position for _ in range(image_len-text_len)]
            right_intervention_locations += [pad_position for _ in range(image_len-text_len)]

        intervention_locations = [left_intervention_locations]*(num_interventions//4) + \
            [right_intervention_locations]*(num_interventions//4) + \
            [left_image_intervention_locations]*(num_interventions//4) + \
            [right_image_intervention_locations]*(num_interventions//4)
    return intervention_locations

    
def compute_intervention(
    id: int, 
    result: dict, 
    tokenizer,
    fields_to_pad = [],
    fields_to_mask = [],
    **kwargs):
    pad_mode = kwargs["pad_mode"]
    # compute intervention locs
    if "positions" in kwargs and "all" in kwargs["positions"]:
        intervention_locations =  get_all_intervention_locations(**kwargs)
    elif "image_positions" in kwargs and "positions" in kwargs:
        intervention_locations = get_image_intervention_locations(**kwargs)
    elif "image_positions" in kwargs:
        intervention_locations = get_image_only_intervention_locations(**kwargs)
    else:
        intervention_locations = get_intervention_locations(**kwargs)
    result["intervention_locations"] = intervention_locations
    result["id"] = id

    # add a single padding token BEFORE input_ids and fix everything
    if fields_to_pad is not None:
        if pad_mode == "first":
            for field in fields_to_pad:
                if field not in result:
                    continue
                if field == "labels":
                    result[field] = torch.cat((torch.tensor([IGNORE_INDEX,]), result[field]))
                else:
                    result[field] = torch.cat((torch.tensor([tokenizer.pad_token_id,]), result[field]))
            result["intervention_locations"] = (torch.IntTensor(result["intervention_locations"]) + 1).tolist()
            result["input_length"] += 1
        elif pad_mode == "last":
            for field in fields_to_pad:
                if field not in result:
                    continue
                if field == "labels":
                    result[field] = torch.cat((result[field], torch.tensor([IGNORE_INDEX,])))
                else:
                    result[field] = torch.cat((result[field], torch.tensor([tokenizer.pad_token_id,])))
            result["input_length"] += 1
        
    # attention masks
    if len(fields_to_mask) == 1:
        result["attention_mask"] = (result[fields_to_mask[0]] != tokenizer.pad_token_id).int()
    else:
        for field in fields_to_mask:
            result[f"{field}_mask"] = (result[field] != tokenizer.pad_token_id).int()

    # does not handle subspaces for now
    # print("Intervention Locations", result["intervention_locations"])
    return result

def reft_post_process(
    out_dict,
    tokenizer,
    idx: int, 
    last_position: int, 
    args = None,
    pad_mode = "none",
    fields_to_pad = [],
    fields_to_mask = []
):
    # print("Out_dict keys:", out_dict.keys())
    out_dict["instruction"] = tokenizer.decode(
        out_dict["input_ids"], skip_special_tokens=True)
    # out_dict["logits"] = out_dict["labels"]
    # out_dict["labels"] = out_dict["target_ids"]
    kwargs = {}
    if args is not None:
        if args.reft_rank != -1:
            kwargs["positions"] = args.positions
        if args.reft_image_rank != -1:
            kwargs["image_positions"] = args.image_positions
        kwargs["share_weights"] = args.share_weights
        layers = [int(l) for l in args.layers.split(";")]
        kwargs["num_interventions"] = len(layers) if args.share_weights else 2 * len(layers)
        # Double interventions if creating separate interventions for texts and images
        if args.reft_image_rank != -1 and args.reft_rank != -1:
            kwargs["num_interventions"] *= 2
        # `n_boxes` is the seq length of the image embeddings
        kwargs["last_offset"] = args.n_boxes
        # Only tested `first` 
        kwargs["pad_mode"] = pad_mode
        kwargs["last_position"] = last_position
        kwargs["tasks"] = args.prompt
    # print(kwargs)

    # print("BEFORE:", out_dict["input_ids"].shape, kwargs["last_position"])
    tokenized = compute_intervention(
            idx, 
            out_dict, 
            tokenizer,
            fields_to_pad,
            fields_to_mask,
            **kwargs)
    # print("AFTER:", tokenized["input_ids"].shape, tokenized["intervention_locations"])
    return tokenized

def keep_intervention_locations(datum):
    new_data = {}
    new_data["input_ids"] = datum["input_ids"]
    new_data["intervention_locations"] = datum["intervention_locations"]
    new_data["attention_mask"] = datum["attention_mask"]
    return new_data


def reft_supplemental_data_collator(batch, tokenizer):
    # Create padded `intervention_locations`
    intervene_batch = [keep_intervention_locations(item) for item in batch]
    # The normal data collator for collating other VLBart fields
    intervention_loc_collate_fn = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=None,
        label_pad_token_id=-100,
        padding="longest"
    )
    
    intervene_batch_entry = intervention_loc_collate_fn(intervene_batch)

    batch_entry = {}
    id = []
    instructions = []
    # Collate `instruction` and `id`
    for i, entry in enumerate(batch):
        if 'instruction' in entry:
            instructions.append(entry['instruction'])
        if 'id' in entry:
            id.append(entry['id'])
    import numpy as np
    batch_entry['id'] = np.array(id)
    batch_entry['instruction'] = instructions
    
    # Pad `intervention_locations` with other stuff in the batch
    if "intervention_locations" in batch[0]:
        batch_entry["intervention_locations"] = intervene_batch_entry["intervention_locations"]
    return batch_entry


Below parts are the same as `VQAFineTuneDataset` in vqa_clip_data.py.

In [9]:
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Dataset, Sampler
import vqa_clip_data as vqa_data

class ReftVQAFineTuneDataset(vqa_data.VQAFineTuneDataset):
    def __init__(self, split='train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train'):
        super().__init__(split, raw_dataset, rank, topk, verbose, args, mode)
    
    def __getitem__(self, idx):

        out_dict = super().__getitem__(idx)

        out_dict["instruction"] = self.tokenizer.decode(
            out_dict['input_ids'], 
            skip_special_tokens=True
        )
        last_position = len(out_dict['input_ids']) - 1
        out_dict = reft_post_process(
            out_dict,
            self.tokenizer,
            idx,
            last_position,
            self.args,
            pad_mode="first",
            fields_to_pad=["input_ids"],
            fields_to_mask=["input_ids"]
        )

        return out_dict


    def collate_fn(self, batch):
        batch_entry = super().collate_fn(batch)
        # BEGIN ADD
        extra_batch = reft_supplemental_data_collator(batch, self.tokenizer)
        for k, v in extra_batch.items():
            batch_entry[k] = v
        # END ADD
        # print("LOGITS:", batch_entry["logits"])
        # print("LABELS:", batch_entry["labels"])

        return batch_entry



In [10]:
def get_loader(args, split='karpathy_train', mode='train',
               batch_size=32, workers=4, distributed=False, gpu=0, topk=-1):

    verbose = (gpu == 0)

    _dset = vqa_data.VQADataset(split, verbose)
    # print("Batch size:", batch_size, "Num workers:", workers, "Topk:", topk)

    dataset = ReftVQAFineTuneDataset(
        split,
        raw_dataset=_dset,
        rank=gpu,
        topk=topk,
        verbose=verbose,
        args=args,
        mode=mode)

    if distributed:
        sampler = DistributedSampler(dataset)
    else:
        sampler = None

    if mode == 'train':
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=(sampler is None),
            num_workers=workers, pin_memory=True, sampler=sampler,
            collate_fn=dataset.collate_fn)
    else:
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=workers, pin_memory=True,
            sampler=sampler,
            shuffle=None if (sampler is not None) else False,
            collate_fn=dataset.collate_fn,
            drop_last=False)

    if verbose:
        loader.evaluator = vqa_data.VQAEvaluator(_dset)

    loader.task = 'vqa'

    return loader



### 1.5 Main Worker, Params, and Run

Main worker here is the same as in multitask.py.

In [11]:
def main_worker(gpu, args):
    # GPU is assigned
    args.gpu = gpu
    args.rank = gpu
    print(f'Process Launching at GPU {gpu}')

    if args.distributed:
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend='nccl')

    print(f"args.feature_type {args.feature_type}")
    feat_dim_dict = {
        "RN50": 2048,
        "RN101": 2048,
        "RN50x4": 2560,
        "ViT": 768
    }
    args.feat_dim = feat_dim_dict[args.feature_type]
    import vqa_clip_data as vqa_data

    vqa_args = deepcopy(args)
    vqa_args.max_text_length = 20


    if args.use_tasks_prompts:
        vqa_args.prompt = "vqa: "
    else:
        vqa_args.prompt = ""

    train_loaders = []

    if args.epochs > 0:
        if 'vqa' in args.tasks:
            print(f'Building VQA train loader at GPU {gpu}')
            vqa_train_loader = get_loader(
                vqa_args,
                split='karpathy_train', mode='train', batch_size=vqa_args.batch_size,
                distributed=args.distributed, gpu=args.gpu,
                workers=args.num_workers,
                topk=args.train_topk,
            )
            train_loaders.append(vqa_train_loader)

    train_loader = multitask_data.MultitaskLoader(
        train_loaders,
        sampling=args.multitask_sampling,
        verbose=gpu==0)

    val_num_workers = 4
    # Validation set
    if gpu == 0:
        val_loader = {}
        if args.epochs > 0:
            if 'vqa' in args.tasks:
                print(f'Building VQA val loader at GPU {gpu}')
                vqa_val_loader = get_loader(
                    vqa_args,
                    split='karpathy_val', mode='val', batch_size=vqa_args.batch_size,
                    distributed=False, gpu=args.gpu,
                    workers=val_num_workers,
                    topk=args.valid_topk,
                )
                val_loader['vqa'] = vqa_val_loader

        # Test set
        test_loader = {}
        if 'vqa' in args.tasks:
            print(f'Building VQA test loader at GPU {gpu}')
            vqa_test_loader = get_loader(
                vqa_args,
                split='karpathy_test', mode='val', batch_size=vqa_args.batch_size,
                distributed=False, gpu=args.gpu,
                workers=val_num_workers,
                topk=args.valid_topk,
            )
            test_loader['vqa'] = vqa_test_loader

            if args.testing:
                vqa_submit_test_loader = get_loader(
                    vqa_args,
                    split='test_4', mode='val', batch_size=vqa_args.batch_size,
                    distributed=False, gpu=args.gpu,
                    workers=val_num_workers,
                    topk=args.valid_topk,
                )
                test_loader['vqa_submit'] = vqa_submit_test_loader
    else:
        val_loader = None
        test_loader = None

    trainer = Trainer(args, train_loader, val_loader, test_loader, train=True)

    trainer.train()

In [12]:
cudnn.benchmark = True
args = parse_args(False)
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node

We added some ReFT parameters as well.

In [13]:
args.distributed = True
args.nproc_per_node = 1
args.master_port = 26464
args.multiGPU = True
args.optim = "adamw"
args.warmup_ratio = 0.1
args.clip_grad_norm = 5
args.weight_decay = 0.01
args.lr = 1e-3
args.epochs = 20
args.num_workers = 4
args.backbone = "facebook/bart-base"
args.output = "snap/VLBart_dora_reft/test/"
args.num_beams = 5
args.use_tasks_prompts = True
args.train_topk = 100
args.valid_topk = 100
args.batch_size = 100
args.valid_batch_size = 100
# args.use_dora = True
args.unfreeze_bias = True
args.unfreeze_layer_norms = True
# args.lora_settings = True
# args.lora_dim = 128
args.tasks = "vqa"
args.dropout = 0.00
args.reft_dropout = 0.00
args.reft_image_dropout = 0.00
args.reft_rank = 4
args.reft_image_rank = 64
args.positions = "f3+l3"
args.image_positions = "f3+l3"

args.feature = "RN101"
args.n_boxes = 36
args.downsample = True
args.image_size = "(224,224)"
args.project_name = "Test"
args.run_name = "tune+lr1e-3"
args.local_rank = 0
args.feature_type = "RN101"
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '26464'

Try the below yourself!

In [None]:
# cudnn.benchmark = True
# args = parse_args(False)
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node
if args.local_rank in [0, -1]:
    print(args)

    comments = []
    if args.load is not None:
        ckpt_str = "_".join(args.load.split('/')[-3:])
        comments.append(ckpt_str)
    if args.comment != '':
        comments.append(args.comment)
    comment = '_'.join(comments)

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M')
    run_name = f'{current_time}_GPU{args.world_size}'
    if len(comments) > 0:
        run_name += f'_{comment}'

    if args.run_name == "":
        args.run_name = run_name

# if args.distributed:
main_worker(args.local_rank, args)


Configurations
{'RefCOCO_BUTD': False,
 'RefCOCO_GT': False,
 'adam_beta1': 0.9,
 'adam_beta2': 0.999,
 'adam_eps': 1e-06,
 'add_adapter_cross_attn': True,
 'add_layer_norm_after_adapter': False,
 'add_layer_norm_before_adapter': False,
 'additional_visual_embedding_layers': 0,
 'answer_normalize': False,
 'backbone': 'facebook/bart-base',
 'batch_size': 100,
 'caption_cocoonly': True,
 'caption_only': False,
 'classifier': False,
 'clip_grad_norm': 5,
 'cls_task': 'tinyimagenet',
 'coco_only': False,
 'comment': '',
 'decoder_prompt_len': 0,
 'deepspeed': None,
 'distributed': True,
 'do_lower_case': False,
 'dora_simple': False,
 'downsample': True,
 'dropout': 0.0,
 'dry': False,
 'efficient_unique_hyper_net': False,
 'encoder_prompt_len': 0,
 'epochs': 20,
 'expand_vis_embedding': False,
 'factorized_phm': True,
 'feat_dim': 2048,
 'feature': 'RN101',
 'feature_type': 'RN101',
 'fp16': False,
 'freeze_bn_statistics': False,
 'freeze_ln_statistics': False,
 'from_scratch': False,
 '



Load 26729 data from split(s) karpathy_val.
# Answers: 3129
Data sources:  ['karpathy_val']
Loaded 26729 data from karpathy_val
Use only 100 data
# all sentences: 100
Building VQA test loader at GPU 0
Load 26280 data from split(s) karpathy_test.
# Answers: 3129
Data sources:  ['karpathy_test']
Loaded 26280 data from karpathy_test
Use only 100 data
# all sentences: 100
IntervenableConfig
{
    "model_type": "None",
    "representations": [
        {
            "layer": 0,
            "component": "block_output",
            "unit": "pos",
            "max_number_of_units": 1,
            "low_rank_dimension": 4,
            "intervention_type": null,
            "intervention": "PLACEHOLDER",
            "subspace_partition": null,
            "group_key": null,
            "intervention_link_key": null,
            "moe_key": null,
            "source_representation": null,
            "hidden_source_representation": null
        },
        {
            "layer": 1,
            "compo



It took 0.8s


[34m[1mwandb[0m: Currently logged in as: [33mpeterzw494[0m ([33mpeterwz[0m). Use [1m`wandb login --relogin`[0m to force relogin




# epoch_tasks: 1


  unit_locations = torch.tensor(
  unit_locations = torch.tensor(
Epoch 0 | LR 0.000500 | VQA 1 | VQA Loss 6.820937: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.29s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.63s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 5859.85it/s]


VQA
Epoch 0: Valid Raw 0.00 Topk 0.00
Epoch 0: Best Raw 0.00

# epoch_tasks: 1


Epoch 1 | LR 0.001000 | VQA 1 | VQA Loss 6.954564: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.51s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.14s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 5824.53it/s]


VQA
Epoch 1: Valid Raw 0.00 Topk 0.00
Epoch 0: Best Raw 0.00

# epoch_tasks: 1


Epoch 2 | LR 0.000944 | VQA 1 | VQA Loss 6.612218: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.71s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.88s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 5870.92it/s]


VQA
Epoch 2: Valid Raw 0.00 Topk 0.00
Epoch 0: Best Raw 0.00

# epoch_tasks: 1


Epoch 3 | LR 0.000889 | VQA 1 | VQA Loss 6.074577: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.97s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.86s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 5982.89it/s]


VQA
Epoch 3: Valid Raw 0.00 Topk 0.00
Epoch 0: Best Raw 0.00

# epoch_tasks: 1


Epoch 4 | LR 0.000833 | VQA 1 | VQA Loss 5.486867: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.14s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.61s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6645.07it/s]


VQA
Epoch 4: Valid Raw 20.90 Topk 20.90
Epoch 4: Best Raw 20.90

# epoch_tasks: 1


Epoch 5 | LR 0.000778 | VQA 1 | VQA Loss 4.943801: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.29s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.08s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6739.03it/s]


VQA
Epoch 5: Valid Raw 1.30 Topk 1.30
Epoch 4: Best Raw 20.90

# epoch_tasks: 1


Epoch 6 | LR 0.000722 | VQA 1 | VQA Loss 4.496159: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.26s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.01s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6471.79it/s]


VQA
Epoch 6: Valid Raw 26.40 Topk 26.40
Epoch 6: Best Raw 26.40

# epoch_tasks: 1


Epoch 7 | LR 0.000667 | VQA 1 | VQA Loss 4.126684: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.21s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.32s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6863.31it/s]


VQA
Epoch 7: Valid Raw 20.00 Topk 14.10
Epoch 6: Best Raw 26.40

# epoch_tasks: 1


Epoch 8 | LR 0.000611 | VQA 1 | VQA Loss 3.899832: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.21s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.88s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6687.13it/s]


VQA
Epoch 8: Valid Raw 30.20 Topk 30.20
Epoch 8: Best Raw 30.20

# epoch_tasks: 1


Epoch 9 | LR 0.000556 | VQA 1 | VQA Loss 3.663690: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.07s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.15s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 4848.01it/s]


VQA
Epoch 9: Valid Raw 30.20 Topk 30.20
Epoch 8: Best Raw 30.20

# epoch_tasks: 1


Epoch 10 | LR 0.000500 | VQA 1 | VQA Loss 3.461259: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.41s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.94s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6794.38it/s]


VQA
Epoch 10: Valid Raw 30.20 Topk 30.20
Epoch 8: Best Raw 30.20

# epoch_tasks: 1


Epoch 11 | LR 0.000444 | VQA 1 | VQA Loss 3.288274: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.12s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.11s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6639.08it/s]


VQA
Epoch 11: Valid Raw 27.30 Topk 27.30
Epoch 8: Best Raw 30.20

# epoch_tasks: 1


Epoch 12 | LR 0.000389 | VQA 1 | VQA Loss 3.133632: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.21s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.23s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6421.26it/s]


VQA
Epoch 12: Valid Raw 25.90 Topk 25.90
Epoch 8: Best Raw 30.20

# epoch_tasks: 1


Epoch 13 | LR 0.000333 | VQA 1 | VQA Loss 2.995314: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.41s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.26s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6600.11it/s]


VQA
Epoch 13: Valid Raw 30.90 Topk 30.90
Epoch 13: Best Raw 30.90

# epoch_tasks: 1


Epoch 14 | LR 0.000278 | VQA 1 | VQA Loss 2.870601: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.46s/it]
VQA Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.14s/it]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 6531.15it/s]


VQA
Epoch 14: Valid Raw 30.20 Topk 30.20
Epoch 13: Best Raw 30.90

# epoch_tasks: 1


  0%|                                                                                                                                                                                                                               | 0/1 [00:00<?, ?it/s]