In [19]:
import logging
import os

import numpy as np
from model_LIT_CODING import LITCodingModel
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from utils.model_utils import SimulatedDataModule
import torch
import pytorch_lightning as pl


import matplotlib.pyplot as plt
import matplotlib
import yaml

In [20]:
n_tbins = 1024
k = 4
rep_freq = 5 * 1e6
rep_tau = 1. / rep_freq

sigma = 10
counts = torch.linspace(10 ** 2, 10 ** 6, 10)
sbr = torch.linspace(0.1, 10.0, 10)

yaml_file = 'best_hyperparameters_2.yaml'

In [21]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",  # Directory to save the model
    filename="coded_model",  # Base name for the checkpoint files
    save_top_k=1,  # Save only the best model
    monitor="val_loss",  # Metric to monitor
    mode="min",  # Minimize the monitored metric
)

In [22]:
with open(yaml_file, 'r') as file:
    config = yaml.safe_load(file)

    init_lr = config['init_lr']
    lr_decay_gamma = config['lr_decay_gamma']
    tv_reg = config['tv_reg']
    epochs = config['epochs']
    batch_size = config['batch_size']
    beta = config['beta']
    num_samples = config['num_samples']

In [None]:
data_module = SimulatedDataModule(n_tbins, counts, sbr, rep_tau, batch_size, num_samples=num_samples, sigma=sigma, normalize=True)
data_module.setup()

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.set_float32_matmul_precision('medium')
    print('GPU')

else:
    device = torch.device("cpu")
    print('CPU')

In [None]:
pl.seed_everything(42)

In [15]:
logger = CSVLogger("tb_logs", name="my_model")

In [None]:
trainer = pl.Trainer(logger=logger, max_epochs=epochs,
                          log_every_n_steps=250, val_check_interval=0.50,
                          callbacks=[checkpoint_callback])

In [None]:
lit_model = LITCodingModel(k=k, n_tbins=n_tbins, init_lr=init_lr, lr_decay_gamma=lr_decay_gamma,
                            beta=beta, tv_reg=tv_reg)
torch.autograd.set_detect_anomaly(True)

In [None]:
trainer.fit(lit_model, datamodule=data_module)