In [1]:
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import sys
from tensorflow.python.framework import ops
from helper_functions import load_code, syndrome
import os
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
device = 'cpu'

In [3]:
seed = 1221
np.random.seed(seed)
snr_lo = 1
snr_hi = 6
snr_step = 1
min_frame_errors = 100
max_frames = 100000000
# num_iterations = 5
H_filename = './codes/hamming.alist'
G_filename = './codes/hamming.gmat'
output_filename = 'out_file'
L = 0.5
steps = 100

In [4]:
code = load_code(H_filename, G_filename)
H = code.H
G = code.G
var_degrees = code.var_degrees
chk_degrees = code.chk_degrees
num_edges = code.num_edges
u = code.u
d = code.d
n = code.n
m = code.m
k = code.k

In [62]:
class Decoder(torch.nn.Module):
    def __init__(self, num_iterations = 5):
        super(Decoder, self).__init__()
        self.W_cv = torch.nn.Parameter(torch.abs(torch.randn((num_iterations, num_edges))))
        self.W_cv.requires_grad = True
        self.B_cv = torch.nn.Parameter(torch.abs(torch.randn((num_iterations, num_edges))))
        self.B_cv.requires_grad = True
        self.W_vc = torch.nn.Parameter(torch.abs(torch.randn((num_iterations, num_edges))))
        self.W_vc.requires_grad = True

        self.num_iterations = num_iterations
        
    def forward(self, soft_input):
        cv = torch.zeros((num_edges, soft_input.shape[1]))
        for iteration in range(0, self.num_iterations):
            vc = self.compute_vc(cv, soft_input, iteration)
            cv = self.compute_cv(vc, iteration)
            soft_input = self.marginalize(soft_input, cv)
        output = torch.sigmoid(soft_input)
        # output = soft_input
        return output

        
    def compute_vc(self, cv, soft_input, iteration):
        edges = []
        for i in range(0, n):
            for j in range(0, var_degrees[i]):
                edges.append(i)

        reordered_soft_input = torch.index_select(soft_input, 0, torch.tensor(edges).to(device))

        vc = []
        edge_order = []

        for i in range(0, n): # for each variable node v
            for j in range(0, var_degrees[i]):
                edge_order.append(d[i][j])
                extrinsic_edges = []
                for jj in range(0, var_degrees[i]):
                    if jj != j: # extrinsic information only
                        extrinsic_edges.append(d[i][jj])
                # if the list of edges is not empty, add them up
                if extrinsic_edges:
                    temp = torch.index_select(cv.to(device), 0, torch.tensor(extrinsic_edges).to(device))
                    temp = torch.sum(temp, 0)
                else:
                    temp = torch.zeros(cv.shape[1])
                vc.append(temp)
        # vc = torch.tensor(vc)
        vc = torch.stack(vc)
        new_order = np.zeros(num_edges).astype(int)
        new_order[edge_order] = np.arange(0, num_edges)
        vc = torch.index_select(vc, 0, torch.tensor(new_order).to(device))
        vc += reordered_soft_input * torch.tile(torch.reshape(self.W_vc[iteration], (-1,1)), (1, cv.shape[1]))       # add soft inputs of the previous iterations!
        return vc 

    def compute_cv(self, vc, iteration):
        cv_list = []
        prod_list = []
        min_list = []
        edge_order = []
        for i in range(0, m): # for each check node c
            for j in range(0, chk_degrees[i]):
                edge_order.append(u[i][j])
                extrinsic_edges = []
                for jj in range(0, chk_degrees[i]):
                    if jj != j:
                        extrinsic_edges.append(u[i][jj])
                temp = torch.index_select(vc.to(device),0,torch.tensor(extrinsic_edges).to(device))
                temp1 = torch.prod(torch.sign(temp),0)
                temp2 = torch.min(torch.abs(temp),0)[0]
                prod_list.append(temp1)
                min_list.append(temp2)
        prods = torch.stack(prod_list)
        mins = torch.stack(min_list)
        mins = torch.relu(mins - torch.tile(torch.reshape(self.B_cv[iteration], (-1,1)), (1, vc.shape[1])))
        cv = prods * mins
        new_order = np.zeros(num_edges).astype(int)
        new_order[edge_order] = np.array(range(0,num_edges)).astype(int)
        cv = torch.index_select(cv, 0, torch.tensor(new_order).to(device))
        cv = cv * torch.tile(torch.reshape(self.W_cv[iteration], (-1,1)), (1, vc.shape[1]))
        return cv

    # combine messages to get posterior LLRs
    def marginalize(self, soft_input, cv):
        weighted_soft_input = soft_input
        soft_output = []
        for i in range(0,n):
            edges = []
            for e in range(0,var_degrees[i]):
                edges.append(d[i][e])
            temp = torch.index_select(cv,0,torch.tensor(edges).to(device))
            temp = torch.sum(temp, 0)
            soft_output.append(temp)
        soft_output = torch.stack(soft_output)
        soft_output = weighted_soft_input + soft_output
        return soft_output

