In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import math
from UnarySim.sw.stream.gen import RNG, SourceGen, BSGen
from UnarySim.sw.kernel.add import GainesAdd
from UnarySim.sw.kernel.shiftreg import ShiftReg
from UnarySim.sw.metric.metric import ProgressiveError, NormStability
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import time
import math
import numpy as np

In [3]:
class CheckNode(torch.nn.Module):
    def __init__(self,
                 stype=torch.float):
        super(CheckNode, self).__init__()
        self.stype = stype
        
    def forward(self, input):
        # assume check node input is stack along axis 0
        output = torch.sum(input, 0).sub(input) % 2
        return output.type(self.stype)

    
class ParityNode(torch.nn.Module):
    def __init__(self,
                 stype=torch.float):
        super(ParityNode, self).__init__()
        self.stype = stype
        
    def forward(self, input):
        # assume check node input is stack along axis 0
        output = torch.sum(input, 0) % 2
        return output.type(self.stype)
    
    
class VariableNodeCNT(torch.nn.Module):
    def __init__(self,
                 degree=1,
                 depth=7,
                 LLR=None,
                 rtype=torch.float,
                 btype=torch.float, 
                 stype=torch.float):
        super(VariableNodeCNT, self).__init__()
        
        # this degree includes channel information
        self.degree = degree
        assert degree >= 1, "Input degree can't be smaller than 2."
        if degree == 1 or degree == 2:
            # no shift register is required, as the channel information is directly sent to the check node
            pass
        elif degree == 3:
            self.im_0 = ShiftReg(depth=2, stype=stype)
            self.im_1 = ShiftReg(depth=2, stype=stype)
            self.im_2 = ShiftReg(depth=2, stype=stype)
        elif degree == 4:
            self.im_0_0 = ShiftReg(depth=2, stype=stype)
            self.im_0_1 = ShiftReg(depth=2, stype=stype)
            self.im_1_0 = ShiftReg(depth=2, stype=stype)
            self.im_1_1 = ShiftReg(depth=2, stype=stype)
            self.im_2_0 = ShiftReg(depth=2, stype=stype)
            self.im_2_1 = ShiftReg(depth=2, stype=stype)
            self.im_3_0 = ShiftReg(depth=2, stype=stype)
            self.im_3_1 = ShiftReg(depth=2, stype=stype)

        self.acc = torch.nn.Parameter(torch.zeros(1).type(btype), requires_grad=False)
        self.acc.data = LLR.type(btype)
        self.acc_max = 2**depth - 1
        self.acc_max_1 = 2**depth
        
        self.rtype = rtype
        self.btype = btype
        self.stype = stype
    
    def degree1_forward(self, c2v, chn):
        # c2v/v2c is [0][...]
        # chn/posterior is [...]
        c2v_eq = torch.zeros_like(c2v)
        c2v_eq[0] = chn.type(self.btype)
        v2c = c2v_eq.type(self.stype)
        posterior = torch.eq(v2c[0], c2v[0]).type(self.stype)
        return v2c, posterior
    
    def degree2_forward(self, c2v, chn):
        # c2v/v2c is [0, 1][...]
        # chn/posterior is [...]
        # index 0
        c2v_eq_0 = (1 - (chn.type(torch.int8) ^ c2v[1].type(torch.int8))).type(self.btype)
        # index 1
        c2v_eq_1 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        c2v_eq = torch.stack((c2v_eq_0, c2v_eq_1), 0)
        v2c = c2v_eq.type(self.stype) \
            * chn \
            + (1 - c2v_eq).type(self.stype) \
            * (torch.gt(self.acc, torch.randint(0, self.acc_max_1, (self.degree, 1)).type(self.btype))).type(self.stype)
        self.acc.data = (self.acc + c2v_eq * chn.mul(2).sub(1).type(self.btype)).clamp(0, self.acc_max)
        posterior = torch.eq(v2c[0], c2v[0]).type(self.stype)
        return v2c, posterior
    
    def degree3_forward(self, c2v, chn):
        # c2v/v2c is [0, 1, 2][...]
        # chn/posterior is [...]
        # index 0
        c2v_eq_0_0 = (1 - (chn.type(torch.int8) ^ c2v[1].type(torch.int8))).type(self.btype)
        internal_0 = c2v_eq_0_0.type(self.stype) \
                   * chn \
                   + (1 - c2v_eq_0_0).type(self.stype) \
                   * self.im_0.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_0(internal_0, mask = c2v_eq_0_0)
        c2v_eq_0 = (1 - (internal_0.type(torch.int8) ^ c2v[2].type(torch.int8))).type(self.btype)
        
        # index 1
        c2v_eq_1_0 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        internal_1 = c2v_eq_1_0.type(self.stype) \
                   * chn \
                   + (1 - c2v_eq_1_0).type(self.stype) \
                   * self.im_1.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_1(internal_1, mask = c2v_eq_1_0)
        c2v_eq_1 = (1 - (internal_1.type(torch.int8) ^ c2v[2].type(torch.int8))).type(self.btype)
        
        # index 2
        c2v_eq_2_0 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        internal_2 = c2v_eq_2_0.type(self.stype) \
                   * chn \
                   + (1 - c2v_eq_2_0).type(self.stype) \
                   * self.im_2.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_2(internal_0, mask = c2v_eq_2_0)
        c2v_eq_2 = (1 - (internal_2.type(torch.int8) ^ c2v[1].type(torch.int8))).type(self.btype)
        
        c2v_eq = torch.stack((c2v_eq_0, c2v_eq_1, c2v_eq_2), 0)
        input_1 =  torch.stack((internal_0, internal_1, internal_2), 0)
        v2c = c2v_eq.type(self.stype) \
            * input_1 \
            + (1 - c2v_eq).type(self.stype) \
            * (torch.gt(self.acc, torch.randint(0, self.acc_max_1, (self.degree, 1)).type(self.btype))).type(self.stype)
        self.acc.data = (self.acc + c2v_eq * input_1.mul(2).sub(1).type(self.btype)).clamp(0, self.acc_max)
        posterior = torch.eq(v2c[0], c2v[0]).type(self.stype)
        return v2c, posterior
    
    def degree4_forward(self, c2v, chn):
        # c2v/v2c is [0, 1, 2, 3][...]
        # chn/posterior is [...]
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # index 0
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # chn, c2v[1]
        c2v_eq_0_0 = (1 - (chn.type(torch.int8) ^ c2v[1].type(torch.int8))).type(self.btype)
        internal_0_0 = c2v_eq_0_0.type(self.stype) \
                     * chn \
                     + (1 - c2v_eq_0_0).type(self.stype) \
                     * self.im_0_0.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_0_0(internal_0_0, mask = c2v_eq_0_0)
        # c2v[2], c2v[3]
        c2v_eq_0_1 = (1 - (c2v[2].type(torch.int8) ^ c2v[3].type(torch.int8))).type(self.btype)
        internal_0_1 = c2v_eq_0_1.type(self.stype) \
                     * c2v[2] \
                     + (1 - c2v_eq_0_1).type(self.stype) \
                     * self.im_0_1.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_0_1(internal_0_1, mask = c2v_eq_0_1)
        
        c2v_eq_0 = (1 - (internal_0_0.type(torch.int8) ^ internal_0_1.type(torch.int8))).type(self.btype)
        
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # index 1
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # chn, c2v[0]
        c2v_eq_1_0 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        internal_1_0 = c2v_eq_1_0.type(self.stype) \
                     * chn \
                     + (1 - c2v_eq_1_0).type(self.stype) \
                     * self.im_1_0.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_1_0(internal_1_0, mask = c2v_eq_1_0)
        # c2v[2], c2v[3]
        c2v_eq_1_1 = (1 - (c2v[2].type(torch.int8) ^ c2v[3].type(torch.int8))).type(self.btype)
        internal_1_1 = c2v_eq_1_1.type(self.stype) \
                     * c2v[2] \
                     + (1 - c2v_eq_1_1).type(self.stype) \
                     * self.im_1_1.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_1_1(internal_1_1, mask = c2v_eq_1_1)
        
        c2v_eq_1 = (1 - (internal_1_0.type(torch.int8) ^ internal_1_1.type(torch.int8))).type(self.btype)
        
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # index 2
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # chn, c2v[0]
        c2v_eq_2_0 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        internal_2_0 = c2v_eq_2_0.type(self.stype) \
                     * chn \
                     + (1 - c2v_eq_2_0).type(self.stype) \
                     * self.im_2_0.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_2_0(internal_2_0, mask = c2v_eq_2_0)
        # c2v[1], c2v[3]
        c2v_eq_2_1 = (1 - (c2v[1].type(torch.int8) ^ c2v[3].type(torch.int8))).type(self.btype)
        internal_2_1 = c2v_eq_2_1.type(self.stype) \
                     * c2v[1] \
                     + (1 - c2v_eq_2_1).type(self.stype) \
                     * self.im_2_1.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_2_1(internal_2_1, mask = c2v_eq_2_1)
        
        c2v_eq_2 = (1 - (internal_2_0.type(torch.int8) ^ internal_2_1.type(torch.int8))).type(self.btype)
        
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # index 3
        # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # chn, c2v[0]
        c2v_eq_3_0 = (1 - (chn.type(torch.int8) ^ c2v[0].type(torch.int8))).type(self.btype)
        internal_3_0 = c2v_eq_3_0.type(self.stype) \
                     * chn \
                     + (1 - c2v_eq_3_0).type(self.stype) \
                     * self.im_3_0.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_3_0(internal_3_0, mask = c2v_eq_3_0)
        # c2v[1], c2v[2]
        c2v_eq_3_1 = (1 - (c2v[1].type(torch.int8) ^ c2v[2].type(torch.int8))).type(self.btype)
        internal_3_1 = c2v_eq_3_1.type(self.stype) \
                     * c2v[1] \
                     + (1 - c2v_eq_3_1).type(self.stype) \
                     * self.im_3_1.sr.data[torch.randint(0, 2, (1, )).type(torch.long).item()]
        dc0, dc1 = self.im_3_1(internal_3_1, mask = c2v_eq_3_1)
        
        c2v_eq_3 = (1 - (internal_3_0.type(torch.int8) ^ internal_3_1.type(torch.int8))).type(self.btype)
        
        c2v_eq = torch.stack((c2v_eq_0, c2v_eq_1, c2v_eq_2, c2v_eq_3), 0)
        input_1 =  torch.stack((internal_0_0, internal_1_0, internal_2_0, internal_3_0), 0)
        v2c = c2v_eq.type(self.stype) \
            * input_1 \
            + (1 - c2v_eq).type(self.stype) \
            * (torch.gt(self.acc, torch.randint(0, self.acc_max_1, (self.degree, 1)).type(self.btype))).type(self.stype)
        self.acc.data = (self.acc + c2v_eq * input_1.mul(2).sub(1).type(self.btype)).clamp(0, self.acc_max)
        posterior = torch.eq(v2c[0], c2v[0]).type(self.stype)
        return v2c, posterior
    
    def forward(self, c2v, chn):
        if self.degree == 1:
            v2c, posterior = self.degree1_forward(c2v, chn)
        elif self.degree == 2:
            v2c, posterior = self.degree2_forward(c2v, chn)
        elif self.degree == 3:
            v2c, posterior = self.degree3_forward(c2v, chn)
        elif self.degree ==4:
            v2c, posterior = self.degree4_forward(c2v, chn)
        return v2c.type(self.stype), posterior.type(self.stype)


