# Pythia Analysis - train small models on HDFS data

* use tokenized version of preprocessed HDFS events
* start with very small pythia models, test increasing size
* start with fine-tuning, then consider resetting weights and training from scratch
* experiment with different tokenizers
  * https://chatgpt.com/share/67448f53-29a0-800f-9913-af22d6ed0894


In [1]:
try:
    import logparser.Drain as Drain
except ImportError:
    %pip install requests git+https://github.com/logpai/logparser

%pip install transformers torch torchvision torchaudio wandb python-dotenv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import logparser.Drain as Drain


In [3]:
%load_ext autoreload
%autoreload 2
import dataloaders as dl


# Download and unzip the HDFS dataset

The functions check if the data is already downloaded and unzipped, and only download and unzip if they are not present.


In [4]:

dl.download_data(dl.datasets["HDFS"]["url"], dl.datasets["HDFS"]["zip_file_name"])
dl.unzip_data(dl.datasets["HDFS"]["zip_file_name"], dl.datasets["HDFS"]["file_name"])

In [5]:
import pandas as pd

structured_file_path = dl.parse_dataset("HDFS")

structured_df = pd.read_csv(structured_file_path)
structured_df.head()


Unnamed: 0,LineId,Date,Time,Pid,Level,Component,Content,EventId,EventTemplate,ParameterList
0,1,81109,203518,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.19.102:5..."
1,2,81109,203518,35,INFO,dfs.FSNamesystem,BLOCK* NameSystem.allocateBlock: /mnt/hadoop/m...,3d91fa85,BLOCK* NameSystem.allocateBlock: <*> <*>,['/mnt/hadoop/mapred/system/job_200811092030_0...
2,3,81109,203519,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.10.6:405..."
3,4,81109,203519,145,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.14.224:4..."
4,5,81109,203519,145,INFO,dfs.DataNode$PacketResponder,PacketResponder 1 for block blk_-1608999687919...,d38aa58d,PacketResponder <*> for block <*> <*>,"['1', 'blk_-1608999687919862906 terminating']"


# Parse out the block id

In [6]:
dl.add_hdfs_blockid_column(structured_df)
structured_df.head()


Unnamed: 0,LineId,Date,Time,Pid,Level,Component,Content,EventId,EventTemplate,ParameterList,BlockId
0,1,81109,203518,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.19.102:5...",blk_-1608999687919862906
1,2,81109,203518,35,INFO,dfs.FSNamesystem,BLOCK* NameSystem.allocateBlock: /mnt/hadoop/m...,3d91fa85,BLOCK* NameSystem.allocateBlock: <*> <*>,['/mnt/hadoop/mapred/system/job_200811092030_0...,blk_-1608999687919862906
2,3,81109,203519,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.10.6:405...",blk_-1608999687919862906
3,4,81109,203519,145,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.14.224:4...",blk_-1608999687919862906
4,5,81109,203519,145,INFO,dfs.DataNode$PacketResponder,PacketResponder 1 for block blk_-1608999687919...,d38aa58d,PacketResponder <*> for block <*> <*>,"['1', 'blk_-1608999687919862906 terminating']",blk_-1608999687919862906


# Load the block labels

In [7]:
dl.unzip_data(dl.datasets["HDFS"]["zip_file_name"],"preprocessed/anomaly_label.csv", base_dir="data/hdfs" )

anomaly_label_df = pd.read_csv("data/hdfs/preprocessed/anomaly_label.csv")
anomaly_label_df.head()


Unnamed: 0,BlockId,Label
0,blk_-1608999687919862906,Normal
1,blk_7503483334202473044,Normal
2,blk_-3544583377289625738,Anomaly
3,blk_-9073992586687739851,Normal
4,blk_7854771516489510256,Normal


# Parse the parameter list

The parameter list is formatted as python code, so we need to use the `ast` library to parse it.

In [8]:
from ast import literal_eval

structured_df['ParsedParameterList'] = structured_df.ParameterList.apply(literal_eval)


In [9]:
event_id_mapping_pdf = (structured_df
 .EventId
 .value_counts()
 .reset_index()
 .reset_index()
 .rename(columns={"index":"NewEventId"})
 [["EventId", "NewEventId"]]
)

In [10]:
structured_with_event_id_pdf = structured_df.merge(event_id_mapping_pdf, on="EventId")
structured_with_event_id_pdf.head()

Unnamed: 0,LineId,Date,Time,Pid,Level,Component,Content,EventId,EventTemplate,ParameterList,BlockId,ParsedParameterList,NewEventId
0,1,81109,203518,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.19.102:5...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.19.102:5410...",0
1,2,81109,203518,35,INFO,dfs.FSNamesystem,BLOCK* NameSystem.allocateBlock: /mnt/hadoop/m...,3d91fa85,BLOCK* NameSystem.allocateBlock: <*> <*>,['/mnt/hadoop/mapred/system/job_200811092030_0...,blk_-1608999687919862906,[/mnt/hadoop/mapred/system/job_200811092030_00...,6
2,3,81109,203519,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.10.6:405...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.10.6:40524,...",0
3,4,81109,203519,145,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.14.224:4...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.14.224:4242...",0
4,5,81109,203519,145,INFO,dfs.DataNode$PacketResponder,PacketResponder 1 for block blk_-1608999687919...,d38aa58d,PacketResponder <*> for block <*> <*>,"['1', 'blk_-1608999687919862906 terminating']",blk_-1608999687919862906,"[1, blk_-1608999687919862906 terminating]",2


## Construct blocks to parse

https://raw.githubusercontent.com/EleutherAI/pythia/refs/heads/main/utils/20B_tokenizer.json has the tokenizer configuration.  We will use the `<|sep|>` token to immediately precede the short event id.  We need to add the `<|sep|>` token to the tokenizer, because it is not in the default tokenizer.  This will hopefully help the attention mechanism attend to the event id specifically.  We have shortened the event id to the minimum length based on the number occurences.  This will gives an efficient coding that will be less complicated for the attention mechanism.

We can consider a more customized tokenizer as another experiment.  This might help because of the special characters and the dominance of numbers in the logs.


In [58]:
from transformers import GPTNeoXTokenizerFast
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/pythia-14m")
tokenizer.add_special_tokens({"additional_special_tokens": ["<|sep|>"]})
tokenizer.sep_token = "<|sep|>"
tokenizer.sep_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id


Double check that the tokenizer properly encodes the new special token

In [59]:

tokenizer.encode("<|sep|>")


[50277]

Review then tokenizer configuration, again to ensure the new special token is included


In [13]:
tokenizer

GPTNeoXTokenizerFast(name_or_path='EleutherAI/pythia-14m', vocab_size=50254, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'sep_token': '<|sep|>', 'additional_special_tokens': ['<|sep|>']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|padding|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50254: AddedToken("                        ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50255: AddedToken("                       ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50256: AddedToken("                      ", rstrip=False, lstrip=False, single_word=False, normalized=T

In [14]:
structured_with_event_id_pdf['event_encoded'] = structured_with_event_id_pdf.apply(lambda row: f"{tokenizer.sep_token}{row['NewEventId']} {' '.join(param for param in row['ParsedParameterList'] if 'blk_' not in param)}", axis=1)
structured_with_event_id_pdf.head()


Unnamed: 0,LineId,Date,Time,Pid,Level,Component,Content,EventId,EventTemplate,ParameterList,BlockId,ParsedParameterList,NewEventId,event_encoded
0,1,81109,203518,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.19.102:5...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.19.102:5410...",0,<|sep|>0 /10.250.19.102:54106 /10.250.19.102:5...
1,2,81109,203518,35,INFO,dfs.FSNamesystem,BLOCK* NameSystem.allocateBlock: /mnt/hadoop/m...,3d91fa85,BLOCK* NameSystem.allocateBlock: <*> <*>,['/mnt/hadoop/mapred/system/job_200811092030_0...,blk_-1608999687919862906,[/mnt/hadoop/mapred/system/job_200811092030_00...,6,<|sep|>6
2,3,81109,203519,143,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.10.6:405...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.10.6:40524,...",0,<|sep|>0 /10.250.10.6:40524 /10.250.10.6:50010
3,4,81109,203519,145,INFO,dfs.DataNode$DataXceiver,Receiving block blk_-1608999687919862906 src: ...,09a53393,Receiving block <*> src: <*> dest: <*>,"['blk_-1608999687919862906', '/10.250.14.224:4...",blk_-1608999687919862906,"[blk_-1608999687919862906, /10.250.14.224:4242...",0,<|sep|>0 /10.250.14.224:42420 /10.250.14.224:5...
4,5,81109,203519,145,INFO,dfs.DataNode$PacketResponder,PacketResponder 1 for block blk_-1608999687919...,d38aa58d,PacketResponder <*> for block <*> <*>,"['1', 'blk_-1608999687919862906 terminating']",blk_-1608999687919862906,"[1, blk_-1608999687919862906 terminating]",2,<|sep|>2 1


In [15]:
encoded_blocks_series = structured_with_event_id_pdf.groupby("BlockId")['event_encoded'].apply(lambda x: "".join(x))
encoded_blocks_series.head()


BlockId
blk_-1000002529962039464    <|sep|>0 /10.251.123.1:41333 /10.251.123.1:500...
blk_-100000266894974466     <|sep|>6 <|sep|>0 /10.250.10.144:36204 /10.250...
blk_-1000007292892887521    <|sep|>0 /10.251.127.47:50228 /10.251.127.47:5...
blk_-1000014584150379967    <|sep|>0 /10.251.43.210:49254 /10.251.43.210:5...
blk_-1000028658773048709    <|sep|>0 /10.251.107.196:58917 /10.251.107.196...
Name: event_encoded, dtype: object

In [16]:
print(encoded_blocks_series.shape)
print(encoded_blocks_series.iloc[0])


(575061,)
<|sep|>0 /10.251.123.1:41333 /10.251.123.1:50010<|sep|>0 /10.251.123.1:53174 /10.251.123.1:50010<|sep|>0 /10.251.202.181:32980 /10.251.202.181:50010<|sep|>6 <|sep|>2 2<|sep|>3 3553241 /10.251.123.1<|sep|>2 0<|sep|>3 3553241 /10.251.202.181<|sep|>1 10.251.126.22:50010 3553241<|sep|>1 10.251.202.181:50010 3553241<|sep|>1 10.251.123.1:50010 3553241<|sep|>2 1<|sep|>3 3553241 /10.251.123.1


# Start with pretrained weights

The intuition is that the model will benefit some from understanding words and numbers (to some extent) when they appear, even if the structure of logs is very different from english sentences.  We can test this with an ablation study by randomizing the weights before training and then looking at the difference in the loss.

### Understanding Pythia Model Vocabulary Size Discrepancy

When loading a Pythia model from EleutherAI, I noticed a discrepancy between the model's embedding weight shape and the tokenizer vocabulary size:

```python
import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m")
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/pythia-14m")
model.get_input_embeddings().weight.data.shape
```

This outputs:
```
torch.Size([50304, 128])
```

However, the tokenizer's vocab size is:
```python
>>> tokenizer.vocab_size + len(tokenizer.added_tokens_encoder)
50279
```

Including special tokens, the vocab size is 50277.

The original 50304 dimensions confused me at first, but it turns out the size is padded in order to facilitate alignment with tensor cores. Specifically, `50304 = 2^7 * 3 * 131`, so the embedding size is a multiple of 128.

From [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/pdf/2401.14489v2):

> Tensor Cores can be fully utilized when GEMM dimensions m, k, and n are multiples
> of 16 bytes and 128 bytes for V100 and A100 GPUs, respectively. Since a FP16
> element is 2 bytes, this corresponds to dimension sizes that are multiples of 8
> and 64 elements, respectively.

So it looks like the embedding size is a multiple of 64.

### Solution

Add padding to the embedding size to match the parallelization factor.
```
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
```



In [60]:
import torch

from transformers import GPTNeoXForCausalLM

def get_model():

    model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m")
    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
    model.get_input_embeddings().weight.data.shape

    return model

model = get_model()

# Encode the blocks using the new tokenizer

In [18]:
encoded_blocks_pdf = encoded_blocks_series.to_frame()
encoded_blocks_pdf['encoded_block'] = encoded_blocks_pdf.event_encoded.apply(tokenizer.encode)


In [19]:
encoded_blocks_pdf

Unnamed: 0_level_0,event_encoded,encoded_block
BlockId,Unnamed: 1_level_1,Unnamed: 2_level_1
blk_-1000002529962039464,<|sep|>0 /10.251.123.1:41333 /10.251.123.1:500...,"[50277, 17, 1227, 740, 15, 21451, 15, 10683, 1..."
blk_-100000266894974466,<|sep|>6 <|sep|>0 /10.250.10.144:36204 /10.250...,"[50277, 23, 209, 50277, 17, 1227, 740, 15, 951..."
blk_-1000007292892887521,<|sep|>0 /10.251.127.47:50228 /10.251.127.47:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 11946, 1..."
blk_-1000014584150379967,<|sep|>0 /10.251.43.210:49254 /10.251.43.210:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 3079, 15..."
blk_-1000028658773048709,<|sep|>0 /10.251.107.196:58917 /10.251.107.196...,"[50277, 17, 1227, 740, 15, 21451, 15, 12224, 1..."
...,...,...
blk_999905757185707736,<|sep|>0 /10.251.39.160:41914 /10.251.39.160:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 1867, 15..."
blk_999915040208161699,<|sep|>0 /10.251.43.210:46583 /10.251.43.210:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 3079, 15..."
blk_999958959261325562,<|sep|>0 /10.251.203.246:56717 /10.251.203.246...,"[50277, 17, 1227, 740, 15, 21451, 15, 17490, 1..."
blk_999974850451006327,<|sep|>0 /10.251.126.5:32870 /10.251.126.5:500...,"[50277, 17, 1227, 740, 15, 21451, 15, 13381, 1..."


In [20]:
print(f"total token count: {encoded_blocks_pdf.encoded_block.apply(len).sum():,}")
encoded_blocks_pdf.encoded_block.apply(len).describe()

total token count: 137,942,766


count    575061.000000
mean        239.875015
std          85.098227
min          27.000000
25%         219.000000
50%         219.000000
75%         223.000000
max        5770.000000
Name: encoded_block, dtype: float64

In [63]:
encoded_blocks_pdf.encoded_block.apply(len).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])


count    575061.000000
mean        239.875015
std          85.098227
min          27.000000
1%           54.000000
5%          174.000000
10%         174.000000
25%         219.000000
50%         219.000000
75%         223.000000
90%         343.000000
95%         405.000000
99%         476.000000
max        5770.000000
Name: encoded_block, dtype: float64

In [21]:
torch.backends.mps.is_available()

True

In [22]:
from sklearn.model_selection import train_test_split

# Merge with anomaly labels
encoded_blocks_with_labels = encoded_blocks_pdf.merge(
    anomaly_label_df, 
    left_index=True, 
    right_on='BlockId'
)

# Split into train/test sets (80/20 split)
train_df, test_df = train_test_split(
    encoded_blocks_with_labels,
    test_size=0.2,
    random_state=42,
    stratify=encoded_blocks_with_labels['Label']
)

print(f"Training samples: {len(train_df)}")
print(f"Test samples: {len(test_df)}")

Training samples: 460048
Test samples: 115013


In [23]:
train_df

Unnamed: 0,event_encoded,encoded_block,BlockId,Label
257494,<|sep|>0 /10.251.67.211:54457 /10.251.67.211:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 2251, 15...",blk_-4040947678439826686,Normal
49365,<|sep|>6 <|sep|>0 /10.251.106.37:36707 /10.251...,"[50277, 23, 209, 50277, 17, 1227, 740, 15, 214...",blk_1870752360007129176,Normal
7319,<|sep|>6 <|sep|>0 /10.251.121.224:40809 /10.25...,"[50277, 23, 209, 50277, 17, 1227, 740, 15, 214...",blk_-1999301527305082358,Normal
295080,<|sep|>0 /10.251.123.20:56258 /10.251.123.20:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 10683, 1...",blk_-2322520798745751605,Normal
64733,<|sep|>6 <|sep|>0 /10.251.107.242:55242 /10.25...,"[50277, 23, 209, 50277, 17, 1227, 740, 15, 214...",blk_-4090429635427697097,Normal
...,...,...,...,...
424427,<|sep|>0 /10.251.37.240:42153 /10.251.37.240:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 1787, 15...",blk_4272247743717120753,Normal
403348,<|sep|>0 /10.251.215.50:36443 /10.251.215.50:5...,"[50277, 17, 1227, 740, 15, 21451, 15, 21351, 1...",blk_1218092075075778522,Normal
253046,<|sep|>0 /10.250.11.53:53272 /10.250.11.53:500...,"[50277, 17, 1227, 740, 15, 9519, 15, 883, 15, ...",blk_-4591257497708039986,Normal
495499,<|sep|>0 /10.251.125.174:53652 /10.251.125.174...,"[50277, 17, 1227, 740, 15, 21451, 15, 9312, 15...",blk_-4092465791855115484,Normal


In [64]:
# Set up training parameters
BATCH_SIZE = 4  # Small batch size for M3
MAX_LENGTH = 343  # Truncate sequences to manage memory
LEARNING_RATE = 1e-4
NUM_EPOCHS = 3



In [25]:
from dotenv import load_dotenv
import os, wandb

load_dotenv()

wandb.login(key=os.getenv("WANDB_API_KEY"))


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
huggingface/tokenizers: The

True

In [68]:
import gc
import psutil

def print_memory_stats(prefix=""):
    """Detailed memory statistics"""
    allocated = torch.mps.current_allocated_memory() / 1024**3
    reserved = torch.mps.driver_allocated_memory() / 1024**3
    
    print(f"\n{prefix} Memory Status:")
    print(f"├── Allocated: {allocated:.2f} GB (actively used by tensors)")
    print(f"├── Reserved:  {reserved:.2f} GB (held by MPS driver)")
    print(f"├── Cached:    {(reserved - allocated):.2f} GB (reserved - allocated)")
    
    # System memory info
    vm = psutil.virtual_memory()
    print(f"└── System Available: {vm.available / 1024**3:.2f} GB")


def get_gpu_memory_metrics():
    """Get system metrics for logging"""
    return {        
        # MPS (GPU) metrics
        "system/mps_memory_allocated_gb": torch.mps.current_allocated_memory() / (1024**3),
        "system/mps_memory_reserved_gb": torch.mps.driver_allocated_memory() / (1024**3),
    }

def clear_memory():
    """Explicitly clear memory"""
    gc.collect()
    torch.mps.empty_cache()
    if hasattr(torch, 'cuda'):
        torch.cuda.empty_cache()

In [65]:
# Create DataLoader
class HDFSDataset(torch.utils.data.Dataset):
    def __init__(self, encoded_blocks, max_length):
        self.encoded_blocks = encoded_blocks
        self.max_length = max_length
        
    def __len__(self):
        return len(self.encoded_blocks)
    
    def __getitem__(self, idx):
        tokens = self.encoded_blocks.iloc[idx]['encoded_block']
        # Truncate if needed
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        
        # Convert to tensor and pad
        input_ids = torch.tensor(tokens, dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
        }

# Create dataset and dataloader
dataset = HDFSDataset(encoded_blocks_pdf, MAX_LENGTH)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda x: {
        'input_ids': torch.nn.utils.rnn.pad_sequence(
            [item['input_ids'] for item in x],
            batch_first=True,
            padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id else 0
        ),
        'attention_mask': torch.nn.utils.rnn.pad_sequence(
            [item['attention_mask'] for item in x],
            batch_first=True,
            padding_value=0
        )
    }
)

In [69]:
import numpy as np

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = get_model().to(device)

def find_optimal_batch_size(start_size=1, max_size=32):
    """Find the largest batch size that fits in memory with detailed logging"""
    
    print("Testing batch sizes...")
    for batch_size in [2**i for i in range(start_size, int(np.log2(max_size)) + 1)]:
        try:
            print(f"\nTesting batch size {batch_size}")
            print_memory_stats("Initial")
            
            # Create test dataloader
            test_loader = torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                collate_fn=dataloader.collate_fn
            )
            
            # Get batch
            print("Loading batch...")
            batch = next(iter(test_loader))
            print_memory_stats("After batch load")
            
            # Move to device
            print("Moving to device...")
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            print_memory_stats("After moving to device")
            
            # Forward pass
            print("Forward pass...")
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=input_ids
            )
            print_memory_stats("After forward pass")
            
            # Backward pass
            print("Backward pass...")
            loss = outputs.loss
            loss.backward()
            print_memory_stats("After backward pass")
            
            # Clean up
            del outputs, loss, input_ids, attention_mask
            torch.mps.empty_cache()
            gc.collect()
            print_memory_stats("After cleanup")
            
        except RuntimeError as e:
            print(f"\nBatch size {batch_size} failed!")
            print(f"Error: {str(e)[:200]}...")
            return batch_size//2
            
    return max_size

# Find optimal batch size with detailed memory tracking
optimal_batch_size = find_optimal_batch_size()
print(f"Optimal batch size: {optimal_batch_size}")


wandb-core(79451) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Testing batch sizes...

Testing batch size 2

Initial Memory Status:
├── Allocated: 1.45 GB (actively used by tensors)
├── Reserved:  10.61 GB (held by MPS driver)
├── Cached:    9.16 GB (reserved - allocated)
└── System Available: 2.13 GB
Loading batch...
After batch load MPS Memory:
  Allocated: 1.45 GB
  Reserved:  10.61 GB
Moving to device...
After moving to device MPS Memory:
  Allocated: 1.45 GB
  Reserved:  10.61 GB
Forward pass...
After forward pass MPS Memory:
  Allocated: 1.86 GB
  Reserved:  10.64 GB
Backward pass...
After backward pass MPS Memory:
  Allocated: 1.67 GB
  Reserved:  10.80 GB


wandb-core(79548) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(79592) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(79809) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(80002) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


After cleanup MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB

Testing batch size 4

Initial Memory Status:
├── Allocated: 1.50 GB (actively used by tensors)
├── Reserved:  6.51 GB (held by MPS driver)
├── Cached:    5.00 GB (reserved - allocated)
└── System Available: 2.67 GB
Loading batch...
After batch load MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Moving to device...
After moving to device MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Forward pass...
After forward pass MPS Memory:
  Allocated: 2.06 GB
  Reserved:  6.57 GB
Backward pass...
After backward pass MPS Memory:
  Allocated: 1.73 GB
  Reserved:  7.78 GB
After cleanup MPS Memory:

wandb-core(80128) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



  Allocated: 1.50 GB
  Reserved:  6.51 GB

Testing batch size 8

Initial Memory Status:
├── Allocated: 1.50 GB (actively used by tensors)
├── Reserved:  6.51 GB (held by MPS driver)
├── Cached:    5.00 GB (reserved - allocated)
└── System Available: 3.12 GB
Loading batch...
After batch load MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Moving to device...
After moving to device MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Forward pass...
After forward pass MPS Memory:
  Allocated: 2.92 GB
  Reserved:  8.16 GB
Backward pass...
After backward pass MPS Memory:
  Allocated: 2.05 GB
  Reserved:  9.70 GB
After cleanup MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB

Testing batch size 16

Initial Memory Status:
├── Allocated: 1.50 GB (actively used by tensors)
├── Reserved:  6.51 GB (held by MPS driver)
├── Cached:    5.00 GB (reserved - allocated)
└── System Available: 3.18 GB
Loading batch...
After batch load MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Mov

wandb-core(80175) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


After forward pass MPS Memory:
  Allocated: 4.33 GB
  Reserved:  10.83 GB
Backward pass...
After backward pass MPS Memory:
  Allocated: 2.60 GB
  Reserved:  13.91 GB


wandb-core(80181) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(80238) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


After cleanup MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB

Testing batch size 32

Initial Memory Status:
├── Allocated: 1.50 GB (actively used by tensors)
├── Reserved:  6.51 GB (held by MPS driver)
├── Cached:    5.00 GB (reserved - allocated)
└── System Available: 3.09 GB
Loading batch...
After batch load MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Moving to device...
After moving to device MPS Memory:
  Allocated: 1.50 GB
  Reserved:  6.51 GB
Forward pass...
After forward pass MPS Memory:
  Allocated: 6.63 GB
  Reserved:  15.04 GB
Backward pass...


wandb-core(80331) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(80337) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



Batch size 32 failed!
Error: MPS backend out of memory (MPS allocated: 7.27 GB, other allocations: 5.73 GB, max allowed: 18.13 GB). Tried to allocate 2.06 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable up...
Optimal batch size: 16


In [38]:
print(f"{torch.mps.current_allocated_memory() / 1024**3:.2f} GB")

1.38 GB


In [27]:
clear_memory()

In [70]:
# Move model to MPS device if available, otherwise CPU
model = get_model().to(device)

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

wandb.init(
    project="log-analysis-pythia",
    config={
        "batch_size": BATCH_SIZE,
        "max_length": MAX_LENGTH,
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "model": "pythia-14m",  
    }
)

# Training loop
global_step = 0
model.train()
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    for batch_idx, batch in enumerate(dataloader):
        
        if batch_idx % 100 == 0:
            clear_memory()

        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids  # For causal language modeling
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print progress every 100 batches
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

        wandb.log({
            "train/batch_loss": loss.item(),
            "train/epoch": epoch + 1,
            "train/batch": batch_idx,
            **get_gpu_memory_metrics()
        }, step=global_step)
        global_step += 1
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} complete. Average loss: {avg_loss:.4f}")
    wandb.log({
        "epoch_avg_loss": avg_loss,
        "epoch": epoch + 1,
    })

wandb-core(81352) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


0,1
batch,▁▁▁▁▁▂▂▂▂▂▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████
batch_loss,█▃▂▂▂▁▂▂▂▂▂▂▁▁▂▂▁▂▁▂▁▁▁▁▂▁▂▂▁▁▁▁▂▁▁▁▁▁▂▂
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
system/mps_memory_allocated_gb,▂▄▃▄▅▄▄▁▁▄▄▃▂▃▆▃█▁▃▂▁▃▂▂▄▃▃▃▁▅▅▃▂▄█▆▄▂▅▅
system/mps_memory_reserved_gb,▁▄▅▄▄▆▆▆▄▄▄▅▅▅▅▆▅▆▆▆▆▇▆▆▆▆▆▇▇▇▇▆▆▆██▇▇▇▇

0,1
batch,6118.0
batch_loss,0.18284
epoch,1.0
system/mps_memory_allocated_gb,0.64247
system/mps_memory_reserved_gb,17.45039


wandb-core(81483) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81484) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


