In [64]:
from pyexpat import model

import torch
import torch.nn.functional as F
import numpy as np
from time import time

from scipy._lib.array_api_compat import device
from torch.utils.data import DataLoader, TensorDataset

In [77]:
class FredkinLayer(torch.nn.Module):
    M3 = torch.tensor([[0,0,0,0,1],
                       [0,1,1,-1,0],
                       [1,0,-1,1,0]
                       ])
    
    def __init__(self,din:int,dout:int,device:str='cpu',seed:int=None):
        super().__init__()
        self.din = din
        #assert dout % 3 == 0 , "number of outputs must be divisible by 3"
        self.dout = dout
        self.device = device
        #self.wgts = torch.nn.Parameter(torch.randn(dout,din,device=device),requires_grad=True) #learnable params
        if seed is not None:
            torch.manual_seed(seed)
        self.roles = torch.stack([torch.randperm(3) for _ in range(dout)],dim=0)        #(dout , 3)
        #
        init_w = self._ordered_initial_weights(din, dout)  # (dout, din), rows sum to 1
        init_logits = torch.log(init_w + 1e-12)
        self.wgts_logits = torch.nn.Parameter(init_logits.to(self.device), requires_grad=True)

        
        
    def _ordered_initial_weights(self, din, dout , top3_prob=0.6):
        """
        Build (dout, din) where for neuron i the three preferred indices
        are start=(i*3) % din, start+1, start+2 (wrapping).
        Those three get top3_prob/3 each (top3_prob total). Remaining din-3 share (1-top3_prob) uniformly.
        Rows sum to 1.
        """
        assert top3_prob > 0 and top3_prob <= 1.0
        init = torch.zeros((dout, din), dtype=torch.float32)
        if din <= 3:
            # edge case: if din <=3, just distribute evenly among available features
            init[:] = 1.0 / din
            return init

        rem_share = (1-top3_prob) / (din - 3)
        for i in range(dout):
            start = (i * 3) % din
            chosen = [(start + k) % din for k in range(3)]
            for j in range(din):
                if j in chosen:
                    init[i, j] = top3_prob/3
                else:
                    init[i, j] = rem_share
        # numerical safety: renormalize rows to sum exactly to 1
        init = init / init.sum(dim=1, keepdim=True)
        return init
    
    def forward(self,x:torch.Tensor):
        batch_size = x.shape[0]
        assert x.shape[-1] == self.din
        
        # compute normalized, non-negative weights per neuron (rows sum to 1)
        wgts = F.softmax(self.wgts_logits, dim=1)  # (dout, din)
        
        #hard selection of top 3 inputs for each gate
        top_vals , top_idx = torch.topk(wgts,k=3,dim=1) #(dout,3)
        
        #create mask for Straight through estimator
        mask = torch.zeros_like(wgts)
        mask.scatter_(1,top_idx,1.0)
        ste_mask = mask + wgts - wgts.detach()
        
        #replicate inputs to match output neurons
        x_tiled = x.unsqueeze(1).expand(-1 , self.dout, -1) #(batch_size,dout,din)
        
        #apply mask
        x_selected = ste_mask.unsqueeze(0) * x_tiled
        
        #collect inputs for fredkin computations
        top3_inputs = torch.gather(x_selected,2,top_idx.unsqueeze(0).expand(batch_size,-1,-1)) #(batch_size,dout,3)
        roles_idx = self.roles.unsqueeze(0).expand(batch_size, -1, -1)  # (batch_size, dout, 3)
        fredkin_inputs = torch.gather(top3_inputs,2,roles_idx) #(batch_size,dout,3)
        u,a,b = fredkin_inputs[:,:,0],fredkin_inputs[:,:,1],fredkin_inputs[:,:,2]
        #print("u[0]:", u[0].detach().cpu().numpy())
        #print("a[0]:", a[0].detach().cpu().numpy())
        #print("b[0]:", b[0].detach().cpu().numpy())

        
        #fredkin computation
        v = u
        a_out = u*a + (1-u)*b
        b_out = u*b + (1-u)*a
        out = torch.stack([v , a_out , b_out],dim=-1)
        
        #flatten output
        out = out.reshape(x.size(0),-1)
        return out
        