In [4]:
cn0 = CheckNode()
cn0_out = cn0(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]))
print(cn0_out)

tensor([[0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 1.]])


In [5]:
pn0 = ParityNode()
pn0_out = pn0(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]))
print(pn0_out)

tensor([0., 0.])


In [6]:
vn0 = VariableNodeCNT(degree=1, depth=7, LLR=torch.tensor([0.7*(2**7)]))
vn0_out, posterior = vn0(torch.tensor([1.]), torch.tensor([1.]))
print(vn0_out, posterior)
vn0_out, posterior = vn0(torch.tensor([0.]), torch.tensor([0.]))
print(vn0_out, posterior)

tensor([1.]) tensor(1.)
tensor([0.]) tensor(1.)


In [7]:
vn0 = VariableNodeCNT(degree=2, depth=7, LLR=torch.tensor([0.7*(2**7)]))
vn0_out, posterior = vn0(torch.tensor([[1.], [1.]]), torch.tensor([1.]))
print(vn0_out, posterior)
vn0_out, posterior = vn0(torch.tensor([[0.], [0.]]), torch.tensor([0.]))
print(vn0_out, posterior)

tensor([[1.],
        [1.]]) tensor([1.])
tensor([[0.],
        [0.]]) tensor([1.])


In [8]:
vn0 = VariableNodeCNT(degree=3, depth=7, LLR=torch.tensor([0.7*(2**7)]))
vn0_out, posterior = vn0(torch.tensor([[1.], [1.], [1.]]), torch.tensor([1.]))
print(vn0_out, posterior)
vn0_out, posterior = vn0(torch.tensor([[0.], [0.], [0.]]), torch.tensor([0.]))
print(vn0_out, posterior)

tensor([[1.],
        [1.],
        [1.]]) tensor([1.])
tensor([[0.],
        [0.],
        [0.]]) tensor([1.])


In [9]:
vn0 = VariableNodeCNT(degree=4, depth=7, LLR=torch.tensor([0.7*(2**7)]))
vn0_out, posterior = vn0(torch.tensor([[1.], [1.], [1.], [1.]]), torch.tensor([1.]))
print(vn0_out, posterior)
vn0_out, posterior = vn0(torch.tensor([[0.], [0.], [0.], [0.]]), torch.tensor([0.]))
print(vn0_out, posterior)

tensor([[1.],
        [1.],
        [1.],
        [1.]]) tensor([1.])
tensor([[0.],
        [0.],
        [0.],
        [0.]]) tensor([1.])


In [None]:
def test():
    