In [7]:
from src.data import MovingWindowProduct, MovingWindowSum, MovingWindowDifference

ModuleNotFoundError: No module named 'src'

In [8]:
import torch
from dotmap import DotMap

In [9]:

class MovingWindowDifference:
    def __init__(self, config, device="cuda"):
        self.min_num = getattr(config, "min_num", 1)
        self.max_num = getattr(config, "max_num", 16)
        self.k = getattr(config, "k", 2)
        self.p = getattr(config, "p", 17)
        self.sep = getattr(config, "sep", 17)
        self.device = device
        assert self.p > self.max_num

    @torch.no_grad()
    def sample(
        self,
        num_samples,
        num_tokens,
    ):
        random_ints = torch.randint(
            low=self.min_num, high=self.max_num + 1, size=(num_samples, num_tokens)
        ).to(self.device)

        random_ints_np = random_ints.detach().cpu().numpy()

        moving_difference = random_ints.clone().detach()
        moving_difference = random_ints.clone()

        for j in range(self.k - 1, num_tokens):
            window = random_ints[:, j - self.k + 1 : j + 1]  # shape (num_samples, k)
            d = window[:, 0]
            for t in range(1, self.k):
                d = d - window[:, t]  # subtract all other elements in the window
            moving_difference[:, j] = d

        samples = (
            torch.cat(
                [
                    random_ints,
                    self.sep * torch.ones(size=(num_samples, 1)).to(self.device),
                    torch.remainder(input=moving_difference, other=self.p),
                ],
                axis=-1,
            )
            .to(int)
            .detach()
        )

        return samples

In [10]:
# Config

device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
'model':
  {
    'n_layer': 1,
    'n_head': 1,
    'n_embd': 256,
    'linear': True,
  },

'data':
  {
    'name': 'window',
    'min_num': 1,
    'max_num': 16,
    'k': 2,
    'p': 17,
    'sep': 17,
    'cot': False,
    'num_tokens': 16,
    'n_train': 256,
    'n_test': 64,
    'fixed_len': True,
  },

'train':
  {
    'lr': 0.0001,
    'grad_clip': -1,
    'num_steps': 500,
    'norm_type': "none_rank",
    'wandb': True,
    'save_ckpt': False,
    'ckpt_freq': 20,
  }
}
config = DotMap(config)

In [10]:
data_samplers = {}
data_samplers['mws'] = MovingWindowSum(
    min_num=config.data.min_num,
    max_num=config.data.max_num,
    k=config.data.k,
    p=config.data.p,
    sep = 17,
)
data_samplers['mwp'] = MovingWindowProduct(
    min_num=config.data.min_num,
    max_num=config.data.max_num,
    k=config.data.k,
    p=config.data.p,
    sep = 17,
)

In [None]:
data_samplers['mws'].sample()

In [11]:
n_train, n_test, num_tokens = (
    config.data.n_train,
    config.data.n_test,
    config.data.num_tokens,
)

# --- MIXED BATCH SAMPLING ---
task_names = list(data_samplers.keys())
n_tasks = len(task_names)

n_train_each = n_train // n_tasks
n_test_each = n_test // n_tasks

mixed_train = []
mixed_test = []

for name, sampler in data_samplers.items():
    data = sampler.sample(
        num_samples=n_train_each + n_test_each,
        num_tokens=num_tokens,
    )
    train_part = data[:n_train_each, :]
    test_part = data[n_train_each:, :]
    mixed_train.append(train_part)
    mixed_test.append(test_part)

train_data = torch.cat(mixed_train, dim=0)
test_data = torch.cat(mixed_test, dim=0)

In [22]:
perm = torch.randperm(train_data.size(0))
train_data = train_data[perm]
perm = torch.randperm(test_data.size(0))
test_data = test_data[perm]

In [None]:
config.