<h1> Configuration </h1>

In [9]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd


In [8]:
class TrainingConfig:
    def __init__(self, config_path: str = 'config.json5'):
        print(f"Loading configuration from: {config_path}")
        with open(config_path, 'r') as f:
            config_data = json5.load(f)

        for key, value in config_data.items():
            setattr(self, key, value)
        
        self.__post_init__()

    def __post_init__(self):
        """Initializes computed fields and directories after loading."""
        if isinstance(self.antenna_indices, int):
            self.antenna_indices = [self.antenna_indices]
        
        if not self.antenna_indices:
            self.antenna_indices = list(range(4))
            print("WARNING: 'antenna_indices' not specified, using all 4 antennas.")

        self.input_channels = len(self.antenna_indices)
        
        if self.input_channels == 1:
            run_name_prefix = f'single_antenna_{self.antenna_indices[0]}'
        elif self.input_channels == 4:
            run_name_prefix = 'all_antennas'
        else:
            antennas_str = '_'.join(map(str, sorted(self.antenna_indices)))
            run_name_prefix = f'custom_antennas_{antennas_str}'
            
        self.current_run_name = f'{run_name_prefix}'

        self.current_checkpoint_dir = os.path.join(self.checkpoint_dir_base, self.current_run_name)
        self.current_log_dir = os.path.join(self.log_dir_base, self.current_run_name)
        self.dataset_path = os.path.join(self.base_data_dir, self.dataset_name, "dataset")

        os.makedirs(self.current_checkpoint_dir, exist_ok=True)
        os.makedirs(self.current_log_dir, exist_ok=True)
        os.makedirs(self.dataset_path, exist_ok=True)

        final_decoder_out_channels = self.input_channels
        self.decoder_conv_transpose_configs.append(
            {"out_channels": final_decoder_out_channels, "kernel_size": (1, 1), "stride": (1, 1), "padding": 0, "activation": "Sigmoid"}
        )

    @staticmethod
    def get_activation(activation_name: str) -> nn.Module:
        
        if activation_name is None:
            return nn.Identity()

        try:
            activation_class = getattr(nn, activation_name)
            return activation_class()
        except AttributeError:
            raise ValueError(f"Activation function not found in torch: '{activation_name}'")

    def __repr__(self):
        return f"TrainingConfig(run_name='{self.current_run_name}', device='{self.device}')"