wandb-core(81508) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81521) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch 1, Batch 0, Loss: 115.7910


wandb-core(81595) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81768) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81819) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81824) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(81864) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(82012) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(82020) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(82033) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(82037) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(82049) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 100, Loss: 1.0232


wandb-core(83669) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83686) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83754) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83861) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83925) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83963) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83969) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(83991) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(84001) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(84009) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 200, Loss: 0.6087


wandb-core(85111) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(85394) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(85586) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(85721) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(85790) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(85874) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(86344) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(86544) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(86725) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(86907) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 300, Loss: 0.5559


wandb-core(88762) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89079) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89301) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89318) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89324) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89419) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89442) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89621) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89655) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(89680) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 400, Loss: 0.4365


wandb-core(93438) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93446) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93532) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93634) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93720) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93767) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93800) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93841) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93887) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch 1, Batch 500, Loss: 0.3324


wandb-core(93900) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(93919) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94096) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94167) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94171) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94182) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94289) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94295) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94320) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94324) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 600, Loss: 0.3675


wandb-core(94482) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94501) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94507) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94654) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94786) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94859) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(94952) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95047) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95202) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95396) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 700, Loss: 0.3071


wandb-core(95447) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95498) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95609) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95695) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95741) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95815) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(95936) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96229) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Epoch 1, Batch 800, Loss: 0.3707


