In [1]:
#!/bin/bash

# Create directory for data and models
!mkdir -p data
!mkdir -p models
!mkdir -p figs

# Download data, metadata, and word embeddings
!curl -LO http://cs.umd.edu/~miyyer/data/relationships.csv.gz
!curl -LO http://cs.umd.edu/~miyyer/data/metadata.pkl
!curl -LO http://cs.umd.edu/~miyyer/data/glove.We
!mv relationships.csv.gz data/
!mv metadata.pkl data/
!mv glove.We data/

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   261  100   261    0     0   2211      0 --:--:-- --:--:-- --:--:--  2211
100   265  100   265    0     0   1409      0 --:--:-- --:--:-- --:--:--  1409
100   244  100   244    0     0    942      0 --:--:-- --:--:-- --:--:--  7393
100   245  100   245    0     0    587      0 --:--:-- --:--:-- --:--:--   587
100 28.7M  100 28.7M    0     0  16.5M      0  0:00:01  0:00:01 --:--:-- 34.1M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100   253  100   253    0     0   2162      0 --:--:-- --:--:-- --:--:--  2162
100   257  100   257    0     0   1374      0 --:--:

In [2]:
import csv, gzip, pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:0")

In [None]:
def load_data(span_path, metadata_path):
    x = csv.DictReader(gzip.open(span_path, 'rt'))
    wmap, cmap, bmap = pickle.load(open(metadata_path, 'rb'), encoding='latin1')
    max_len = -1
    
    revwmap = dict((v,k) for (k,v) in wmap.items())
    revbmap = dict((v,k) for (k,v) in enumerate(bmap))
    revcmap = dict((v,k) for (k,v) in cmap.items())
    
    span_dict = {}
    for row in x:
        text = row['Words'].split()
        if len(text) > max_len:
            max_len = len(text)
        key = '___'.join([row['Book'], row['Char 1'], row['Char 2']])
        if key not in span_dict:
            span_dict[key] = []
        span_dict[key].append([wmap[w] for w in text])
        
    span_data = []
    for key in span_dict:
        book, c1, c2 = key.split('___')
        book = np.array([revbmap[book], ]).astype('int32')
        chars = np.array([revcmap[c1], revcmap[c2]]).astype('int32')

        # convert spans to numpy matrices 
        spans = span_dict[key]
        s = np.zeros((len(spans), max_len)).astype('int32')
        m = np.zeros((len(spans), max_len)).astype('float32')
        for i in range(len(spans)):
            curr_span = spans[i]
            s[i][:len(curr_span)] = curr_span
            m[i][:len(curr_span)] = 1.
        span_data.append([book, chars, s, m])
    return np.asarray(span_data), max_len, wmap, cmap, bmap

def generate_negative_samples(num_traj, span_size, negs, span_data):
    inds = np.random.randint(0, num_traj, negs)
    neg_words = np.zeros((negs, span_size)).astype('int32')
    neg_masks = np.zeros((negs, span_size)).astype('float32')
    for index, i in enumerate(inds):
        rand_ind = np.random.randint(0, len(span_data[i][2]))
        neg_words[index] = span_data[i][2][rand_ind]
        neg_masks[index] = span_data[i][3][rand_ind]
    return torch.from_numpy(neg_words).long(), torch.from_numpy(neg_masks)
    
print('Loading data')
span_data, span_size, wmap, cmap, bmap = load_data('data/relationships.csv.gz', 'data/metadata.pkl')
We = pickle.load(open('data/glove.We', 'rb'), encoding='latin1').astype('float32')
norm_We = We / np.linalg.norm(We, axis=1)[:, None]
We = np.nan_to_num(norm_We)
descriptor_log = 'models/descriptors.log'
trajectory_log = 'models/trajectories.log'

In [None]:
# embedding/hidden dimensionality
d_word = We.shape[1]
d_char = 50
d_book = 50
d_hidden = 50

