In [None]:
# %%
# !pip install DeepMIMO==4.0.0b10

# %%
# %%
# =============================================================================
# 1. IMPORTS AND WARNINGS SETUP
#    - Load necessary PyTorch modules, utilities, and suppress UserWarnings
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import torch
from tqdm import tqdm
import math
# from utils import (generate_channels_and_labels, tokenizer_train, tokenizer, make_sample, nmse_loss,
                #    create_train_dataloader, patch_maker, count_parameters, train_lwm)
from collections import defaultdict
import numpy as np
# import pretrained_model  # Assuming this contains the LWM model definition
import matplotlib.pyplot as plt
import warnings
import os
import bisect
# from collections import defaultdict
from tqdm import tqdm
warnings.filterwarnings("ignore", category=UserWarning)
# from utils import *
import deepmimo as dm
from sklearn.metrics import mean_squared_error
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from models import GPTPathDecoder
from dataset.dataloaders import MySeqDataLoader
from utils.utils import *

In [None]:
# %%
# scenario = 'city_89_nairobi_3p5'
scenario = 'city_0_newyork_3p5'

dm.download(scenario)
dataset = dm.load(scenario, )

# %%
dataset.scene.plot()


# %%
dm.info()


# %%
config = {
    "BATCH_SIZE":64,
    "PAD_VALUE": 500,
    "USE_WANDB": False,
    "LR":2e-5,
    "epochs" : 100,
    "interaction_weight": 0.01,  # Weight for interaction loss
    "experiment": f"{scenario}_interacaction_all_inter_str_dec_all_aod"
}


# %%
train_data  = MySeqDataLoader(dataset, train=True, split_by="user", sort_by="power")

train_loader = torch.utils.data.DataLoader(
    dataset     = train_data,
    batch_size  = config['BATCH_SIZE'],
    shuffle     = True,
    collate_fn= train_data.collate_fn
    )
val_data  = MySeqDataLoader(dataset, train=False, split_by="user", sort_by="power")
val_loader = torch.utils.data.DataLoader(
    dataset     = val_data,
    batch_size  = config['BATCH_SIZE'],
    shuffle     = False,
    collate_fn= val_data.collate_fn
    )

for item in train_loader:
    print(f"Prompt shape: {item[0].shape}, Paths shape: {item[1].shape}, Num paths shape: {item[2].shape}")
    
    break


# %%
print("No. of Train Points   : ", train_data.__len__())
print("Batch Size           : ", config["BATCH_SIZE"])
print("Train Batches        : ", train_loader.__len__())
print("No. of Train Points   : ", val_data.__len__())
print("Val Batches          : ", val_loader.__len__())

# %%



In [None]:




