Skip to content

Commit

Permalink
added ELBO_IQ as originally implemented, corected scaling in loss (no…
Browse files Browse the repository at this point in the history
… sps), added namedtuple at output
  • Loading branch information
VincentLauinger authored and noc0lour committed Apr 5, 2024
1 parent cddbddd commit a3ff3d3
Showing 1 changed file with 262 additions and 6 deletions.
268 changes: 262 additions & 6 deletions src/mokka/equalizers/adaptive/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..torch import Butterfly2x2
import torch

from collections import namedtuple


class CMA(torch.nn.Module):
"""
Expand Down Expand Up @@ -225,13 +227,88 @@ def ELBO_DP(

# We compute B without constants
# B_tilde = -N * torch.log(C)
loss = torch.sum(A) + torch.sum((N - L + 1) / sps * torch.log(C + 1e-8))
var = C / (N - L + 1) * sps
#bias1 = 1e-8
loss = torch.sum(A) + (N - L + 1) * torch.sum(torch.log(C + bias)) #/ sps
var = C / (N - L + 1) #* sps #N
return loss, var


##############################################################################################

def ELBO_DP_IQ(y, q, sps, amp_levels, h_est, p_amps=None):
"""
Calculate dual-pol. ELBO loss for arbitrary complex constellations.
Instead of splitting into in-phase and quadrature we can just
the whole thing.
This implements the dual-polarization case.
"""
# Input is a sequence y of length N
N = y.shape[1]
h = h_est
pol = 2 # dual-polarization
# Now we have two polarizations in the first dimension
# We assume the same transmit constellation for both, calculating
# q needs to be shaped 2 x N x M -> for each observation on each polarization we have M q-values
# we have M constellation symbols
L = h.shape[-1]
L_offset = (L - 1) // 2
if p_amps is None:
p_amps = (
torch.ones_like(amp_levels) / amp_levels.shape[0]
)



# # Precompute E_Q{c} = sum( q * c) where c is x and |x|**2
E_Q_x = torch.zeros(2,2,N, device=q.device, dtype=torch.float32)
Var = torch.zeros(2,N, device=q.device, dtype=torch.float32)
num_lev = amp_levels.shape[0]
E_Q_x[:,0,::sps] = torch.sum(q[:,:,:num_lev] * amp_levels.unsqueeze(0).unsqueeze(0), dim=-1)#.permute(0,2,1)
E_Q_x[:,1,::sps] = torch.sum(q[:,:,num_lev:] * amp_levels.unsqueeze(0).unsqueeze(0), dim=-1)#.permute(0,2,1)
Var[:,::sps] = torch.add( # Precompute E_Q{|x|^2}
torch.sum(q[:,:,:num_lev] * (amp_levels**2).unsqueeze(0).unsqueeze(0), dim=-1),
torch.sum(q[:,:,num_lev:] * (amp_levels**2).unsqueeze(0).unsqueeze(0), dim=-1)
)
Var[:,::sps] -= torch.sum(E_Q_x[:,:,::sps]**2, dim=1)
p_amps = p_amps.repeat(2)

h_absq = torch.sum(h**2, dim=2)

D_real = torch.zeros(2,N-2*L_offset, device=q.device, dtype=torch.float32)
D_imag = torch.zeros(2,N-2*L_offset, device=q.device, dtype=torch.float32)
E = torch.zeros(2, device=q.device, dtype=torch.float32)
idx = torch.arange(2*L_offset,N)
nm = idx.shape[0]

for j in range(2*L_offset+1): # h[chi,nu,c,k]
D_real += h[:,0,0:1,j].expand(-1,nm) * E_Q_x[0,0:1,idx-j].expand(pol,-1) - h[:,0,1:2,j].expand(-1,nm) * E_Q_x[0,1:2,idx-j].expand(pol,-1) \
+ h[:,1,0:1,j].expand(-1,nm) * E_Q_x[1,0:1,idx-j].expand(pol,-1) - h[:,1,1:2,j].expand(-1,nm) * E_Q_x[1,1:2,idx-j].expand(pol,-1)
D_imag += h[:,0,1:2,j].expand(-1,nm) * E_Q_x[0,0:1,idx-j].expand(pol,-1) + h[:,0,0:1,j].expand(-1,nm) * E_Q_x[0,1:2,idx-j].expand(pol,-1) \
+ h[:,1,1:2,j].expand(-1,nm) * E_Q_x[1,0:1,idx-j].expand(pol,-1) + h[:,1,0:1,j].expand(-1,nm) * E_Q_x[1,1:2,idx-j].expand(pol,-1)
Var_sum = torch.sum(Var[:,idx-j], dim=-1)
E += h_absq[:,0,j] * Var_sum[0] + h_absq[:,1,j] * Var_sum[1]


# Term A - sum all the things, but spare the first dimension, since the two polarizations
# are sorta independent
bias = 1e-14
A = torch.sum(
q[:,L_offset:-L_offset,:] * torch.log((q[:,L_offset:-L_offset,:] / p_amps.unsqueeze(0).unsqueeze(0)) + bias),
dim=(1, 2),
)# Limit the length of y to the "computable space" because y depends on more past values than given
# We try to generate the received symbol sequence with the estimated symbol sequence
C = torch.sum(
y[:, L_offset:-L_offset].real ** 2 + y[:, L_offset:-L_offset].imag ** 2, axis=1
)
C += -2*torch.sum( y[:, L_offset:-L_offset].real*D_real + y[:, L_offset:-L_offset].imag*D_imag, dim=1) + torch.sum( D_real**2 + D_imag**2, dim=1) + E

# We compute B without constants
# B_tilde = -N * torch.log(C)
loss = torch.sum(A) + (N - L + 1) * torch.sum(torch.log(C + 1e-8)) #/ sps
var = C / (N - L + 1) #* sps
return loss, var

##############################################################################################

class VAE_LE_DP(torch.nn.Module):
"""
Expand Down Expand Up @@ -384,7 +461,186 @@ def forward(self, y):
p_constellation=self.demapper.p_symbols,
IQ_separate=self.IQ_separate,
)

