In [1]:
import json
import pickle
import torch
from torch import nn, optim

# TODO move to constants file
ACTION_STOP = 0
ACTION_GO = 1
TASK_PAD = "*pad*"
TASK_SOS = "*sos*"
TASK_EOS = "*eos*"

DEVICE = torch.device("cuda:0")

In [2]:
class Dataset(object):
    def __init__(self, annotations, actions, xs, ys, zs, mats, annotation_vocab):
        self.annotations = annotations
        self.actions = actions
        self.xs = xs
        self.ys = ys
        self.zs = zs
        self.mats = mats
        self.annotation_vocab = annotation_vocab
        
        self.n_annotations = len(annotation_vocab)
        
    def sample_task_batch(self, batch_size):
        indices = np.random.randint(len(self.annotations), size=batch_size)
        batch_annotations = []
        for i in indices:
            actions = self.actions[i]
            boundaries = [j for j in range(len(actions)) if actions[j] == ACTION_STOP]
            annotations = [self.annotations[i][j-1] for j in boundaries]
            annotations = [self.annotation_vocab[TASK_SOS]] + annotations + [self.annotation_vocab[TASK_EOS]]
            batch_annotations.append(annotations)
        max_len = max(len(s) for s in batch_annotations)
        data = np.zeros((max_len, batch_size), dtype=np.int64)
        for i in range(len(batch_annotations)):
            data[:len(batch_annotations[i]), i] = batch_annotations[i]
        return torch.tensor(data)
        
def make_dataset(path_pkl_data, ann_vocab_json_data):
    with open(path_pkl_data, "rb") as f:
        paths = pickle.load(f)
    with open(ann_vocab_json_data) as f:
        annotation_vocab = json.load(f)
    annotations = []
    actions = []
    xs = []
    ys = []
    zs = []
    mats = []
    for path in paths:
        p_annotations, p_actions, p_positions = zip(*path)    
        p_xs, p_ys, p_zs, p_mats = zip(*p_positions)
        annotations.append(p_annotations)
        actions.append(p_actions)
        xs.append(p_xs)
        ys.append(p_ys)
        zs.append(p_zs)
        mats.append(p_mats)
    return Dataset(annotations, actions, xs, ys, zs, mats, annotation_vocab)

In [3]:
N_EMB_TASK = 64
N_HIDDEN = 100

class RnnModelState(object):
    def __init__(self):
        pass

class RnnModel(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        
        self.annotation_emb = nn.Embedding(dataset.n_annotations, N_EMB_TASK)
        self.annotation_rnn = nn.LSTM(N_EMB_TASK, N_HIDDEN)
        self.annotation_out = nn.Linear(N_HIDDEN, dataset.n_annotations)
        self.task_loss = nn.CrossEntropyLoss(ignore_index=dataset.annotation_vocab[TASK_PAD])
        
    def score_task(self, batch):
        ctx = batch[:-1, :]
        tgt = batch[1:, :]
        tgt = tgt.view(tgt.shape[0] * tgt.shape[1])
        annotation_emb = self.annotation_emb(ctx)
        annotation_hid, _ = self.annotation_rnn(annotation_emb)
        annotation_out = self.annotation_out(annotation_hid).view(tgt.shape[0], self.dataset.n_annotations)
        loss = self.task_loss(annotation_out, tgt)
        return loss

In [4]:
INIT_LR = 0.001
LR_STEP = 100
BATCH_SIZE = 100

def train(data_type):
    assert data_type in ("flat", "hier")
    dataset = make_dataset("../dataset/{}.pkl".format(data_type), "../dataset/annotation_vocab.json")
    model = RnnModel(dataset).to(DEVICE)
    print([p.shape for p in model.parameters()])
    #train_tasks(dataset, model)
    #train_actions(dataset, model)
    
def train_tasks(dataset, model):
    opt = optim.Adam(model.parameters(), INIT_LR)
    opt_sched = optim.lr_scheduler.StepLR(opt, LR_STEP)
    for i in range(100):
        batch = dataset.sample_task_batch(BATCH_SIZE).to(model.device)
        opt.zero_grad()
        loss = model.score_task(batch)
        loss.backward()
        opt.step()
        opt_sched.step()
        print("{:0.3f}".format(loss.item()))
        
def train_actions(model):
    pass
        

In [5]:
train("hier")

RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/THCTensorRandom.cu:25