In [63]:
def epoch_train(loader, clf, criterion, opt):
    clf.train(True)
    avg_loss = 0
    avg_acc = 0
    correct = 0
    # load batch
    for model_input, target in loader:
        # move data to device
        model_input = model_input.to(device)
        target = target.to(device)
        # calculate outputs, loss and accuracy
        model_output = clf(model_input)
        loss = criterion(model_output, target)
        avg_loss += loss
        correct += torch.count_nonzero(torch.heaviside(model_output-0.5, torch.tensor([0.])) == target)
        # calculate grad, upd weights
        opt.zero_grad()
        loss.backward()
        opt.step()
    avg_loss = avg_loss / len(loader)
    avg_acc = correct / len(loader.dataset)
    return avg_loss, avg_acc

def epoch_test(loader, clf, criterion):
    clf.eval()
    avg_loss = 0
    avg_acc = 0
    correct = 0
    # load batch
    for model_input, target in loader:
        # move data to device
        model_input = model_input.to(device)
        target = target.to(device)
        # calculate outputs, loss and accuracy
        model_output = clf(model_input)
        loss = criterion(model_output, target)
        avg_loss += loss
        correct += torch.count_nonzero(torch.heaviside(model_output-0.5, torch.tensor([0.])) == target)
    avg_loss = avg_loss / len(loader)
    avg_acc = correct / len(loader.dataset)

    return avg_loss, avg_acc

def train(train_loader, test_loader, clf, criterion, opt, n_epochs=50):
    for epoch in tqdm(range(n_epochs)):
        train_loss, train_acc = epoch_train(train_loader, clf, criterion, opt)
        test_loss, test_acc = epoch_test(test_loader, clf, criterion)

        if (np.mod(epoch+1,100)==0):
            print(f'[Epoch {epoch + 1}] train loss: {train_loss:.3f}; train acc: {train_acc:.2f}; ' + 
                  f'test loss: {test_loss:.3f}; test acc: {test_acc:.2f}')

## Create dataloader

In [64]:
messages = np.random.randint(0,2,[k,batch_size])
codewords = np.dot(G, messages) % 2
BPSK_codewords = (0.5 - codewords.astype(np.float32)) * 2.0
soft_input = np.zeros_like(BPSK_codewords)
channel_information = np.zeros_like(BPSK_codewords)
SNRs = np.arange(1,6)
for i in range(0,len(SNRs)):
    sigma = np.sqrt(1. / (2 * (np.float(k)/np.float(n)) * 10**(SNRs[i]/10)))
    noise = sigma * np.random.randn(n,batch_size//len(SNRs))
    start_idx = batch_size*i//len(SNRs)
    end_idx = batch_size*(i+1)//len(SNRs)
    soft_input[:,start_idx:end_idx] = BPSK_codewords[:,start_idx:end_idx] + noise

tensor_x = torch.Tensor(soft_input) # transform to torch tensor
tensor_y = torch.Tensor(codewords)

dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y) # create your datset
loader = torch.utils.data.DataLoader(dataset, batch_size=120, shuffle=True) # create your dataloader

In [65]:
decoder = Decoder(num_iterations=5).to(device)
opt = torch.optim.Adam(decoder.parameters())
criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.MSELoss()

In [66]:
train(loader, loader, decoder, criterion, opt, n_epochs=2500)

  4%|█████▏                                                                                                                  | 109/2500 [00:02<00:50, 47.62it/s]

[Epoch 100] train loss: 295.776; train acc: 10.29; test loss: 296.759; test acc: 8.29


  8%|██████████                                                                                                              | 209/2500 [00:04<00:47, 47.79it/s]

[Epoch 200] train loss: 295.934; train acc: 10.43; test loss: 296.042; test acc: 10.71


 12%|██████████████▊                                                                                                         | 309/2500 [00:06<00:45, 47.80it/s]

[Epoch 300] train loss: 296.944; train acc: 9.00; test loss: 297.061; test acc: 8.00


 16%|███████████████████▋                                                                                                    | 409/2500 [00:08<00:43, 47.82it/s]

[Epoch 400] train loss: 295.680; train acc: 11.71; test loss: 296.731; test acc: 10.14


 20%|████████████████████████▍                                                                                               | 509/2500 [00:10<00:41, 47.79it/s]

[Epoch 500] train loss: 295.052; train acc: 16.00; test loss: 294.811; test acc: 13.14


 24%|█████████████████████████████▏                                                                                          | 609/2500 [00:12<00:39, 47.77it/s]

