In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn # as nn
from torch.utils.data import DataLoader
import os, yaml#, json
import editdistance
import socket, datetime#, getpass
import wandb as wb


In [2]:
import random
import transformers
from transformers import get_scheduler

from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
from PIL import Image
# import models._model_utils as model_utils
from torch.utils.data import Dataset

In [3]:
def train_epoch(data_loader, model, optimizer, lr_scheduler, evaluator, logger, **kwargs):
    model.model.train()

    for batch_idx, batch in enumerate(tqdm(data_loader)):
        gt_answers = batch['answers']
        # outputs, pred_answers, pred_answer_page, answer_conf = model.forward(batch, return_pred_answer=True)
        # outputs, pred_answers, answer_conf = model.forward(batch, return_pred_answer=True)
        # print(len(model.forward(batch, return_pred_answer=True)))
        outputs, pred_answers, _ = model.forward(batch, return_pred_answer=True)

        # loss = outputs.loss + outputs.ret_loss if hasattr(outputs, 'ret_loss') else outputs.loss
        loss = outputs.loss
 
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        optimizer.zero_grad()

        metric = evaluator.get_metrics(gt_answers, pred_answers)

        batch_acc = np.mean(metric['accuracy'])
        batch_anls = np.mean(metric['anls'])

        log_dict = {
            'Train/Batch loss': outputs.loss.item(),
            'Train/Batch Accuracy': batch_acc,         
            'Train/Batch ANLS': batch_anls,
            'lr': optimizer.param_groups[0]['lr']
        }

        # if hasattr(outputs, 'ret_loss'):
        #     log_dict['Train/Batch retrieval loss'] = outputs.ret_loss.item()

        # if 'answer_page_idx' in batch and None not in batch['answer_page_idx']:
        #     ret_metric = evaluator.get_retrieval_metric(batch.get('answer_page_idx', None), pred_answer_page)
        #     batch_ret_prec = np.mean(ret_metric)
        #     log_dict['Train/Batch Ret. Prec.'] = batch_ret_prec

        logger.logger.log(log_dict, step=logger.current_epoch * logger.len_dataset + batch_idx)

    # return total_accuracies, total_anls, answers


# def seed_worker(worker_id):
#     worker_seed = torch.initial_seed() % 2 ** 32
#     np.random.seed(worker_seed)
#     np.seed(worker_seed)

