<a href="https://colab.research.google.com/github/mobarakol/Project_Gen/blob/main/LoRA_PitVQA_Sentence_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Download code
!git clone https://github.com/HRL-Mike/PitVQA.git

#Download Dataset
!mkdir /content/PitVQA/datasets
%cd /content/PitVQA/datasets
!gdown --id 1FoAEY_u0PTAlrscjEifi2om15A83wL78

# Unzipping the VQA EndoVis18 Dataset
!unzip -q EndoVis-18-VQA.zip
%cd /content/PitVQA

Cloning into 'PitVQA'...
remote: Enumerating objects: 401, done.[K
remote: Counting objects: 100% (139/139), done.[K
remote: Compressing objects: 100% (138/138), done.[K
remote: Total 401 (delta 74), reused 0 (delta 0), pack-reused 262 (from 1)[K
Receiving objects: 100% (401/401), 14.44 MiB | 16.49 MiB/s, done.
Resolving deltas: 100% (199/199), done.
/content/PitVQA/datasets
Downloading...
From (original): https://drive.google.com/uc?id=1FoAEY_u0PTAlrscjEifi2om15A83wL78
From (redirected): https://drive.google.com/uc?id=1FoAEY_u0PTAlrscjEifi2om15A83wL78&confirm=t&uuid=b128829b-3e55-4f36-923d-59f08c857a1e
To: /content/PitVQA/datasets/EndoVis-18-VQA.zip
100% 2.71G/2.71G [01:00<00:00, 44.6MB/s]
/content/PitVQA


In [2]:
!pip install -q timm==0.9.12 fairscale==0.4.13 scikit-learn==1.3.2 -U evaluate bert_score

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/60.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/266.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m74.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for fairscale (pyproject.toml) 

### Dataloader

In [2]:
import os
import glob

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from pathlib import Path
from torchvision.transforms.functional import InterpolationMode

class EndoVis18VQAGPTGen(Dataset):
    def __init__(self, seq, folder_head, folder_tail):

        self.transform = transforms.Compose([
            transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),  # input image size
            transforms.ToTensor(),
        ])

        # files, question and answers
        filenames = []
        for curr_seq in seq:
            filenames = filenames + glob.glob(folder_head + str(curr_seq) + folder_tail)
        self.vqas = []
        for file in filenames:
            file_data = open(file, "r")
            lines = [line.strip("\n") for line in file_data if line != "\n"]
            file_data.close()
            for line in lines:
                self.vqas.append([file, line])
        print('Total files: %d | Total question: %.d' % (len(filenames), len(self.vqas)))

        # Labels
        self.labels = ['kidney',
                'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation',
                'Tool_Manipulation', 'Cutting', 'Cauterization', 'Suction',
                'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing',
                'left-top', 'right-top', 'left-bottom', 'right-bottom']

    def __len__(self):
        return len(self.vqas)

    def __getitem__(self, idx):
        qa_full_path = Path(self.vqas[idx][0])
        seq_path = qa_full_path.parents[2]
        file_name = self.vqas[idx][0].split('/')[-1]  # / in linux and \\ in windows

        # img
        img_loc = os.path.join(seq_path, 'left_fr', file_name.split('_')[0] + '.png')
        raw_image = Image.open(img_loc).convert('RGB')
        img = self.transform(raw_image)

        # question and answer
        question, answer = self.vqas[idx][1].split('|')

        return img_loc, img, question, answer

### Model

In [3]:
import torch
from torch import nn

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import ViTModel, BlipConfig, BlipTextModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class _LoRA_qkv(nn.Module):
    def __init__(self, w_qkv, w_a, w_b, lora_alpha, lora_dropout):
        super().__init__()
        self.w_qkv = w_qkv
        self.w_a = w_a
        self.w_b = w_b
        self.lora_alpha = lora_alpha
        self.dropout = nn.Dropout(lora_dropout)
        self.scaling = self.lora_alpha / self.w_a.weight.shape[0]  # alpha / r

        self.weight = self.w_qkv.weight  # load original weights

    def forward(self, x):
        return self.w_qkv(x) + self.scaling * self.dropout(self.w_b(self.w_a(x)))

class LoRAInitializer:
    def __init__(self, model, r=None, lora=None, lora_alpha=32, lora_dropout=0.1):
        if r is None:
            r = [14, 14, 12, 12, 10, 10, 8, 8, 8, 8, 8, 8]
        if lora is None:
            lora = ['q', 'v']

        self.model = model
        self.r = r
        self.lora = lora
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.w_As = []
        self.w_Bs = []

    def reset_parameters(self):
        for w_A, w_B in zip(self.w_As, self.w_Bs):
            # normal distribution init for w_A
            nn.init.normal_(w_A.weight, mean=0.0, std=0.02)
            nn.init.zeros_(w_B.weight)  # zero init for w_B

    def initialize_lora(self):
        for param in self.model.transformer.parameters():
            param.requires_grad = False  # freeze transformer parameters
            # param.requires_grad = True

        for t_layer_i, blk in enumerate(self.model.transformer.h):  # t_layer_i = [0, 11], blk = transformer block
            # GPT2 uses a single c_attn for q, k, v
            w_qkv = blk.attn.c_attn
            in_features = w_qkv.weight.shape[0]  # 768
            out_features = w_qkv.weight.shape[1]  # 2304

            w_a_linear = nn.Linear(in_features, self.r[t_layer_i], bias=False)
            w_b_linear = nn.Linear(self.r[t_layer_i], out_features, bias=False)
            self.w_As.append(w_a_linear)
            self.w_Bs.append(w_b_linear)
            blk.attn.c_attn = _LoRA_qkv(w_qkv, w_a_linear, w_b_linear, self.lora_alpha, self.lora_dropout)

        self.reset_parameters()
        print("LoRA params initialized!")
        return self.model


class BLIPGPTVQAGen(nn.Module):
    def __init__(self, r=None, lora=None, lora_alpha=32, lora_dropout=0.1):
        super(BLIPGPTVQAGen, self).__init__()

        # gpt2 decoder
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt = LoRAInitializer(self.gpt, r=r, lora=lora, lora_alpha=lora_alpha,
                       lora_dropout=lora_dropout).initialize_lora()  # add lora

        # visual encoder
        model_name = "google/vit-base-patch16-224-in21k"
        self.visual_encoder = ViTModel.from_pretrained(model_name)

        # tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token  # end of string

        # text encoder
        config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
        self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)

        # modify embedding layer
        new_vocab_size = len(self.tokenizer)
        embedding_dim = self.text_encoder.embeddings.word_embeddings.embedding_dim
        self.text_encoder.embeddings.word_embeddings = nn.Embedding(new_vocab_size, embedding_dim)  # He init

    def forward(self, image, question_inputs, answer_inputs=None):
        # visual encoder
        image = image.to(device)
        image_embeds = self.visual_encoder(image).last_hidden_state  # torch.Size([bs, 197, 768])
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)  # torch.Size([bs, 197])

        question_input_ids = question_inputs['input_ids']  # torch.Size([bs, 25])
        question_att_mask = question_inputs['attention_mask']

        answer_input_ids = answer_inputs['input_ids']  # torch.Size([bs, 25])
        answer_att_mask = answer_inputs['attention_mask']

        # multimodal encoder
        img_question_output = self.text_encoder(input_ids=question_input_ids,
                         attention_mask=question_att_mask,
                         encoder_hidden_states=image_embeds,
                         encoder_attention_mask=image_atts,
                         return_dict=True)

        img_question_embeds = img_question_output.last_hidden_state  # torch.Size([bs, 25, 768]), args.question_len=25

        # multimodal encoder
        img_answer_output = self.text_encoder(input_ids=answer_input_ids,
                         attention_mask=answer_att_mask,
                         encoder_hidden_states=image_embeds,
                         encoder_attention_mask=image_atts,
                         return_dict=True)

        img_answer_embeds = img_answer_output.last_hidden_state  # torch.Size([bs, 25, 768]), args.question_len=25
        # print('img_answer_embeds:', img_answer_embeds.shape)

        inputs_embeds_qa = torch.cat((img_question_embeds, img_answer_embeds), dim=1)
        # print('inputs_embeds_qa:', inputs_embeds_qa.shape)

        # text decoder
        gpt_output = self.gpt(inputs_embeds=inputs_embeds_qa,
                              encoder_attention_mask=question_att_mask)  # torch.Size([bs, 25, 50257])
        return gpt_output.logits

### Main

In [None]:
import os
import torch
import argparse
import torch.utils.data
import numpy as np
import random

from torch import nn
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer

from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

def adjust_learning_rate(optimizer, shrink_factor):
    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train(args, train_dataloader, model, criterion, optimizer, epoch, tokenizer, device):
    model.train()
    total_loss = AverageMeter()

    for i, (_, images, questions, answers) in enumerate(tqdm(train_dataloader), 0):
        question_inputs = tokenizer(questions, padding="max_length", max_length=int(args.seq_length),
                                    return_tensors="pt", truncation=True)
        answer_inputs = tokenizer(answers, padding="max_length", max_length=int(args.seq_length),
                                  return_tensors="pt", truncation=True)

        # get logits and labels
        logits = model(image=images.to(device), question_inputs=question_inputs.to(device), answer_inputs=answer_inputs.to(device))
        labels = answer_inputs['input_ids'].to(device)

        # print('logit:', logits.shape)
        # get shifted logits and labels
        shift_logits = logits[:, args.seq_length:, :].contiguous()
        shift_labels = labels[:, :].contiguous()
        # print('shift_logits:', shift_logits.shape)
        # print('shift_labels:', shift_labels.shape)

        # compute loss
        loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss.update(loss.item())

    print("Epoch: {}/{} Loss: {:.6f} AVG_Loss: {:.6f}".format(epoch, args.epochs, total_loss.val, total_loss.avg))

def validate(args, val_loader, model, criterion, epoch, tokenizer, device):
    references = []
    hypotheses = []

    model.eval()
    total_loss = AverageMeter()
    with torch.no_grad():
        for i, (_, images, questions, answers) in enumerate(tqdm(val_loader), 0):
            question_inputs = tokenizer(questions, padding="max_length", max_length=int(args.seq_length),
                                        return_tensors="pt", truncation=True)
            answer_inputs = tokenizer(answers, padding="max_length", max_length=int(args.seq_length),
                                      return_tensors="pt", truncation=True)

            # get logits and labels
            logits = model(image=images.to(device), question_inputs=question_inputs.to(device), answer_inputs=answer_inputs.to(device))
            labels = answer_inputs['input_ids'].to(device)

            # get shifted logits and labels
            shift_logits = logits[:, args.seq_length:, :].contiguous()
            shift_labels = labels[:, :].contiguous()

            # compute loss
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss.update(loss.item())

            # generate predicted answer
            _, predicted = torch.max(shift_logits, dim=-1)

            # decode references and predictions
            reference_answers = tokenizer.batch_decode(labels, skip_special_tokens=True)
            predicted_answers = tokenizer.batch_decode(predicted, skip_special_tokens=True)
            # print('reference_answers:', reference_answers)
            # add references and hypotheses to lists
            for ref, hyp in zip(reference_answers, predicted_answers):
                references.append([ref.split()])
                hypotheses.append(hyp.split())


        # Calculate BLEU_1~4
        metrics = {}
        metrics["Bleu_1"] = corpus_bleu(references, hypotheses, weights=(1.00, 0.00, 0.00, 0.00))
        metrics["Bleu_2"] = corpus_bleu(references, hypotheses, weights=(0.50, 0.50, 0.00, 0.00))
        metrics["Bleu_3"] = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0.00))
        metrics["Bleu_4"] = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
        print(f"Epoch: {epoch}/{args.epochs} EVA LOSS: {total_loss.avg:.6f} "
              f"BLEU-1: {metrics['Bleu_1']:.6f} BLEU-2: {metrics['Bleu_2']:.6f} "
              f"BLEU-3: {metrics['Bleu_3']:.6f} BLEU-4: {metrics['Bleu_4']:.6f}")
    return metrics

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