[Epoch 600] train loss: 294.154; train acc: 17.86; test loss: 295.488; test acc: 17.57


 28%|██████████████████████████████████                                                                                      | 709/2500 [00:14<00:37, 47.76it/s]

[Epoch 700] train loss: 294.775; train acc: 22.14; test loss: 293.599; test acc: 24.29


 32%|██████████████████████████████████████▊                                                                                 | 809/2500 [00:17<00:35, 47.77it/s]

[Epoch 800] train loss: 292.674; train acc: 29.00; test loss: 291.996; test acc: 32.00


 36%|███████████████████████████████████████████▋                                                                            | 909/2500 [00:19<00:33, 47.75it/s]

[Epoch 900] train loss: 291.160; train acc: 31.86; test loss: 291.145; test acc: 32.57


 40%|████████████████████████████████████████████████                                                                       | 1009/2500 [00:21<00:31, 47.72it/s]

[Epoch 1000] train loss: 291.454; train acc: 29.71; test loss: 291.678; test acc: 31.29


 44%|████████████████████████████████████████████████████▊                                                                  | 1109/2500 [00:23<00:29, 47.76it/s]

[Epoch 1100] train loss: 285.165; train acc: 58.14; test loss: 285.350; test acc: 53.86


 48%|█████████████████████████████████████████████████████████▌                                                             | 1209/2500 [00:25<00:27, 47.75it/s]

[Epoch 1200] train loss: 277.895; train acc: 78.86; test loss: 278.194; test acc: 78.14


 52%|██████████████████████████████████████████████████████████████▎                                                        | 1309/2500 [00:27<00:24, 47.80it/s]

[Epoch 1300] train loss: 272.532; train acc: 89.29; test loss: 275.342; test acc: 82.86


 56%|███████████████████████████████████████████████████████████████████                                                    | 1409/2500 [00:29<00:22, 47.76it/s]

[Epoch 1400] train loss: 275.116; train acc: 84.00; test loss: 271.824; test acc: 90.29


 60%|███████████████████████████████████████████████████████████████████████▊                                               | 1509/2500 [00:31<00:20, 47.79it/s]

[Epoch 1500] train loss: 275.171; train acc: 83.43; test loss: 271.968; test acc: 90.14


 64%|████████████████████████████████████████████████████████████████████████████▌                                          | 1609/2500 [00:33<00:18, 47.76it/s]

[Epoch 1600] train loss: 274.209; train acc: 84.14; test loss: 274.767; test acc: 83.71


 68%|█████████████████████████████████████████████████████████████████████████████████▎                                     | 1709/2500 [00:35<00:16, 47.76it/s]

[Epoch 1700] train loss: 272.157; train acc: 89.14; test loss: 269.275; test acc: 94.86


 72%|██████████████████████████████████████████████████████████████████████████████████████                                 | 1809/2500 [00:37<00:14, 47.70it/s]

[Epoch 1800] train loss: 275.204; train acc: 82.71; test loss: 272.342; test acc: 88.86


 76%|██████████████████████████████████████████████████████████████████████████████████████████▊                            | 1909/2500 [00:40<00:12, 47.77it/s]

[Epoch 1900] train loss: 274.486; train acc: 83.86; test loss: 274.534; test acc: 83.43


 80%|███████████████████████████████████████████████████████████████████████████████████████████████▋                       | 2009/2500 [00:42<00:10, 47.75it/s]

[Epoch 2000] train loss: 268.340; train acc: 96.43; test loss: 275.305; test acc: 82.43


 84%|████████████████████████████████████████████████████████████████████████████████████████████████████▍                  | 2109/2500 [00:44<00:08, 47.78it/s]

[Epoch 2100] train loss: 275.109; train acc: 82.14; test loss: 275.459; test acc: 82.00


 88%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▏             | 2209/2500 [00:46<00:06, 47.78it/s]

[Epoch 2200] train loss: 275.469; train acc: 83.00; test loss: 271.488; test acc: 91.29


 92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▉         | 2309/2500 [00:48<00:03, 47.77it/s]

[Epoch 2300] train loss: 264.975; train acc: 109.00; test loss: 273.827; test acc: 90.43


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋    | 2409/2500 [00:50<00:01, 47.74it/s]

[Epoch 2400] train loss: 269.826; train acc: 98.86; test loss: 270.532; test acc: 97.43


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [00:52<00:00, 47.69it/s]

[Epoch 2500] train loss: 270.875; train acc: 96.00; test loss: 269.515; test acc: 99.00





In [67]:
torch.count_nonzero(torch.round(decoder(tensor_x)) == tensor_y)

tensor(778)