## Data Setup

In [1]:
task_que = '(conversation)'
system_message = "You are an expert in generating realistic 2-party dialogs."
# system_message = ""

### Fetch Tweets

In [2]:
import pandas as pd

data_path = "../datasets/sample-tweets.csv"
df = pd.read_csv(data_path)#, delimiter='\t')

### Build User profile

In [3]:
batch_size = 16

In [4]:
import json
import pathlib

expt_name = "expt3_1"

expt_setup = json.loads(
    pathlib.Path(
        f'../datasets/inst-templates/{expt_name}/expt-setup.json'
    ).read_bytes()
)

raw_user_profiles = expt_setup['user_profiles']

In [5]:
expt_setup['llm']

'meta-llama/Llama-2-70b-chat-hf'

In [6]:
import random

# These are a list of pairings of the user's motivations and dispositions
user_profiles = []

for motivation in raw_user_profiles:
    for disposition in motivation['dispositions']:
        user_profiles.append(
            (
                pathlib.Path(motivation['path']).read_text(),
                pathlib.Path(disposition['path']).read_text()
            )
        )

### Fetch instruction template

In [7]:
inst_path = pathlib.Path(f'../datasets/inst-templates/{expt_name}/conv-gen.txt')
inst = inst_path.read_text()

## Batch Generate Conversations

In [8]:
output_dir = pathlib.Path(f"../datasets/inst-templates/{expt_name}/conversations")

In [9]:
import torch # Incase you want to load in half precision: `torch.float16`
from generation_utils import BatchTextGenerator

text_gen = BatchTextGenerator(expt_setup['llm'], load_in_4bit=True)

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [11]:
import json
from tqdm import tqdm


df_with_progbar = tqdm(df.groupby(df.index//batch_size), desc="Processing batch")

for batch_idx, batch_df in df_with_progbar:
    idxes = list(batch_df.index)
    tweet_ids = list(batch_df['tweet_id'])    
    batch_user_profiles = random.choices(
        user_profiles, k=batch_size
    )

    instructions = [
        inst.format(disp, motv, tweet)
            for (motv, disp), (_, _, tweet) in zip(
                batch_user_profiles,
                batch_df.itertuples()
            )
    ]

    batch_op = text_gen(system_message, task_que, instructions)

    # Build output jsons
    for idx, id, convo, (_, _, tweet), (motv, disp) in zip(
        idxes, tweet_ids, batch_op,
        batch_df.itertuples(),
        batch_user_profiles
    ):
        convo_json = {
            'tweet_id': id,
            'disposition': disp,
            'motivation': motv,
            'tweet': tweet,
            'conv': convo
        }

        convo_file = output_dir.joinpath(f'{idx}.json')
        convo_file.write_text(json.dumps(convo_json))


Processing batch: 100%|██████████| 1/1 [17:23<00:00, 1043.31s/it]