class Args:
    def __init__(self):
        self.epochs = 20
        self.batch_size = 8
        self.workers = 8
        self.random_seed = 42
        self.seq_length = 32
        self.lr = 0.00002

        self.vector_rank = [14, 14, 12, 12, 10, 10, 8, 8, 8, 8, 8, 8]
        self.lora_alpha = 32
        self.lora_dropout = 0.1

if __name__ == '__main__':
    args = Args()
    os.makedirs('./checkpoints/', exist_ok=True)

    seed_everything(args.random_seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    start_epoch = 1
    best_epoch = [0]
    best_results = [0.0]
    epochs_since_improvement = 0

    # data location
    train_seq = [2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]
    val_seq = [1, 5, 16]

    folder_head = '/content/PitVQA/datasets/EndoVis-18-VQA/seq_'
    folder_tail = '/vqa/Sentence/*.txt'

    # dataloader
    train_dataset = EndoVis18VQAGPTGen(train_seq, folder_head, folder_tail)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    val_dataset = EndoVis18VQAGPTGen(val_seq, folder_head, folder_tail)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    print(f'num of elements: {len(args.vector_rank)}')
    model = BLIPGPTVQAGen(r=args.vector_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)  # same learning rate for LoRA weights and other weights

    model = model.to(device)
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('model params: ', pytorch_total_params)
    criterion = nn.CrossEntropyLoss().to(device)

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    print('Start training.')
    for epoch in range(start_epoch, args.epochs+1):

        if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # train
        train(args, train_dataloader=train_dataloader, model=model, criterion=criterion, optimizer=optimizer,
              epoch=epoch, tokenizer=tokenizer, device=device)

        # validation
        metrics = validate(args, val_loader=val_dataloader, model=model, criterion=criterion, epoch=epoch,
                           tokenizer=tokenizer, device=device)

        if metrics["Bleu_4"] >= best_results[0]:
            epochs_since_improvement = 0
            best_results[0] = metrics["Bleu_4"]
            best_epoch[0] = epoch
            print(f'Best epoch: {epoch}, Best Bleu_4: {metrics["Bleu_4"]}')
            torch.save(model.state_dict(), 'checkpoints/model_best.pth')
        else:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))

    print('End training.')

