In [1]:
from minichat.common import get_base_dir
import os
import fcntl
import urllib.request
import tempfile, zipfile
import shutil
import hashlib
import yaml
import csv
import json

In [2]:
def download_file_with_lock(url, filename, postprocess_fn=None):
    """
    Downloads a file from a URL to a local path in the base directory.
    Uses a lock file to prevent concurrent downloads among multiple ranks.
    """
    base_dir = get_base_dir()
    file_path = os.path.join(base_dir, filename)
    lock_path = file_path + ".lock"

    if os.path.exists(file_path):
        return file_path

    with open(lock_path, 'w', encoding='utf-8') as lock_file:

        # Only a single rank can acquire this lock
        # All other ranks block until it is released
        fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)

        # Recheck after acquiring lock (another process may have downloaded it)
        if os.path.exists(file_path):
            return file_path

        # Download the content as bytes
        print(f"Downloading {url}...")
        with urllib.request.urlopen(url) as response:
            content = response.read() # bytes

        # Write to local file
        with open(file_path, 'wb') as f:
            f.write(content)
        print(f"Downloaded to {file_path}")

        # Run the postprocess function if provided
        if postprocess_fn is not None:
            postprocess_fn(file_path)

    # Clean up the lock file after the lock is released
    try:
        os.remove(lock_path)
    except OSError:
        pass  # Ignore if already removed by another process

def place_eval_bundle(file_path):
    # here file_path is the path to the eval_bundle.zip file
    # we need to unzip it and place it in the base directory
    base_dir = get_base_dir()
    eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
    with tempfile.TemporaryDirectory() as tmpdir:
        with zipfile.ZipFile(file_path, 'r') as zip_ref:
            zip_ref.extractall(tmpdir)
        extracted_bundle_dir = os.path.join(tmpdir, "eval_bundle")
        shutil.move(extracted_bundle_dir, eval_bundle_dir)
    print(f"Placed eval_bundle directory at {eval_bundle_dir}")

In [3]:
EVAL_BUNDLE_URL = "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip"

In [4]:
download_file_with_lock(EVAL_BUNDLE_URL, "eval_bundle.zip", postprocess_fn=place_eval_bundle)

'/home/shadeform/minichat/minichat/eval_bundle.zip'

In [5]:
base_dir = get_base_dir()
eval_bundle_dir = os.path.join(base_dir, "eval_bundle")
config_path = os.path.join(eval_bundle_dir, "core.yaml")
data_base_path = os.path.join(eval_bundle_dir, "eval_data")
eval_meta_data = os.path.join(eval_bundle_dir, "eval_meta_data.csv")

with open(config_path, 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

In [6]:
tasks = config['icl_tasks']

In [7]:
random_baselines = {}
with open(eval_meta_data, 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        task_name = row['Eval Task']
        random_baseline = row['Random baseline']
        random_baselines[task_name] = float(random_baseline)

In [8]:
task = tasks[0]  # Example: get the first task

label = task['label']
task_meta = {
    'task_type': task['icl_task_type'],
    'dataset_uri': task['dataset_uri'],
    'num_fewshot': task['num_fewshot'][0],
    'continuation_delimiter': task.get('continuation_delimiter', ' ')
}
        

In [9]:
label, task_meta

('hellaswag_zeroshot',
 {'task_type': 'multiple_choice',
  'dataset_uri': 'language_understanding/hellaswag.jsonl',
  'num_fewshot': 0,
  'continuation_delimiter': ' '})

In [10]:
data_path = os.path.join(data_base_path, task_meta['dataset_uri'])
with open(data_path, 'r', encoding='utf-8') as f:
    data = [json.loads(line.strip()) for line in f.readlines()]

In [11]:
import random
import torch

max_per_task = 10

shuffle_rng = random.Random(1337)
shuffle_rng.shuffle(data)
data = data[:max_per_task]


correct = torch.zeros(len(data))
for i in range(len(data)):
    example = data[i]
    task_type = task_meta['task_type']
    num_fewshot = task_meta['num_fewshot']
    continuation_delimiter = task_meta['continuation_delimiter']

    fewshot_examples = []
    if num_fewshot > 0:
        rng = random.Random(42 + i)  # Seed with a combination of a constant and the example index
        available_indices = [idx for idx in range(len(data)) if idx != i]
        fewshot_indices = rng.sample(available_indices, num_fewshot)
        fewshot_examples = [data[idx] for idx in fewshot_indices]         
    


In [12]:
from jinja2 import Template

template_str = """
{%- for ex in fewshot_examples %}
{{ex.query}}{{ continuation_delimiter }}{{ ex.choices[ex.gold] }}

{%- endfor %}
{{item.query}}{{ continuation_delimiter }}{{choice}}
""".strip() 

template = Template(template_str)

In [13]:
prompts = [template.render(
    fewshot_examples=fewshot_examples,
    item=example,
    choice=choice,
    continuation_delimiter=continuation_delimiter)
    for choice in example['choices']]

In [14]:
from minichat.tokenizer import get_tokenizer
tokenizer = get_tokenizer()

In [15]:
def find_common_length(prompts, direction='left'):
    """
    Finds the common length of tokenized prompts.
    direction: 'left' or 'right' - indicates which side to consider for commonality
    """
    tokenized_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    if direction == 'left':
        min_length = min(len(tokens) for tokens in tokenized_prompts)
        common_length = 0
        for i in range(min_length):
            current_token = tokenized_prompts[0][i]
            if all(tokens[i] == current_token for tokens in tokenized_prompts):
                common_length += 1
            else:
                break
        return common_length
    elif direction == 'right':
        min_length = min(len(tokens) for tokens in tokenized_prompts)
        common_length = 0
        for i in range(1, min_length + 1):
            current_token = tokenized_prompts[0][-i]
            if all(tokens[-i] == current_token for tokens in tokenized_prompts):
                common_length += 1
            else:
                break
        return common_length
    else:
        raise ValueError("Direction must be 'left' or 'right'")

In [19]:
tokens = tokenizer.encode(prompts, prepand=tokenizer.get_bos_token_id())
start_indices = find_common_length(prompts, direction='left')
end_indices = [len(p) for p in prompts]

In [20]:
start_indices

59

In [23]:
def start_sequences(tokens, pad_token_id):
    bsz, seq_len = len(tokens), max(len(t) for t in tokens)
    input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
    for i, x in enumerate(tokens):
        input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
    return input_ids

In [None]:
# test start_sequences
vocab_size = 256
  
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))

In [None]:
max_token_len = 1024
new_tokens, new_start_ids, new_end_ids = [], [], []

for t, s, e in zip(tokens, start_indices, end_indices):
    if len(t) > max_token_len:
        num_to_crop = len(t) - max_token_len
        new_tokens.append(t[-max_token_len:])
        new_start_ids.append(s-num_to_crop)
        new_end_ids.append(e-num_to_crop)
    else:
        new_tokens.append(t)
        new_start_ids.append(s)
        new_end_ids.append(e)
    
tokens, start_ids, end_ids = new_tokens, new_start_ids, new_end_ids

pad_token_id = tokenizer.get_bos_token_id()
input_ids = 

    