In [1]:
from RNN_QSR import *
from Patched_TF import PE2D
def patch2D(x,n,Lx):
    # type: (Tensor,int,int) -> Tensor
    """patch your sequence into chunks of 4"""
    #make the input 2D then break it into 2x2 chunks 
    #afterwards reshape the 2x2 chunks to vectors of size 4 and flatten the 2d bit
    return x.view([x.shape[0],Lx,Lx]).unfold(-2,n,n).unfold(-2,n,n).reshape([x.shape[0],int(Lx*Lx//n**2),int(n**2)])

def unpatch2D(x,n,Lx):
    # type: (Tensor,int,int) -> Tensor
    """inverse function for patch"""
    # original sequence order can be retrieved by chunking twice more
    #in the x-direction you should have chunks of size 2, but in y it should
    #be chunks of size Ly//2
    return x.unfold(-2,Lx//n,Lx//n).unfold(-2,n,n).reshape([x.shape[0],Lx*Lx])


Lx=6
a = torch.arange(Lx**2).unsqueeze(0)
print(a)
print(a.view([Lx,Lx]))
b = patch2D(a,3,Lx)
print(b)
c = unpatch2D(b,3,Lx)
print(c)

cuda:0
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]])
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])
tensor([[[ 0,  1,  2,  6,  7,  8, 12, 13, 14],
         [ 3,  4,  5,  9, 10, 11, 15, 16, 17],
         [18, 19, 20, 24, 25, 26, 30, 31, 32],
         [21, 22, 23, 27, 28, 29, 33, 34, 35]]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]])


# The below functions are only for patches of size 4

In [2]:
@torch.jit.script
def patch2idx(patch):
    #moving the last dimension to the front
    patch=patch.unsqueeze(0).transpose(-1,0).squeeze(-1)
    out=torch.zeros(patch.shape[1:],device=patch.device)
    for i in range(4):
        out+=patch[i]<<i
    return out.to(torch.int64)

@torch.jit.script
def patch2onehot(patch):
    #moving the last dimension to the front
    patch=patch.unsqueeze(0).transpose(-1,0).squeeze(-1)
    out=torch.zeros(patch.shape[1:],device=patch.device)
    for i in range(4):
        out+=patch[i]<<i
    return nn.functional.one_hot(out.to(torch.int64), num_classes=16)

