In [1]:
from scipy import stats
import numpy as np

In [2]:
# Dict to convert movement idx to actual movement coordinates
idx2mov = {0:np.array([0,0], dtype=int), 
               1:np.array([1,0], dtype=int), 
               2:np.array([-1,0], dtype=int), 
               3:np.array([0,1], dtype=int), 
               4:np.array([0,-1], dtype=int)}

# Convert coordinate to flattened idx
def loc2idx(loc, grid_size=np.array([5, 5], dtype=int)):
    return loc[0]*grid_size[0] + loc[1]

# Convert location flattened idx to coordinate
def idx2loc(idx, grid_size=np.array([5, 5], dtype=int)):
    return np.array([idx // grid_size[0], idx % grid_size[0]], dtype=int)


def sample_n_back_spatial(n, p_stop=0.05, max_length=40, grid_size=np.array([5, 5], dtype=int), boundary='periodic', return_trajectory=False):
    """
    Function to generate a sample for the n-back spatial task.

    Args:
    - n: response delay
    - p_stop: after n steps, probability of stoping walk (default=0.05)
    - max_length: maximum trajectory length (left zero-padding is applied to reach this length)
    - grid_size (array-like): size of gridworld, must be odd (default=[5,5])
    - boundary ['periodic', 'strict']: boundary conditions
    - return_trajectory (bool): whether to return trajectory

    Returns: movements (1D array, as index), n_back_idx (n-back location as idx), (trajectory) 
    
    """
    assert boundary in ['periodic', 'strict'], "boundary must be either 'periodic' or 'strict'"
    assert (grid_size[0] % 2 == 1) & (grid_size[1] % 2 == 1), "grid size must be odd"

    zero = np.array([(grid_size[0]-1)//2, (grid_size[1]-1)//2], dtype=int)
    movements = np.random.randint(5, size=np.minimum((n + stats.nbinom.rvs(1, p_stop)), max_length))
    movements = np.concat(([0]*(max_length - movements.shape[0]), movements))
    
    trajectory = [zero]
    
    for idx in movements:
        if boundary == 'periodic':
            trajectory.append((trajectory[-1] + idx2mov[idx]) % grid_size)
        elif boundary == 'strict':
            trajectory.append(np.clip(trajectory[-1] + idx2mov[idx], a_min=[0,0], a_max=grid_size))
        
    trajectory = np.array(trajectory)

    n_back_idx = loc2idx(trajectory[-(n+1)], grid_size=grid_size)

    if return_trajectory:
        return movements, n_back_idx, trajectory
    else:
        return movements, n_back_idx


In [3]:
from torch.utils.data import Dataset, DataLoader
import torch

class NBackDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
		
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def create_n_back_dataset(num_samples, n, p_stop=0.05, max_length=40, grid_size=np.array([5, 5], dtype=int), boundary='periodic'):
    X, Y = [], []
    for _ in range(num_samples):
        x, y = sample_n_back_spatial(n, p_stop=p_stop, max_length=max_length, grid_size=grid_size, boundary=boundary)
        X.append(x); Y.append(y)

    X = np.vstack(X)
    Y = np.array(Y)

    X = torch.tensor(X, dtype=int)
    Y = torch.tensor(Y, dtype=int)

    return NBackDataset(X, Y)

In [4]:
data = create_n_back_dataset(100, 3)
data[0][0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 2, 3, 1])

In [5]:
import torch.nn.functional as F
from torch import nn

class GRUExplorer(nn.Module):
    
    def __init__(self, hidden_state_size, num_layers=1, grid_size=np.array([5, 5], dtype=int)):

        super(GRUExplorer, self).__init__()
        
        self.hidden_state_size = hidden_state_size
        self.output_size = grid_size[0]*grid_size[1]
        self.num_layers = num_layers

        self.core = nn.GRU(5, self.hidden_state_size, batch_first=True)
        self.head = nn.Linear(self.hidden_state_size, self.output_size)

    def forward(self, X):

        X = F.one_hot(X, num_classes=5).to(torch.float32)
        h0 = torch.zeros(self.num_layers, X.size(0), self.hidden_state_size).to(X.device)
        
        states, _ = self.core(X, h0)
        logits = self.head(states[:, -1, :])
        return torch.softmax(logits, -1)

In [6]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [7]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
from torch import optim

batch_size = 1000
epochs = 100
train_sample_size = 100_000
test_sample_size = 10_000

n_back = 0

train_dataloader = DataLoader(create_n_back_dataset(train_sample_size, n_back), batch_size=batch_size)
test_dataloader = DataLoader(create_n_back_dataset(test_sample_size, n_back), batch_size=batch_size)

model = GRUExplorer(128, num_layers=2).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
    
print("Done!")

Epoch 1
-------------------------------
loss: 3.219199  [ 1000/100000]
loss: 3.207131  [11000/100000]
loss: 3.160513  [21000/100000]
loss: 3.181263  [31000/100000]
loss: 3.169338  [41000/100000]
loss: 3.197101  [51000/100000]
loss: 3.178111  [61000/100000]
loss: 3.169788  [71000/100000]
loss: 3.160528  [81000/100000]
loss: 3.170960  [91000/100000]
Test Error: 
 Accuracy: 11.6%, Avg loss: 3.168160 

Epoch 2
-------------------------------
loss: 3.163949  [ 1000/100000]
loss: 3.171055  [11000/100000]
loss: 3.162553  [21000/100000]
loss: 3.166800  [31000/100000]
loss: 3.155782  [41000/100000]
loss: 3.172868  [51000/100000]
loss: 3.160067  [61000/100000]
loss: 3.150602  [71000/100000]
loss: 3.153774  [81000/100000]
loss: 3.151710  [91000/100000]
Test Error: 
 Accuracy: 11.5%, Avg loss: 3.153340 

Epoch 3
-------------------------------
loss: 3.148552  [ 1000/100000]
loss: 3.160189  [11000/100000]
loss: 3.153510  [21000/100000]
loss: 3.159347  [31000/100000]
loss: 3.151096  [41000/100000]
l