<h1> Configuration </h1>

In [2]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd


In [None]:
class TrainingConfig:
    def __init__(self, config_path: str = 'config.json5'):
        print(f"Loading configuration from: {config_path}")
        with open(config_path, 'r') as f:
            config_data = json5.load(f)

        for key, value in config_data.items():
            setattr(self, key, value)
        
        
        self.dataset_path = os.path.join(self.base_data_dir, self.dataset_name)
        #os.makedirs(self.checkpoint_dir_base,exist_ok=True)
        #os.makedirs(self.log_dir_base,exist_ok=True)
        os.makedirs(os.path.join(self.base_data_dir,self.dataset_name),exist_ok=True)


    def update_run_config(self, window_size: int, antenna_indices: list, seed: int):
        """
        Update the configuration with run-specific parameters.
        """
        self.window_size=window_size
        self.antenna_indices=antenna_indices

        if not self.antenna_indices:
            self.antenna_indices=list(range(4))
            print("WARNING: 'antenna_indices' not specified, using all 4 antennas.")

        self.input_channels=len(self.antenna_indices)

        dist_tag = self.latent_distribution[:4]

        checkpoint_base_path = f"{self.checkpoint_dir_base}_{dist_tag}"
        log_base_path = f"{self.log_dir_base}_{dist_tag}"

        run_name_prefix = f'{dist_tag}_win{self.window_size}_ant{self.antenna_indices[0]}_seed{seed}'
        self.current_run_name = run_name_prefix

        
        self.current_checkpoint_dir = os.path.join(checkpoint_base_path, self.current_run_name)
        self.current_log_dir = os.path.join(log_base_path, self.current_run_name)

        os.makedirs(self.current_checkpoint_dir, exist_ok=True)
        os.makedirs(self.current_log_dir, exist_ok=True)

 

        if self.decoder_conv_transpose_configs[-1].get("activation") == "Sigmoid":
            self.decoder_conv_transpose_configs.pop()
            
        final_decoder_out_channels = self.input_channels
        self.decoder_conv_transpose_configs.append(
            {"out_channels": final_decoder_out_channels, "kernel_size": (1, 1), "stride": (1, 1), "padding": 0, "activation": "Sigmoid"}
        )



    @staticmethod
    def get_activation(activation_name: str) -> nn.Module:
        
        if activation_name is None:
            return nn.Identity()

        try:
            activation_class = getattr(nn, activation_name)
            return activation_class()
        except AttributeError:
            raise ValueError(f"Activation function not found in torch: '{activation_name}'")

    def __repr__(self):
        run_name = getattr(self, 'current_run_name', 'N/A')
        return f"TrainingConfig(run_name='{run_name}', device='{self.device}')"