# number of descriptors
num_descs = 30

# number of negative samples per relationship
num_negs = 50

# word dropout probability
p_drop = 0.75
    
n_epochs = 15
lr = 0.001
lda = 1e-6
num_chars = len(cmap)
num_books = len(bmap)
num_traj = len(span_data)
len_voc = len(wmap)
revmap = {}
for w in wmap:
    revmap[wmap[w]] = w

print('d_word: {}, span_size: {}, num_descs: {}, len_voc: {}, num_chars: {}, num_books: {}, num_traj: {}'\
      .format(d_word, span_size, num_descs, len_voc, num_chars, num_books, num_traj))

In [None]:
class RMNDataset(Dataset):
    def __init__(self, data):
        self.books = [torch.as_tensor(r[0][0]) for r in data]
        self.char1s = [torch.as_tensor(r[1][0]) for r in data]
        self.char2s = [torch.as_tensor(r[1][1]) for r in data]
        self.spans = [torch.from_numpy(r[2]) for r in data]
        self.masks = [torch.from_numpy(r[3]) for r in data]
        
    def __getitem__(self, index):
        b = self.books[index].long()
        c1 = self.char1s[index].long()
        c2 = self.char2s[index].long()
        span = self.spans[index].long()
        mask = self.masks[index]
        return b, c1, c2, span, mask
    
    def __len__(self):
        return len(self.books)

In [None]:
dataset = RMNDataset(span_data)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=32, pin_memory=True)

In [None]:
class Net(nn.Module):
    def __init__(self, We, num_books, num_chars, d_book, d_char, d_word, num_descs, balance = 0.5):
        super().__init__()
        self.balance = balance
        # embeddings
        self.b_emb = nn.Embedding(num_books, d_book)
        self.c_emb = nn.Embedding(num_chars, d_char)
        self.w_emb = nn.Embedding.from_pretrained(torch.as_tensor(We, dtype=torch.float32))
        # size d x 4d
        self.Wbook = nn.Linear(d_book, d_word)
        self.Wchar = nn.Linear(d_char, d_word)
        self.Wtext = nn.Linear(d_word, d_word)
        # size K x K + d
        self.Wd = nn.Linear(d_word, num_descs)
        self.Wd2 = nn.Linear(num_descs, num_descs)
        # personality matrix
        self.R = nn.Parameter(torch.zeros([num_descs, d_word]))
        
        init = [self.b_emb, self.c_emb, self.Wtext, self.Wchar, self.Wbook, self.Wd, self.Wd2]
        for e in init:
            torch.nn.init.xavier_uniform_(e.weight)
        torch.nn.init.xavier_uniform_(self.R)
        
    def _embed(self, book, char1, char2, span, span_mask, neg = None, neg_mask = None):
        book = self.b_emb(book)
        char1 = self.c_emb(char1)
        char2 = self.c_emb(char2)
        span = self.w_emb(span) * span_mask[:,:,None]
        emb_span = span.sum(1).div(span_mask.sum(1, keepdim=True).clamp(1))
        if neg is None: return book, char1, char2, emb_span
        
        # Add random dropout to produce new embeddings
        drop_mask = (span_mask * torch.cuda.FloatTensor(span_mask.shape).uniform_() > p_drop).float()
        span_drop = span * drop_mask[:,:,None]
        emb_span_drop = span_drop.sum(1).div(drop_mask.sum(1, keepdim=True).clamp(1))
        # Embed negative samples
        neg = self.w_emb(neg) * neg_mask[:,:,None]
        emb_neg = neg.sum(1).div(neg_mask.sum(1, keepdim=True).clamp(1))
        return book, char1, char2, emb_span_drop, emb_span, emb_neg
    
    def _updateState(self, action, book, c1, c2, state):
        update = F.relu(self.Wbook(book) + self.Wchar(c1) + self.Wchar(c2) + self.Wtext(action))
        new = F.softmax(self.Wd(update) + self.Wd2(state), dim=0)
        return self.balance * new + (1-self.balance) * state
        
    def forward(self, book, c1, c2, span, span_mask, neg = None, neg_mask = None):
        if neg is not None: book, c1, c2, span_drop, span, neg = self._embed(book, c1, c2, span, span_mask, neg, neg_mask)
        else: book, c1, c2, span = self._embed(book, char1, char2, span, span_mask)
        # ret either contains reconstruction vectors (train) or state vectors (eval)
        ret = []
        state = torch.zeros(num_descs).to(device)
        for action in span_drop if neg is not None else span:
            state = self._updateState(action, book, c1, c2, state)
            if neg is not None: ret.append(F.normalize(torch.matmul(self.R.t(), state), dim=0))
            else: ret.append(state)
        if neg is not None: return torch.stack(ret), F.normalize(span, dim=1), F.normalize(neg, dim=1)
        else: return torch.stack(ret)
    
