## Data Setup

### Fetch Conversations

In [1]:
import json
import pathlib
from typing import Any, Dict, List

from datautils import ConversationDataset


expt_name = "expt3_1"

conv_dir = f"../datasets/inst-templates/{expt_name}/conversations"
conv_dir = pathlib.Path(conv_dir)

conversation_dataset = ConversationDataset(conv_dir)

### Prompt setup

In [2]:
task_que = '(restyled-conversation)'
system_message = "You are an expert in restyling 2-party dialogs."
# system_message = ""

In [3]:
batch_size = 16

expt_setup = json.loads(
    pathlib.Path(
        f'../datasets/inst-templates/{expt_name}/expt-setup.json'
    ).read_bytes()
)

In [4]:
expt_setup['llm']

'meta-llama/Llama-2-70b-chat-hf'

### Fetch instruction template

In [5]:
inst_path = pathlib.Path(f'../datasets/inst-templates/{expt_name}/conv-restyle.txt')
inst = inst_path.read_text()

## Batch Restyle Conversations

In [6]:
output_dir = pathlib.Path(f"../datasets/inst-templates/{expt_name}/restyled-conversations")
output_dir.mkdir(parents=True, exist_ok=True)

In [7]:
import torch # Incase you want to load in half precision: `torch.float16`
from generation_utils import BatchTextGenerator

# Refer: https://stackoverflow.com/a/77354686/10944913
text_gen = BatchTextGenerator(expt_setup['llm'], load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [8]:
import json
from tqdm import tqdm
from torch.utils.data import DataLoader


convo_dataloader = DataLoader(conversation_dataset, batch_size=batch_size)

convo_dataloader = tqdm(
    convo_dataloader,
    desc="Processing batch"
)

for batch_idx, batch_convos in enumerate(convo_dataloader):
    # Torch Datasets group keys similar to how
    # hf's tokenizers does
    instructions = [
        inst.format(convo) for convo in batch_convos['conv']
    ]

    batch_op = text_gen(system_message, task_que, instructions)

    # Build output jsons
    for idx, (tweet_id, restyled_convo) in enumerate(zip(
        batch_convos["tweet_id"].tolist(), batch_op
    )):
        convo_json = {
            "tweet_id": tweet_id,
            "restyled_conv": restyled_convo
        }

        convo_file = output_dir.joinpath(
            f'{batch_idx*batch_size+idx}.json'
        )
        convo_file.write_text(json.dumps(convo_json))


Processing batch: 100%|██████████| 1/1 [08:08<00:00, 488.55s/it]