wandb-core(96237) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96244) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96249) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96298) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96708) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(96714) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97074) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97084) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97140) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97320) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 900, Loss: 0.3009


wandb-core(97485) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97517) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97569) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97648) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97908) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97917) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97958) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97965) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(97970) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(98019) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 1000, Loss: 0.3606


wandb-core(98961) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(98980) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99082) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99092) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99153) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99226) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99272) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99286) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99300) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(99343) MallocStackLogging: can't turn off malloc stack logging because 

Epoch 1, Batch 1100, Loss: 0.3932


wandb-core(597) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(626) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(710) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(780) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(857) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(930) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(943) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(965) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(1001) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(1015) MallocStackLogging: can't turn off malloc stack logging because it was not enabled

Epoch 1, Batch 1200, Loss: 0.3150


wandb-core(2155) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2379) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2445) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2458) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2469) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2525) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(2731) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(3084) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(3628) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(3666) MallocStackLogging: can't turn off malloc stack logging because it was not

In [30]:
get_gpu_memory_metrics()


{'system/mps_memory_allocated_gb': 1.316429853439331,
 'system/mps_memory_reserved_gb': 16.292190551757812}

In [55]:
output_k = np.array(input_ids.cpu())
sum(sum(output_k == 0))

  output_k = np.array(input_ids.cpu())


np.int64(742)

In [56]:
tokenizer.pad_token_id

In [57]:
tokenizer.eos_token_id

0