Total files: 1560 | Total question: 10574
Total files: 447 | Total question: 3216
num of elements: 12
LoRA params initialized!
model params:  363611136
Start training.


 86%|████████▌ | 1139/1322 [13:21<02:01,  1.51it/s]

In [None]:
import torch
import numpy as np
from nltk.tokenize import TreebankWordTokenizer
def treebank_tokenize(s):
    return TreebankWordTokenizer().tokenize(s)
def generate_beam(
    model,
    tokenizer,
    beam_size: int = 5,
    generated=None,
    entry_length=65,
    temperature=1.0,
    stop_token: str = "<|endoftext|>",
):
    model.eval()
    stop_token_index = tokenizer.encode(stop_token)[0]
    tokens = None
    scores = None
    device = next(model.parameters()).device
    seq_lengths = torch.ones(beam_size, device=device)
    is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
    with torch.no_grad():
        for i in range(entry_length):
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits

            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

            logits = logits.softmax(-1).log()
            # final_logit

            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                generated = generated.expand(beam_size, *generated.shape[1:])
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
                    beam_size, -1
                )
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            if model.model_type == "biogpt":
                next_token_embed = model.gpt.biogpt.embed_tokens(
                    next_tokens.squeeze()
                ).view(generated.shape[0], 1, -1)
            elif model.model_type == "gpt2":
                next_token_embed = model.gpt.transformer.wte(
                    next_tokens.squeeze()
                ).view(generated.shape[0], 1, -1)
            else:
                next_token_embed = model.gpt.get_input_embeddings()(tokens[:,-1])
                next_token_embed=next_token_embed.squeeze().view(generated.shape[0], 1, -1)
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [
        tokenizer.decode(output[: int(length)])
        for output, length in zip(output_list, seq_lengths)
    ]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts

