In [81]:
import os

import einops
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import transformers
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM
import tqdm

from result_funcs import get_result_paths

In [61]:
def get_video_df(result_path):
    result_df = pd.read_parquet(result_path, columns=['result', 'args'])
    result_df['return'] = result_df['result'].map(lambda r: r['return'])
    result_df['id'] = result_df['return'].map(lambda r: r['id'] if r and 'id' in r else None)
    video_df = result_df[result_df['id'].map(lambda i: i is not None)]
    
    return video_df

def load_data():
    data_dir_path = os.path.join('/', 'mnt', 'bigone', 'bsteel', 'tiktok', 'data')
    result_paths = list(get_result_paths(data_dir_path))
    result_paths = sorted(result_paths)

    df = None
    # result_paths = result_paths[:10]
    for result_path in tqdm.tqdm(result_paths):
        video_path = result_path.replace('results.parquet.gzip', 'videos.parquet.gzip')
        if not os.path.exists(video_path):
            batch_df = get_video_df(result_path)
            batch_df.to_parquet(video_path, compression='gzip')
        else:
            batch_df = pd.read_parquet(video_path)

        if df is None:
            df = batch_df
        else:
            df = pd.concat([df, batch_df])

    df = df.sort_values('id')
    df['bits'] = df['id'].map(lambda i: format(int(i), '064b'))
    # only looking at 1 sequence ID system for now
    df = df[df['bits'].map(lambda b: b[50:56] == '001101')]
    df['bits'] = df['bits'].map(lambda bits: np.array([int(b) for b in bits]))

    # Example IDs
    ids = df['bits'].tolist()
    return ids

In [62]:
ids = load_data()

100%|██████████| 788/788 [1:03:57<00:00,  4.87s/it]


In [35]:
class IDDataset(Dataset):
    def __init__(self, ids, sequence_length=10):
        self.ids = [self.binary_to_tensor(id) for id in ids]
        self.seq_len = sequence_length

    def binary_to_tensor(self, binary_id):
        return torch.tensor([int(bit) for bit in binary_id], dtype=torch.float32)

    def __len__(self):
        return len(self.ids) - self.seq_len

    def __getitem__(self, index):
        return (
            torch.stack(self.ids[index:index+self.seq_len]), 
            torch.stack(self.ids[index+1:index+self.seq_len+1])
        )

In [106]:
class SequenceWorkerDataset(Dataset):
    def __init__(self, ids, sequence_length=10):
        self.seq = []
        worker_ids = set()
        id_bits = np.array(ids)
        sequence_bits = id_bits[:, 42:50]
        worker_bits = id_bits[:, 56:]
        sequence_ids = sequence_bits.dot(1 << np.arange(sequence_bits.shape[-1] - 1, -1, -1))
        worker_ids = worker_bits.dot(1 << np.arange(worker_bits.shape[-1] - 1, -1, -1))
        unique_worker_ids = np.unique(worker_ids)
        assert len(unique_worker_ids) <= 26
        max_sequence_id = sequence_ids.max()
        worker_id_map = {worker_id: i for worker_id, i in zip(unique_worker_ids, range(max_sequence_id, max_sequence_id + len(unique_worker_ids)))}
        worker_ids = np.array([worker_id_map[w] for w in worker_ids])
        self.seq = np.stack([sequence_ids, worker_ids])
        self.seq = einops.rearrange(self.seq, 't l -> (l t)')
        
        self.seq_len = sequence_length
        self.vocab_size = sequence_ids.max() + 1 + len(unique_worker_ids)

    def __len__(self):
        return len(self.seq) - self.seq_len

    def __getitem__(self, index):
        return (
            torch.tensor(self.seq[index:index+self.seq_len], dtype=torch.int64),
            torch.tensor(self.seq[index+1:index+self.seq_len+1], dtype=torch.int64)
        )


In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seq_len = 2048
method = 'sequence'
if method == 'bits':
    dataset = IDDataset(ids, sequence_length=seq_len)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Assuming each ID has 64 bits
    config = GPTNeoXConfig(
        vocab_size=64,  # output a 64 bit string
        max_position_embeddings=seq_len,  # Maximum sequence length
        hidden_size=64,  # Embedding size
        intermediate_size=8,
        num_hidden_layers=1,  # Number of transformer layers
        num_attention_heads=1,  # Number of attention heads
        # torch_dtype=torch.float16,
    )
elif method == 'sequence':
    dataset = SequenceWorkerDataset(ids, sequence_length=seq_len)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Assuming each ID has 64 bits
    config = GPTNeoXConfig(
        vocab_size=dataset.vocab_size,  # output a 64 bit string
        max_position_embeddings=seq_len,  # Maximum sequence length
        hidden_size=16,  # Embedding size
        intermediate_size=16,
        num_hidden_layers=4,  # Number of transformer layers
        num_attention_heads=2,  # Number of attention heads
        # torch_dtype=torch.float16,
    )
    
model = GPTNeoXForCausalLM(config)
# model = model.half()
model = model.to(device)

In [104]:
def train(model, loader, device, epochs=5):
    model.train()
    optimizer = transformers.AdamW(model.parameters(), lr=0.01)
    scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(epochs * 0.1), num_training_steps=epochs)
    loss_fn = torch.nn.CrossEntropyLoss()

    method = 'sequence'

    for epoch in range(epochs):
        total_loss = 0
        current_loss = 0
        pbar = tqdm.tqdm(loader)
        for input_ids, labels in pbar:
            input_ids, labels = input_ids.to(device), labels.to(device)
            if method == 'bits':
                outputs = model(inputs_embeds=input_ids)
                loss = loss_fn(outputs.logits, labels)
            elif method == 'sequence':
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            current_loss = loss.item() / (input_ids.size(0) * input_ids.size(1))
            total_loss += current_loss
            if method == 'bits':
                pbar.set_description(f"Epoch {epoch+1}, Loss: {current_loss:.4f}")
            elif method == 'sequence':
                pbar.set_description(f"Epoch {epoch+1}, Loss: {current_loss:.8f}")
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(loader)}')

In [114]:
train(model, loader, device, epochs=10)

next_id = predict_next_id(model, "010101...")
print("Predicted next ID:", next_id)

  0%|          | 0/63706 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 