In [145]:
import numpy as np
import numpy.linalg as LA
import jax
import jax.numpy as jnp
import jax.numpy.linalg as JLA

import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
from jax.example_libraries import optimizers
from tqdm.notebook import trange
from functools import partial
from flax import linen as nn
from flax.experimental import nnx
from typing import Sequence, Callable, Tuple

In [146]:
M = 32
K = 100
num_hidden_units = 10 
sigma = 0.1

In [147]:
def mini_batch(K):
    # r = np.zeros((M, K))
    # row = np.random.randint(0, M, K)
    # col = np.array(range(K))
    # r[row, col] = 1
    r = np.zeros((K, M))
    row = np.array(range(K))
    col = np.random.randint(0, M, K)
    r[row, col] = 1
    return jnp.array(r)

In [148]:
class Encoder(nn.Module):
    hidden_dim : int
    normalizer : Callable
    act_fn : Callable = nnx.relu
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = self.act_fn(x)
        x = nn.Dense(2)(x)
        x = self.normalizer(x)
        return x

In [149]:
class Decoder(nn.Module):
    hidden_dim : int
    output_dim : int
    act_fn : Callable = nnx.relu
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = self.act_fn(x)
        x = nn.Dense(self.output_dim)(x)
        x = nnx.softmax(x)
        return x

In [150]:
class ChannelModel(nn.Module):
    hidden_dim : int
    output_dim : int    
    normalizer : Callable
    sigma : float
    act_fn : Callable = nnx.relu

    def setup(self):
        self.encoder = Encoder(hidden_dim = num_hidden_units, normalizer = self.normalizer)
        self.decoder = Decoder(hidden_dim = num_hidden_units, output_dim = self.output_dim)
    
    @nn.compact
    def __call__(self, x):
        x = self.encoder(x)
        x = x + self.sigma * jnp.array(np.random.randn(K, 2))
        x = self.decoder(x)
        return x

In [151]:
def peak_const(x):
    return x/jnp.sqrt(max(jnp.sum(x**2, axis=1)))

In [152]:
x = mini_batch(K)
channel_model = ChannelModel(hidden_dim=num_hidden_units, output_dim=M, normalizer=peak_const, sigma=sigma)
key = jax.random.PRNGKey(0) 
params = channel_model.init(key, x[:1])["params"]
channel_model.apply({"params":params}, x).shape

(100, 32)