# qauntum neural network belief propagation

Implementing the NNBP decode quantum code`[1]`, I implementing this by Pytorch`[2]`. This part focus on the impementation to decode 2D Toric code`[3]`. This note just illustrate the diffence between the decoding classic codes and quantum code, for more detail, one can refer to the NNBP repo.  
The main diffence between classic and quantum decoding in computational basis is that the syndrome change the parity check eqaution's result. Also, we introduce a residual connection to overcome the challenge of gradient vanishing.

### **First Part: define the hyperparameters**

In [None]:
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import math
import torch 
import error_generate
import os
import torch.nn as nn
import torch.utils.data as Data
from torch.autograd import Variable
import torch.nn.functional as F
import time

L = 4
P = list(np.linspace(0.01, 0.05, num = 6))
H = torch.from_numpy(error_generate.generate_PCM(2 * L * L - 2, L))
h_prep = error_generate.H_Prep(H)
H_prep = torch.from_numpy(h_prep.get_H_Prep())
BATCH_SIZE = 120
#torch.manual_seed(1)
run = 1200
lr = 2e-4
Nc = 15
torch.cuda.set_device(0)
torch.autograd.set_detect_anomaly(True)
dataset = error_generate.gen_syn(P, L, H, run)

### **Second Part: dataset and one iteration of the NNBP algorithm**  

