In [None]:
import json
import os

import einops
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import transformers
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM
import tqdm

from result_funcs import get_result_paths

In [None]:
def get_video_df(result_path):
    result_df = pd.read_parquet(result_path, columns=['result', 'args'])
    result_df['return'] = result_df['result'].map(lambda r: r['return'])
    result_df['id'] = result_df['return'].map(lambda r: r['id'] if r and 'id' in r else None)
    video_df = result_df[result_df['id'].map(lambda i: i is not None)]
    
    return video_df

def parse_ids(df):
    df = df.sort_values('id')
    df['bits'] = df['id'].map(lambda i: format(int(i), '064b'))
    # only looking at 1 sequence ID system for now
    df = df[df['bits'].map(lambda b: b[50:56] == '001101')]
    df['bits'] = df['bits'].map(lambda bits: np.array([int(b) for b in bits]))

    # Example IDs
    ids = df['bits'].tolist()
    return ids

def load_data(num_files=None):
    data_dir_path = os.path.join('/', 'mnt', 'bigone', 'bsteel', 'tiktok', 'data')
    result_paths = list(get_result_paths(data_dir_path))
    result_paths = sorted(result_paths)

    seconds_ids = []
    if num_files:
        result_paths = result_paths[:num_files]
    for result_path in tqdm.tqdm(result_paths):
        video_path = result_path.replace('results.parquet.gzip', 'videos.parquet.gzip')
        if not os.path.exists(video_path):
            batch_df = get_video_df(result_path)
            batch_df.to_parquet(video_path, compression='gzip')
        else:
            batch_df = pd.read_parquet(video_path)

        seconds_ids.extend(parse_ids(batch_df))

    
    return seconds_ids

In [None]:
seconds_ids = load_data(num_files=10)

In [None]:
class IDDataset(Dataset):
    def __init__(self, ids, sequence_length=10):
        self.ids = [self.binary_to_tensor(id) for id in ids]
        self.seq_len = sequence_length

    def binary_to_tensor(self, binary_id):
        return torch.tensor([int(bit) for bit in binary_id], dtype=torch.float32)

    def __len__(self):
        return len(self.ids) - self.seq_len

    def __getitem__(self, index):
        return (
            torch.stack(self.ids[index:index+self.seq_len]),
            torch.stack(self.ids[index+1:index+1+self.seq_len])
        )