from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score,roc_auc_score
# from utils import generate_beam
from nltk.translate.bleu_score import sentence_bleu
from transformers import GPT2Tokenizer
import pdb
from evaluate import load
import collections
from torch.cuda.amp import autocast
import os

def print_nearest_text_token(vis_token, model):
    """print the nearest token in the vocabulary to the given token through model.gpt.embeddings.weight"""
    embeddings = model.gpt.transformer.wte.weight
    distances = torch.norm(embeddings - vis_token, dim=1)
    nearest_token_idx = torch.argmin(distances)
    print(model.tokenizer.decode([nearest_token_idx.item()]))

def compute_f1(gold_toks, pred_toks):
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

def eval_gpt_open_ended(model, dataset, args, print_vis_token_meaning=False):
    model.eval()
    model=model.cuda()
    bert_score = load("bertscore")
    # tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained(model.model_type)
    bleu_avg1=0.
    bert_avg1 = 0.
    bert_avg2 = 0.
    bert_avg3 = 0.
    f1_avg = 0.
    acc = 0.
    acc_oe = 0.
    acc_yn = 0.
    c_oe =1e-9
    c_yn =1e-9
    with tqdm(total=len(dataset)) as epoch_pbar:
        epoch_pbar.set_description("Testing")
        for item in range(len(dataset)):
            prefix,  labels, tokens, mask, q_len = dataset[item]
            prefix = prefix.type(torch.float32).cuda()
            tokens = tokens.type(torch.long).cuda()
            mask = mask.cuda()
            with autocast(dtype=torch.float16):
              with torch.no_grad():
                  embed = model.generate(prefix,labels,tokens,mask,q_len).view(1,tokens.size(0),-1)
                  if print_vis_token_meaning:
                    prefix_projections = embed[:,q_len:q_len+model.prefix_length,:]
                    for i in range(prefix_projections.size(1)):
                      print_nearest_text_token(prefix_projections[0,i], model)
                  out_text = generate_beam(model, model.tokenizer,generated=embed,entry_length=dataset.max_seqs_len[1], temperature=1)[0]

            if out_text.lower()==dataset.answers[item].lower():
              acc+=1
            if dataset.answers[item].lower()=='yes' or dataset.answers[item].lower()=='no':
              if out_text.lower()==dataset.answers[item].lower():
                acc_yn+=1
              c_yn+=1
            else:
              if out_text.lower()==dataset.answers[item].lower():
                acc_oe+=1
              c_oe+=1

            reference = [str(dataset.answers[item])]
            candidate = [out_text]

            bleu_1 = sentence_bleu(reference[0], candidate[0], weights=(1, 0, 0, 0))

            a = bert_score.compute(references = reference,predictions = candidate,model_type = 'bert-base-uncased')
            bert_avg1+= a['precision'][0]
            bert_avg2+= a['recall'][0]
            bert_avg3+= a['f1'][0]


            f1_avg += compute_f1(tokenizer.encode(reference[0]),tokenizer.encode(candidate[0]))
            bleu_avg1+=bleu_1


    print('------------')
    print("BLEU {}".format(round(bleu_avg1/len(dataset),3)))
    print("BERTScore {}".format(round(bert_avg3/len(dataset),3)))
    print("F1 {}".format(round(f1_avg/len(dataset),3)))
    print("Accuracy {}".format(round(acc/len(dataset),3)))
    print("Accuracy YN{}".format(round(acc_yn/c_yn,3)))
    print("Accuracy OE{}".format(round(acc_oe/c_oe,3)))