In [4]:
class IFDocVQA(Dataset):

    def __init__(self, imbd_dir, images_dir, split, kwargs):
        data = np.load(os.path.join(imbd_dir, "infographics_imdb_{:s}.npy".format(split)), allow_pickle=True)
        # self.header = data[0]
        # self.imdb = data[1:]
        self.imdb = data

        self.max_answers = 2
        self.images_dir = images_dir

        # self.hierarchical_method = kwargs.get('hierarchical_method', False)
        self.use_images = kwargs.get('use_images', False)
        self.get_raw_ocr_data = kwargs.get('get_raw_ocr_data', False)

    def __len__(self):
        return len(self.imdb)

    def __getitem__(self, idx):
        record = self.imdb[idx]
        question = record['question']
        context = ' '.join([word.lower() for word in record['ocr_tokens']])
        # context_page_corresp = [0 for ix in range(len(context))]  # This is used to predict the answer page in MP-DocVQA. To keep it simple, use a mock list with corresponding page to 0.
        
        #나중에 if문 걸어서 train/inference 따로 
        if 'answers' in record :
            answers = list(set(answer.lower() for answer in record['answers']))
        else : 
            answers = ['0' * len(question)] 
        # answers = list(set(answer.lower() for answer in record.get('answers', [])))
        
        if self.use_images:
            # image_name = os.path.join(self.images_dir, "{:s}.png".format(record['image_name']))
            image_name = os.path.join(self.images_dir, "{:s}".format(record['image_name']))
            image = Image.open(image_name).convert("RGB")
        
        if self.get_raw_ocr_data:
            words = [word.lower() for word in record['ocr_tokens']]
            context = ' '.join([word.lower() for word in record['ocr_tokens']])
            boxes = np.array([bbox for bbox in record['ocr_normalized_boxes']])
        # else:
        #     words = 
        #     context = procsec
        #     boxes = 
            
        # if self.hierarchical_method:
        #     words = [words]
        #     boxes = [boxes]
        #     image_name = [image_name]
        #     image = [image]

        start_idxs, end_idxs = self._get_start_end_idx(context, answers)

        sample_info = {'question_id': record['question_id'],
                       'questions': question,
                       'contexts': context,
                       'answers': answers,
                       'start_indxs': start_idxs,
                       'end_indxs': end_idxs
                       }        

        if self.use_images:
            sample_info['image_names'] = image_name
            sample_info['images'] = image
        

        if self.get_raw_ocr_data:
            sample_info['words'] = words
            sample_info['boxes'] = boxes
            # sample_info['num_pages'] = 1
            # sample_info['answer_page_idx'] = 0

        else:  # Information for extractive models
            # sample_info['context_page_corresp'] = context_page_corresp
            sample_info['start_indxs'] = start_idxs
            sample_info['end_indxs'] = end_idxs

        return sample_info

    def _get_start_end_idx(self, context, answers):

        answer_positions = []
        for answer in answers:
            start_idx = context.find(answer)

            if start_idx != -1:
                end_idx = start_idx + len(answer)
                answer_positions.append([start_idx, end_idx])

        if len(answer_positions) > 0:
            start_idx, end_idx = random.choice(answer_positions)  # If both answers are in the context. Choose one randomly.
        else:
            start_idx, end_idx = 0, 0  # If the indices are out of the sequence length they are ignored. Therefore, we set them as a very big number.

        return start_idx, end_idx

In [5]:
def build_dataset(config, split):

    # Specify special params for data processing depending on the model used.
    dataset_kwargs = {}
    dataset_kwargs['get_raw_ocr_data'] = True
    dataset_kwargs['use_images'] = True

    # Build dataset
    # from datasets.IF_DocVQA import IFDocVQA
    dataset = IFDocVQA(config['imdb_dir'], config['images_dir'], split, dataset_kwargs)
    
    return dataset

In [6]:
class Logger:

    def __init__(self, config):

        self.log_folder = config['save_dir']

        experiment_date = datetime.datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
        self.experiment_name = "{:s}__{:}".format(config['model_name'], experiment_date)

        machine_dict = {'cvc117': 'Local', 'cudahpc16': 'DAG', 'cudahpc25': 'DAG-A40'}
        machine = machine_dict.get(socket.gethostname(), socket.gethostname())

        dataset = config['dataset_name']
        # page_retrieval = config.get('page_retrieval', '-').capitalize()
        visual_encoder = config.get('visual_module', {}).get('model', '-').upper()

        document_pages = config.get('max_pages', None)
        page_tokens = config.get('page_tokens', None)
        tags = [config['model_name'], dataset, machine]
        config = {'Model': config['model_name'], 'Weights': config['model_weights'], 'Dataset': dataset,
                  'Visual Encoder': visual_encoder,
                  'Batch size': config['batch_size'], 'Max. Seq. Length': config.get('max_sequence_length', '-'),
                  'lr': config['lr'], 'seed': config['seed']}

        if document_pages:
            config['Max Pages'] = document_pages

        if page_tokens:
            config['PAGE tokens'] = page_tokens

        # self.logger = wb.init(project="MP-DocVQA", name=self.experiment_name, dir=self.log_folder, tags=tags, config=config)
        self.logger = wb.init(project="Hyunyoung in the house motherfuckers~", name=self.experiment_name, dir=self.log_folder, tags=tags, config=config)
        self._print_config(config)

        self.current_epoch = 0
        self.len_dataset = 0

    def _print_config(self, config):
        print("{:s}: {:s} \n{{".format(config['Model'], config['Weights']))
        for k, v in config.items():
            if k != 'Model' and k != 'Weights':
                print("\t{:}: {:}".format(k, v))
        print("}\n")

    def log_model_parameters(self, model):
        total_params = sum(p.numel() for p in model.model.parameters())
        trainable_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)

        self.logger.config.update({
            'Model Params': int(total_params / 1e6),  # In millions
            'Model Trainable Params': int(trainable_params / 1e6)  # In millions
        })

        print("Model parameters: {:d} - Trainable: {:d} ({:2.2f}%)".format(
            total_params, trainable_params, trainable_params / total_params * 100))

    def log_val_metrics(self, accuracy, anls, ret_prec, update_best=False):

        str_msg = "Epoch {:d}: Accuracy {:2.2f}     ANLS {:2.4f}    Retrieval precision: {:2.2f}%".format(self.current_epoch, accuracy*100, anls, ret_prec*100)
        self.logger.log({
            'Val/Epoch Accuracy': accuracy,
            'Val/Epoch ANLS': anls,
            'Val/Epoch Ret. Prec': ret_prec,
        }, step=self.current_epoch*self.len_dataset + self.len_dataset)

        if update_best:
            str_msg += "\tBest Accuracy!"
            self.logger.config.update({
                "Best Accuracy": accuracy,
                "Best epoch": self.current_epoch
            }, allow_val_change=True)

        print(str_msg)