# print("noise_sigma: ", self.demapper.noise_sigma)
loss.backward()
self.optimizer.step()
#self.optimizer_var.step()
self.optimizer.zero_grad()
#self.optimizer_var.zero_grad()

if self.var_from_estimate == True:
self.demapper.noise_sigma = torch.clamp(
torch.sqrt(torch.mean(var.detach().clone())/2), min=torch.tensor(0.05, requires_grad=False, device=q_hat.device) , max=2*self.demapper.noise_sigma.detach().clone() #torch.sqrt(var).detach()), min=0.1
)

output_symbols = y_symb[
:, : self.block_size
] # - self.butterfly_forward.num_taps // 2]
# logger.debug("VAE LE num output symbols: %s", output_symbols.shape[1])
out.append(
output_symbols
) # out.append(y_symb[:,:num_samps-self.butterfly_forward.num_taps +1])

output_q = q_hat[
:, : self.block_size, :
]
out_q.append(
output_q
)

#print("loss: ", loss, "\t\t\t var: ", var)
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps // 2 :])

if self.requires_q == True:
eq_out = namedtuple("eq_out", ["y", "q", "var", "loss"])
return eq_out(torch.cat(out, axis=1), torch.cat(out_q, axis=1), var, loss)
return torch.cat(out, axis=1)

def update_lr(self, new_lr):
self.lr = new_lr
for group in self.optimizer.param_groups:
group["lr"] = self.lr

def update_var(self, new_var):
self.demapper.noise_sigma = new_var

##############################################################################################

class VAE_LE_DP_IQ(torch.nn.Module):
"""
Class that can be dropped in to perform equalization as in ...
"""

def __init__(
self,
num_taps_forward,
num_taps_backward,
demapper,
sps,
block_size=200,
lr=0.5e-2,
requires_q=False,
var_from_estimate=False,
device='cpu'
):
super(VAE_LE_DP_IQ, self).__init__()