In [None]:
class SyntheticSequenceWorkerDataset(Dataset):
    def __init__(self, sequence_length=10):
        num_ids = 10000
        sequence_ids = np.zeros(num_ids, dtype=int)
        worker_ids = np.zeros(num_ids, dtype=int)
        for i in range(num_ids):
            sequence_id = int(i % 10)
            worker_id = int((i // 10) % 10)
            sequence_ids[i] = sequence_id
            worker_ids[i] = worker_id
        self._convert_to_id_seq(sequence_ids, worker_ids, sequence_length)

    def _convert_to_id_seq(self, sequence_ids, worker_ids, sequence_length):
        unique_worker_ids = np.unique(worker_ids)
        max_sequence_id = sequence_ids.max()
        worker_id_map = {worker_id: i for worker_id, i in zip(unique_worker_ids, range(max_sequence_id + 1, max_sequence_id + 1 + len(unique_worker_ids)))}
        worker_ids = np.array([worker_id_map[w] for w in worker_ids])
        self.seq = np.stack([sequence_ids, worker_ids])
        self.seq = einops.rearrange(self.seq, 't l -> (l t)')
        
        self.seq_len = sequence_length
        self.vocab_size = sequence_ids.max() + 1 + len(unique_worker_ids)

    def __len__(self):
        return len(self.seq) - self.seq_len

    def __getitem__(self, index):
        return (
            torch.tensor(self.seq[index:index+self.seq_len], dtype=torch.int64)
        )

class SequenceWorkerDataset(SyntheticSequenceWorkerDataset):
    def __init__(self, seconds_ids, sequence_length=10):
        self.num_seconds = len(seconds_ids)
        ids = [id for second in seconds_ids for id in second]
        id_bits = np.array(ids)
        sequence_bits = id_bits[:, 42:50]
        worker_bits = id_bits[:, 56:]
        sequence_ids = sequence_bits.dot(1 << np.arange(sequence_bits.shape[-1] - 1, -1, -1))
        worker_ids = worker_bits.dot(1 << np.arange(worker_bits.shape[-1] - 1, -1, -1))
        self._convert_to_id_seq(sequence_ids, worker_ids, sequence_length)


In [None]:
def create_datasets(seconds_ids, dataset_cls, sequence_length=None):
    train_size = int(0.8 * len(seconds_ids))
    train_ids = seconds_ids[:train_size]
    test_ids = seconds_ids[train_size:]
    train_dataset = dataset_cls(train_ids, sequence_length)
    test_dataset = dataset_cls(test_ids, sequence_length)
    return train_dataset, test_dataset

def get_loaders(train_dataset, test_dataset):
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_loader, test_loader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seq_len = 2048
method = 'sequence'
if method == 'bits':
    train_dataset, test_dataset = create_datasets(seconds_ids, IDDataset, sequence_length=seq_len)
    train_loader, test_loader = get_loaders(train_dataset, test_dataset)

    # Assuming each ID has 64 bits
    config = GPTNeoXConfig(
        vocab_size=64,  # output a 64 bit string
        max_position_embeddings=seq_len,  # Maximum sequence length
        hidden_size=64,  # Embedding size
        intermediate_size=8,
        num_hidden_layers=1,  # Number of transformer layers
        num_attention_heads=1,  # Number of attention heads
        # torch_dtype=torch.float16,
    )
elif method == 'sequence':
    train_dataset, test_dataset = create_datasets(seconds_ids, SequenceWorkerDataset, sequence_length=seq_len)
    train_loader, test_loader = get_loaders(train_dataset, test_dataset)

    # Assuming each ID has 64 bits
    config = GPTNeoXConfig(
        vocab_size=train_dataset.vocab_size,  # output a 64 bit string
        max_position_embeddings=seq_len,  # Maximum sequence length
        hidden_size=16,  # Embedding size
        intermediate_size=16,
        num_hidden_layers=4,  # Number of transformer layers
        num_attention_heads=2,  # Number of attention heads
        # torch_dtype=torch.float16,
    )
elif method == 'synthetic':
    train_dataset, test_dataset = create_datasets(seconds_ids, SyntheticSequenceWorkerDataset, sequence_length=seq_len)
    train_loader, test_loader = get_loaders(train_dataset, test_dataset)

    # Assuming each ID has 64 bits
    config = GPTNeoXConfig(
        vocab_size=train_dataset.vocab_size,  # output a 64 bit string
        max_position_embeddings=seq_len,  # Maximum sequence length
        hidden_size=16,  # Embedding size
        intermediate_size=16,
        num_hidden_layers=4,  # Number of transformer layers
        num_attention_heads=2,  # Number of attention heads
        # torch_dtype=torch.float16,
    )
    
model = GPTNeoXForCausalLM(config)
# model = model.half()
model = model.to(device)

In [None]:
def train(model, loader, device, epochs=5):
    model.train()
    optimizer = transformers.AdamW(model.parameters(), lr=0.01)
    scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(epochs * 0.1), num_training_steps=epochs)
    loss_fn = torch.nn.CrossEntropyLoss()

    method = 'sequence'

    for epoch in range(epochs):
        total_loss = 0
        current_loss = 0
        pbar = tqdm.tqdm(loader)
        for batch in pbar:
            if method == 'bits':
                input_ids, labels = batch
                input_ids, labels = input_ids.to(device), labels.to(device)
                outputs = model(inputs_embeds=input_ids)
                loss = loss_fn(outputs.logits, labels)
            elif method == 'sequence':
                input_ids = batch
                input_ids = input_ids.to(device)
                outputs = model(input_ids=input_ids, labels=input_ids)
                loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            current_loss = loss.item() / (input_ids.size(0) * input_ids.size(1))
            total_loss += current_loss
            if method == 'bits':
                pbar.set_description(f"Epoch {epoch+1}, Loss: {current_loss:.4f}")
            elif method == 'sequence':
                pbar.set_description(f"Epoch {epoch+1}, Loss: {current_loss:.8f}")
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(loader)}')

In [None]:
def evaluate_accuracy(model, test_loader, device):
    model.eval()
    num_correct = 0
    num_total = 0
    with torch.no_grad():
        for input_ids in tqdm.tqdm(test_loader):
            input_ids = input_ids.to(device)
            outputs = model(input_ids=input_ids)
            predictions = outputs.logits.argmax(dim=-1)
            labels = input_ids[:, 1:]
            predictions = predictions[:, :-1]
            num_correct += (predictions == labels).sum().item()
            num_total += input_ids.size(0) * input_ids.size(1)
    print(f'Accuracy: {num_correct / num_total}')

def evaluate_usage(model, test_dataset, device):
    model.eval()
    num_requests = 0
    num_misses = 0
    with torch.no_grad():
        for i in tqdm.tqdm(range(len(test_dataset))):
            # TODO sample sequence and worker ID from the model
            # test it (against test dataset)
            # if in the test dataset, add it to the sequence so far
            # else, add it to failed IDs, and generate new ID excluding failed IDs
            input_ids = input_ids.to(device)
            outputs = model(input_ids=input_ids)
            predictions = outputs.logits.argmax(dim=-1)
            labels = input_ids[:, 1:]
            predictions = predictions[:, :-1]
            num_correct += (predictions == labels).sum().item()
            num_total += input_ids.size(0) * input_ids.size(1)
    print(f'Accuracy: {num_correct / num_total}')

In [None]:
evaluate(model, test_dataset, device)

In [None]:
train(model, train_loader, device, epochs=1)

In [None]:
evaluate(model, test_loader, device)

In [None]:
def baseline(test_dataset):
    with open(os.path.join('..', 'figs', 'all_videos', 'all_two_segments_combinations.json'), 'r') as file:
        data = json.load(file)
    num_reqs_per_milli = len(data)
    num_seconds = test_dataset.num_seconds
    print(f"Num Requests: {num_reqs_per_milli * 1000 * num_seconds}, Coverage: {1.0}")
    
    
baseline(test_dataset)