def evaluate_model(model, val_loader, max_generate=26, log_to_wandb=False):
    model.eval()

    delay_errors = []
    power_errors = []
    phase_errors = []
    path_length_rmses = []



    delay_maes = []
    power_maes = []
    phase_maes = []
    path_length_maes = []
    

    # AoA metrics
    az_errors = []
    ch_nmse_errors = []
    el_errors = []
    az_maes = []
    el_maes = []
    with torch.no_grad():
        outer_bar = tqdm(val_loader, desc="Evaluating (batches)", leave=True)

        for prompts, paths, path_lengths,interactions in outer_bar:
            prompts = prompts.cuda()
            paths = paths.cuda()
            path_lengths = path_lengths.cuda()
         
            # Inner tqdm to show per-sample progress
            inner_bar = tqdm(range(prompts.size(0)), 
                             desc="   Processing samples", 
                             leave=False)


            for b in inner_bar:
                generated, path_lengths_pred, inter_str_pred = generate_paths(model, prompts[b], max_steps=max_generate)

                generated = generated.cuda()
                # ground truth: delay, power, phase, aoa_az, aoa_el
                gt = paths[b][1:, :5]

                # Mask padded values
                valid_mask = (gt[:,0] != train_data.pad_value)
                gt = gt[valid_mask]
                # print(valid_mask)
                T = min(len(gt), len(generated))
                pred = generated[:T]
                gt = gt[:T]

                # ---- Compute Metrics ----
                delay_rmse = torch.mean((pred[:,0] - gt[:,0])**2).sqrt().item()
                delay_mae = torch.mean(torch.abs(pred[:,0] - gt[:,0])).item()

                power_rmse = torch.mean((pred[:,1]/0.01 - gt[:,1]/0.01)**2).sqrt().item()
                power_mae = torch.mean((torch.abs(pred[:,1]/0.01 - gt[:,1]/0.01))).item()


                # Phase errors
                y_hat_angles = (pred[:,2] / (np.pi/180))
                y_angles = (gt[:,2] / (np.pi/180))
                phase_circular_dist = (y_hat_angles - y_angles + 180) % 360 - 180
                phase_rmse = torch.mean(phase_circular_dist**2).sqrt().item()
                phase_mae = torch.mean(torch.abs(phase_circular_dist)).item()

                # AoA azimuth errors
                y_hat_az = (pred[:,3] / (np.pi/180))
                y_az = (gt[:,3] / (np.pi/180))
                az_circular_dist = (y_hat_az - y_az + 180) % 360 - 180
                az_rmse = torch.mean(az_circular_dist**2).sqrt().item()
                az_mae = torch.mean(torch.abs(az_circular_dist)).item()

                # AoA elevation errors
                y_hat_el = (pred[:,4] / (np.pi/180))
                y_el = (gt[:,4] / (np.pi/180))
                el_circular_dist = (y_hat_el - y_el + 180) % 360 - 180
                el_rmse = torch.mean(el_circular_dist**2).sqrt().item()
                el_mae = torch.mean(torch.abs(el_circular_dist)).item()

                # Path length RMSE
                # print(path_lengths_pred, path_lengths[b],)
                length_rmse = (torch.mean( (path_lengths_pred - path_lengths[b])**2)).sqrt().item()
                length_mae = (torch.mean(torch.abs(path_lengths_pred - path_lengths[b]))).item()

                power_pred = np.expand_dims(pred[:,1].cpu().numpy(),0)
                delay_pred = np.expand_dims(pred[:,0].cpu().numpy(),0)

                pred_power_linear = 10**( ((power_pred)/0.01)/10)
                pred_delay_secs = delay_pred/ 1e6


                delay_t = np.expand_dims(gt[:, 0].cpu().numpy(),0)
                power_t = np.expand_dims(gt[:, 1].cpu().numpy(),0)
                phase = np.expand_dims(gt[:, 2].cpu().numpy(),0)
                az = np.expand_dims(gt[:, 3].cpu().numpy(),0)
                el = np.expand_dims(gt[:, 4].cpu().numpy(),0)

                power_t = np.where(power_t==config["PAD_VALUE"], 0, power_t)
                power_linear = 10**( (power_t/0.01)/10)

                phase_pred = np.expand_dims( pred[:,2].cpu().numpy(),0)
                az_pred = np.expand_dims(pred[:,3].cpu().numpy(),0)
                el_pred = np.expand_dims(pred[:,4].cpu().numpy(),0)

                delay_secs = delay_t/ 1e6
                mask= delay_secs == config["PAD_VALUE"]/ 1e6
                delay_secs = np.where(mask, np.nan, delay_secs)
                phase = np.where(mask, np.nan, phase)
                pred_delay_secs = np.where(mask, np.nan, pred_delay_secs)
                pred_power_linear = np.where(mask, np.nan, pred_power_linear)
                phase_pred = np.where(mask, np.nan, phase_pred)
                az_pred = np.where(mask, np.nan, az_pred)
                el_pred = np.where(mask, np.nan, el_pred)


                power_linear = np.where(mask, np.nan, power_linear)
                phase = np.where(mask, np.nan, phase)
                az = np.where(mask, np.nan, az)
                el = np.where(mask, np.nan, el)


                
                
                predicted_channels = mycomputer.compute_channels(pred_power_linear,pred_delay_secs, phase_pred, az_pred, el_pred,kwargs=None  )
                gt_channels = mycomputer.compute_channels(power_linear,delay_secs, phase, az, el ,kwargs=None )

            


                ch_nmse = compute_channel_nmse(predicted_channels, gt_channels)
                db_min, db_max = -20.0, 0.0
                normalized = (ch_nmse - db_min) / (db_max - db_min)
                score = 1.0 - normalized
                ch_nmse = max(0.0, min(1.0, score))
                
                ch_nmse_errors.append(ch_nmse)
                # Save metrics
                delay_errors.append(delay_rmse)
                power_errors.append(power_rmse)
                phase_errors.append(phase_rmse)
                path_length_rmses.append(length_rmse)
                # AoA
                az_errors.append(az_rmse)
                el_errors.append(el_rmse)

                delay_maes.append(delay_mae)
                power_maes.append(power_mae)
                phase_maes.append(phase_mae)
                path_length_maes.append(length_mae)
                az_maes.append(az_mae)
                el_maes.append(el_mae)
                # Show live metric values in tqdm
                inner_bar.set_postfix({
                    "delay_rmse": f"{delay_rmse:.3f}",
                    "power_rmse": f"{power_rmse:.3f}",
                    "phase_rmse": f"{phase_rmse:.3f}",
                    "az_rmse": f"{az_rmse:.3f}",
                    "el_rmse": f"{el_rmse:.3f}",
                    "ch_nmse": f"{ch_nmse:.3f}",
                    "length_rmse": f"{length_rmse:.3f}",
                    "delay_mae": f"{delay_mae:.3f}",
                    "power_mae": f"{power_mae:.3f}",
                    "phase_mae": f"{phase_mae:.3f}",
                    "az_mae": f"{az_mae:.3f}",
                    "el_mae": f"{el_mae:.3f}",
                    "length_mae": f"{length_mae:.3f}"
                })

                # wandb logging
                if log_to_wandb:
                    wandb.log({
                        "test_delay_rmse": delay_rmse,
                        "test_power_rmse": power_rmse,
                        "test_phase_circ_err": phase_rmse,
                        "test_stop_length_rmse": length_rmse,
                        "test_az_rmse": az_rmse,
                        "test_el_rmse": el_rmse,
                        "test_delay_mae": delay_mae,
                        "test_power_mae": power_mae,
                        "test_phase_circ_err_mae": phase_mae,
                        "test_az_mae": az_mae,
                        "test_el_mae": el_mae,
                        "test_stop_length_mae": length_mae,
                    })
            # print("Batch evaluation complete.")
            
            # print("\n================= Up toBATCH EVALUATION RESULTS =================")
            # print(f"Avg Delay RMSE           : {np.mean(delay_errors):.4f} µs")
            # print(f"Avg Power RMSE           : {np.mean(power_errors):.4f} dB")
            # print(f"Avg Phase RMSE           : {np.mean(phase_errors):.4f} degrees")
            # print(f"Avg Path Length RMSE     : {np.mean(path_length_rmses):.4f}")
            print(f"Avg ch_score_nsme    : {np.mean(ch_nmse_errors):.4f}")

            # print(f"Avg Delay MAE           : {np.mean(delay_maes):.4f} µs")
            # print(f"Avg Power MAE           : {np.mean(power_maes):.4f} dB")
            # print(f"Avg Phase MAE           : {np.mean(phase_maes):.4f} degrees")
            # print(f"Avg Path Length MAE     : {np.mean(path_length_maes):.4f}")
            # print("============================================================")
            

    # ---- Final Aggregated Results ----
    avg_delay = np.mean(delay_errors)
    avg_power = np.mean(power_errors)
    avg_phase = np.mean(phase_errors)
    avg_az = np.mean(az_errors) if len(az_errors) > 0 else 0.0
    avg_el = np.mean(el_errors) if len(el_errors) > 0 else 0.0
    avg_path_length_rmse = np.mean(path_length_rmses)
    avg_ch_nmse = np.mean(ch_nmse_errors)
    avg_delay_mae = np.mean(delay_maes)
    avg_power_mae = np.mean(power_maes)
    avg_phase_mae = np.mean(phase_maes)
    avg_path_length_mae= np.mean(path_length_maes)

    print("\n=================  Final EVALUATION RESULTS =================")
    print(f"Delay RMSE           : {avg_delay:.4f} µs")
    print(f"Power RMSE           : {avg_power:.4f} dB")
    print(f"Phase RMSE           : {avg_phase:.4f} degrees")
    print(f"AoA Azimuth RMSE     : {avg_az:.4f} degrees")
    print(f"AoA Elevation RMSE   : {avg_el:.4f} degrees")
    print(f"Path Length RMSE     : {avg_path_length_rmse:.4f}")
    print(f"avg_ch_nmse          :       {avg_ch_nmse:.4f}")   
    print(f"Delay MAE           : {avg_delay_mae:.4f} µs")
    print(f"Power MAE           : {avg_power_mae:.4f} dB")
    print(f"Phase MAE           : {avg_phase_mae:.4f} degrees")
    print(f"AoA Azimuth MAE     : {np.mean(az_maes) if len(az_maes)>0 else 0.0:.4f} degrees")
    print(f"AoA Elevation MAE   : {np.mean(el_maes) if len(el_maes)>0 else 0.0:.4f} degrees")
    print(f"Path Length MAE     : {avg_path_length_mae:.4f}")
    print("=====================================================\n")

    if log_to_wandb:
        wandb.run.summary["test_delay_rmse"] = avg_delay
        wandb.run.summary["test_power_rmse"] = avg_power
        wandb.run.summary["test_phase_circ_err"] = avg_phase
        wandb.run.summary["test_path_length_rmse"] = avg_path_length_rmse
        
        wandb.run.summary["test_delay_mae"] = avg_delay_mae
        wandb.run.summary["test_power_mae"] = avg_power_mae
        wandb.run.summary["test_phase_circ_err_mae"] = avg_phase_mae
        wandb.run.summary["test_path_length_mae"] = avg_path_length_mae

    return avg_delay, avg_power, avg_phase, avg_az, avg_el, avg_path_length_rmse, avg_delay_mae, avg_power_mae, avg_phase_mae, avg_path_length_mae 



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = PathDecoder().to(device)
model = GPTPathDecoder().to(device)

