In [None]:
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from src.small_dataset import ReviewsDataModule
from src.models import NNMemoryModel
import wandb

In [None]:
# Training Budget
epochs = 1

def main():
    wandb_logger = WandbLogger(project="Memory ML")

    config = wandb_logger.experiment.config
    data = ReviewsDataModule(config.batch_size, config.reviews_history_size)
    model = NNMemoryModel(
        learning_rate=config.learning_rate,
        reviews_history_size=config.reviews_history_size,
    )

    trainer = L.Trainer(max_epochs=epochs, val_check_interval=0.2, logger=wandb_logger)
    trainer.fit(model, data)
    trainer.test(model, data)


sweep_config = {
    "method": "random",
    "metric": {"goal": "minimize", "name": "test_loss"},
    "parameters": {
        "batch_size": {"values": [32, 64, 128, 256]},
        "reviews_history_size": {"values": [4, 8, 16, 32]},
        "learning_rate": {"min": 1e-6, "max": 1e-3, "distribution": "log_uniform_values"},
    },
}

sweep_id = wandb.sweep(sweep_config, project="Memory ML")
wandb.agent(sweep_id, function=main, count=10)