In [7]:
class Evaluator:
    def __init__(self, case_sensitive=False):

        self.case_sensitive = case_sensitive
        self.get_edit_distance = editdistance.eval
        self.anls_threshold = 0.5

        self.total_accuracies = []
        self.total_anls = []

        self.best_accuracy = 0
        # self.best_anls = 0
        self.best_epoch = 0

    def get_metrics(self, gt_answers, preds, answer_types=None, update_global_metrics=True):
        answer_types = answer_types if answer_types is not None else ['string' for batch_idx in range(len(gt_answers))]
        batch_accuracy = []
        batch_anls = []
        for batch_idx in range(len(preds)):
            gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]]
            pred = self._preprocess_str(preds[batch_idx])

            batch_accuracy.append(self._calculate_accuracy(gt, pred, answer_types[batch_idx]))
            batch_anls.append(self._calculate_anls(gt, pred, answer_types[batch_idx]))

        # if accumulate_metrics:
        #     self.total_accuracies.extend(batch_accuracy)
        #     self.total_anls.extend(batch_anls)

        return {'accuracy': batch_accuracy, 'anls': batch_anls}

    def get_retrieval_metric(self, gt_answer_page, pred_answer_page):
        retrieval_precision = [1 if gt == pred else 0 for gt, pred in zip(gt_answer_page, pred_answer_page)]
        return retrieval_precision

    def update_global_metrics(self, accuracy, anls, current_epoch):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.best_epoch = current_epoch
            return True

        else:
            return False

    def _preprocess_str(self, string):
        if not self.case_sensitive:
            string = string.lower()

        return string.strip()

    def _calculate_accuracy(self, gt, pred, answer_type):

        if answer_type == 'not-answerable':
            return 1 if pred in ['', 'none', 'NA', None, []] else 0

        if pred == 'none' and answer_type != 'not-answerable':
            return 0

        for gt_elm in gt:
            if gt_elm == pred:
                return 1

        return 0

    def _calculate_anls(self, gt, pred, answer_type):
        if len(pred) == 0:
            return 0

        if answer_type == 'not-answerable':
            return 1 if pred in ['', 'none', 'NA', None, []] else 0

        if pred == 'none' and answer_type != 'not-answerable':
            return 0

        answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt]
        max_similarity = max(answers_similarity)

        anls = max_similarity if max_similarity >= self.anls_threshold else 0
        return anls

In [8]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [9]:
def Info_docvqa_collate_fn(batch):
    batch = {k: [dic[k] for dic in batch] for k in batch[0]}  # List of dictionaries to dict of lists.
    return batch

In [10]:
def build_optimizer(model, length_train_loader, config):
    optimizer_class = getattr(transformers, 'AdamW')
    optimizer = optimizer_class(model.model.parameters(), lr=float(config['lr']))
    num_training_steps = config['train_epochs'] * length_train_loader
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=config['warmup_iterations'], num_training_steps=num_training_steps
    )

    return optimizer, lr_scheduler