print("Total trainable parameters:", count_parameters(model))


# %%
if config["USE_WANDB"]:
    import wandb

    wandb.init(
        project="deepmimo-path-decoder",
        config = config
        # config={
        #     "batch_size": train_loader.batch_size,
        #     "split_type": train_data.split_by,
        # }
    )


# %%


optimizer = torch.optim.AdamW(model.parameters(), lr=config["LR"])
# scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, mode="min")
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=25,      # Restart every 10 epochs
    T_mult=1,    # Double the period after each restart
    eta_min=1e-8 # Minimum LR
)

# Initialize best checkpoint tracking (based on path_length loss)

# scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, mode="min")
mycomputer = MyChannelComputer()
checkpoint_path = f"{config['experiment']}_best_model_checkpoint.pth"
os.makedirs("checkpoints2", exist_ok=True)
checkpoint_path = os.path.join("checkpoints2", checkpoint_path)

def train_with_interactions(model, train_loader, val_loader, config, train_data, nmse_train=False):
    """
    Modified training loop with interaction prediction.
    """

    best_val_loss = float('inf')


    for epoch in range(config["epochs"]):
        # -------------------- TRAINING --------------------
        model.train()
        train_losses = []
        train_loss_delay = []
        train_loss_power = []
        train_loss_phase = []
        train_loss_path_length = []
        train_loss_interaction = []  # NEW
        train_path_length_rmse = []
        train_loss_az = []
        train_loss_el = []
        train_ch_nmse = []
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False)
        for prompts, paths, path_lengths, interactions in pbar:  # NEW: added interactions
            prompts = prompts.cuda()
            paths = paths.cuda()
            path_lengths = path_lengths.cuda()
            interactions = interactions.cuda()  # NEW
            
            paths_in = paths[:, :-1, :]
            interactions_in = interactions[:, :-1, :]

            paths_out = paths[:, 1:, :]
            interactions_out = interactions[:, 1:, :]  # NEW: shift targets

            (delay_pred, power_pred, phase_sin_pred, phase_cos_pred, phase_pred,
             az_sin_pred, az_cos_pred, az_pred, el_sin_pred, el_cos_pred, el_pred,
             path_length_pred, interaction_logits) = model(prompts, paths_in, interactions_in)

            (total_loss, loss_delay, loss_power, loss_phase,
             loss_az, loss_el, loss_path_length, loss_interaction) = masked_loss(
                delay_pred, power_pred, phase_sin_pred, phase_cos_pred,
                az_sin_pred, az_cos_pred, el_sin_pred, el_cos_pred,
                path_length_pred, interaction_logits, paths_out, path_lengths,
                interactions_out, pad_value=train_data.pad_value,
                interaction_weight=config.get("interaction_weight", 0.1)
            )
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            path_length_rmse = compute_stop_metrics(path_length_pred.detach().squeeze(-1), 
                                                    path_lengths)
            # ch_nmse = 0
            # if epoch >= 0:
            pred_power_linear = 10**( ((power_pred.cpu().detach().numpy())/0.01)/10)
            pred_delay_secs = delay_pred.cpu().detach().numpy()/ 1e6


            delay_t = paths_out[:, :, 0].cpu().detach().numpy()
            power_t = paths_out[:, :, 1].cpu().detach().numpy()
            phase = paths_out[:, :, 2].cpu().detach().numpy()
            az = paths_out[:, :, 3].cpu().detach().numpy()
            el = paths_out[:, :, 4].cpu().detach().numpy()
            power_t = np.where(power_t==config["PAD_VALUE"], 0, power_t)
            power_linear = 10**( (power_t/0.01)/10)

            phase_pred = phase_pred.cpu().detach().numpy()
            az_pred = az_pred.cpu().detach().numpy()
            el_pred = el_pred.cpu().detach().numpy()

            delay_secs = delay_t/ 1e6
            mask= delay_secs == config["PAD_VALUE"]/ 1e6
            delay_secs = np.where(mask, np.nan, delay_secs)
            phase = np.where(mask, np.nan, phase)
            pred_delay_secs = np.where(mask, np.nan, pred_delay_secs)
            pred_power_linear = np.where(mask, np.nan, pred_power_linear)
            phase_pred = np.where(mask, np.nan, phase_pred)
            az_pred = np.where(mask, np.nan, az_pred)
            el_pred = np.where(mask, np.nan, el_pred)


            power_linear = np.where(mask, np.nan, power_linear)
            phase = np.where(mask, np.nan, phase)
            az = np.where(mask, np.nan, az)
            el = np.where(mask, np.nan, el)


            
            
            predicted_channels = mycomputer.compute_channels(pred_power_linear,pred_delay_secs, phase_pred, az_pred, el_pred,kwargs=None  )
            gt_channels = mycomputer.compute_channels(power_linear,delay_secs, phase, az, el ,kwargs=None )

        
            

            ch_nmse = compute_channel_nmse(predicted_channels, gt_channels)
            train_ch_nmse.append(ch_nmse)
            train_losses.append(total_loss.item())
            train_loss_delay.append(loss_delay.item())
            train_loss_power.append(loss_power.item())
            train_loss_phase.append(loss_phase.item())
            train_loss_path_length.append(loss_path_length.item())
            # track aoa losses
            # if 'train_loss_az' not in locals():
            #     train_loss_az = []
            #     train_loss_el = []
            train_loss_az.append(loss_az.item())
            train_loss_el.append(loss_el.item())
            train_loss_interaction.append(loss_interaction.item())  # NEW
            train_path_length_rmse.append(path_length_rmse)
            current_lr = optimizer.param_groups[0]["lr"]
            pbar.set_postfix({
                "loss": f"{total_loss.item():.4f}",
                "delay": f"{loss_delay.item():.4f}",
                
                "power": f"{loss_power.item():.4f}",
                "phase": f"{loss_phase.item():.4f}",
                "az": f"{loss_az.item():.4f}",
                "el": f"{loss_el.item():.4f}",
                "inter": f"{loss_interaction.item():.4f}",  # NEW
                "path_rmse": f"{path_length_rmse:.4f}",
                "ch_nmse":f"{ch_nmse:.4f}",
                "lr": f"{current_lr:.2e}"
            })

        avg_train_loss = np.mean(train_losses)
        avg_train_delay = np.mean(train_loss_delay)
        avg_train_power = np.mean(train_loss_power)
        avg_train_phase = np.mean(train_loss_phase)
        avg_train_az = np.mean(train_loss_az) 
        avg_train_el = np.mean(train_loss_el)
        avg_train_ch_nmse = np.mean(train_ch_nmse)
        avg_train_path_length = np.mean(train_loss_path_length)
        avg_train_interaction = np.mean(train_loss_interaction)  # NEW
        avg_train_path_length_rmse = np.mean(train_path_length_rmse)

        # -------------------- VALIDATION --------------------
        model.eval()
        val_losses = []
        val_loss_delay = []
        val_loss_power = []
        val_loss_phase = []
        val_loss_path_length = []
        val_loss_interaction = []  # NEW
        val_path_length_rmse = []
        val_ch_nmse = []
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False)
            # prepare val aoa loss lists
            val_loss_az = []
            val_loss_el = []
            for prompts, paths, path_lengths, interactions in pbar:  # NEW
                prompts = prompts.cuda()
                paths = paths.cuda()
                path_lengths = path_lengths.cuda()
                interactions = interactions.cuda()  # NEW

                paths_in = paths[:, :-1, :]
                interactions_in = interactions[:, :-1, :]

                paths_out = paths[:, 1:, :]
                interactions_out = interactions[:, 1:, :]  # NEW: shift targets

                (delay_pred, power_pred, phase_sin_pred, phase_cos_pred, phase_pred,
                 az_sin_pred, az_cos_pred, az_pred, el_sin_pred, el_cos_pred, el_pred,
                 path_length_pred, interaction_logits) = model(prompts, paths_in, interactions_in)

                (total_loss, loss_delay, loss_power, loss_phase,
                 loss_az, loss_el, loss_path_length, loss_interaction) = masked_loss(
                    delay_pred, power_pred, phase_sin_pred, phase_cos_pred,
                    az_sin_pred, az_cos_pred, el_sin_pred, el_cos_pred,
                    path_length_pred, interaction_logits, paths_out, path_lengths,
                    interactions_out, pad_value=train_data.pad_value,
                    interaction_weight=config.get("interaction_weight", 0.1)
                )

                path_length_rmse = compute_stop_metrics(path_length_pred.detach().squeeze(-1), 
                                                       path_lengths)
                
                pred_power_linear = 10**( ((power_pred.cpu().detach().numpy())/0.01)/10)
                pred_delay_secs = delay_pred.cpu().detach().numpy()/ 1e6


                delay_t = paths_out[:, :, 0].cpu().detach().numpy()
                power_t = paths_out[:, :, 1].cpu().detach().numpy()
                phase = paths_out[:, :, 2].cpu().detach().numpy()
                az = paths_out[:, :, 3].cpu().detach().numpy()
                el = paths_out[:, :, 4].cpu().detach().numpy()
                power_t = np.where(power_t==config["PAD_VALUE"], 0, power_t)
                power_linear = 10**( (power_t/0.01)/10)

                phase_pred = phase_pred.cpu().detach().numpy()
                az_pred = az_pred.cpu().detach().numpy()
                el_pred = el_pred.cpu().detach().numpy()

                delay_secs = delay_t/ 1e6
                mask= delay_secs == config["PAD_VALUE"]/ 1e6
                delay_secs = np.where(mask, np.nan, delay_secs)
                phase = np.where(mask, np.nan, phase)
                pred_delay_secs = np.where(mask, np.nan, pred_delay_secs)
                pred_power_linear = np.where(mask, np.nan, pred_power_linear)
                phase_pred = np.where(mask, np.nan, phase_pred)
                az_pred = np.where(mask, np.nan, az_pred)
                el_pred = np.where(mask, np.nan, el_pred)


                power_linear = np.where(mask, np.nan, power_linear)
                phase = np.where(mask, np.nan, phase)
                az = np.where(mask, np.nan, az)
                el = np.where(mask, np.nan, el)


                
                
                predicted_channels = mycomputer.compute_channels(pred_power_linear,pred_delay_secs, phase_pred, az_pred, el_pred,kwargs=None  )
                gt_channels = mycomputer.compute_channels(power_linear,delay_secs, phase, az, el ,kwargs=None )

            


                ch_nmse = compute_channel_nmse(predicted_channels, gt_channels)
                val_ch_nmse.append(ch_nmse)

                val_losses.append(total_loss.item())
                val_loss_delay.append(loss_delay.item())
                val_loss_power.append(loss_power.item())
                val_loss_phase.append(loss_phase.item())
                val_loss_az.append(loss_az.item())
                val_loss_el.append(loss_el.item())
                val_loss_path_length.append(loss_path_length.item())
                val_loss_interaction.append(loss_interaction.item())  # NEW
                val_path_length_rmse.append(path_length_rmse)

                pbar.set_postfix({
                    "val_loss": f"{total_loss.item():.4f}",
                    "inter": f"{loss_interaction.item():.4f}",  # NEW
                })

        avg_val_loss = np.mean(val_losses)
        avg_val_delay = np.mean(val_loss_delay)
        avg_val_power = np.mean(val_loss_power)
        avg_val_phase = np.mean(val_loss_phase)
        avg_val_az = np.mean(val_loss_az) 
        avg_val_el = np.mean(val_loss_el)
        avg_val_ch_nmse = np.mean(val_ch_nmse)
        avg_val_path_length = np.mean(val_loss_path_length)
        avg_val_interaction = np.mean(val_loss_interaction)  # NEW
        avg_val_path_length_rmse = np.mean(val_path_length_rmse)

        # scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]["lr"]

        # -------------------- CHECKPOINT SAVING --------------------
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_loss': torch.tensor(best_val_loss),
            }, checkpoint_path)
            print(f"  ✓ Best checkpoint saved (val_loss: {best_val_loss:.4f})")

        if config.get("USE_WANDB", False):
            import wandb
            wandb.log({
                "train_loss": avg_train_loss,
                "train_loss_delay": avg_train_delay,
                "train_loss_power": avg_train_power,
                "train_loss_phase": avg_train_phase,
                "train_loss_az": avg_train_az,
                "train_loss_el": avg_train_el,
                "train_loss_path_length": avg_train_path_length,
                "train_loss_interaction": avg_train_interaction,  # NEW
                "train_path_length_rmse": avg_train_path_length_rmse,
                "avg_train_ch_nmse": avg_train_ch_nmse,
                "val_loss": avg_val_loss,
                "val_loss_delay": avg_val_delay,
                "val_loss_power": avg_val_power,
                "avg_val_ch_nmse": avg_val_ch_nmse,
                "val_loss_phase": avg_val_phase,
                "val_loss_az": avg_val_az,
                "val_loss_el": avg_val_el,
                "val_loss_path_length": avg_val_path_length,
                "val_loss_interaction": avg_val_interaction,  # NEW
                "val_path_length_rmse": avg_val_path_length_rmse,
                "epoch": epoch,
                "lr": current_lr,
            })

        print(f"\nEpoch {epoch:02d}")
        print(f"  Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"    Delay: {avg_train_delay:.4f} (val: {avg_val_delay:.4f})")
        print(f"    Power: {avg_train_power:.4f} (val: {avg_val_power:.4f})")
        print(f"    Phase: {avg_train_phase:.4f} (val: {avg_val_phase:.4f})")
        print(f"    NMSE: {avg_train_ch_nmse:.4f} (val: {avg_val_ch_nmse:.4f})")
        print(f"    Az: {avg_train_az:.4f} (val: {avg_val_az:.4f})")
        print(f"    El: {avg_train_el:.4f} (val: {avg_val_el:.4f})")

        print(f"    Interaction: {avg_train_interaction:.4f} (val: {avg_val_interaction:.4f})")  # NEW
        print(f"    PathLength: {avg_train_path_length:.4f} (val: {avg_val_path_length:.4f})")  # NEW

        print(f"  LR: {current_lr:.3e}")


# %%

# Train
# train_with_interactions(model, train_loader, val_loader, config, train_data)

# Load best checkpoint for inference/evaluation
best_epoch, best_loss = load_best_checkpoint(model, checkpoint_path=checkpoint_path)



# %%
print(evaluate_model(model, val_loader))


# %%
# show_example(model, val_loader, sample_index=24)



 

Total trainable parameters: 20006413
✓ Loaded best checkpoint from epoch 98 (val_loss: 0.6563)


Evaluating (batches):   0%|          | 0/152 [00:00<?, ?it/s]
Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5660.33it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5121.25it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4373.62it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5419.00it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 3840.94it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 3276.80it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 2966.27it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 2133.42it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4072.14it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4466.78it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4341.93it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4865.78it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5017.11it/s]

