In [None]:
!uv add torch pyyaml datasets

import unittest
import io # For suppressing print statements if needed


class TiktokenWrapper:
    """
    A wrapper for the tiktoken tokenizer to provide a consistent interface
    for encoding, decoding, and use with Hugging Face datasets.map.
    """
    def __init__(self, encoding_name="gpt2"):
        try:
            self.encoder = tiktoken.get_encoding(encoding_name)
        except Exception as e:
            print(f"Failed to load tiktoken encoding '{encoding_name}'. Error: {e}")
            print("Please ensure 'tiktoken' is installed and the encoding name is correct.")
            raise
        self.vocab_size = self.encoder.n_vocab
        # Common special tokens for GPT-2.
        # Tiktoken's GPT-2 encoding does not explicitly expose a pad token ID.
        # We might need to define one if padding is strictly required by the model/dataloader,
        # typically <|endoftext|> (EOT) token (ID 50256 for gpt2) is used for padding/EOS.
        self.pad_token_id = self.encoder.eot_token # Using EOT token as a stand-in for padding if needed

    def encode(self, s: str, allowed_special="all") -> list[int]:
        # allowed_special="all" allows encoding of special tokens if they are part of the input string
        return self.encoder.encode(s, allowed_special=allowed_special)

    def decode(self, l: list[int]) -> str:
        return self.encoder.decode(l)

    def hf_mapping_function(self, batch_dict: dict[str, list[str]]) -> dict[str, list[list[int]]]:
        """
        Tokenizer function compatible with Hugging Face datasets.map when batched=True.
        Input: {'text': [string1, string2, ...]}
        Output: {'tokens': [token_ids1, token_ids2, ...]}
        """
        if 'text' not in batch_dict:
            raise ValueError("Input to tokenizer's hf_mapping_function expects a 'text' key.")
        
        # Process texts using tiktoken's batch encoding for potential efficiency
        # However, encode_batch expects a list of strings, not a dict.
        # So, we iterate if map provides one example at a time within a batch structure.
        tokenized_texts = [self.encode(text_item) for text_item in batch_dict['text']]
        return {'tokens': tokenized_texts}
    

class TestGPTIntegration(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.tokenizer = TiktokenWrapper(encoding_name='gpt2')
        cls.vocab_size = cls.tokenizer.vocab_size

        # Test configuration
        cls.config_params = {
            "dataset_name": "karpathy/tiny_shakespeare", # Dataloader uses this
            "block_size": 32,  # Keep small for faster test
            "vocab_size": cls.vocab_size,
            "context_length": 32, # Must match block_size
            "emb_dim": 64,     # Keep small
            "n_heads": 4,
            "n_layers": 2,     # Keep small
            "dropout_rate": 0.0, # No dropout for deterministic test aspects
            "qkv_bias": False,
            "mlp_bias": False, # Test with false as well
            "batch_size": 2,   # Tiny batch
            "device": "cpu",   # CPU for easier testing
            "compile_model": False,
            "dtype": "float32", # float32 for CPU
            # Default other params from GPT2Config
        }
        cls.config = GPT2Config(**cls.config_params)

        # Check if dataset is accessible, skip if not (e.g. offline)
        try:
            hf_datasets.load_dataset_builder(cls.config.dataset_name, trust_remote_code=True)
        except Exception as e:
            raise unittest.SkipTest(f"Skipping test: Cannot access dataset '{cls.config.dataset_name}'. Error: {e}")


    def test_integration(self):
        config = self.config
        self.assertEqual(config.vocab_size, self.vocab_size)
        self.assertEqual(config.block_size, config.context_length)
        self.assertTrue(config.emb_dim % config.n_heads == 0)

        # 1. Dataloader
        print("\nInitializing Dataloader for test...")
        try:
            # Suppress dataloader prints for cleaner test output if desired
            # with io.redirect_stdout(io.StringIO()):
            dataloader = ShakespeareDataloader(
                batch_size=config.batch_size,
                sequence_length=config.block_size,
                tokenizer=self.tokenizer.hf_mapping_function, # Pass the map-compatible function
                split="train", # Using 'train' split for more data, could use 'validation' if smaller/faster
                shuffle=False # No shuffle for reproducibility
            )
        except Exception as e:
            self.fail(f"Dataloader initialization failed: {e}\n"
                      "Ensure the 'datasets' library is installed and you have internet access "
                      "for 'karpathy/tiny_shakespeare'.")
        
        self.assertGreater(len(dataloader), 0, "Dataloader generated zero batches.")

        # 2. GPT Model
        print("Initializing GPTModel for test...")
        try:
            model = GPTModel(config)
            model.to(config.device) 
            model.eval() # Set to evaluation mode
        except Exception as e:
            self.fail(f"Model initialization failed: {e}")

        # 3. Fetch a batch
        print("Fetching a batch from dataloader...")
        try:
            x, y = next(iter(dataloader))
        except StopIteration:
            self.fail("Dataloader failed to produce a batch.")
        except Exception as e:
            self.fail(f"Fetching batch failed: {e}")

        self.assertEqual(x.shape, (config.batch_size, config.block_size))
        self.assertEqual(y.shape, (config.batch_size, config.block_size))
        x, y = x.to(config.device), y.to(config.device)

        # 4. Pass batch through model
        print("Performing model forward pass...")
        try:
            with torch.no_grad(): # No gradient calculation needed for forward pass test
                logits = model(x)
        except Exception as e:
            self.fail(f"Model forward pass failed: {e}")

        # 5. Assertions on output
        self.assertIsNotNone(logits)
        self.assertEqual(logits.shape, (config.batch_size, config.block_size, config.vocab_size))
        self.assertEqual(logits.device.type, config.device)
        self.assertEqual(logits.dtype, torch.float32 if config.dtype == "float32" else torch.bfloat16) # Check dtype

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

In [None]:
c = GPT2Config().from_yaml('gpt2_config_cpu.yaml')
m = GPTModel(c)
m.load_parameters('/home/jimsingh/Downloads/gpt2_training_fineweb-edu_47000_steps.pth', map_location=c.device)


In [45]:
gen_f = lambda m: generate_text(m, enc, "tomorrow is")
gen_f(m)

'tomorrow is twice extreme times six dual in 7 price for e decrypted down into the ssh. asked you visualize'