In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numba

In [None]:
#@numba.jit
def randomWalk1D(numberOfCycles, jumpsPerCycle, p=0.5):
    distribution = np.zeros(2 * jumpsPerCycle + 1)
    positions = np.zeros(jumpsPerCycle)
    msd = np.zeros(jumpsPerCycle)
    
    for cycle in range(numberOfCycles):
        currentPosition = 0
        
        # Make displacements and save to positions array
        for k in range(jumpsPerCycle):
            currentPosition += (1 if np.random.rand() < p else -1)
            positions[k] = currentPosition
        
        # Update MSD and position position distribution
        msd += sampleMSD(positions, jumpsPerCycle)
        distribution[int(currentPosition + jumpsPerCycle)] += 1
        
    # normalize
    distribution /= (numberOfCycles * jumpsPerCycle)
    msd /= numberOfCycles
    return distribution, msd

#@numba.jit
def sampleMSD(positions, jumpsPerCycle):
    # This implementation is post-hoc, which is fine if you can pass an array of positions, 
    # but is intractable for simulations, as we then have to save (T, N, 3) sized position arrays.
    msd = np.zeros(jumpsPerCycle)
    for dt in range(1, jumpsPerCycle):
        msd[dt] = np.mean((positions[dt:] - positions[:-dt])**2)
    return msd

latticeVectors = np.array([[1,0], [0,1], [-1,0], [0,-1]], dtype=np.int32)

@numba.njit
def randomWalk2D(numberOfCycles: int, numberOfParticles: int, latticeSize: int, maxCorrelationTime: int = 500, maxOrigins: int = 50):
    lattice = np.zeros((latticeSize, latticeSize), dtype=np.int32)
    
    # select random lattice sites to initialize particles
    indices = np.random.choice(latticeSize**2, size=numberOfParticles, replace=False)
    xPositions = indices % latticeSize
    yPositions = indices // latticeSize
    
    positions = np.column_stack((xPositions, yPositions))
    unwrappedPositions = positions.copy()

    # msd sampling
    msd = np.zeros((maxCorrelationTime, 2))
    counts = np.zeros(maxCorrelationTime, dtype=np.int32)
    originPositions = np.zeros((maxOrigins, numberOfParticles, 2))
    originTimes = np.zeros(maxOrigins, dtype=np.int32)
    originIndex = 0
    originInterval = 50

    accepted = 0

    for cycle in range(numberOfCycles):
        # Select particle and displacement
        particleIndex = np.random.choice(numberOfParticles)
        dx = latticeVectors[np.random.choice(4)]
        xold, yold = positions[particleIndex]
        
        # Get new position and wrap in box
        newPosition = (positions[particleIndex] + dx) % latticeSize
        xnew, ynew = newPosition
        
        # Check if lattice site is occupied
        if lattice[xnew, ynew] == 0:
            accepted += 1
            lattice[xold, yold] = 0
            lattice[xnew, ynew] = 1
            positions[particleIndex] = newPosition
            unwrappedPositions[particleIndex] += dx
    
        if (cycle > 0.25 * numberOfCycles):
            # sample msd
            if cycle % originInterval == 0:
                originTimes[originIndex] = cycle
                originPositions[originIndex] = unwrappedPositions
                originIndex = (originIndex + 1) % maxOrigins

            for i in range(min(cycle // maxOrigins, maxOrigins)):
                time_difference = cycle - originTimes[i]
                if time_difference < maxCorrelationTime:
                    counts[time_difference] += 1
                    msd[time_difference] += np.sum((unwrappedPositions - originPositions[i])**2, axis=0)

    nonZero = counts > 0
    msd[nonZero] /= counts[nonZero][:, None]

    print(f"Total accepted: {accepted}\nLattice occupation: {np.sum(lattice)}")
    return msd

    

In [None]:
numberOfCycles = 10000
jumpsPerCycle = 100
p = 0.5

walk, msd = randomWalk1D(numberOfCycles, jumpsPerCycle, p)

In [None]:
fig, ax = plt.subplots(2)
ax[0].plot(np.arange(-jumpsPerCycle, jumpsPerCycle + 1), walk)
ax[0].set_xlim(-100, 100)
ax[1].plot(msd)

In [None]:
numberOfParticles = 800
latticeSize = 40
numberOfCycles = 1000000

In [None]:
msd = randomWalk2D(numberOfCycles, numberOfParticles, latticeSize)

In [None]:
fig, ax = plt.subplots()
ax.plot(msd[:, 0])
ax.plot(msd[:, 1])

In [None]:
x.plot(msd[:, 1])