Generating channels: 100%|██

Avg ch_score_nsme    : 0.2665



Generating channels: 100%|██████████| 1/1 [00:00<00:00, 1492.10it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5289.16it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 3139.45it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6384.02it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4614.20it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4750.06it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4681.14it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 2061.08it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5825.42it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7073.03it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 3170.30it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7397.36it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7319.90it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7767.23it/s]

Generating channels

Avg ch_score_nsme    : 0.2680



Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6374.32it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7319.90it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6775.94it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7489.83it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6626.07it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5753.50it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7013.89it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7516.67it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5275.85it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4696.87it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5907.47it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6355.01it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5761.41it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5447.15it/s]

Generating channels

Avg ch_score_nsme    : 0.2433



Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6186.29it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7269.16it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4928.68it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7061.12it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6605.20it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7194.35it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5793.24it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7463.17it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6909.89it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7345.54it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7073.03it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6374.32it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5882.61it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5614.86it/s]

Generating channels

Avg ch_score_nsme    : 0.2404



Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5090.17it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7121.06it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5737.76it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6250.83it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5645.09it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5761.41it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6413.31it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 5405.03it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 4364.52it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6114.15it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6967.28it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 7423.55it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 6355.01it/s]

Generating channels: 100%|██████████| 1/1 [00:00<00:00, 2434.30it/s]

Generating channels

KeyboardInterrupt: 

In [None]:
# normalized = (db_value - db_min) / (db_max - db_min)
# score = 1.0 - normalized
# score = max(0.0, min(1.0, score))

In [None]:
print(evaluate_model(model, val_loader))