In [None]:
%load_ext autoreload  
%autoreload 2 

In [None]:
import ffcv
from ffcv.writer import DatasetWriter
from ffcv.fields import IntField, NDArrayField, FloatField
import datasets
from subset_active_learning.subset_selection import select, preprocess
import wandb
import numpy as np
import torch

In [None]:
import psutil

# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
processed_ds = preprocess.preprocess_sst2("google/electra-small-discriminator")

In [None]:
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
# Convert HF dataset into Torch dataset for ffcv support
class InMemorySST2(torch.utils.data.Dataset):
    def __init__(self, hf_ds):
        self.in_memory_ds = []
        for row in hf_ds: 
            self.in_memory_ds.append(row)
    
    def __getitem__(self, i):
        return self.in_memory_ds[i]
    
    def __len__(self):
        return len(self.in_memory_ds)

In [None]:
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
train_ds = InMemorySST2(hf_ds=processed_ds["train"])

In [None]:
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
valid_ds = InMemorySST2(hf_ds=processed_ds["validation"])

In [None]:
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
test_ds = InMemorySST2(hf_ds=processed_ds["test"])

In [None]:
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
import time


class BatchSizeComparisonRun: 
    def __init__(self, train_ds: datasets.Dataset, valid_ds: datasets.Dataset, test_ds: datasets.Dataset, seed: int):
        self.train_ds,self.valid_ds,self.test_ds = train_ds,valid_ds,test_ds
        self.seed = seed

    def one_run(self, wandb_tag: str, config: select.SubsetTrainingArguments):
        wandb_run = wandb.init(project="subset-search-gpu-opt", entity="johnny-gary", tags=[wandb_tag, self.seed])
        wandb.log({"batch_size": config.batch_size})
        subset_trainer = select.SubsetTrainer(
            params=config, valid_ds=self.valid_ds, test_ds=self.test_ds
        )
        start_time = time.time()
        subset_trainer.train_one_step(subset=self.train_ds, calculate_test_accuracy=True)
        wandb.log({"run_time": round(time.time() - start_time, 2)})
        wandb_run.finish()

    def run_comparison(self, small_batch_config: select.SubsetTrainingArguments, large_batch_config: select.SubsetTrainingArguments): 
        """
        - train small batch size until early stopping
        - train large batch size until early stopping
        """
        self.one_run(wandb_tag=f"small_batch_{small_batch_config.batch_size}", config=small_batch_config)
        self.one_run(wandb_tag=f"large_batch_{large_batch_config.batch_size}", config=large_batch_config)

In [None]:
############# In Memory Experiments ###############

for seed in range(42, 47):
    train_ds = processed_ds["train"].shuffle(seed=seed).select(range(100))
    batch_size_comparison = BatchSizeComparisonRun(train_ds=train_ds, valid_ds=processed_ds["validation"], test_ds=processed_ds["test"], seed=seed)
    batch_size_comparison.on_run

In [None]:
INCREASE_FACTOR = 4
small_batch_config = select.SubsetTrainingArguments(batch_size=8, learning_rate=1e-5)
large_batch_config = select.SubsetTrainingArguments(batch_size=small_batch_config.batch_size*INCREASE_FACTOR, learning_rate=small_batch_config.learning_rate*(3/4))

In [None]:
for seed in range(42, 47):
    train_ds = processed_ds["train"].shuffle(seed=seed).select(range(100))
    batch_size_comparison = BatchSizeComparisonRun(train_ds=train_ds, valid_ds=processed_ds["validation"], test_ds=processed_ds["test"])
    batch_size_comparison.run_comparison(small_batch_config=small_batch_config, large_batch_config=large_batch_config)