# Copying Memory Task

## Overview

In this task, each input sequence has length T+20. The first 10 values are chosen randonly among the digits 1-8, with the rest being all zeros, except for the last 11 entries that are filled with the digits '9' (the first '9' is a delimiter). The goal is to generate an output of same length that is zero everywhere, except the last 10 values after the delimier, where the model is excepted to repeat the 10 values it encountered at the start of the input.

**NOTE**: Because a TCN's receptive fields depends on depth of the network and the filter size, we need to make sure the model we use can cover the sequence length T+20. Using the `seq_len` flag, one can change the # of values to recall(the typical setup is 10).

## Settings

In [1]:
import torch as th
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
import torch.nn.functional as F

BATCH_SIZE = 32
DEVICE = "cpu"
DROPOUT = 0.0
CLIP = 1.0
EPOCHS = 5
KSIZE = 8
ITERS = 100
LEVELS = 8
BLANK_LEN = 10000
SEQ_LEN = 10
LR = 5e-4
OPTIM = "Adam"
NHID = 10
SEED = 1111

T = BLANK_LEN
N_STEPS = T + (2 * SEQ_LEN)
N_CLASSES = 10
N_TRAIN = 10000
N_TEST = 1000

CHANNEL_SIZES = [NHID] * LEVELS

th.manual_seed(SEED)

<torch._C.Generator at 0x7f773c0999b0>

## Data Generation

In [2]:
def data_generator(b_size):
    """Generate data for the copying memory task.
    
    :param T: the total blank time length.
    :param mem_length: the length of the memory to be recalled.
    :param b_size: the batch size.
    """
    seq = th.from_numpy(np.random.randint(1, 9, size=(b_size, SEQ_LEN))).float()
    zeros = th.zeros((b_size, T))
    marker = 9 * th.ones((b_size, SEQ_LEN + 1))
    place_holders = th.zeros((b_size, SEQ_LEN))
    
    x = th.cat((seq, zeros[:, :-1], marker), 1)
    y = th.cat((place_holders, zeros, seq), 1).long()
    
    return x, y

print("Producing data...")
train_x, train_y = data_generator(N_TRAIN)
test_x, test_y = data_generator(N_TEST)

train_x, train_y = train_x.to(DEVICE), train_y.to(DEVICE)
test_x, test_y = test_x.to(DEVICE), test_y.to(DEVICE)
print("Finished.")

Producing data...
Finished.


## Build Model

In [3]:
from core.tcn import TemporalConvNet

class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super().__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)
    
    def forward(self, x):
        y1 = self.tcn(x)
        return self.linear(y1.transpose(1, 2))



print("Building model...")
model = TCN(1, N_CLASSES, CHANNEL_SIZES, KSIZE, dropout=DROPOUT)
model = model.to(DEVICE)

optimizer = getattr(th.optim, OPTIM)(model.parameters(), lr=LR)
print("Finished.")

Building model...
Finished.


## Run

In [5]:
def evaluate():
    model.eval()
    with th.no_grad():
        out = model(test_x.unsqueeze(1).contiguous())
        pred = out.view(-1, N_CLASSES).data.max(1, keepdim=True)[1]
        correct = pred.eq(test_y.data.view_as(pred)).cpu().sum()
        counter = out.view(-1, N_CLASSES).size(0)
    print(f'Accuracy: {100. * correct / counter:.4f}')

def train(ep):
    model.train()
    process = tqdm(range(0, N_TRAIN, BATCH_SIZE))
    for batch in process:
        start_ind = batch
        end_ind = start_ind + BATCH_SIZE

        x = train_x[start_ind:end_ind]
        y = train_y[start_ind:end_ind]
        
        optimizer.zero_grad()
        out = model(x.unsqueeze(1).contiguous())
        loss = F.cross_entropy(out.view(-1, N_CLASSES), y.view(-1))
        pred = out.view(-1, N_CLASSES).data.max(1, keepdim=True)[1]
        if CLIP > 0:
            th.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        loss.backward()
        optimizer.step()
        process.set_description(f"Train Epoch: {ep:2d}, loss: {loss.item():.6f}")

for ep in range(1, EPOCHS + 1):
    train(ep)
    evaluate()

  0%|          | 0/313 [00:00<?, ?it/s]

Accuracy: 99.9125


  0%|          | 0/313 [00:00<?, ?it/s]

Accuracy: 99.9124


  0%|          | 0/313 [00:00<?, ?it/s]

Accuracy: 99.9122


  0%|          | 0/313 [00:00<?, ?it/s]

Accuracy: 99.9125


  0%|          | 0/313 [00:00<?, ?it/s]

Accuracy: 99.9124