In [None]:
#DEBUG HELPER CODE
        print("shapes:")
        #print("x_selected",x_selected.shape)
        print("X_tiled",x_tiled.shape)
        print("ste_mask",ste_mask.shape)
        print("ste_mask_squeezed",ste_mask.unsqueeze(0).shape)
        print("weights",self.wgts.shape)
        
        print("shapes of inputs:")
        print("u",u.shape)
        print("a",a.shape)
        print("b",b.shape)

In [57]:
#Define some easy training dataset
def create_majority_dataset(samplesize = 1000 , n_bits = 9 , p_one=0.5, device='cpu' ):
    x = torch.bernoulli(torch.full((samplesize,n_bits) , p_one , device=device)).float()
    labels = ( x.sum(dim=1) > (n_bits/2) ).float()
    return x,labels

x_train,labels_train = create_majority_dataset()
x_test,labels_test = create_majority_dataset(samplesize=100,n_bits=9,p_one=0.5)

#sneak preview in training data
for i in range(max(10,x_train.shape[0])):
    print(x_train[i].numpy(),labels_train[i].numpy())

train_data = TensorDataset(x_train,labels_train)
val_data = TensorDataset(x_test,labels_test)


[0. 1. 0. 0. 1. 1. 0. 1. 1.] 1.0
[1. 1. 0. 1. 1. 0. 1. 1. 1.] 1.0
[1. 1. 0. 0. 0. 1. 1. 1. 0.] 1.0
[1. 1. 1. 0. 1. 1. 1. 1. 1.] 1.0
[1. 1. 0. 1. 0. 0. 1. 0. 1.] 1.0
[1. 1. 1. 0. 1. 0. 1. 0. 1.] 1.0
[1. 0. 0. 0. 0. 1. 0. 1. 1.] 0.0
[1. 0. 1. 0. 0. 0. 1. 0. 0.] 0.0
[0. 1. 0. 1. 1. 0. 0. 1. 1.] 1.0
[1. 0. 1. 1. 1. 1. 0. 0. 1.] 1.0
[0. 0. 0. 1. 0. 0. 0. 1. 1.] 0.0
[0. 1. 1. 0. 0. 1. 1. 0. 1.] 1.0
[0. 0. 1. 1. 0. 0. 1. 1. 1.] 1.0
[1. 1. 0. 0. 0. 1. 1. 1. 0.] 1.0
[0. 1. 0. 1. 1. 0. 1. 1. 1.] 1.0
[0. 0. 0. 1. 1. 1. 0. 0. 0.] 0.0
[0. 0. 1. 1. 1. 1. 0. 1. 1.] 1.0
[1. 1. 1. 1. 0. 1. 0. 1. 1.] 1.0
[1. 1. 1. 1. 0. 1. 0. 0. 0.] 1.0
[1. 0. 1. 0. 0. 1. 1. 1. 1.] 1.0
[0. 1. 0. 0. 0. 1. 1. 0. 0.] 0.0
[1. 1. 1. 1. 1. 1. 1. 1. 0.] 1.0
[0. 1. 1. 1. 1. 1. 0. 1. 1.] 1.0
[1. 1. 1. 0. 1. 1. 1. 1. 0.] 1.0
[1. 0. 1. 0. 0. 0. 1. 0. 0.] 0.0
[0. 0. 0. 1. 0. 1. 1. 0. 0.] 0.0
[1. 1. 1. 1. 1. 1. 0. 0. 1.] 1.0
[1. 0. 1. 1. 0. 1. 1. 1. 1.] 1.0
[1. 1. 1. 1. 0. 1. 1. 0. 0.] 1.0
[0. 0. 0. 1. 0. 0. 1. 1. 0.] 0.0
[1. 1. 1. 

In [81]:
#Create model and run training
NUM_EPOCHS = 100
LEARNING_RATE = 0.0001

class FredkinNet(torch.nn.Module):
    def __init__(self,din,dout):
        super().__init__()
        self.fred1 = FredkinLayer(din,5,seed=42)
        self.fred2 = FredkinLayer(15,5,seed=43)
        self.fred3 = FredkinLayer(15,5,seed=44)
        self.fred4 = FredkinLayer(15,5,seed=45)
        self.fred5 = FredkinLayer(15,5,seed=46)
        self.fred6 = FredkinLayer(15,3,seed=47)
        self.fred7 = FredkinLayer(9,2,seed=48)
        self.fred8 = FredkinLayer(6,1,seed=49)
    
    def forward(self,x):
        out = self.fred1(x)
        out = self.fred2(out)
        out = self.fred3(out)
        out = self.fred4(out)
        out = self.fred5(out)
        out = self.fred6(out)
        out = self.fred7(out)
        out = self.fred8(out)
        a_primes = out[:,1::3]     #(batch, dout)
        out_summmed = a_primes.sum(dim=1) #(batch,)
        return out_summmed

train_loader = DataLoader(train_data,batch_size=4,shuffle=True)
val_loader = DataLoader(val_data,batch_size=4,shuffle=False)

net = FredkinNet(9,15)

optim = torch.optim.Adam(net.parameters(),lr=LEARNING_RATE)
criterion = torch.nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for epoch in range(1,NUM_EPOCHS+1):
    #TRAIN
    net.train()
    running_loss =0.0
    running_samples = 0
    
    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        
        logits = net(x)
        loss = criterion(logits,y)
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        bs = x.size(0)
        running_loss += loss.item()*bs
        running_samples += bs
    
    epoch_loss = running_loss/running_samples
    #VALIDATE
    net.eval()
    val_loss = 0.0
    val_samples = 0
    correct =0
    
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = net(x)
            loss = criterion(logits,y)
            
            bs = x.size(0)
            val_loss += loss.item()*bs
            val_samples += bs
            
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()
            correct += (preds == y).float().sum().item()
    val_loss = val_loss/val_samples
    val_acc = correct / val_samples
    print(f"Epoch {epoch:02d} train_loss = {epoch_loss:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.4f}")
    
        
        


Epoch 01 train_loss = 0.6684 val_loss=0.6932 val_acc=0.4600
Epoch 02 train_loss = 0.6701 val_loss=0.7422 val_acc=0.4600
Epoch 03 train_loss = 0.6834 val_loss=0.7580 val_acc=0.4600
Epoch 04 train_loss = 0.6780 val_loss=0.6932 val_acc=0.4600
Epoch 05 train_loss = 0.6749 val_loss=0.7546 val_acc=0.4600
Epoch 06 train_loss = 0.6702 val_loss=0.7266 val_acc=0.4600
Epoch 07 train_loss = 0.6747 val_loss=0.7756 val_acc=0.4600
Epoch 08 train_loss = 0.6713 val_loss=0.7580 val_acc=0.4600
Epoch 09 train_loss = 0.6836 val_loss=0.7580 val_acc=0.4600
Epoch 10 train_loss = 0.6708 val_loss=0.7266 val_acc=0.4600
Epoch 11 train_loss = 0.6739 val_loss=0.7756 val_acc=0.4600
Epoch 12 train_loss = 0.6675 val_loss=0.7246 val_acc=0.4600
Epoch 13 train_loss = 0.6712 val_loss=0.7370 val_acc=0.4600
Epoch 14 train_loss = 0.6870 val_loss=0.7370 val_acc=0.4600
Epoch 15 train_loss = 0.6829 val_loss=0.7370 val_acc=0.4600
Epoch 16 train_loss = 0.6849 val_loss=0.7266 val_acc=0.4600
Epoch 17 train_loss = 0.6748 val_loss=0.