In [1]:
from datasets import load_dataset
import pandas as pd
from feature.selector import Selective, SelectionMethod
from textwiser import TextWiser, Embedding, Transformation
import datasets

In [20]:
dataset = load_dataset("takojunior/llama_2_finetune")

In [21]:
dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 600210
    })
    validation: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 200
    })
})

In [14]:
df = dataset.to_pandas()


In [16]:
df.fillna("", inplace=True)
df.head()

Unnamed: 0,instruction,input,output
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin A...",Virgin Australia commenced services on 31 Augu...
1,Which is a species of fish? Tope or Rope,,Tope
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them...
3,"Alice's parents have three daughters: Amy, Jes...",,The name of the third daughter is Alice
4,When was Tomoaki Komorida born?,Komorida was born in Kumamoto Prefecture on Ju...,"Tomoaki Komorida was born on July 10,1981."


In [17]:
json_records = df.to_json(orient ='records', indent=2) 
parsed = json.loads(json_records)

In [22]:
parsed[0:10]

[{'instruction': 'When did Virgin Australia start operating?',
  'input': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
  'output': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.'},
 {'instruction': 'Which is a species of fish? Tope or Rope',
  'input': '',
  'output': 'Tope'},
 {'instruction': 'Why can camels survive for long without water?',
  'input': '',
  'output': 'Camels use the fat in their humps to keep them filled with energy and hydration for long periods 

In [28]:
len(df)

300105

In [26]:
args={}
args['selection_percentage'] = 0.05

In [27]:
def get_selected_data(df, args, text_column="text", label_column="category"):
    num_rows = len(df)
    df[text_column] = df["instruction"] + " " + df["input"] + " " + df["output"]
    df_T = df.loc[:, text_column].to_frame().T
    
    labels = pd.get_dummies(df[label_column], dtype=int)
    labels.columns = ["label_" + str(i) for i in range(1, len(labels.columns)+1)] 
    labels_T = labels.T
    
    # TextWiser featurization method to create text embeddings
    textwiser = TextWiser(Embedding.TfIdf(), Transformation.NMF(n_components=20))

    # Text-based selection
    # The goal is to select a subset of articles 
    # that is most diverse in the text embedding space of articles
    # and covers the most labels in each topic
    selector = Selective(SelectionMethod.TextBased(num_features=round(num_rows*args["selection_percentage"]), 
                                                   featurization_method=textwiser,
                                                   optimization_method='kmeans'))

    # Feature reduction
    subset = selector.fit_transform(df_T, labels_T)
    
    return df.loc[subset.columns, ["instruction", "input", "output"]]

In [29]:
df['task_name'] = "dummy_task"
df.head()

Unnamed: 0,instruction,input,output,task_name
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin A...",Virgin Australia commenced services on 31 Augu...,dummy_task
1,Which is a species of fish? Tope or Rope,,Tope,dummy_task
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them...,dummy_task
3,"Alice's parents have three daughters: Amy, Jes...",,The name of the third daughter is Alice,dummy_task
4,When was Tomoaki Komorida born?,Komorida was born in Kumamoto Prefecture on Ju...,"Tomoaki Komorida was born on July 10,1981.",dummy_task


In [30]:
selected_df = get_selected_data(df, args, text_column="text", label_column="task_name") # label_column is not used
selected_df.to_csv("../data/selected_data_from_300k.csv", index=False)

In [31]:
selected_df.head(200).to_csv("../data/validation_selected_data_from_300k.csv", index=False)