In [3]:
clue = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/clue'

In [2]:
ust = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/ust'

In [None]:
task = mpqa

In [3]:
file_path = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/clue/mpqa/10-1/train.csv'

In [6]:
"""Dataset utils for different data settings for GLUE."""

import os
import copy
import logging
import torch
import numpy as np
import time
from filelock import FileLock
import json
import itertools
import random
import transformers
from transformers.data.processors.utils import InputFeatures
from transformers import DataProcessor, InputExample
from transformers.data.processors.glue import *
from transformers.data.metrics import glue_compute_metrics
import dataclasses
from dataclasses import dataclass, asdict
from typing import List, Optional, Union
from sentence_transformers import SentenceTransformer, util
from copy import deepcopy
import pandas as pd
import logging

logger = logging.getLogger(__name__)

class MrpcProcessor(DataProcessor):
    """Processor for the MRPC data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")


    def get_un_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "un_train.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[3]
            text_b = line[4]
            label = line[0]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class MnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_large_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train_large.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train_large.tsv")), "train_large")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "un_train.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")

    def get_labels(self):
        """See base class."""
        return ["contradiction", "entailment", "neutral"]
    def get_mappings(self):
        return {'contradiction': 'No', 'entailment': 'Yes', 'neutral': 'Maybe'}

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[8].lower()
            text_b = line[9].lower()
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class MnliMismatchedProcessor(MnliProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")


class CLUE_MnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_clue_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_large_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train_large.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train_large.tsv")), "train_large")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "un_train.tsv")))
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
      
        return self._create_clue_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")

    def get_labels(self):
        """See base class."""
        return ["contradiction", "entailment", "neutral"]

    def get_mappings(self):
        return {'contradiction': 'No', 'entailment': 'Yes', 'neutral': 'Maybe'}

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])

            text_a = line[8].lower()
            text_b = line[9].lower()
          
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

    def _create_clue_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""

        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[-3].lower()
            text_b = line[-2].lower()
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples



class CLUE_MnliMismatchedProcessor(CLUE_MnliProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def _create_clue_test_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""

        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1].lower()
            text_b = line[2].lower()
            label = line[4]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")



class CLUE_MnliClueProcessor(CLUE_MnliProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def _create_clue_test_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""

        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1].lower()
            text_b = line[2].lower()
            label = line[4]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_clue_test_examples(self._read_tsv(os.path.join(data_dir, "test_clue.tsv")), "test_mismatched")


class SnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["contradiction", "entailment", "neutral"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[7]
            text_b = line[8]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class ColaProcessor(DataProcessor):
    """Processor for the CoLA data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        test_mode = set_type == "test"
        text_index = 3
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line[text_index]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


class Sst2Processor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        text_index = 0
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[text_index]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples



