In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
from numpy import random

  Referenced from: <D9493EF5-8DAB-3A5D-85D5-684F04544B84> /opt/homebrew/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <BB02660F-1D5B-3388-B48B-486877D726F6> /opt/homebrew/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [56]:
N = 3
k1 = 2**N
L_enc = 3
L_dec = 3
I_dec = 3
N_enc = 96
N_dec = 96
n1 = 16

encoder_learning_rate = 2e-5
decoder_learning_rate = 2e-5
noise = 0

In [85]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()

        layers = []

        layers.append(nn.Linear(k1, N_enc))
        layers.append(nn.SELU())
        for i in range(L_enc):
            layers.append(nn.Linear(N_enc, N_enc))
            layers.append(nn.SELU())
        layers.append(nn.Linear(N_enc, n1))

        #layers.append(nn.Tanh())

        self.fcnn = nn.Sequential(
            *layers
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.fcnn(x)
        return logits

In [86]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()

        # it's a set of layers of one decoder out of I_dec decoders
        layers = []

        # first decoder 
        layers.append(nn.Linear(n1, N_dec))
        layers.append(nn.SELU())
        for i in range(L_enc):
            layers.append(nn.Linear(N_dec, N_dec))
            layers.append(nn.SELU())
        
        # second & third decoders
        for j in range(2):
            for i in range(L_enc + 1):
                layers.append(nn.Linear(N_dec, N_dec))
                layers.append(nn.SELU())

        # last decoder
        for i in range(L_enc):
            layers.append(nn.Linear(N_dec, N_dec))
            layers.append(nn.SELU())
        layers.append(nn.Linear(N_dec, k1))
        
        #layers.append(nn.Tanh())

        self.fcnn = nn.Sequential(
            *layers
        )
        
    def forward(self, x):
        x = self.flatten(x)
        logits = self.fcnn(x)
        return logits

In [98]:
class PowerNormaliser:
    def normalize(self, tensor):
        #print(tensor.shape)
        norm = torch.norm(tensor, p=2)
        normalized_vector = tensor / norm
        return normalized_vector

In [101]:
a = torch.tensor([1, 2, 3], dtype=torch.float32)
p = PowerNormaliser()
print(p.normalize(a))

tensor([0.2673, 0.5345, 0.8018])


In [89]:
class Channel():
    def __init__(self, sigma):
        self.sigma = sigma
    def add_noise(self, data):
        noise = random.normal(0, self.sigma, data.shape)
        return data + torch.from_numpy(noise).to(dtype=torch.float32)

In [90]:
class Binarise():
    def __init__(self):
        self.threshold = 0
    def binarise(self, data):
        return (data > self.threshold).float()

In [91]:
encoder = Encoder()
awgn_channel = Channel(noise)
decoder = Decoder()
normaliser = PowerNormaliser()
binarise = Binarise()

In [92]:
enc_optimizer = torch.optim.Adam(encoder.parameters(), lr = encoder_learning_rate)
#enc_scheduler = torch.optim.lr_scheduler.ExponentialLR(enc_optimizer, gamma=0.95)
dec_optimizer = torch.optim.Adam(decoder.parameters(), lr = decoder_learning_rate)
#dec_scheduler = torch.optim.lr_scheduler.ExponentialLR(dec_optimizer, gamma=0.95)
loss_fn = nn.BCEWithLogitsLoss()

In [103]:
num_samples = 2**k1
training_epochs = 10
num_epochs = 1000
batch_size = 1

# x_train = torch.rand((num_samples, k1))
# y_train = x_train
# dataset = torch.utils.data.TensorDataset(x_train, y_train)
# train_loader = DataLoader(dataset, batch_size=batch_size)

In [None]:
#torch.randint(low = -1, high = )

In [106]:
# implement the training loop for the encoder part
# at the same time the decoder part is frozen
# the code should use train_loader to get the data
# and the loss function should be loss_fn
fl = True
for i in range(training_epochs):
#for i in range(1):
    # generate random data for training
    x_train = torch.randint(low=0, high=2, size=(num_samples, k1)).to(dtype=torch.float32)
    y_train = x_train
    dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = DataLoader(dataset, batch_size=batch_size)

    # implement the training loop for the decoder part
    # at the same time the encoder part is frozen
    # the code should use train_loader to get the data
    # and the loss function should be loss_fn
    
    for epoch in range(200):
        for x_batch, y_batch in train_loader:
            dec_optimizer.zero_grad()
            
            encoded = encoder(x_batch)
            enc_norm = normaliser.normalize(encoded)
            enc_norm_noise = awgn_channel.add_noise(enc_norm)
            
            decoded = decoder(enc_norm_noise)
            # decoded_bin = binarise.binarise(decoded) - Can't do binarise here, because it's a part of the loss function
            
            loss = loss_fn(decoded, y_batch)
            loss.backward()
            dec_optimizer.step()
            
            #print(f"{epoch=}, loss={loss.item():.4f}")

    print(f"Decoder training, loss: {loss.item():.4f}, epoch: {i}")
    
    
    x_train = torch.randint(low=0, high=2, size=(num_samples, k1)).to(dtype=torch.float32)
    y_train = x_train
    dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = DataLoader(dataset, batch_size=batch_size)
    
    for epoch in range(100):
        for x_batch, y_batch in train_loader:
            enc_optimizer.zero_grad()
            
            encoded = encoder(x_batch)
            enc_norm = normaliser.normalize(encoded)
            enc_norm_noisy = awgn_channel.add_noise(enc_norm)
            
            decoded = decoder(enc_norm_noisy)
            
            loss = loss_fn(decoded, y_batch)
            loss.backward()
            enc_optimizer.step()
        
    print(f"Encoder training, loss: {loss.item():.4f}, epoch: {i}")
    print()


Decoder training, loss: 0.5788, epoch: 0
Encoder training, loss: 0.3138, epoch: 0

Decoder training, loss: 0.4275, epoch: 1
Encoder training, loss: 0.4383, epoch: 1