In [11]:
def evaluate(data_loader, model, evaluator, **kwargs):

    return_scores_by_sample = kwargs.get('return_scores_by_sample', False)
    return_answers = kwargs.get('return_answers', False)

    if return_scores_by_sample:
        scores_by_samples = {}
        total_accuracies = []
        total_anls = []
        total_ret_prec = []

    else:
        total_accuracies = 0
        total_anls = 0
        total_ret_prec = 0

    all_pred_answers = []
    model.model.eval()

    for batch_idx, batch in enumerate(tqdm(data_loader)):
        bs = len(batch['question_id'])
        with torch.no_grad():
            # outputs, pred_answers, pred_answer_page, answer_conf = model.forward(batch, return_pred_answer=True)
            outputs, pred_answers, answer_conf = model.forward(batch, return_pred_answer=True)
            # print(pred_answers)

        metric = evaluator.get_metrics(batch['answers'], pred_answers, batch.get('answer_type', None))

        # if 'answer_page_idx' in batch and pred_answer_page is not None:
        #     ret_metric = evaluator.get_retrieval_metric(batch['answer_page_idx'], pred_answer_page)
        # else:
        #     ret_metric = [0 for _ in range(bs)]

        if return_scores_by_sample:
            for batch_idx in range(bs):
                #TODO : 여기다가 만약 inference 단계이면 answer 필요한 부분 다 빼기
                scores_by_samples[batch['question_id'][batch_idx]] = {
                    'accuracy': metric['accuracy'][batch_idx],
                    'anls': metric['anls'][batch_idx],
                    # 'ret_prec': ret_metric[batch_idx],
                    'pred_answer': pred_answers[batch_idx],
                    'pred_answer_conf': answer_conf[batch_idx],
                    # 'pred_answer_page': pred_answer_page[batch_idx] if pred_answer_page is not None else None,
                    'image_names' : batch['image_names'][batch_idx], # 여기서 부터 추가한 부분
                    'question' : batch['questions'][batch_idx], 

                    
                }

        if return_scores_by_sample:
            total_accuracies.extend(metric['accuracy'])
            total_anls.extend(metric['anls'])
            # total_ret_prec.extend(ret_metric)

        else:
            total_accuracies += sum(metric['accuracy'])
            total_anls += sum(metric['anls'])
            # total_ret_prec += sum(ret_metric)

        if return_answers:
            all_pred_answers.extend(pred_answers)

    if not return_scores_by_sample:
        total_accuracies = total_accuracies/len(data_loader.dataset)
        total_anls = total_anls/len(data_loader.dataset)
        total_ret_prec = total_ret_prec/len(data_loader.dataset)
        scores_by_samples = []

    return total_accuracies, total_anls, total_ret_prec, all_pred_answers, scores_by_samples

In [12]:
def save_yaml(path, data):
    # print(data)
    with open(path, 'w+') as f:
        yaml.dump(data, f)

In [13]:
def save_model(model, epoch, update_best=False, **kwargs):
    save_dir = os.path.join(kwargs['save_dir'], 'checkpoints', "{:s}_{:s}".format(kwargs['model_name'].lower(), kwargs['dataset_name'].lower()))
    # model.model.save_pretrained(os.path.join(save_dir, "model__{:d}.ckpt".format(epoch)))

    tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else model.processor if hasattr(model, 'processor') else None
    # if tokenizer is not None:
    #     tokenizer.save_pretrained(os.path.join(save_dir, "model__{:d}.ckpt".format(epoch)))

    # if hasattr(model.model, 'visual_embeddings'):
    #     model.model.visual_embeddings.feature_extractor.save_pretrained(os.path.join(save_dir, "model__{:d}.ckpt".format(epoch)))

    # save_yaml(os.path.join(save_dir, "model__{:d}.ckpt".format(epoch), "experiment_config.yml"), kwargs)

    if update_best:
        model.model.save_pretrained(os.path.join(save_dir, "best.ckpt"))
        tokenizer.save_pretrained(os.path.join(save_dir, "best.ckpt"))
        save_yaml(os.path.join(save_dir, "best.ckpt", "experiment_config.yml"), kwargs)