In [None]:
class CustomDataset(Data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        return dataset[2 * index], dataset[2 * index + 1]
    
    def __len__(self):
        return int(len(self.dataset) / 2)


class ResidualBlock(torch.nn.Module):
    def __init__(self, H, resi):
        super(ResidualBlock, self).__init__()
        self.H = Variable(H, requires_grad = False).cuda()
        self.rows, self.cols = H.size()
        self.W_check = torch.nn.Parameter((torch.ones((self.rows, self.rows), dtype = torch.float64) - \
                                           torch.eye(self.rows, dtype = torch.float64)))
        self.W_vec = torch.nn.Parameter(Variable(torch.ones((self.rows, 1), dtype = torch.float64)))
        self.resi = resi
        
    def forward(self, x):
        L, Syn, M_check, results= x
        H = (Variable(torch.ones((M_check.size()), dtype = torch.float64))).cuda().mul(self.H)
        check = Variable(torch.ones((self.rows, self.rows), dtype = torch.float64) - \
                         torch.eye(self.rows, dtype = torch.float64)).cuda()
        var = Variable(torch.ones((self.cols, self.cols), dtype = torch.float64) - \
                       torch.eye(self.cols, dtype = torch.float64)).cuda()
        result = M_check
        result = torch.where(H.transpose(1, 2) == 0, result.transpose(1, 2), \
                             torch.matmul(result.transpose(1, 2), self.W_check.mul(check)))
        result += torch.matmul(L, self.H.t())
        result = torch.clamp(result, -10 ,10)
        result = torch.where(H.transpose(1, 2) == 0, result, torch.tanh(result / 2))
        Coeff = torch.where(result < 0, torch.ones(result.size(), dtype = torch.float64).cuda(), \
                            torch.zeros(result.size(), dtype = torch.float64).cuda())
        result = torch.where(H.transpose(1, 2) == 0, result, abs(result))
        result = torch.where(H.transpose(1, 2) == 0, result + 3, result)
#        print(torch.where(result > 0, torch.zeros(result.size(), dtype = torch.float64).cuda(), result).to_sparse()._values())
        result = torch.where(H.transpose(1, 2) == 0, result, result.clamp(1e-40))
        result = torch.where(H.transpose(1, 2) == 0, result, torch.log(result))
        result = torch.where(H.transpose(1, 2) == 0, result - 3, result)
        result = torch.where(H == 0, result.transpose(1, 2), torch.matmul(result.transpose(1, 2), var))
#        print(torch.where(result < 0, torch.zeros(result.size(), dtype = torch.float64), result).to_sparse()._values())
        result = torch.where(H == 0, result, torch.exp(result))
        Coeff = torch.where(H == 0, Coeff.transpose(1, 2), torch.matmul(Coeff.transpose(1, 2), var))
        result = torch.where(H == 0, result, torch.cos(math.pi * (Coeff + Syn[:, 0 : self.rows].unsqueeze(2))).mul(result))
        result = torch.where(H == 0, result, torch.log((1 + result) / (1 - result)))
        '''
        As the author suggest, we can introduce residual connect to overcome the challence of gradient vanishing.
        '''
        if self.resi == 1:
            result += M_check
        output = result
        result = torch.matmul(output.transpose(1, 2), self.W_vec)
        result += torch.matmul(L, Variable(torch.ones((self.cols, 1), dtype = torch.float64)).cuda())
        results.append(result)
        return L, Syn, output, results

### **Third part: neural network belief propagation**

In [None]:
class NNBP(torch.nn.Module):
    def __init__(self, H, Nc, resi):
        super(NNBP, self).__init__()
        self.H = H.cuda()
        self.resi = resi
        self.rows, self.cols = H.size()
        self.Nc = Nc
        self.layer = self._make_layer()
        
    def _make_layer(self):
        layers = []
        for i in range(self.Nc):
            layers.append(ResidualBlock(self.H, self.resi).cuda())
        return torch.nn.Sequential(*layers)
        
    def forward(self, x):
        L = Variable(torch.zeros((x.size()[0], self.cols, self.cols), dtype = torch.float64))
        for i in range(x.size()[0]):
            L[i] = torch.diag(x[i, 0])
        Syn = x[:, 1, :]
        L = L.cuda()
        Syn = Syn.cuda()
        results = []
        M_check_init = Variable(torch.zeros(x.size()[0], self.rows, self.cols, dtype = torch.float64)).cuda()
        x = (L, Syn, M_check_init, results)
        x = self.layer(x)
        return x[3]

### **Forth Part: loss function and train**

In [None]:
class LossFunc(torch.nn.Module):
    def __init__(self, H_prep):
        super(LossFunc, self).__init__()
        self.H_prep = Variable(H_prep).cuda()
        
    def forward(self, result, err):
        medium = torch.matmul(self.H_prep, (err.transpose(1, 2) + torch.sigmoid(-1 * result)))
        medium = abs(torch.sin(medium * math.pi / 2))
        loss = torch.sum(medium)
        return loss

    
def training(H, lr, L, train_num, train, load):
    loss_sum = 0
    rows, cols  = H.shape
    if (train or load):
        resi = 1
    else:
        resi = 0
    decoder = NNBP(H, Nc, resi).cuda()
#    if load:
#        decoder.load_state_dict(torch.load('./model/decoder_parameters_L=%d.pkl' % L))
#        for name, param in decoder.named_parameters():
#            if param.requires_grad:
##                print(name,torch.where(param > 0.0, torch.zeros(param.size(), dtype = torch.float64), param).to_sparse()._values())
#                print(name,param)
#        return

    torch_dataset = CustomDataset(dataset)
    loader = Data.DataLoader(
        dataset = torch_dataset,      # torch TensorDataset format
        batch_size = BATCH_SIZE,      # mini batch size
        shuffle = True,               # random shuffle for training
        num_workers = 0,              # subprocesses for loading data
    )
    
    criterion = LossFunc(H_prep).cuda()
    optimizer = torch.optim.Adam(decoder.parameters(), lr = lr)
    for epoch in range(train_num):
        print('epoch',epoch)
        for step, (data, target) in enumerate(loader):
            loss = Variable(torch.zeros((1), dtype = torch.float64)).cuda()
            data, target = Variable(data).cuda(), Variable(target).cuda()
            optimizer.zero_grad()
            results = decoder(data)
            if train:
                for result in results:
                    loss += criterion(result, target)
                loss /= len(results)
                loss.backward()       
                optimizer.step()
                for p in decoder.parameters():
                    p.data.clamp_(1e-10)
            else:
                loss_sum += criterion(results[len(results) - 1], target)
    if train:
        torch.save(decoder.state_dict(), '.\model\decoder_parameters_L=%d.pkl' % L)
    return loss_sum

### **Fifth Part: training the neural network**

In [None]:
if __name__ == '__main__':
    train = 0
    load = 1
    train_num = 1
    loss = training(H, lr, L, train_num, train, load)
    if not train: print(loss.item() / (run * (2 * L ** 2 - 2)))

### **Sixth Part: results and evalaution**

Waiting for coming

### **Reference**

`[1]`: Ye-Hua Liu and David Poulin, “Neural belief-propagation decoders for quantum error-correcting codes,” arXiv preprint arXiv:1811.07835 (2018).  
`[2]`: https://pytorch.org/  
`[3]`: A. Yu. Kitaev, Ann. Phys. (N.Y.) 303, 2 (2003).  