In [3]:
class PTFRNN(Sampler):#(torch.jit.ScriptModule):
    TYPES={"GRU":nn.GRU,"ELMAN":nn.RNN,"LSTM":nn.LSTM}
    """
    Base class for the two patch transformer architectures 
    
    Architexture wise this is how a patched transformer works:
    
    You give it a (2D) state and it patches it into groups of 4 (think of a 2x2 cnn filter with stride 2). It then tells you
    the probability of each patch given it and all previous patches in your sequence using masked attention.
    
    Outputs should either be size 1 (the probability of the current patch which is input) or size 16 (for 2x2 patches where 
    the probability represented is of each potential patch)
    
    """
    def __init__(self,Lx,px=4,device=device,Nh=128,dropout=0.0,num_layers=2,nhead=8,rnn_patch=4,rnntype="GRU", **kwargs):
        super(PTFRNN, self).__init__()
        #print(nhead)
        self.pe = PE2D(Nh, Lx,Lx,device)
        self.device=device
        #Encoder only transformer
        #misinterperetation on encoder made it so this code does not work
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=Nh, nhead=nhead, dropout=dropout)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)        
        
        assert rnn_patch==4
        #have to do more work to implement other values but I may want to do so later
        
        assert rnntype!="LSTM"
        #rnn takes input shape [B,L,1]
        self.rnn = RNN.TYPES[rnntype](input_size=rnn_patch,hidden_size=Nh,batch_first=True)
        
        
        self.lin = nn.Sequential(
                nn.Linear(Nh,Nh),
                nn.ReLU(),
                nn.Linear(Nh,2**rnn_patch),
                nn.Softmax(dim=-1)
            )
        
        #sequence size in x
        self.Lx=Lx
        #transformer patch size in x
        self.px=px
        #transformer patch size
        self.p=px**2
        #rnn patch size
        self.prnn=rnn_patch
        
        self.set_mask(Lx**2//self.p)
        
        self.options=torch.zeros([16,4],device=self.device)
        tmp=torch.arange(16,device=self.device)
        for i in range(4):
            self.options[:,i]=(tmp>>i)%2
        
        
        self.to(device)
        
    def set_mask(self, L):
        # type: (int)
        # take the log of a lower triangular matrix
        self.L=L
        self.mask = torch.log(torch.tril(torch.ones([L,L],device=self.device)))
        self.pe.L=L
        
    
    def rnnforward(self,hidden,input):
        """
        Inputs:
            hidden - shape [L/p,B,Nh] tensor
            input - shape [L/p,B,p] tensor
        Outputs
            out - shape [L/p,B,p/4,Nh] tensor
        """
        
        Lp,B,Nh=hidden.shape
        h0 = hidden.view([1,Lp*B,Nh])
        input = patch2D(input.reshape([Lp*B,self.p]),2,self.px)
        #probably need this?
        input[:,1:]=input[:,:-1]
        input[:,0]=0
        
        out,h=self.rnn(input,h0)
        return out.view([Lp,B,self.p//4,Nh])
        
        
    def forward(self, input):
        # input is shape [B,L,1]
        # add positional encoding to get shape [B,L,Nh]
        if input.shape[1]//self.p!=self.L:
            self.set_mask(input.shape[1]//self.p)
        
        input=patch2D(input.squeeze(-1),self.px,self.Lx).transpose(1,0)
        #pe should be sequence first [L/p,B,Nh]
        hidden = self.transformer(self.pe(input),self.mask)        
        rnnout = self.rnnforward(hidden,input)
        
        #[L/p,B,p/4,Nh] -> [B,L/4,16]
        output = self.lin(rnnout).transpose(1,0).view([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
        return output
    
    def next_with_cache(self,tgt,cache=None,idx=-1):
        # type: (Tensor,Optional[Tensor],int) -> Tuple[Tensor,Tensor]
        """Efficiently calculates the next output of a transformer given the input sequence and 
        cached intermediate layer encodings of the input sequence
        
        Inputs:
            tgt - Tensor of shape [L,B,1]
            cache - Tensor of shape ?
            idx - index from which to start
            
        Outputs:
            output - Tensor of shape [?,B,1]
            new_cache - Tensor of shape ?
        """
        #HMMM
        output = tgt
        new_token_cache = []
        #go through each layer and apply self attention only to the last input
        for i,layer in enumerate(self.transformer.layers):
            
            tgt=output
            #have to merge the functions into one
            src = tgt[idx:, :, :]
            mask = None if idx==-1 else self.mask[idx:]

            # self attention part
            src2 = layer.self_attn(
                src,#only do attention with the last elem of the sequence
                tgt,
                tgt,
                attn_mask=mask,  
                key_padding_mask=None,
            )[0]
            #straight from torch transformer encoder code
            src = src + layer.dropout1(src2)
            src = layer.norm1(src)
            src2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(src))))
            src = src + layer.dropout2(src2)
            src = layer.norm2(src)
            #return src
            
            output = src#self.next_attn(output,layer,idx)
            new_token_cache.append(output)
            if cache is not None:
                #layers after layer 1 need to use a cache of the previous layer's output on each input
                output = torch.cat([cache[i], output], dim=0)

        #update cache with new output
        if cache is not None:
            new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1)
        else:
            new_cache = torch.stack(new_token_cache, dim=0)

        return output, new_cache
    
    def make_cache(self,tgt):
        output = tgt
        new_token_cache = []
        #go through each layer and apply self attention only to the last input
        for i, layer in enumerate(self.transformer.layers):
            output = layer(output,src_mask=self.mask)#self.next_attn(output,layer,0)
            new_token_cache.append(output)
        #create cache with tensor
        new_cache = torch.stack(new_token_cache, dim=0)
        return output, new_cache
    
    @torch.jit.export
    def sample_with_labelsALT(self,B,L,grad=False,nloops=1):
        # type: (int,int,bool,int) -> Tuple[Tensor,Tensor,Tensor]
        sample,probs = self.sample_with_labels(B,L,grad,nloops)
        logsqrtp=probs.mean(dim=1)/2
        sumsqrtp = torch.exp(probs/2-logsqrtp.unsqueeze(1)).sum(dim=1)
        return sample,sumsqrtp,logsqrtp
    @torch.jit.export
    def sample_with_labels(self,B,L,grad=False,nloops=1):
        # type: (int,int,bool,int) -> Tuple[Tensor,Tensor]
        sample=self.sample(B,L,None)
        return self._off_diag_labels(sample,B,L,grad,nloops)
    
    @torch.jit.export
    def sample(self,B,L,cache=None):
        # type: (int,int,Optional[Tensor]) -> Tensor
        """ Generates a set states
        Inputs:
            B (int)            - The number of states to generate in parallel
            L (int)            - The length of generated vectors
        Returns:
            samples - [B,L,1] matrix of zeros and ones for ground/excited states
        """
        return self.sampleDebug(B,L,cache)[0]
    @torch.jit.export
    def sampleDebug(self,B,L,cache=None):
        # type: (int,int,Optional[Tensor]) -> Tuple[Tensor,Tensor]


        #length is divided by four due to patching
        L=L//self.p
        
        DEBUG=torch.zeros([B],device=self.device)
        #return (torch.rand([B,L,1],device=device)<0.5).to(torch.float32)
        #Sample set will have shape [B,L,1]
        #need one extra zero batch at the start for first pred hence input is [L+1,B,1] 
        input = torch.zeros([L+1,B,self.p],device=self.device)
         
        with torch.no_grad():
          for idx in range(1,L+1):
            
            #pe should be sequence first [L,B,Nh]
            encoded_input = self.pe(input[:idx,:,:])
                        
            #Get transformer output
            output,cache = self.next_with_cache(encoded_input,cache)
            
            h = output[-1,:,:].unsqueeze(0)
            
            rnnseq = torch.zeros([B,self.p//4+1,4],device=self.device)
            for rdx in range(1,self.p//4+1):
                
                out,h=self.rnn(rnnseq[:,rdx-1:rdx,:],h)
                #check out the probability of all 16 vectors
                probs=self.lin(out).view([B,16])
                #sample from the probability distribution
                indices = torch.multinomial(probs,1,False).squeeze(1)
                #extract samples
                sample = self.options[indices]
                #add sample to sequence
                rnnseq[:,rdx] = sample
                
                #debugging info
                real=patch2onehot(sample)
                total = torch.sum(real*probs,dim=-1)
                DEBUG+=torch.log(total)
                
            #set input to the (unpatched) rnn sequence
            input[idx] = unpatch2D(rnnseq[:,1:],2,self.px)
            
        #remove the leading zero in the input    
        input=input[1:]
        #sample is repeated 16 times at 3rd index so we just take the first one
        return unpatch2D(input.transpose(1,0),self.px,self.Lx).unsqueeze(-1),DEBUG
    
    
    @torch.jit.export
    def logprobability(self,input):
        # type: (Tensor) -> Tensor
        """Compute the logscale probability of a given state
            Inputs:
                input - [B,L,1] matrix of zeros and ones for ground/excited states
            Returns:
                logp - [B] size vector of logscale probability labels
        """
                
        if input.shape[1]//self.p!=self.L:
            self.set_mask(input.shape[1]//self.p)
        
        #shape is modified to [L//p,B,p]
        input = patch2D(input.squeeze(-1),self.px,self.Lx).transpose(1,0)
        
        data=torch.zeros(input.shape,device=self.device)
        data[1:]=input[:-1]
        
        #[L//p,B,p] -> [L//p,B,Nh]
        encoded=self.pe(data)
        # [L//p,B,Nh] -> [L//p,B,Nh]
        hidden = self.transformer(encoded,self.mask)     
        # [L//p,B,Nh],[L//p,B,p] -> [L/p,B,p/4,Nh]
        rnnout = self.rnnforward(hidden,input)
        
        #[L/p,B,p/4,Nh] -> [B,L/4,16]
        output = self.lin(rnnout).transpose(1,0).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
        
        #real is going to be a onehot with the index of the appropriate patch set to 1
        #shape will be [B,L//4,16]
        
        Lp,B,Nh=hidden.shape
        
        
        
        #reshaping the input to match the shape of the output (idk if this is correct tbh)
        susman = patch2D(input.reshape([Lp*B,self.p]),2,self.px).view([Lp,B,self.p//4,4]).transpose(1,0)

        
        #print(susman[:,0])
        
        real=patch2onehot(susman).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
        
        #print(real.shape,output.shape)
        
        #[B,L//4,16] -> [B,L//4]
        total = torch.sum(real*output,dim=-1)
        #[B,L//4] -> [B]
        logp=torch.sum(torch.log(total+1e-10),dim=1)
        return logp

    
    
    
    @torch.jit.export
    def _off_diag_labels(self,sample,B,L,grad,D=1):
        # type: (Tensor,int,int,bool,int) -> Tuple[Tensor, Tensor]
        """label all of the flipped states  - set D as high as possible without it slowing down runtime
        Parameters:
            sample - [B,L,1] matrix of zeros and ones for ground/excited states
            B,L (int) - batch size and sequence length
            D (int) - Number of partitions sequence-wise. We must have L%D==0 (D divides L)
            
        Outputs:
            
            sample - same as input
            probs - [B,L] matrix of probabilities of states with the jth excitation flipped
        """
        
        
        
        sample0=sample
        #sample is batch first at the moment
        sample = patch2D(sample.squeeze(-1),self.px,self.Lx)
        
        sflip = torch.zeros([B,L,L//self.p,self.p],device=self.device)
        #collect all of the flipped states into one array
        for j in range(L//self.p):
            #have to change the order of in which states are flipped for the cache to be useful
            for j2 in range(self.p):
                sflip[:,j*self.p+j2] = sample*1.0
                sflip[:,j*self.p+j2,j,j2] = 1-sflip[:,j*self.p+j2,j,j2]
            
        #switch sample into sequence-first
        sample = sample.transpose(1,0)
            
        #compute all of their logscale probabilities
        with torch.no_grad():
            

            data=torch.zeros(sample.shape,device=self.device)
            data[1:]=sample[:-1]
            
            #[L//4,B,4] -> [L//4,B,Nh]
            encoded=self.pe(data)
            
            #add positional encoding and make the cache
            out,cache=self.make_cache(encoded)

            probs=torch.zeros([B,L],device=self.device)
            #expand cache to group L//D flipped states
            cache=cache.unsqueeze(2)

            #this line took like 1 hour to write I'm so sad
            #the cache has to be shaped such that the batch parts line up
                        
            cache=cache.repeat(1,1,L//D,1,1).transpose(2,3).reshape(cache.shape[0],L//self.p,B*L//D,cache.shape[-1])

            
            
            rnnout = self.rnnforward(out,sample)
            #[L/p,B,p/4,Nh] -> [B,L/4,16]
            pred0 = self.lin(rnnout).transpose(1,0).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
            #real is going to be a onehot with the index of the appropriate patch set to 1
            #shape will be [B,L//4,16]
            Lp,B,Nh=out.shape
            #reshaping the input to match the shape of the output (idk if this is correct tbh)
            susman = patch2D(sample.reshape([Lp*B,self.p]),2,self.px).view([Lp,B,self.p//4,4]).transpose(1,0)
            real=patch2onehot(susman).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])

            
            #[B,L//4,16] -> [B,L//4]
            total0 = torch.sum(real*pred0,dim=-1)

            for k in range(D):

                N = k*L//D
                #next couple of steps are crucial          
                #get the samples from N to N+L//D
                #Note: samples are the same as the original up to the Nth spin
                real = sflip[:,N:(k+1)*L//D]
                #flatten it out and set to sequence first
                tmp = real.reshape([B*L//D,L//self.p,self.p]).transpose(1,0)
                #set up next state predction
                fsample=torch.zeros(tmp.shape,device=self.device)
                fsample[1:]=tmp[:-1]
                # put sequence before batch so you can use it with your transformer
                tgt=self.pe(fsample)
                #grab your transformer output
                out,_=self.next_with_cache(tgt,cache[:,:N//self.p],N//self.p)
                #only the newly computed parts are necessary
                out = out [N//self.p:]
                
                rnnout = self.rnnforward(out,tmp[N//self.p:])
                
                Lp,Bp,Nh=out.shape
                
                
                # grab output for the new part
                output = self.lin(rnnout).transpose(1,0).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
                
                # reshape output separating batch from spin flip grouping
                pred = output.view([B,L//D,(L-N)//4,16])
                
                susman = patch2D(tmp[N//self.p:].reshape([Lp*Bp,self.p]),2,self.px).view([Lp,Bp,self.p//4,4]).transpose(1,0)
        
                real=patch2onehot(susman).reshape([rnnout.shape[1],rnnout.shape[0]*rnnout.shape[2],16])
                
                real = real.view([B,L//D,(L-N)//4,16])
                
                total = torch.sum(real*pred,dim=-1)
                #sum across the sequence for probabilities
                
                #print(total.shape,total0.shape)
                logp=torch.sum(torch.log(total+1e-10),dim=-1)
                logp+=torch.sum(torch.log(total0[:,:N//4]+1e-10),dim=-1).unsqueeze(-1)
                probs[:,N:(k+1)*L//D]=logp
                
        return sample0,probs

In [4]:
pb = PTFRNN(8)
sample,p0 = pb.sampleDebug(6,8*8)

print(sample.shape)

ps = pb.logprobability(sample)

print(ps.shape)

torch.Size([6, 64, 1])
torch.Size([6])


In [5]:
print("",ps,'\n',p0)

 tensor([-44.7609, -44.6124, -44.1222, -43.8909, -44.3229, -44.9748],
       device='cuda:0', grad_fn=<SumBackward1>) 
 tensor([-44.7608, -44.6122, -44.1223, -43.8911, -44.3229, -44.9747],
       device='cuda:0')


In [6]:
def get_indices(px,Lx):
    sflip = torch.arange(Lx*Lx,device=device).to(torch.int64).reshape([1,Lx,Lx])
    sflip = patch2D(sflip,px,Lx).reshape(Lx*Lx)
    
    return sflip

if False:
    B=32

    s = pb.sample(B,8*8)
    probs = super(PTFRNN,pb)._off_diag_labels(s,B,8*8,False,D=4)[1][:,get_indices(4,8)]
    
    p2 = pb._off_diag_labels(s,B,8*8,False,D=4)[1]

    print(abs(probs-p2).mean().item(),torch.var_mean(probs)[0].item()**0.5)
    print(probs.mean(),p2.mean())
    print(abs(probs-p2).max())
    plt.imshow(abs(probs-p2).cpu())

In [7]:
op=Opt()
Lx=24
op.L=Lx*Lx
op.Nh=128
op.lr=5e-4
op.M=0.9
op.Q=1
op.K=256
op.USEQUEUE=0
op.kl=0.0
#op.apply(sys.argv[1:])
op.B=op.K*op.Q

#op.steps=4000
op.dir="PTFRNN"
#op.steps=100
op.NLOOPS=36
print(op)

L                             			576
Q                             			1
K                             			256
B                             			256
TOL                           			0.15
M                             			0.9
USEQUEUE                      			0
NLOOPS                        			36
hamiltonian                   			Rydberg
steps                         			12000
dir                           			PTFRNN
Nh                            			128
lr                            			0.0005
kl                            			0.0



In [8]:
trainsformer = torch.jit.script(PTFRNN(Lx,Nh=op.Nh))

beta1=0.9;beta2=0.999
optimizer = torch.optim.Adam(
trainsformer.parameters(), 
lr=op.lr, 
betas=(beta1,beta2)
)

  "Consider removing it.".format(name))


In [None]:
if op.USEQUEUE:
    queue_train(op,(trainsformer,sampleformer,optimizer))
else:
    print("Training. . .")
    reg_train(op,(trainsformer,optimizer))

Training. . .
Output folder path established
-0.3724 576
1,2.79|636,-0.32|1268,-0.33|1900,-0.33|2561,-0.34|3196,-0.35|3830,-0.37|4464,-0.37|5098,-0.37|5735,-0.37|6371,-0.37|7421,-0.37|8634,-0.37|9869,-0.37|11073,-0.37|12012,-0.37|