In [14]:
def train(model, **kwargs):

    epochs = kwargs['train_epochs']
    # device = kwargs['device']
    batch_size = kwargs['batch_size']
    seed_everything(kwargs['seed'])

    evaluator = Evaluator(case_sensitive=False)
    logger = Logger(config=kwargs)
    logger.log_model_parameters(model)

    train_dataset = build_dataset(config, 'train')
    val_dataset   = build_dataset(config, 'val')
    print('done')
    # g = torch.Generator()
    # g.manual_seed(kwargs['seed'])

    train_data_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=Info_docvqa_collate_fn)
    val_data_loader   = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=Info_docvqa_collate_fn)
    # train_data_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=singledocvqa_collate_fn, worker_init_fn=seed_worker, generator=g)
    # val_data_loader   = DataLoader(val_dataset, batch_size=config['batch_size'],  shuffle=False, collate_fn=singledocvqa_collate_fn, worker_init_fn=seed_worker, generator=g)

    logger.len_dataset = len(train_data_loader)
    optimizer, lr_scheduler = build_optimizer(model, length_train_loader=len(train_data_loader), config=kwargs)

    if kwargs.get('eval_start', False):
        logger.current_epoch = -1
        accuracy, anls, ret_prec, _, _ = evaluate(val_data_loader, model, evaluator, return_scores_by_sample=False, return_pred_answers=False, **kwargs)
        is_updated = evaluator.update_global_metrics(accuracy, anls, -1)
        logger.log_val_metrics(accuracy, anls, ret_prec, update_best=is_updated)

    for epoch_ix in range(epochs):
        logger.current_epoch = epoch_ix
        train_epoch(train_data_loader, model, optimizer, lr_scheduler, evaluator, logger, **kwargs)
        accuracy, anls, ret_prec, _, _ = evaluate(val_data_loader, model, evaluator, return_scores_by_sample=False, return_pred_answers=False, **kwargs)
        is_updated = evaluator.update_global_metrics(accuracy, anls, epoch_ix)
        logger.log_val_metrics(accuracy, anls, ret_prec, update_best=is_updated)
        save_model(model, epoch_ix, update_best=is_updated, **kwargs)


In [15]:
def get_extractive_confidence(outputs):
    bs = len(outputs['start_logits'])
    start_idxs = torch.argmax(outputs.start_logits, axis=1)
    end_idxs = torch.argmax(outputs.end_logits, axis=1)

    answ_confidence = []
    for batch_idx in range(bs):
        conf_mat = np.matmul(np.expand_dims(outputs.start_logits.softmax(dim=1)[batch_idx].unsqueeze(dim=0).detach().cpu(), -1),
                             np.expand_dims(outputs.end_logits.softmax(dim=1)[batch_idx].unsqueeze(dim=0).detach().cpu(), 1)).squeeze(axis=0)

        answ_confidence.append(
            conf_mat[start_idxs[batch_idx], end_idxs[batch_idx]].item()
        )

    return answ_confidence

