<a href="https://colab.research.google.com/github/harshalDharpure/Multimodality_Hateful_Meme/blob/main/Dataset_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Modified for Hindi datasets Prompthate code

In [None]:
import os
import json
import pickle as pkl
import numpy as np
import torch
from tqdm import tqdm

class Multimodal_Data:
    def __init__(self, opt, tokenizer, dataset, mode='train', few_shot_index=0):
        super(Multimodal_Data, self).__init__()
        self.opt = opt
        self.tokenizer = tokenizer
        self.mode = mode
        if self.opt.FEW_SHOT:
            self.few_shot_index = str(few_shot_index)
            self.num_shots = self.opt.NUM_SHOTS
            print('Few shot learning setting for Iteration:', self.few_shot_index)
            print('Number of shots:', self.num_shots)

        self.num_ans = self.opt.NUM_LABELS
        # Maximum length for a single sentence
        self.length = self.opt.LENGTH
        # Maximum length of the concatenation of sentences
        self.total_length = self.opt.TOTAL_LENGTH
        self.num_sample = self.opt.NUM_SAMPLE
        self.add_ent = self.opt.ADD_ENT
        self.add_dem = self.opt.ADD_DEM
        print('Adding entity information?', self.add_ent)
        print('Adding demographic information?', self.add_dem)

        if self.opt.FINE_GRIND:
            self.label_mapping_word = {0: self.opt.POS_WORD, 1: self.opt.NEG_WORD}
        else:
            self.label_mapping_word = {0: self.opt.POS_WORD, 1: self.opt.NEG_WORD}

        self.label_mapping_id = {}
        for label in self.label_mapping_word.keys():
            mapping_word = self.label_mapping_word[label]
            self.label_mapping_id[label] = self.tokenizer._convert_token_to_id(
                self.tokenizer.tokenize(' ' + mapping_word)[0])
            print('Mapping for label %d, word %s, index %d' % (label, mapping_word, self.label_mapping_id[label]))

        self.template = "*<s>**sent_0*.*_It_was*label_**</s>*"
        self.template_list = self.template.split('*')
        print('Template:', self.template)
        print('Template list:', self.template_list)

        self.special_token_mapping = {
            '<s>': self.tokenizer.convert_tokens_to_ids('<s>'),
            '<mask>': self.tokenizer.mask_token_id,
            '<pad>': self.tokenizer.pad_token_id,
            '</s>': self.tokenizer.convert_tokens_to_ids('<\s>')
        }

        if self.opt.DEM_SAMP:
            print('Using demonstration sampling strategy...')
            self.img_rate = self.opt.IMG_RATE
            self.text_rate = self.opt.TEXT_RATE
            self.samp_rate = self.opt.SIM_RATE
            print('Image rate for measuring CLIP similarity:', self.img_rate)
            print('Text rate for measuring CLIP similarity:', self.text_rate)
            print('Sampling from top:', self.samp_rate * 100.0, 'examples')
            self.clip_clean = self.opt.CLIP_CLEAN
            clip_path = os.path.join(self.opt.CAPTION_PATH, dataset, dataset + '_sim_scores.pkl')
            print('Clip feature path:', clip_path)
            self.clip_feature = pkl.load(open(clip_path, 'rb'))

        self.support_examples = self.load_entries('train')
        print('Length of supporting examples:', len(self.support_examples))
        self.entries = self.load_entries(mode)
        if self.opt.DEBUG:
            self.entries = self.entries[:128]
        self.prepare_exp()
        print('The length of the dataset for:', mode, 'is:', len(self.entries))

    def load_entries(self, mode):
        path = os.path.join(self.opt.DATA, 'domain_splits', self.opt.DATASET + '_' + mode + '.json')
        data = json.load(open(path, 'rb'))
        cap_path = os.path.join(self.opt.CAPTION_PATH, self.opt.DATASET + '_' + self.opt.PRETRAIN_DATA,
                               self.opt.IMG_VERSION + '_captions.pkl')
        captions = pkl.load(open(cap_path, 'rb'))
        entries = []
        for k, row in enumerate(data):
            label = row['label']
            img = row['img']
            cap = captions[img.split('.')[0]][:-1]  # remove the punctuation at the end
            sent = row['clean_sent']
            # remember the punctuations at the end of each sentence
            cap = cap + ' . ' + sent + ' . '
            # whether using external knowledge
            if self.add_ent:
                cap = cap + ' . ' + row['entity'] + ' . '
            if self.add_dem:
                cap = cap + ' . ' + row['race'] + ' . '
            entry = {
                'cap': cap.strip(),
                'label': label,
                'img': img
            }
            entries.append(entry)
        return entries

    def enc(self, text):
        return self.tokenizer.encode(text, add_special_tokens=False)

    def prepare_exp(self):
        support_indices = list(range(len(self.support_examples))
        self.example_idx = []
        for sample_idx in tqdm(range(self.num_sample)):
            for query_idx in range(len(self.entries)):
                if self.opt.DEM_SAMP:
                    candidates = [support_idx for support_idx in support_indices
                                  if support_idx != query_idx or self.mode != "train"]
                    sim_score = []
                    count_each_label = {label: 0 for label in range(self.opt.NUM_LABELS}
                    context_indices = []
                    clip_info_que = self.clip_feature[self.entries[query_idx]['img']]
                    for support_idx in candidates:
                        img = self.support_examples[support_idx]['img']
                        if self.clip_clean:
                            img_sim = clip_info_que['clean_img'][img]
                        else:
                            img_sim = clip_info_que['img'][img]
                        text_sim = clip_info_que['text'][img]
                        total_sim = self.img_rate * img_sim + self.text_rate * text_sim
                        sim_score.append((support_idx, total_sim))
                    sim_score.sort(key=lambda x: x[1], reverse=True)
                    num_valid = int(len(sim_score) // self.opt.NUM_LABELS * self.samp_rate)
                    for support_idx, score in sim_score:
                        cur_label = self.support_examples[support_idx]['label']
                        if count_each_label[cur_label] < num_valid:
                            count_each_label[cur_label] += 1
                            context_indices.append(support_idx)
                else:
                    context_indices = [support_idx for support_idx in support_indices
                                       if support_idx != query_idx or self.mode != "train"]
                self.example_idx.append((query_idx, context_indices, sample_idx))

    def select_context(self, context_examples):
        max_demo_per_label = 1
        counts = {k: 0 for k in range(self.opt.NUM_LABELS)}
        if self.opt.DEBUG:
            print('Number of context examples available:', len(context_examples))
        order = np.random.permutation(len(context_examples))
        selection = []
        for i in order:
            label = context_examples[i]['label']
            if num_labels == 1:
                # Regression
                #No implementation currently
                label = '0' if\
                float(label) <= median_mapping[self.args.task_name] else '1'
                        if counts[label] < max_demo_per_label:
                selection.append(context_examples[i])
                counts[label] += 1
             if sum(counts.values()) == len(counts) * max_demo_per_label:
                break

        assert len(selection) > 0
        return selection

    def process_prompt(self, examples, first_sent_limit, other_sent_limit):
        if self.fine_grind:
            prompt_arch = ' It was targeting '
        else:
            prompt_arch = ' It was '
        input_ids = []
        attention_mask = []
        mask_pos = None  # Position of the mask token
        concat_sent = ""
        for segment_id, ent in enumerate(examples):
            new_tokens = []
            if segment_id == 0:
                new_tokens.append(self.special_token_mapping['<s>'])
                length = first_sent_limit
                temp = prompt_arch + '<mask>' + ' . </s>'
            else:
                length = other_sent_limit
                if self.fine_grind:
                    label_word = self.label_mapping_word[ent['label']]
                else:
                    label_word = self.label_mapping_word[ent['label']]
                temp = prompt_arch + label_word + ' . </s>'
            new_tokens += self.enc(' ' + ent['cap'])
            new_tokens = new_tokens[:length]
            new_tokens += self.enc(temp)
            whole_sent = ' ' + ent['cap'] + temp
            concat_sent += whole_sent

            input_ids += new_tokens
            attention_mask += [1 for i in range(len(new_tokens)]

        while len(input_ids) < self.total_length:
            input_ids.append(self.special_token_mapping['<pad>'])
            attention_mask.append(0)
        if len(input_ids) > self.total_length:
            input_ids = input_ids[:self.total_length]
            attention_mask = attention_mask[:self.total_length]
        mask_pos = [input_ids.index(self.special_token_mapping['<mask>'])]

        assert mask_pos[0] < self.total_length
        result = {'input_ids': input_ids,
                  'sent': '<s>' + concat_sent,
                  'attention_mask': attention_mask,
                  'mask_pos': mask_pos}
        return result

    def __getitem__(self, index):
        entry = self.entries[index]
        query_idx, context_indices, bootstrap_idx = self.example_idx[index]
        supports = self.select_context([self.support_examples[i] for i in context_indices])
        exps = []
        exps.append(entry)
        exps.extend(supports)
        prompt_features = self.process_prompt(
            exps,
            self.length,
            self.length
        )

        vid = entry['img']
        label = torch.tensor(entry['label'])
        target = torch.from_numpy(np.zeros((self.num_ans), dtype=np.float32))
        target[label] = 1.0

        cap_tokens = torch.Tensor(prompt_features['input_ids'])
        mask_pos = torch.LongTensor(prompt_features['mask_pos'])
        mask = torch.Tensor(prompt_features['attention_mask'])
        batch = {
            'sent': prompt_features['sent'],
            'mask': mask,
            'img': vid,
            'target': target,
            'cap_tokens': cap_tokens,
            'mask_pos': mask_pos,
            'label': label
        }
        return batch

    def __len__(self):
        return len(self.entries)