def hingeOrthoLoss(recon, truth, negs, R, I):
    correct = torch.sum(recon * truth, axis=1)
    incorrect = torch.matmul(recon, negs.t())
    loss = torch.sum(torch.sum(1. - correct[:, None] + incorrect, axis=1).clamp(0))
            
    norm_R = F.normalize(R)
    ortho_penalty = (torch.mm(norm_R, norm_R.t()) - I).sum()
            
    return loss + lda * ortho_penalty

In [None]:
net = Net(We, num_books, num_chars, d_book, d_char, d_word, num_descs).to(device)
I = torch.eye(len(net.R)).to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)

net.train()
for epoch in range(n_epochs):
    step = 0
    cum_loss = 0
    for book, char1, char2, span, span_mask in loader:
        # don't use batches
        book, char1, char2 = book[0].to(device), char1[0].to(device), char2[0].to(device)
        span, span_mask = span[0].to(device), span_mask[0].to(device)

        neg, neg_mask = generate_negative_samples(num_traj, span_size, num_negs, span_data)
        neg, neg_mask = neg.to(device), neg_mask.to(device)

        net.zero_grad()
        recon, emb_words, emb_neg = net(book, char1, char2, span, span_mask, neg, neg_mask)
        loss = hingeOrthoLoss(recon, emb_words, emb_neg, net.R, I)
        loss.backward()
        optimizer.step()

        step += 1
        cum_loss += loss
        if step % 1000 == 0:
            print('Epoch {}, Step {}, Loss {}'.format(epoch, step, cum_loss/1000))
            cum_loss = 0

In [None]:
net.eval()
R = net.R.cpu().detach().numpy()
with open(descriptor_log, 'w+') as log:
    for ind in range(len(R)):
        desc = R[ind] / np.linalg.norm(R[ind])
        sims = We.dot(desc.T)
        ordered_words = np.argsort(sims)[::-1]
        desc_list = [revmap[w] for w in ordered_words[:10]]
        log.write(' '.join(desc_list) + '\n')

In [None]:
with open(trajectory_log, 'w+') as log:
    traj_writer = csv.writer(log)
    traj_writer.writerow(['Book', 'Char 1', 'Char 2', 'Span ID'] + ['Topic ' + str(i) for i in range(num_descs)])
    for book, chars, span, span_mask in span_data:
        c1, c2 = [cmap[c] for c in chars]
        bname = bmap[book[0]]
        
        book = torch.as_tensor(book[0]).long().to(device)
        char1 = torch.as_tensor(chars[0]).long().to(device)
        char2 = torch.as_tensor(chars[1]).long().to(device)
        span = torch.from_numpy(span).long().to(device)
        span_mask = torch.from_numpy(span_mask).to(device)

        traj = net(book, char1, char2, span, span_mask).cpu().detach().numpy()
        for ind in range(len(traj)):
            step = traj[ind]
            traj_writer.writerow([bname, c1, c2, ind] + list(step))