In [16]:
class LayoutLMv3_hy:

    def __init__(self, config):
        self.batch_size = config['batch_size']
        self.apply_ocr = config['apply_ocr']
        self.processor = LayoutLMv3Processor.from_pretrained(config['model_weights'], apply_ocr=config['apply_ocr'])  # Check that this do not fuck up the code.
        # self.processor = LayoutLMv3Processor.from_pretrained(config['model_weights'], apply_ocr=False)  # Check that this do not fuck up the code.
        self.model = LayoutLMv3ForQuestionAnswering.from_pretrained(config['model_weights'])
        self.ignore_index = 9999  # 0

    # def parallelize(self):
    #     self.model = nn.DataParallel(self.model)

    def forward(self, batch, return_pred_answer=False):

        # bs = len(batch['question_id'])
        question = batch['questions']
        context = batch['contexts']
        answers = batch['answers']
        images = batch['images']

        boxes = [(bbox * 1000).astype(int) for bbox in batch['boxes']]  # Scale boxes 0->1 to 0-->1000.
        
        if self.apply_ocr:
            encoding = self.processor(images, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
        else:
            encoding = self.processor(images, question, batch["words"], boxes=boxes, return_tensors="pt", padding=True, truncation=True).to(self.model.device)

        start_pos, end_pos = self.get_start_end_idx(encoding, context, answers)
        outputs = self.model(**encoding, start_positions=start_pos, end_positions=end_pos)
        pred_answers, answ_confidence = self.get_answer_from_model_output(encoding.input_ids, outputs) if return_pred_answer else None

        return outputs, pred_answers, answ_confidence

    def get_concat_v_multi_resize(self, im_list, resample=Image.BICUBIC):
        min_width = min(im.width for im in im_list)
        im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=resample) for im in im_list]

        # Fix equal height for all images (breaks the aspect ratio).
        heights = [im.height for im in im_list]
        im_list_resize = [im.resize((im.height, max(heights)), resample=resample) for im in im_list_resize]

        total_height = sum(im.height for im in im_list_resize)
        dst = Image.new('RGB', (min_width, total_height))
        pos_y = 0
        for im in im_list_resize:
            dst.paste(im, (0, pos_y))
            pos_y += im.height
        return dst

    def get_start_end_idx(self, encoding, context, answers):
        pos_idx = []
        for batch_idx in range(len(encoding.input_ids)):
            answer_pos = []
            for answer in answers[batch_idx]:
                encoded_answer = [token for token in self.processor.tokenizer.encode([answer], boxes=[0, 0, 0, 0]) if token not in self.processor.tokenizer.all_special_ids]
                answer_tokens_length = len(encoded_answer)

                for token_pos in range(len(encoding.input_ids[batch_idx])):
                    if encoding.input_ids[batch_idx][token_pos: token_pos+answer_tokens_length].tolist() == encoded_answer:
                        answer_pos.append([token_pos, token_pos + answer_tokens_length-1])

            if len(answer_pos) == 0:
                pos_idx.append([self.ignore_index, self.ignore_index])

            else:
                answer_pos = random.choice(answer_pos)  # To add variability, pick a random correct span.
                pos_idx.append(answer_pos)

        start_idxs = torch.LongTensor([idx[0] for idx in pos_idx]).to(self.model.device)
        end_idxs = torch.LongTensor([idx[1] for idx in pos_idx]).to(self.model.device)

        return start_idxs, end_idxs

    def get_answer_from_model_output(self, input_tokens, outputs):
        predicted_start_idxs = torch.argmax(outputs.start_logits, axis=1)
        predicted_end_idxs = torch.argmax(outputs.end_logits, axis=1)

        predicted_answers = [self.processor.tokenizer.decode(input_tokens[batch_idx][predicted_start_idxs[batch_idx]: predicted_end_idxs[batch_idx]+1], skip_special_tokens=True).strip() for batch_idx in range(len(input_tokens))]
        # answers_conf = ((outputs.start_logits.max(dim=1).values + outputs.end_logits.max(dim=1).values) / 2).tolist()

        start_logits = outputs.start_logits.softmax(dim=1).detach().cpu()
        end_logits = outputs.end_logits.softmax(dim=1).detach().cpu()
        answ_confidence = []
        for batch_idx in range(len(input_tokens)):
            conf_mat = np.matmul(np.expand_dims(start_logits[batch_idx].unsqueeze(dim=0), -1),
                                 np.expand_dims(end_logits[batch_idx].unsqueeze(dim=0), 1)).squeeze(axis=0)

            answ_confidence.append(
                conf_mat[predicted_start_idxs[batch_idx], predicted_end_idxs[batch_idx]].item()
            )

        answ_confidence = get_extractive_confidence(outputs)

        return predicted_answers, answ_confidence

