In [1]:
import pandas as pd
import torch
import sys, os
from torch.utils.data import DataLoader
from setupData import GameStatsTextDataset, collate_fn



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
csv_path = "/Users/josheevrai/CS4650/GameStats2Text/data/dataset.csv"

# 1. Instantiate the dataset
dataset = GameStatsTextDataset(csv_file=csv_path, tokenizer_name='gpt2', max_length=128)

# 2. Basic sanity checks
print(f"→ Number of examples: {len(dataset)}")
print(f"→ Stat feature columns ({len(dataset.feature_cols)}): {dataset.feature_cols}\n")

# 3. Inspect a single example
sample = dataset[0]
print("Sample[0]['stats'] shape:", sample['stats'].shape)
print("Sample[0]['input_ids'] shape:", sample['input_ids'].shape)
print("Sample[0]['attention_mask'] shape:", sample['attention_mask'].shape)
print("Sample[0]['labels'] shape:", sample['labels'].shape)

# 4. Decode to verify text alignment
decoded_q = dataset.tokenizer.decode(sample['input_ids'],   skip_special_tokens=True)
decoded_a = dataset.tokenizer.decode(sample['labels'],      skip_special_tokens=True)
print(f"\nDecoded question: {decoded_q}")
print(f"Decoded answer:   {decoded_a}")

# 5. Test batching
loader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
batch = next(iter(loader))
print("\nBatch shapes:")
for k, v in batch.items():
    print(f"  {k}: {v.shape}")

print("\n✅ setupData.py is loading and batching correctly!")




→ Number of examples: 1193
→ Stat feature columns (10): ['MP', 'PTS', 'FG%', 'TRB', 'AST', 'STL', 'BLK', 'TOV', 'PF', 'Result']

Sample[0]['stats'] shape: torch.Size([10])
Sample[0]['input_ids'] shape: torch.Size([128])
Sample[0]['attention_mask'] shape: torch.Size([128])
Sample[0]['labels'] shape: torch.Size([128])

Decoded question: I know you say that you're a football player, but do you think football players would get some fouls?
Decoded answer:   (Laughing) It was definitely a physical game tonight.  You know, fouls were called at times and weren't called at times.  You know, this is what it's about.  You know, you can't look to get fouls and you've got to try to be as aggressive as possible.

Batch shapes:
  stats: torch.Size([4, 10])
  input_ids: torch.Size([4, 128])
  attention_mask: torch.Size([4, 128])
  labels: torch.Size([4, 128])

✅ setupData.py is loading and batching correctly!


In [14]:
print(sample['stats'])

tensor([45.3667, 10.0000,  0.3330, 10.0000,  9.0000,  4.0000,  1.0000,  2.0000,
         1.0000,  0.0000])


In [15]:
print(sample['input_ids'])

tensor([   40,   760,   345,   910,   326,   345,   821,   257,  4346,  2137,
           11,   475,   466,   345,   892,  4346,  1938,   561,   651,   617,
        15626,    82,    30, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256])

In [17]:
print(sample['attention_mask'])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])


In [18]:
print(sample['labels'])

tensor([    7,    43,  1567,   278,     8,   632,   373,  4753,   257,  3518,
          983,  9975,    13,   220,   921,   760,    11, 15626,    82,   547,
         1444,   379,  1661,   290,  6304,   470,  1444,   379,  1661,    13,
          220,   921,   760,    11,   428,   318,   644,   340,   338,   546,
           13,   220,   921,   760,    11,   345,   460,   470,   804,   284,
          651, 15626,    82,   290,   345,  1053,  1392,   284,  1949,   284,
          307,   355,  8361,   355,  1744,    13, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256])