self.register_buffer("block_size", torch.as_tensor(block_size))
self.register_buffer("sps", torch.as_tensor(sps))
self.register_buffer("start_lr", torch.as_tensor(lr))
self.register_buffer("lr", torch.as_tensor(lr))
self.register_buffer("num_taps_forward", torch.as_tensor(num_taps_forward))
self.register_buffer("num_taps_backward", torch.as_tensor(num_taps_backward))
self.register_buffer("requires_q", torch.as_tensor(requires_q))
self.register_buffer("var_from_estimate", torch.as_tensor(var_from_estimate))
self.butterfly_forward = Butterfly2x2(
num_taps=num_taps_forward, trainable=True, timedomain=True, device=device
)
pol = 2 # dual-polarization
self.h_est = torch.zeros([pol,pol,2,num_taps_backward]) # initialize estimated impulse response
self.h_est[0,0,0,num_taps_backward//2+1], self.h_est[1,1,0,num_taps_backward//2+1] = 1, 1
self.demapper = demapper
self.optimizer = torch.optim.Adam(
self.butterfly_forward.parameters(),
lr=self.start_lr, # 0.5e-2,
)
self.optimizer.add_param_group({"params": self.h_est})

self.optimizer_var = torch.optim.Adam(
[self.demapper.noise_sigma],
lr=0.5, # 0.5e-2,
)

def reset(self):
self.lr = self.start_lr.clone()
self.butterfly_forward = Butterfly2x2(
num_taps=self.num_taps_forward.item(), trainable=True, timedomain=True, device=self.butterfly_forward.taps.device
)
pol = 2 # dual-polarization
self.h_est = torch.zeros([pol,pol,2,self.num_taps_backward]) # initialize estimated impulse response
self.h_est[0,0,0,self.num_taps_backward//2+1], self.h_est[1,1,0,self.num_taps_backward//2+1] = 1, 1
self.optimizer = torch.optim.Adam(
self.butterfly_forward.parameters(),
lr=self.lr,
)
self.optimizer.add_param_group({"params": self.h_est})

def forward(self, y):
# We need to produce enough q values on each forward pass such that we can
# calculate the ELBO loss in the backward pass & update the taps

num_samps = y.shape[1]
# samples_per_step = self.butterfly_forward.num_taps + self.block_size

out = []
out_q = []
# We start our loop already at num_taps (because we cannot equalize the start)
# We will end the loop at num_samps - num_taps - sps*block_size (safety, so we don't overrun)
# We will process sps * block_size - 2 * num_taps because we will cut out the first and last block

index_padding = (self.butterfly_forward.num_taps - 1) // 2
for i, k in enumerate(
range(
index_padding,
num_samps
- index_padding
- self.sps
* self.block_size, # Back-off one block-size + filter_overlap from end to avoid overrunning
self.sps * self.block_size,
)
):
# if i % (20000//self.block_size) == 0 and i != 0:
# print("Updating learning rate")
# self.update_lr(self.lr * 0.5)
# logger.debug("VAE LE block: %s", i)
in_index = torch.arange(
k - index_padding,
k + self.sps * self.block_size + index_padding,
)
# Equalization will give sps * block_size samples (because we add (num_taps - 1) in the beginning)
y_hat = self.butterfly_forward(y[:, in_index], "valid")

# We downsample so we will have floor(((sps * block_size - num_taps + 1) / sps) = floor(block_size - (num_taps - 1)/sps)
y_symb = y_hat[
:, 0 :: self.sps
] # ---> y[0,(self.butterfly_forward.num_taps + 1)//2 +1 ::self.sps]

q_hat = torch.cat(
(
torch.cat(
(
self.demapper(y_symb[0, :].real).unsqueeze(0),
self.demapper(y_symb[0, :].imag).unsqueeze(0),
), axis=-1
),
torch.cat(
(
self.demapper(y_symb[1, :].real).unsqueeze(0),
self.demapper(y_symb[1, :].imag).unsqueeze(0),
), axis=-1
),
), axis=0
)

# We calculate the loss with less symbols, since the forward operation with "valid"
# is missing some symbols
# We assume the symbol of interest is at the center tap of the filter
y_index = in_index[
(self.butterfly_forward.num_taps - 1)
// 2 : -((self.butterfly_forward.num_taps - 1) // 2)
]
loss, var = ELBO_DP_IQ(
#loss, var = ELBO_DP(
y[:, y_index],
q_hat,
self.sps,
self.demapper.constellation,
self.h_est,
p_amps=self.demapper.p_symbols
#p_constellation=self.demapper.p_symbols
)

# print("noise_sigma: ", self.demapper.noise_sigma)
loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -412,9 +668,9 @@ def forward(self, y):
out_q.append(output_q)
# out.append(y_symb[:, self.block_size - self.butterfly_forward.num_taps // 2 :])
if self.requires_q == True:
return torch.cat(out, axis=1), torch.cat(out_q, axis=1)
else:
return torch.cat(out, axis=1)
eq_out = namedtuple("eq_out", ["y", "q", "var", "loss"])
return eq_out(torch.cat(out, axis=1), torch.cat(out_q, axis=1), var, loss)
return torch.cat(out, axis=1)

def update_lr(self, new_lr):
self.lr = new_lr
Expand Down

0 comments on commit a3ff3d3

Please sign in to comment.