In [17]:
args_dict = {
    # "model": "hy",
    "dataset": "infographics",
    "eval_start": True,
    "no_eval_start": False,
    "page_retrieval": None,
    "batch_size": None,
    "max_sequence_length": None,
    "seed": 42,
    "save_dir": "saving_dir/",
    "apply_ocr": False,
    "data_parallel": False,
    "no_data_parallel": False,
    "model_name": "hy",
    "model_weights": "microsoft/layoutlmv3-base",
    "device": "cuda",
    # "training_parameters": {
    "lr": 1e-4,
    "batch_size": 20,
    "train_epochs": 10,
    "warmup_iterations": 5,
    # },
    "dataset_name": "infographicVQA",
    # "imdb_dir": "./task3/imdb",
    # "images_dir": "./task3/images",
    "imdb_dir": "./Task3_test/imdb",
    "images_dir": "./Task3_test/images"
}


In [18]:
config = args_dict
# config.pop('model')
model_name = config['model_name'].lower()
if 'save_dir' in config:
    if not config['save_dir'].endswith('/'):
        config['save_dir'] = config['save_dir'] + '/'

    if not os.path.exists(config['save_dir']):
        os.makedirs(config['save_dir'])

# if 'seed' not in config:
#     print("Seed not specified. Setting default seed to '{:d}'".format(42))
#     config['seed'] = 42

model = LayoutLMv3_hy(config)

if config['device'] == 'cuda' and config['data_parallel'] and torch.cuda.device_count() > 1:
    model.parallelize()

model.model.to(config['device'])

Some weights of LayoutLMv3ForQuestionAnswering were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['qa_outputs.dense.bias', 'qa_outputs.dense.weight', 'qa_outputs.out_proj.bias', 'qa_outputs.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LayoutLMv3ForQuestionAnswering(
  (layoutlmv3): LayoutLMv3Model(
    (embeddings): LayoutLMv3TextEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (x_position_embeddings): Embedding(1024, 128)
      (y_position_embeddings): Embedding(1024, 128)
      (h_position_embeddings): Embedding(1024, 128)
      (w_position_embeddings): Embedding(1024, 128)
    )
    (patch_embed): LayoutLMv3PatchEmbeddings(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (encoder): LayoutLMv3Encoder(


In [19]:
train(model, **config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbigchoi3449[0m ([33mlevel2-cv-10-detection[0m). Use [1m`wandb login --relogin`[0m to force relogin


hy: microsoft/layoutlmv3-base 
{
	Dataset: infographicVQA
	Visual Encoder: -
	Batch size: 20
	Max. Seq. Length: None
	lr: 0.0001
	seed: 42
}

Model parameters: 125919106 - Trainable: 125919106 (100.00%)




done


  return torch.tensor(value)
100%|██████████| 2/2 [00:06<00:00,  3.18s/it]


Epoch -1: Accuracy 0.00     ANLS 0.0000    Retrieval precision: 0.00%


100%|██████████| 8/8 [01:04<00:00,  8.10s/it]
100%|██████████| 2/2 [00:04<00:00,  2.33s/it]


Epoch 0: Accuracy 0.00     ANLS 0.0000    Retrieval precision: 0.00%


100%|██████████| 8/8 [01:06<00:00,  8.27s/it]
100%|██████████| 2/2 [00:04<00:00,  2.30s/it]


Epoch 1: Accuracy 0.00     ANLS 0.0000    Retrieval precision: 0.00%


100%|██████████| 8/8 [01:05<00:00,  8.20s/it]
100%|██████████| 2/2 [00:04<00:00,  2.27s/it]


Epoch 2: Accuracy 0.00     ANLS 0.0000    Retrieval precision: 0.00%


100%|██████████| 8/8 [01:05<00:00,  8.14s/it]
100%|██████████| 2/2 [00:04<00:00,  2.25s/it]


Epoch 3: Accuracy 0.00     ANLS 0.0000    Retrieval precision: 0.00%


 75%|███████▌  | 6/8 [01:03<00:21, 10.54s/it]


KeyboardInterrupt: 