class CLUE_Sst2Processor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_CLUE_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_CLUE_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        text_index = 0
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[text_index]
            label = line[1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

    def _create_CLUE_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):

            guid = "%s-%s" % (set_type, i)
            text_a = line[-2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


class CLUE_Sst2ClueProcessor(CLUE_Sst2Processor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def _create_CLUE_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line[-2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_CLUE_examples(self._read_tsv(os.path.join(data_dir, "test_clue.tsv")), "test_clue")




class StsbProcessor(DataProcessor):
    """Processor for the STS-B data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return [None]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[7]
            text_b = line[8]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class QqpProcessor(DataProcessor):
    """Processor for the QQP data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question1"].numpy().decode("utf-8"),
            tensor_dict["question2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        test_mode = set_type == "test"
        q1_index = 3
        q2_index = 4
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            try:
                text_a = line[q1_index]
                text_b = line[q2_index]
                label = line[5]
            except IndexError:
                continue
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class CLUE_QqpProcessor(QqpProcessor):
    """Processor for the QQP data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question1"].numpy().decode("utf-8"),
            tensor_dict["question2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )


    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_clue.tsv")), "test-clue")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        test_mode = set_type == "test"
        q1_index = 3
        q2_index = 4
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            try:
                text_a = line[q1_index]
                text_b = line[q2_index]
                label = line[5]
            except IndexError:
                continue
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class QnliProcessor(DataProcessor):
    """Processor for the QNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question"].numpy().decode("utf-8"),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class RteProcessor(DataProcessor):
    """Processor for the RTE data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class CLUE_RteProcessor(DataProcessor):
    """Processor for the RTE data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples



class CLUE_RteClueProcessor(CLUE_RteProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_clue.tsv")), "test_clue")


class WnliProcessor(DataProcessor):
    """Processor for the WNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "un_train.tsv")), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, line[0])
            text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

class TextClassificationProcessor(DataProcessor):
    """
    Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa).
    """

    def __init__(self, task_name):
        self.task_name = task_name 

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )
  
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), header=None).values.tolist(), "train")

    def get_un_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(pd.read_csv(os.path.join(data_dir, "un_train.csv"), header=None).values.tolist(), "un_train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), header=None).values.tolist(), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), header=None).values.tolist(), "test")

    def get_labels(self):
        """See base class."""
        if self.task_name == "mr":
            return list(range(2))
        elif self.task_name == "sst-5":
            return list(range(5))
        elif self.task_name == "subj":
            return list(range(2))
        elif self.task_name == "trec":
            return list(range(6))
        elif self.task_name == "cr":
            return list(range(2))
        elif self.task_name == "mpqa":
            return list(range(2))
        else:
            raise Exception("task_name not supported.")
        
    def _create_examples(self, lines, set_type):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            if self.task_name == "ag_news":
                examples.append(InputExample(guid=guid, text_a=line[1] + '. ' + line[2], short_text=line[1] + ".", label=line[0]))
            elif self.task_name == "yelp_review_full":
                examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0]))
            elif self.task_name == "yahoo_answers":
                text = line[1]
                if not pd.isna(line[2]):
                    text += ' ' + line[2]
                if not pd.isna(line[3]):
                    text += ' ' + line[3]
                examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0])) 
            elif self.task_name in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
                examples.append(InputExample(guid=guid, text_a=line[1], label=line[0]))
            else:
                raise Exception("Task_name not supported.")

        return examples


class CLUE_TextClassificationProcessor(TextClassificationProcessor):
    """
    Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa).
    """


    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(pd.read_csv(os.path.join(data_dir, "test_clue.csv"), header=None).values.tolist(),
                                     "test-clue")


def text_classification_metrics(task_name, preds, labels):
    return {"acc": (preds == labels).mean()}

# Add your task to the following mappings

processors_mapping = {
    "cola": ColaProcessor(),
    "mnli": MnliProcessor(),
    "mnli-mm": MnliMismatchedProcessor(),
    "mrpc": MrpcProcessor(),
    "sst-2": Sst2Processor(),
    "sts-b": StsbProcessor(),
    "qqp": QqpProcessor(),
    "qnli": QnliProcessor(),
    "rte": RteProcessor(),
    "wnli": WnliProcessor(),
    "snli": SnliProcessor(),
    "mr": TextClassificationProcessor("mr"),
    "sst-5": TextClassificationProcessor("sst-5"),
    "subj": TextClassificationProcessor("subj"),
    "trec": TextClassificationProcessor("trec"),
    "cr": TextClassificationProcessor("cr"),
    "mpqa": TextClassificationProcessor("mpqa")
}

CLUE_processors_mapping = {
    "cola": ColaProcessor(),
    "mnli": CLUE_MnliProcessor(),
    "mnli-clue": CLUE_MnliClueProcessor(),
    "mnli-mm": CLUE_MnliMismatchedProcessor(),
    "mrpc": MrpcProcessor(),
    "sst-2": CLUE_Sst2Processor(),
    "sst-2-clue": CLUE_Sst2ClueProcessor(),
    "sts-b": StsbProcessor(),
    "qqp": QqpProcessor(),
    "qqp-clue": CLUE_QqpProcessor(),
    "qnli": QnliProcessor(),
    "rte": CLUE_RteProcessor(),
    "rte-clue": CLUE_RteClueProcessor(),
    "wnli": WnliProcessor(),
    "snli": SnliProcessor(),
    "mr": TextClassificationProcessor("mr"),
    "mr-clue":  CLUE_TextClassificationProcessor("mr"),
    "sst-5": TextClassificationProcessor("sst-5"),
    "subj": TextClassificationProcessor("subj"),
    "subj-clue": CLUE_TextClassificationProcessor("subj"),
    "trec": TextClassificationProcessor("trec"),
    "cr": TextClassificationProcessor("cr"),
    "mpqa": TextClassificationProcessor("mpqa"),
   "mpqa-clue": CLUE_TextClassificationProcessor("mpqa")
}

num_labels_mapping = {
    "cola": 2,
    "mnli": 3,
    "mrpc": 2,
    "sst-2": 2,
    "sts-b": 1,
    "qqp": 2,
    "qnli": 2,
    "rte": 2,
    "wnli": 2,
    "snli": 3,
    "mr": 2,
    "sst-5": 5,
    "subj": 2,
    "trec": 6,
    "cr": 2,
    "mpqa": 2
}

output_modes_mapping = {
    "cola": "classification",
    "mnli": "classification",
    "mnli-mm": "classification",
    "mnli-mm-clue": "classification",
    "mrpc": "classification",
    "sst-2": "classification",
    "sst-2-clue": "classification",
    "sts-b": "regression",
    "qqp": "classification",
    "qqp-clue": "classification",
    "qnli": "classification",
    "rte": "classification",
     "rte-clue": "classification",
    "wnli": "classification",
    "snli": "classification",
    "mr": "classification",
    "mr-clue": "classification",
    "sst-5": "classification",
    "subj": "classification",
    "subj-clue": "classification",
    "trec": "classification",
    "cr": "classification",
    "mpqa": "classification",
    "mpqa-clue": "classification"
}

# Return a function that takes (task_name, preds, labels) as inputs
compute_metrics_mapping = {
    "cola": glue_compute_metrics,
    "mnli": glue_compute_metrics,
    "mnli-clue": text_classification_metrics,
    "mnli-mm": glue_compute_metrics,
    "mrpc": glue_compute_metrics,
    "sst-2": glue_compute_metrics,
    "sst-2-clue": text_classification_metrics,
    "sts-b": glue_compute_metrics,
    "qqp": glue_compute_metrics,
    "qqp-clue": text_classification_metrics,
    "qnli": glue_compute_metrics,
    "rte": glue_compute_metrics,
    "rte-clue": text_classification_metrics,
    "wnli": glue_compute_metrics,
    "snli": text_classification_metrics,
    "mr": text_classification_metrics,
    "mr-clue": text_classification_metrics,
    "sst-5": text_classification_metrics,
    "subj": text_classification_metrics,
    "subj-clue": text_classification_metrics,
    "trec": text_classification_metrics,
    "cr": text_classification_metrics,
    "mpqa": text_classification_metrics,
    "mpqa-clue": text_classification_metrics,
}



# For regression task only: median
median_mapping = {
    "sts-b": 2.5
}

bound_mapping = {
    "sts-b": (0, 5)
}


In [14]:
from pathlib import Path

for  task_name in [ 'QQP', 'RTE']:
    for k in [ 30]:
        for seed in [1,2,3,4,5]:
            
            
            output_dir = clue + '/'+task_name+'/'+str(100)+'-'+str(seed)
            #os.mkdir(output_dir)
            Path(output_dir).mkdir(parents=True, exist_ok=True)
#             try:
#                 os.remove(os.path.join(output_dir, 'test-clue.tsv'))
#             except:
#                 continue

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test', 'test_clue']:
                if mode == 'train':
                    examples = []
                    for k in [30, 35]:
                        if k == 35:
                            seed = [100, 13, 21, 42, 87][seed-1]
                            data_dir = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/k-shot' + '/'+task_name+'/'+str(k)+'-'+str(seed)
                        else:
                            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
                        examples += processor.get_train_examples(data_dir)
                        support_example += examples
                        import pdb
                        pdb.set_trace()
                elif mode == 'un_train':
                    examples = processor.get_un_train_examples(data_dir)
                elif mode == 'test':
                    examples = processor.get_test_examples(data_dir)
                elif mode == 'test_clue':
                    examples = CLUE_processors_mapping[(task_name+'-clue').lower()].get_test_examples(data_dir)
                    
                
                #output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
                import pdb
                pdb.set_trace()
                for example in examples:
                    text_a = example.text_a
                    text_b = example.text_b
                    label = example.label
#                     output_file.write(text_a+'\t')
#                     if text_b is not None:
#                         output_file.write(text_b+'\t')
#                     output_file.write(label+'\n')
#                 output_file.close()

                    
            
        
  
        

> [0;32m<ipython-input-14-e6378c7facb9>[0m(20)[0;36m<module>[0;34m()[0m
[0;32m     18 [0;31m                [0;32mif[0m [0mmode[0m [0;34m==[0m [0;34m'train'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m                    [0mexamples[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m                    [0;32mfor[0m [0mk[0m [0;32min[0m [0;34m[[0m[0;36m30[0m[0;34m,[0m [0;36m35[0m[0;34m][0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m                        [0;32mif[0m [0mk[0m [0;34m==[0m [0;36m35[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m                            [0mseed[0m [0;34m=[0m [0;34m[[0m[0;36m100[0m[0;34m,[0m [0;36m13[0m[0;34m,[0m [0;36m21[0m[0;34m,[0m [0;36m42[0m[0;34m,[0m [0;36m87[0m[0;34m][0m[0;34m[[0m[0mseed[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> r
> [0;32m<i

BdbQuit: 

In [44]:
from pathlib import Path

for  task_name in [ 'MNLI']:
    for k in [10, 20, 30]:
        for seed in [1,2,3,4,5]:
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            output_dir = ust+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test_matched', 'test_mismatched', 'test_clue']:
                if mode == 'train':
                    examples = processor.get_train_examples(data_dir)
                elif mode == 'un_train':
                    examples = processor.get_un_train_examples(data_dir)
                elif mode == 'test_machted':
                    examples = processor.get_test_examples(data_dir)
                elif mode == 'test_mismatched':
                    examples = CLUE_processors_mapping['mnli-mm'].get_test_examples(data_dir)
                elif mode == 'test_clue':
                    examples = CLUE_processors_mapping[(task_name+'-clue').lower()].get_test_examples(data_dir)
                    
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
                for example in examples:
                    text_a = example.text_a
                    text_b = example.text_b
                    label = example.label
                    output_file.write(text_a+'\t')
                    if text_b is not None:
                        output_file.write(text_b+'\t')
                    output_file.write(label+'\n')
                output_file.close()

                    
            
        
  
        

In [36]:
from pathlib import Path

for  task_name in [ 'subj', 'mpqa']:
    for k in [10, 20, 30]:
        for seed in [1,2,3,4,5]:
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            output_dir = ust+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            Path(output_dir).mkdir(parents=True, exist_ok=True)
            try:
                os.remove(os.path.join(output_dir, 'test-clue.csv'))
            except:
                continue

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test', 'test_clue']:
                if mode == 'train':
                    examples = processor.get_train_examples(data_dir)
                elif mode == 'un_train':
                    examples = processor.get_un_train_examples(data_dir)
                elif mode == 'test':
                    examples = processor.get_test_examples(data_dir)
                elif mode == 'test_clue':
                    examples = CLUE_processors_mapping[(task_name+'-clue').lower()].get_test_examples(data_dir)
                output_file = open(os.path.join(output_dir, mode+'.csv'), 'w')
                for example in examples:
                    text_a = example.text_a
                    text_b = example.text_b
                    label = example.label
                    try:
                        output_file.write(text_a+'\t')
                    except:
                        continue
                        print(example)
                    if text_b is not None:
                        output_file.write(text_b+'\t')
                    output_file.write(str(label)+'\n')
                output_file.close()

                    
            
        
  
        

In [12]:
from pathlib import Path

for  task_name in [ 'MNLI']
    for k in [10, 20, 30]:
        for seed in [1,2,3,4,5]:
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            output_dir = ust+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test', 'test-clue']:
                if mode == 'train':
                    examples = processor.get_train_examples(data_dir)
                elif mode == 'un_train':
                    examples = processor.get_un_train_examples(data_dir)
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
                for example in examples:
                    text_a = example.text_a
                    text_b = example.text_b
                    label = example.label
                    output_file.write(text_a+'\t')
                    if text_b is not None:
                        output_file.write(text_b+'\t')
                    output_file.write(label+'\n')
                output_file.close()

                    
            
        
  
        

In [14]:
for example in train_examples:
    print(example.text_b is None)

True
True
True
True
True
True
True
True
True
True


In [21]:
from pathlib import Path

for  task_name in [ 'QQP', 'RTE']:
    for k in [ 30]:
        for seed in [1,2,3,4,5]:
            
            
            output_dir = clue + '/'+task_name+'/'+str(100)+'-'+str(seed)
            #os.mkdir(output_dir)
            Path(output_dir).mkdir(parents=True, exist_ok=True)
#             try:
#                 os.remove(os.path.join(output_dir, 'test-clue.tsv'))
#             except:
#                 continue

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test', 'test_clue']:
                lines = []
                if mode == 'train':
                    examples = []
                    for k in [30, 35]:
                        if k == 35:
                            new_seed = [100, 13, 21, 42, 87][seed-1]
                            new_data_dir = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/k-shot' + '/'+task_name+'/'+str(k)+'-'+str(new_seed)
                            f = open(os.path.join(new_data_dir, 'train.tsv'))
                        else:
                            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
                            
                            f = open(os.path.join(data_dir, 'train.tsv'))
                        if k ==35:
                
                            lines += f.readlines()[1:]
                        else:
                            lines += f.readlines()

                elif mode == 'un_train':
                    
                    f = open(os.path.join(data_dir, 'un_train.tsv'))
                    lines += f.readlines()
                elif mode == 'test':
                    f = open(os.path.join(data_dir, 'test.tsv'))
                    lines += f.readlines()
                elif mode == 'test_clue':
                    f = open(os.path.join(data_dir, 'test_clue.tsv'))
                    lines += f.readlines()
                    
                
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
               
                for line in lines:
                    output_file.write(line)
                output_file.close()
                  
#                     output_file.write(text_a+'\t')
#                     if text_b is not None:
#                         output_file.write(text_b+'\t')
#                     output_file.write(label+'\n')
#                 output_file.close()

                    
            
        
  
        

In [24]:
from pathlib import Path

for  task_name in [ 'MNLI']:
    for k in [30]:
        for seed in [1,2,3,4,5]:
            output_dir = clue + '/'+task_name+'/'+str(100)+'-'+str(seed)
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            #os.mkdir(output_dir)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['un_train', 'test_matched', 'test_mismatched', 'test_clue']:
                
                lines = []

                f = open(os.path.join(data_dir, mode+'.tsv'))
                lines = f.readlines()

                    
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
                for line in lines:
                    
                    output_file.write(line)
                output_file.close()

                    
            
        
  
        

In [None]:
def process_lines(lines):
    

In [None]:
from pathlib import Path
K_shot = 23

for  task_name in [ 'MNLI']:
    for k in [ 30]:
        for seed in [1,2,3,4,5]:
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            output_dir = ust+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train']:
                lines = []
                if mode == 'train':
                    examples = []
                    
                    for k in [30, K_shot]:
                        if k == K_shot:
                            new_seed = [100, 13, 21, 42, 87][seed-1]
                            new_data_dir = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/k-shot' + '/'+task_name+'/'+str(k)+'-'+str(new_seed)
                            f = open(os.path.join(new_data_dir, 'train.tsv'))
                        else:
                            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
                            
                            f = open(os.path.join(data_dir, mode+'.tsv'))
                        if k == K_shot:
                
                            lines += f.readlines()[1:]
                            import pdb
                            pdb.set_trace()
                    
                        else:
                            lines += f.readlines()

                else:
                    f = open(os.path.join(data_dir, mode+'.tsv'))
                    lines = f.readlines()
                    
                    
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
                for line in lines:
                        output_file.write(line)
                output_file.close()

                    
            
        
  
        

In [17]:
def match(subset, original_set, pairwise=True):
    guid_dict = {}
    
    if isinstance(original_set, dict):
        
        set_dict = original_set
        
    else:
        set_dict = {}
    
        for i, example in enumerate(original_set):
            text_a = example.text_a
            guid = example.guid
            text_b = example.text_b
            if pairwise:
                set_dict[text_a.lower()+''+text_b.lower()] = guid
            else:
                set_dict[text_a.lower()] = guid
            
    for example in subset:
        text_a = example.text_a
        text_b = example.text_b
        if pairwise:
            set_key = text_a.lower() + ''+text_b.lower()
        else:
            set_key = text_a.lower()
        if set_key in set_dict:
            guid_dict[set_dict[set_key]] = True
        else:
            print(set_key)
    return guid_dict
        

            
       
            
    
    
    

In [10]:
from pathlib import Path

data_dict = {}

for  task_name in [ 'MNLI']:
    for k in [10, 20, 30]:
        data_dict[k] = {}
        for seed in [1,2,3,4,5]:
            data_dict[k][seed] = {}
            
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            #output_dir = ust+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            #Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            example_list = [] 
            for mode in ['train', 'un_train']:
                if mode == 'train':
                    examples = processor.get_train_examples(data_dir)
                elif mode == 'un_train':
                    examples = processor.get_un_train_examples(data_dir)
                elif mode == 'test_machted':
                    examples = processor.get_test_examples(data_dir)
                elif mode == 'test_mismatched':
                    examples = CLUE_processors_mapping['mnli-mm'].get_test_examples(data_dir)
                elif mode == 'test_clue':
                    examples = CLUE_processors_mapping[(task_name+'-clue').lower()].get_test_examples(data_dir)
                example_list.append(examples)
                    
            data_dict[k][seed] = match(example_list[0], example_list[1])
          

                    
            
        
  
        

KeyboardInterrupt: 

In [29]:
from pathlib import Path

data_dict = {}
keep = False

for  task_name in [ 'MNLI']:
    for k in [30]:
        data_dict[k] = {}
        for seed in [1,2,3,4,5]:
            print(k, seed)
            data_dict[k][seed] = {}
            
            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
            output_dir = clue+ '/'+task_name+'/'+str(k)+'-'+str(seed)
            #Path(output_dir).mkdir(parents=True, exist_ok=True)

            processor = CLUE_processors_mapping[task_name.lower()]
            example_list = [] 
            for mode in ['train', 'un_train']:
                if mode == 'train':
                    examples = processor.get_train_examples(data_dir)
                elif mode == 'un_train':
                    if not keep:
                        examples = processor.get_un_train_examples(data_dir)
                        keep = True
                        set_dict = {}
                        f = open(os.path.join(data_dir, mode+'.tsv'))
                        lines = f.readlines()
                        for i, example in enumerate(examples):
                            try:
                                text_a = example.text_a
                            except:
                                import pdb
                                pdb.set_trace()
                            guid = example.guid
                            text_b = example.text_b
                            
                            set_dict[text_a.lower()+''+text_b.lower()] = guid
                           
                elif mode == 'test_machted':
                    examples = processor.get_test_examples(data_dir)
                elif mode == 'test_mismatched':
                    examples = CLUE_processors_mapping['mnli-mm'].get_test_examples(data_dir)
                elif mode == 'test_clue':
                    examples = CLUE_processors_mapping[(task_name+'-clue').lower()].get_test_examples(data_dir)
                example_list.append(examples)
                
            if keep:
                    
                data_dict[k][seed] = match(example_list[0], set_dict)
            else:
                data_dict[k][seed] = match(example_list[0], example_list[1])
                
                
            
            train_output_file = open(os.path.join(output_dir, 'train_original_format.tsv'), 'w+')
            
            selected_list, remain_list = index_data(data_dict[k][seed], lines)
            
           

            for line in selected_list:
                train_output_file.write(line)
            train_output_file.close()

            un_train_output_file = open(os.path.join(output_dir, 'un_train_original.tsv'), 'w+')
            for line in remain_list:
                un_train_output_file.write(line)
            un_train_output_file.close()

                
            
                
                
            
          

                    
            
        
  
        

30 1
30 2
30 3
30 4
30 5


In [26]:
def index_data(idx_dict, lines):
    selected_list = [] 
    remain_list = []
    new_idx_dict = {}
    for l in idx_dict:
        new_idx_dict[l.split('-')[1]] = idx_dict[l]
    idx_dict = new_idx_dict
    for i, l in enumerate(lines):
        if i == 0:
            selected_list.append(l)
            remain_list.append(l)    
            continue
        idx = l.split('\t')[0]
        if idx in idx_dict:
            selected_list.append(l)
        else:
            remain_list.append(l)
    return selected_list, remain_list
            
        

In [32]:
from pathlib import Path
K_shot= 23
for  task_name in [ 'MNLI']:
    for k in [ 30]:
        for seed in [1,2,3,4,5]:
            
            
            output_dir = clue + '/'+task_name+'/'+str(100)+'-'+str(seed)
            #os.mkdir(output_dir)
            Path(output_dir).mkdir(parents=True, exist_ok=True)
#             try:
#                 os.remove(os.path.join(output_dir, 'test-clue.tsv'))
#             except:
#                 continue

            processor = CLUE_processors_mapping[task_name.lower()]
            for mode in ['train', 'un_train', 'test_matched', 'test_mismatched', 'test_clue']:
                lines = []
                if mode == 'train':
                    examples = []
                    for k in [30, K_shot]:
                        if k == K_shot:
                            new_seed = [100, 13, 21, 42, 87][seed-1]
                            new_data_dir = '/home/t-yaqingwang/Projects/Few-shot-Learning/data/k-shot' + '/'+task_name+'/'+str(k)+'-'+str(new_seed)
                            f = open(os.path.join(new_data_dir, 'train.tsv'))
                        else:
                            data_dir = clue + '/'+task_name+'/'+str(k)+'-'+str(seed)
                            
                            f = open(os.path.join(data_dir, 'train_original_format.tsv'))
                        if k == K_shot:
                
                            lines += f.readlines()[1:]
                        else:
                            lines += f.readlines()

                elif mode == 'un_train':
                    
                    f = open(os.path.join(data_dir, 'un_train_original.tsv'))
                    lines = f.readlines()
                   
                else:
                    f = open(os.path.join(data_dir, mode+'.tsv'))
                    lines = f.readlines()
                    
                
                output_file = open(os.path.join(output_dir, mode+'.tsv'), 'w')
               
                for line in lines:
                    output_file.write(line)
                output_file.close()
                  
#                     output_file.write(text_a+'\t')
#                     if text_b is not None:
#                         output_file.write(text_b+'\t')
#                     output_file.write(label+'\n')
#                 output_file.close()

                    
            
        
  
        