In [1]:
import torch
import numpy as np
import wiredOR_dataset
from torch.utils.data import sampler

In [2]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cpu


In [3]:
def toeplitz(x, window):
    return x.unfold(1, window, 1)

def signal_window_dot(x, w, b):
    return (x.unsqueeze(3) * w).sum(2) + b

In [4]:
def random_weight(shape):
    """
    Create random Tensors for weights; setting requires_grad=True means that we
    want to compute gradients for these Tensors during the backward pass.
    We use Kaiming normalization: sqrt(2 / fan_in)
    """
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
    else:
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]
    # randn is standard normal distribution generator. 
    w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)

In [36]:
def decode(x, params):
    x.squeeze(0)
    data = toeplitz(x, window)

    w1, b1, w2, b2, w3, b3 = params
    data = signal_window_dot(data, w1, b1).clamp(0)
    data = signal_window_dot(data, w2, b2).clamp(0)
    data = signal_window_dot(data, w3, b3).squeeze().sum(2)
    return data


In [6]:
window = 31
chunk_size = 100000
dset = wiredOR_dataset.WiredORDataset('C:/Users/jbrown/Desktop/research/arg/signal_data/10b_1w/subset_0/dataset.h5', window, chunk_size)
num_electrode = dset.shape[0]
loader_train = torch.utils.data.DataLoader(torch.utils.data.Subset(dset, [1, 2, 3]))

In [29]:
out_ch_1 = 31
out_ch_2 = 50
out_ch_3 = 31

weight1 = random_weight((num_electrode - 1, 1, window, out_ch_1))
bias1 = zero_weight((num_electrode - 1, 1, out_ch_1))
weight2 = random_weight((num_electrode - 1, 1, out_ch_1, out_ch_2))
bias2 = zero_weight((num_electrode - 1, 1, out_ch_2))
weight3 = random_weight((num_electrode - 1, 1, out_ch_2, out_ch_3))
bias3 = zero_weight((num_electrode - 1, 1, out_ch_3))
params = [weight1, bias1, weight2, bias2, weight3, bias3]

In [39]:
lr = 1e-3
wd = 0 #0.9
optimizer = torch.optim.Adam(params, lr = lr, weight_decay = wd)
mb = 1000
def train_decoder(loader, model_fn, params, learning_rate):
    """
    """
    for t, (x, y) in enumerate(loader):
        x = x.squeeze()
        y = y.squeeze()
        for _mb in range(chunk_size // mb):
            print(t, _mb)

            # Move the data to the proper device (GPU or CPU)
            x_tmp = x[1:, mb * _mb : mb * (_mb + 1)]
            x_tmp = x_tmp.to(device=device, dtype=dtype)
            y_tmp = y[1:, mb * _mb + window // 2 : mb * (_mb + 1) - window // 2]
            y_tmp = y_tmp.to(device=device, dtype=dtype)

            # Forward pass: compute scores and loss
            optimizer.zero_grad()
            scores = model_fn(x_tmp, params)
            print(scores.shape)
            loss = (scores - y_tmp).pow(2).sum()
            print(loss.item())
            # Backward pass: PyTorch figures out which Tensors in the computational
            # graph has requires_grad=True and uses backpropagation to compute the
            # gradient of the loss with respect to these Tensors, and stores the
            # gradients in the .grad attribute of each Tensor.
            loss.backward()

            # Update parameters. We don't want to backpropagate through the
            # parameter updates, so we scope the updates under a torch.no_grad()
            # context manager to prevent a computational graph from being built.
            optimizer.step()

            #if t % print_every == 0:
            #    print('Iteration %d, loss = %.4f' % (t, loss.item()))
            #    print()

In [40]:
train_decoder(loader_train, decode, params, lr)

0 0
torch.Size([512, 970])
16814596.0
0 1
torch.Size([512, 970])
18008240.0


KeyboardInterrupt: 

##def check_accuracy_part2(loader, model_fn, params):
    """
    Check the accuracy of a classification model.
    
    Inputs:
    - loader: A DataLoader for the data split we want to check
    - model_fn: A function that performs the forward pass of the model,
      with the signature scores = model_fn(x, params)
    - params: List of PyTorch Tensors giving parameters of the model
    
    Returns: Nothing, but prints the accuracy of the model
    """
    split = 'val' if loader.dataset.train else 'test'
    print('Checking accuracy on the %s set' % split)
    num_correct, num_samples = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype= torch.int16)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.int16)
            scores = model_fn(x, params)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f%%)' % (num_correct, num_samples, 100 * acc))