In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn
import torch.nn.functional as F

from IPython import display
import time

from PIL import Image

In [826]:
def sym(x):
    z =  x + x.flip(2).flip(3) + x.transpose(2,3).flip(2) + x.transpose(2,3).flip(3)
    
    return z/4

class Model(nn.Module):
    def __init__(self, RES=512):
        super(Model,self).__init__()
        self.RES = RES
        
        self.spins = (2*torch.randint(2,size=(1,1,RES,RES))-1).cuda().half()
        
        self.J1 = nn.Conv2d(1,16,5,padding=2).cuda().half()
        self.J1.weight.data = sym(self.J1.weight.data)
        self.J2 = nn.Conv2d(16,16,5,padding=2).cuda().half()
        self.J2.weight.data = sym(self.J2.weight.data)
        
        self.J3 = nn.Conv2d(16,16,5,padding=2).cuda().half()
        self.J3.weight.data = sym(self.J3.weight.data)
        self.J4 = nn.Conv2d(16,16,5,padding=2).cuda().half()
        self.J4.weight.data = sym(self.J4.weight.data)
        
        self.J5 = nn.Conv2d(16,1,5,padding=2, bias=None).cuda().half()
        self.J5.weight.data = sym(self.J5.weight.data)
        
        self.T = 10
        
    def step(self):
        z = F.relu(self.J1(self.spins))
        mu = z.view(z.size(0),z.size(1),-1).mean(2).unsqueeze(2).unsqueeze(3)
        std = z.view(z.size(0),z.size(1),-1).std(2).unsqueeze(2).unsqueeze(3)
        z = (z-mu)/(1e-8+std)
        
        z = F.relu(self.J2(z))
        mu = z.view(z.size(0),z.size(1),-1).mean(2).unsqueeze(2).unsqueeze(3)
        std = z.view(z.size(0),z.size(1),-1).std(2).unsqueeze(2).unsqueeze(3)
        z = (z-mu)/(1e-8+std)

        z = F.relu(self.J3(z))
        mu = z.view(z.size(0),z.size(1),-1).mean(2).unsqueeze(2).unsqueeze(3)
        std = z.view(z.size(0),z.size(1),-1).std(2).unsqueeze(2).unsqueeze(3)
        z = (z-mu)/(1e-8+std)

        z = F.relu(self.J4(z))
        mu = z.view(z.size(0),z.size(1),-1).mean(2).unsqueeze(2).unsqueeze(3)
        std = z.view(z.size(0),z.size(1),-1).std(2).unsqueeze(2).unsqueeze(3)
        z = (z-mu)/(1e-8+std)

        z = self.J5(z)
                
        p = F.sigmoid(self.T * (z + 0.3 * self.spins))
        
        self.spins = (2*(torch.rand_like(self.spins)<p).half()-1).detach()
        

In [827]:
model = Model()

torch.Size([400])
torch.Size([16])
torch.Size([6400])
torch.Size([16])
torch.Size([6400])
torch.Size([16])
torch.Size([6400])
torch.Size([16])
torch.Size([400])


In [828]:
model.T = 0.1
idx = 0

for i in range(4000):
    model.step()
    im = (255*(1+model.spins.cpu().detach().numpy())/2).astype(np.uint8)[0,0]
    
    if i%10==0:
        im = (255*(1+model.spins.cpu().detach().numpy())/2).astype(np.uint8)
        im = Image.fromarray(im[0,0])
        im.save("output/%.6d.png" % idx)
        idx += 1
        
        model.T += 0.05

