In [5]:
import torch
import torch.nn as nn
from collections import OrderedDict
import shutil
import time
import gzip
import os
import json
import numpy as np
from dpp_nets.utils.io import make_embd, make_tensor_dataset, load_tensor_dataset
from dpp_nets.utils.io import data_iterator, load_embd
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
import time
from dpp_nets.my_torch.utilities import pad_tensor


root = '/Users/Max/data/beer_reviews'
data_file = 'reviews.aspect3.train.txt.gz'
embd_file = 'review+wiki.filtered.200.txt.gz'
save_path = os.path.join(root,'pytorch/aspect3_train.pt')
data_path = os.path.join(root, data_file)
embd_path = os.path.join(root, embd_file)


def read_rationales(path):
    """
    This reads the json.annotations file. 
    Creates a list of dictionaries, which holds the 994 reviews for which
    sentence-level annotations are available. 
    """
    data = []
    fopen = gzip.open if path.endswith(".gz") else open
    with fopen(path) as fin:
        for line in fin:
            item = json.loads(line)
            data.append(item)
    return data

In [None]:
from collections import defaultdict
import torch
import torch.nn as nn
from dpp_nets.my_torch.linalg import custom_decomp
from dpp_nets.my_torch.DPP import DPP
from dpp_nets.my_torch.DPP import AllInOne
from dpp_nets.my_torch.utilities import compute_baseline

class DPP_Classifier(nn.Module):
    
    def __init__(self, dtype):
        
        super(DPP_Classifier, self).__init__()
        # Float vs Double
        self.dtype = dtype

        # Network parameters
        self.kernel_in = kernel_in = 400
        self.kernel_h = kernel_h = 1000
        self.kernel_out = kernel_out = 400

        self.pred_in = pred_in = 200 # kernel_in / 2
        self.pred_h = pred_h = 500
        self.pred_h2 = pred_h2 = 200
        self.pred_out = pred_out = 3
        
        # 2-Hidden-Layer Networks 
        self.kernel_net = torch.nn.Sequential(nn.Linear(kernel_in, kernel_h), nn.ELU(),
                                              nn.Linear(kernel_h, kernel_h), nn.ELU(), 
                                              nn.Linear(kernel_h, kernel_out))
        # 3-Hidden-Layer-Networks
        self.pred_net = torch.nn.Sequential(nn.Linear(pred_in, pred_h), nn.ReLU(),
                                             nn.Linear(pred_h, pred_h), nn.ReLU(),
                                             nn.Linear(pred_h, pred_h2), nn.ReLU(),
                                             nn.Linear(pred_h2, pred_out), nn.Sigmoid())
        
        self.kernel_net.type(self.dtype)
        self.pred_net.type(self.dtype)
        
        # Sampling Parameter
        self.alpha_iter = 5

        # Convenience
        self.kernels = []
        self.subsets = None
        self.picks = None
        self.preds = None
        
        self.saved_subsets = None
        self.saved_losses = None # not really necesary
        self.saved_baselines = None # not really necessary
        
    def forward(self, reviews):
        """
        reviews: batch_size x max_set_size x embd_dim = 200
        Output: batch_size x pred_out (the prediction)
        Challenges: Need to resize tensor appropriately and 
        measure length etc. 
        """
        batch_size, max_set_size, embd_dim = reviews.size()
        alpha_iter = self.alpha_iter
        self.saved_subsets = actions = [[] for i in range(batch_size)]
        picks = [[] for i in range(batch_size)]
        
        # Create context
        lengths = reviews.sum(2).abs().sign().sum(1)
        context = (reviews.sum(1) / lengths.expand_as(reviews.sum(1))).expand_as(reviews)
        mask = reviews.sum(2).abs().sign().expand_as(reviews).byte()

        # Mask out zero words
        reviews = reviews.masked_select(mask).view(-1, embd_dim)
        context = context.masked_select(mask).view(-1, embd_dim)

        # Compute batched_kernel
        kernel_input = torch.cat([reviews, context], dim=1)
        kernel_output = self.kernel_net(kernel_input)
        
        # Extract the kernel for each review from batched_kernel
        s = list(lengths.squeeze().cumsum(0).long().data - lengths.squeeze().long().data)
        e = list(lengths.squeeze().cumsum(0).long().data)

        for i, (s, e) in enumerate(zip(s, e)):
            review = reviews[s:e] # original review, without zero words
            kernel = kernel_output[s:e] # corresponding kernel 
            self.kernels.append(kernel.data)
            #vals, vecs = custom_decomp()(kernel)
            for j in range(alpha_iter):
                subset = AllInOne()(kernel)
                #subset = DPP()(vals, vecs)
                actions[i].append(subset)
                pick = subset.diag().mm(review).sum(0)
                picks[i].append(pick)

        # Predictions
        picks = torch.stack([torch.stack(pick) for pick in picks]).view(-1, embd_dim)
        preds = self.pred_net(picks).view(batch_size, alpha_iter, -1)
        
        return preds

def register_rewards(preds, targets, criterion, net):
    
    #targets = targets.unsqueeze(1).unsqueeze(1).expand_as(preds)
    targets = targets.unsqueeze(1).expand_as(preds)
    loss = criterion(preds, targets)
    
    actions = net.saved_subsets
    
    losses = ((preds - targets)**2).mean(2)
    losses = [[i.data[0] for i in row] for row in losses]
    net.saved_losses = losses # not really necessary
    baselines = [compute_baseline(i) for i in losses]
    net.saved_baselines = baselines # not really necessary
    
    for actions, rewards in zip(actions, baselines):
        for action, reward in zip(actions, rewards):
            action.reinforce(reward)

    return loss

In [None]:
# Useful Support

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    This is good!
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
        
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    lr = lr * (0.1 ** (epoch // 5))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
def train(train_loader, embd, model, criterion, optimizer, epoch, dtype):
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    subset_size = AverageMeter()
    
    target_dim = 3

    end = time.time()
    for i, (review, target) in enumerate(train_loader):
        
        # measure data loading time
        data_time.update(time.time() - end)
        
        targets = Variable(target[:,:target_dim].type(dtype))
        reviews = embd(Variable(review)).type(dtype)

        # compute output
        model.alpha_iter = 2
        pred = model(reviews)
        loss = register_rewards(pred, targets, criterion, model)

        ##measure accuracy and record loss ????????????????????????
        # prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], reviews.size(0))
        for l in model.saved_subsets:
            for s in l:
                subset_size.update(s.data.sum())
        # top1.update(prec1[0], input.size(0))
        # top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #if i % args.print_freq == 0:
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'SSize {subset_size.val:.2f} ({subset_size.avg: .2f})'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, subset_size = subset_size, loss=losses))

def validate(val_loader, model, criterion):
    
    batch_time = AverageMeter()
    losses = AverageMeter()
    t_prec = AverageMeter()
    t_recall = AverageMeter()
    t_tp = AverageMeter()
    t_fp = AverageMeter()
    t_fn = AverageMeter()
    
    target_dim = 3

    # switch to evaluate mode
    # model.eval()

    end = time.time()
    for i, (review, target) in enumerate(val_loader):
        
        target = target.sum(1).sign().type(dtype).squeeze().byte()
        # targets = target[:,:target_dim,:].type(dtype)
        reviews = embd(Variable(review, volatile=True)).type(dtype)

        # compute output
        model.alpha_iter = 1
        preds = model(reviews)

        subset = model.saved_subsets[0][0]
        subset = pad_tensor(subset.data,0,0,412).byte()
        # target = targets

        # targets = target[:,:target_dim,:].type(dtype)
        reviews = embd(Variable(review, volatile=True)).type(dtype)

        # compute output
        model.alpha_iter = 1
        preds = model(reviews)
        
        subset = model.saved_subsets[0][0]
        subset = pad_tensor(subset.data,0,0,412).byte()
        # target = target[:,:target_dim,:].squeeze()

        retriev = subset.sum()
        relev = target.sum()

        tp = target.masked_select(subset).sum()
        fp = (1 - target.masked_select(subset)).sum()
        fn = (1 - subset.masked_select(target)).sum()
        t_tp.update(tp)
        t_fp.update(fp)
        t_fn.update(fn)

        if retriev: 
            prec = tp / retriev
            t_prec.update(prec)

        if relev: 
            recall = tp / relev
            t_recall.update(recall)

        # measure accuracy and record loss
        #prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        #losses.update(loss.data[0], input.size(0))
        #top1.update(prec1[0], input.size(0))
        #top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Precision {t_prec.val:.4f} ({t_prec.avg:.4f})\t'
                  'Recall {t_recall.val:.4f} ({t_recall.avg:.4f})\t'.format(
                   i, len(val_loader), batch_time=batch_time, t_prec=t_prec, t_recall=t_recall))
            
    return t_prec.avg

In [None]:
### MAIN PROGRAMME


global best_prec1
best_prec1 = 0

# set parameters
lr = 1e-1
momentum = 0.9
weight_decay = 0.
start_epoch = 0
epochs = 1
batch_size = 20
print_freq = 10

data = '/Users/Max/data/beer_reviews/pytorch'
dtype = torch.DoubleTensor

# create model
embd = load_embd('/Users/Max/data/beer_reviews/pytorch/embeddings.pt')
model = DPP_Classifier(torch.DoubleTensor)

# define loss function (criterion) and optimizer
criterion = nn.L1Loss()

optimizer = torch.optim.SGD(model.parameters(), lr,
                            momentum=momentum,
                            weight_decay=weight_decay)

# Data loading code
trainpath = os.path.join(data, 'aspect1_train.pt')
valpath = os.path.join(data, 'aspect1_heldout.pt')
ratpath = os.path.join(data, 'annotated.pt')

train_set = torch.load(trainpath)
val_set = torch.load(valpath)
rat_set = torch.load(ratpath)

rat_train_set = torch.load(os.path.join(data, 'annotated_common.pt'))
#train_loader = DataLoader(train_set, batch_size, shuffle=True)
#val_loader = DataLoader(val_set)
rat_train_loader = DataLoader(rat_train_set, batch_size, shuffle=True)
rat_loader = DataLoader(rat_set)

In [None]:
epochs = 20
criterion = nn.L1Loss()

for epoch in range(start_epoch, epochs):
    adjust_learning_rate(optimizer, epoch)

    # train for one epoch
    train(rat_train_loader, embd, model, criterion, optimizer, epoch, dtype)

    # evaluate on validation set
    prec1 = validate(rat_loader, model, criterion)

# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    'optimizer' : optimizer.state_dict(),
}, is_best)

In [None]:
import random
#word_to_ix = make_embd(embd_path, only_index_dict=True)
#ix_to_word = {ix: word for word, ix in word_to_ix.items()}

rat_set, ix_to_word
def sample_words(rat_set, model, ix_to_word):
    
    # Sample a review
    ix = random.randint(0,len(rat_set))

    # Make a prediction
    x = rat_set.data_tensor[ix].unsqueeze(0)
    review = embd(Variable(x, volatile=True)).type(dtype)
    model.alpha_iter = 1
    model(review)

    # What words were selected
    subset = model.saved_subsets[0][0]
    subset = pad_tensor(subset.data,0,0,412).byte()

    # Convert to words
    all_words = [ix_to_word[ix -1] for ix in x.squeeze() if ix > 0]
    filtered_words = [ix_to_word[ix -1] for ix in x.masked_select(subset)]
    print(" ".join(all_words) )
    print("DPP Selection: ", filtered_words)


    ix = random.randint(0,len(rat_set))
    rat_set.data_tensor[ix].unsqueeze(0)

In [None]:
   
def sample_prediction(rat_set, model):
    # Sample a review
    ix = random.randint(0,len(rat_train_set))

    # Make a prediction
    x = rat_train_set.data_tensor[ix].unsqueeze(0)
    target = rat_train_set.target_tensor[ix][:3]
    review = embd(Variable(x, volatile=True)).type(dtype)
    model.alpha_iter = 1
    pred = model(review).data.squeeze()
    print(pred, target)
    return pred, target

In [None]:
pred, target = sample_prediction(rat_set, model)

In [None]:
criterion(Variable(pred), Variable(target))

In [None]:
torch.save(model.pred_net.state_dict(), 'pred_dict25.pt')

In [None]:
import torch

e = 0
for i in range(100):
    v = torch.normal(torch.FloatTensor([1,2,3,4,5]))
    e += v
e / 100

In [None]:
non_lin = torch.sin
torch.sin(v)

In [None]:
non_lin(v)

In [None]:
non_lin

In [None]:
batch_size = 2
set_size = 3
embd_dim = 4
words = torch.randn(batch_size, set_size, embd_dim)

In [None]:
v = torch.normal(torch.FloatTensor([1,2,3,4,5])torch.cos(torch.sin(words).mean(1)).squeeze()

In [None]:
v = torch.normal(torch.FloatTensor([1,2,3,4,5]))
torch.log(v)

In [None]:
import numpy as np
batch_size = 100
n_clusters = 10
set_size = 40
embd_dim = pred_in = 50
dtype = dtype = torch.DoubleTensor
np.random.seed(0)
means = dtype(np.random.randint(-50,50,[n_clusters, int(pred_in)]).astype("float"))

def generate(batch_size):
    """sdf"
    Arguments:
    means: Probs best to make this an attribute of the class, 
    so that repeated training works with the same data distribution.


    """


    # Generate index
    index = torch.cat([torch.arange(0, float(n_clusters)).expand(batch_size, n_clusters).long(), 
                      torch.multinomial(torch.ones(batch_size, n_clusters), set_size - n_clusters, replacement=True)]
                     ,dim=1)
    index = index.t()[torch.randperm(set_size)].t().contiguous()

    # Generate words, context, target
    words = dtype(torch.normal(means.index_select(0,index.view(index.numel()))).view(batch_size, set_size, embd_dim))
    context = dtype(words.sum(1).expand_as(words))

    target = torch.sin(torch.pow(words.abs(),2).mean(1)).squeeze()

    return words, context, target

In [None]:
words, context, target = generate(5)
print(target)

In [None]:
(torch.std(target, dim=0) / torch.mean(target, dim=0)).mean()

In [None]:
target

In [None]:
v1 = torch.randn(2,2)
v2 = torch.randn(2,2)
v3 = torch.randn(2,2)
v4 = torch.randn(2,2)
v5 = torch.randn(2,2)
v6 = torch.randn(2,2)


In [None]:
import torch.nn as nn
nn.MSELoss()

In [None]:
from dpp_nets.my_torch.simulator import SimKDPPDeepSet
import torch
network_params = {'set_size': 40, 'n_clusters': 10}
dtype = torch.DoubleTensor
sim = SimKDPPDeepSet(network_params, dtype)

In [None]:
mod = torch.nn.Sequential(nn.Linear(10,20), nn.ReLU(), nn.Linear(20,10))

In [None]:
for mod in mod.modules():
    print(mod)

In [None]:
A = Variable(torch.randn(10,20))

In [None]:
mod(A)

In [None]:
batch_size = 3
set_size = 4
embd_dim = 5
words = Variable(torch.randn(batch_size, set_size, embd_dim))
print(words)
subset = Variable(torch.ByteTensor([1,0,0,1]),requires_grad=True)
words[1].masked_select(Variable(subset.data.expand_as(words[1].t())).t()).view(-1,embd_dim)

In [11]:
data_path

'/Users/Max/data/beer_reviews/reviews.aspect3.train.txt.gz'

In [10]:
word_to_ix = make_embd(embd_path, only_index_dict=True)

In [74]:
old_dataset = make_tensor_dataset(data_path, word_to_ix)

In [76]:
old_dataset.data_tensor_tensor


 3.0000e+00  5.4000e+01  5.7364e+04  ...   0.0000e+00  0.0000e+00  0.0000e+00
 3.0000e+00  2.7000e+01  4.3000e+01  ...   0.0000e+00  0.0000e+00  0.0000e+00
 2.3000e+01  5.4800e+02  1.0875e+05  ...   0.0000e+00  0.0000e+00  0.0000e+00
                ...                   ⋱                   ...                
 1.3660e+03  2.2170e+03  2.4590e+03  ...   0.0000e+00  0.0000e+00  0.0000e+00
 1.6400e+02  1.6300e+02  7.3000e+01  ...   0.0000e+00  0.0000e+00  0.0000e+00
 8.0000e+00  5.6900e+02  1.6600e+02  ...   0.0000e+00  0.0000e+00  0.0000e+00
[torch.LongTensor of size 70000x1012]

In [107]:
ix = 1324
print(torch.cat([old_dataset.data_tensor[ix, 200:220].unsqueeze(1), dataset.data_tensor[ix, 200:220].unsqueeze(1)],dim=1))


  225   109
   13    98
   56   225
   67    13
   33    56
   15    67
  499    33
    2    15
    0   499
    0     2
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
    0     0
[torch.LongTensor of size 20x2]



In [105]:
# this is bad as maximum set_size might increase#
from torch.utils.data import TensorDataset
import re

reviews = []
targets = []
max_set_size = 0

for i, (review, target) in enumerate(data_iterator(data_path)):
    review_ix = []
    for word in review:
        if word in word_to_ix:
            ix = word_to_ix[word] + 1
            review_ix.append(ix)
        else:
            candidates = re.split('[;|,-/."]',word)
            for word in candidates:
                if word in word_to_ix:
                    print(i)
                    ix = word_to_ix[word] + 1
                    review_ix.append(ix)

    max_set_size = max(max_set_size, len(review_ix))
    reviews.append(review_ix)
    targets.append(target)

reviews_tensor = []
for review in reviews:
    review = torch.LongTensor(review)
    review = pad_tensor(review, 0, 0, max_set_size)
    reviews_tensor.append(review)

reviews = torch.stack(reviews_tensor)
targets = torch.stack(targets)

dataset = TensorDataset(reviews, targets)

13
13
30
37
37
41
43
46
46
54
54
54
54
55
55
60
60
68
68
69
83
83
85
85
94
94
94
107
107
110
110
112
112
112
115
115
123
123
123
127
127
129
129
136
136
149
149
154
154
154
154
154
154
154
154
158
158
172
172
175
175
175
178
178
186
186
186
197
197
208
208
224
224
235
236
236
236
236
236
245
245
246
246
246
246
246
247
247
247
247
247
262
262
267
267
274
274
275
275
292
292
294
294
294
294
308
308
310
310
313
313
313
313
313
322
322
324
324
336
336
336
341
360
360
362
364
364
371
371
371
371
371
372
372
374
374
375
375
381
381
382
382
386
386
386
386
401
401
406
406
407
407
407
407
413
413
413
413
413
413
413
416
416
416
422
424
424
424
425
425
447
453
454
454
455
455
461
461
465
465
467
467
467
467
467
482
482
485
485
489
489
492
492
495
495
499
499
507
507
510
510
510
515
515
518
518
518
519
519
519
519
519
519
533
533
533
533
533
538
538
541
541
541
542
542
551
551
551
551
556
556
556
556
556
556
556
558
558
558
559
559
559
563
563
570
570
570
570
577
577
580
580
580
580
580
580
580

4097
4100
4100
4105
4105
4105
4105
4105
4105
4105
4105
4105
4106
4106
4111
4111
4111
4111
4111
4111
4115
4115
4120
4120
4123
4123
4123
4128
4131
4131
4132
4132
4133
4133
4135
4135
4135
4135
4135
4147
4147
4161
4161
4161
4162
4162
4171
4171
4171
4171
4171
4171
4175
4177
4177
4177
4177
4177
4177
4180
4180
4180
4180
4195
4195
4204
4204
4210
4210
4210
4224
4224
4224
4224
4224
4224
4224
4224
4224
4224
4236
4236
4240
4240
4243
4243
4246
4246
4246
4246
4246
4251
4251
4251
4251
4251
4251
4252
4252
4256
4256
4259
4259
4259
4259
4262
4262
4262
4262
4262
4263
4263
4263
4268
4268
4268
4268
4282
4282
4282
4282
4282
4286
4286
4288
4288
4298
4298
4301
4301
4303
4309
4309
4309
4309
4309
4312
4312
4320
4320
4320
4340
4355
4355
4355
4355
4357
4357
4361
4361
4361
4361
4362
4362
4365
4365
4370
4370
4370
4374
4374
4374
4378
4378
4381
4381
4384
4384
4384
4386
4386
4387
4387
4391
4391
4393
4393
4393
4401
4401
4402
4402
4402
4412
4419
4419
4438
4438
4438
4440
4440
4447
4447
4448
4448
4450
4455
4455
4455
4455


7676
7676
7682
7682
7689
7689
7689
7689
7689
7689
7689
7689
7694
7694
7694
7700
7700
7703
7703
7703
7703
7703
7703
7703
7703
7708
7708
7708
7708
7708
7708
7708
7708
7708
7714
7714
7718
7718
7722
7722
7731
7731
7742
7742
7750
7750
7762
7762
7762
7762
7762
7767
7767
7767
7767
7782
7782
7784
7784
7785
7785
7785
7785
7785
7785
7785
7785
7785
7785
7790
7790
7790
7790
7793
7807
7807
7811
7811
7813
7813
7816
7816
7816
7816
7816
7816
7822
7822
7837
7837
7837
7849
7849
7854
7854
7859
7859
7871
7871
7871
7871
7872
7872
7872
7872
7872
7880
7880
7884
7884
7891
7891
7893
7893
7897
7897
7897
7900
7900
7900
7900
7900
7900
7902
7902
7911
7911
7911
7911
7926
7926
7929
7929
7930
7930
7936
7936
7936
7936
7936
7936
7936
7936
7938
7938
7939
7949
7949
7952
7952
7952
7955
7955
7959
7959
7964
7964
7973
7973
7982
7987
7987
7988
7988
7993
7993
8004
8004
8004
8009
8009
8009
8012
8012
8018
8022
8022
8027
8027
8028
8032
8032
8041
8041
8041
8045
8045
8047
8047
8047
8047
8057
8057
8062
8062
8062
8062
8062
8073
8073


12289
12289
12291
12291
12291
12291
12295
12295
12313
12319
12320
12320
12324
12324
12324
12324
12329
12329
12329
12329
12335
12335
12335
12335
12347
12347
12347
12355
12355
12355
12355
12356
12356
12356
12357
12357
12359
12359
12362
12362
12365
12365
12366
12366
12366
12370
12370
12375
12375
12375
12375
12375
12379
12379
12384
12384
12384
12384
12384
12387
12387
12397
12397
12398
12398
12402
12402
12410
12410
12410
12424
12424
12424
12426
12426
12426
12426
12427
12427
12430
12430
12430
12434
12434
12437
12438
12438
12438
12438
12438
12438
12440
12440
12442
12442
12449
12449
12449
12455
12455
12471
12471
12471
12478
12478
12479
12479
12503
12503
12505
12505
12505
12519
12519
12520
12520
12522
12522
12530
12530
12534
12534
12543
12543
12546
12546
12546
12546
12547
12547
12551
12551
12551
12557
12557
12557
12572
12572
12574
12574
12582
12582
12582
12600
12600
12607
12607
12607
12607
12607
12612
12612
12612
12615
12615
12615
12615
12615
12615
12616
12616
12634
12634
12634
12682
12682
1268

16106
16106
16106
16117
16117
16118
16118
16129
16129
16129
16132
16132
16132
16132
16133
16133
16150
16150
16164
16164
16168
16168
16168
16168
16168
16178
16178
16185
16185
16185
16185
16185
16186
16196
16196
16200
16200
16200
16200
16205
16205
16205
16211
16220
16220
16220
16228
16228
16236
16236
16237
16237
16243
16243
16243
16256
16256
16259
16259
16267
16267
16267
16267
16267
16267
16272
16272
16272
16272
16279
16279
16279
16279
16283
16290
16292
16292
16292
16292
16293
16293
16293
16301
16307
16307
16307
16307
16314
16317
16317
16318
16318
16318
16318
16327
16327
16327
16332
16332
16339
16355
16355
16362
16362
16362
16362
16362
16362
16362
16362
16363
16363
16377
16377
16377
16377
16377
16378
16378
16378
16382
16382
16383
16383
16393
16393
16394
16394
16394
16398
16398
16398
16400
16400
16400
16404
16404
16414
16414
16414
16419
16423
16423
16424
16424
16442
16442
16445
16445
16445
16447
16447
16447
16447
16447
16447
16448
16448
16451
16451
16455
16457
16457
16468
16468
16468
1646

19068
19068
19068
19072
19072
19073
19073
19073
19073
19079
19079
19079
19088
19088
19088
19092
19092
19092
19092
19093
19093
19094
19094
19106
19106
19114
19114
19115
19115
19116
19116
19124
19124
19124
19124
19130
19130
19130
19130
19134
19134
19134
19134
19134
19134
19135
19139
19139
19140
19140
19163
19163
19163
19163
19163
19163
19163
19163
19174
19174
19181
19181
19184
19184
19188
19188
19188
19188
19188
19192
19200
19200
19202
19202
19202
19202
19206
19206
19210
19210
19211
19211
19211
19211
19211
19224
19224
19224
19224
19224
19224
19224
19224
19224
19224
19224
19224
19228
19228
19231
19231
19231
19245
19245
19257
19257
19263
19263
19263
19264
19264
19264
19264
19267
19267
19275
19275
19278
19278
19278
19278
19278
19278
19278
19278
19278
19285
19285
19285
19285
19285
19285
19285
19285
19287
19287
19287
19287
19316
19316
19320
19320
19321
19321
19321
19321
19321
19334
19345
19345
19360
19360
19367
19367
19367
19367
19367
19367
19375
19375
19383
19383
19391
19393
19393
19393
1940

22716
22716
22723
22723
22725
22727
22727
22738
22738
22738
22738
22738
22738
22738
22740
22740
22740
22740
22754
22754
22756
22756
22761
22761
22762
22762
22762
22762
22763
22763
22763
22763
22763
22763
22763
22763
22763
22763
22763
22763
22766
22766
22769
22769
22781
22781
22782
22782
22782
22782
22784
22784
22784
22784
22789
22789
22800
22800
22800
22800
22800
22800
22802
22805
22805
22805
22821
22821
22821
22821
22825
22825
22836
22836
22837
22837
22837
22837
22837
22840
22840
22843
22843
22843
22854
22854
22854
22854
22856
22856
22872
22872
22874
22874
22874
22883
22883
22884
22884
22890
22890
22894
22894
22896
22896
22897
22897
22902
22902
22913
22913
22913
22913
22917
22917
22917
22917
22920
22920
22921
22921
22923
22923
22923
22923
22925
22925
22948
22948
22960
22960
22974
22974
22974
22974
22974
22974
22974
22974
22974
22974
22974
22974
22974
22974
22984
22984
22984
22985
22985
22985
22985
22989
22989
22989
22990
22990
22992
22992
22997
22997
22997
22998
22998
23001
23001
2300

26273
26273
26295
26295
26295
26295
26313
26313
26313
26313
26313
26314
26314
26314
26314
26320
26320
26320
26320
26326
26326
26328
26328
26328
26329
26329
26329
26329
26331
26331
26331
26331
26331
26331
26331
26331
26345
26345
26359
26359
26359
26359
26364
26364
26364
26364
26364
26364
26364
26368
26368
26377
26377
26383
26383
26383
26391
26391
26395
26395
26397
26397
26397
26399
26399
26399
26399
26399
26399
26399
26414
26414
26414
26414
26430
26430
26434
26434
26434
26434
26437
26437
26437
26437
26438
26438
26461
26461
26461
26461
26461
26461
26465
26465
26465
26473
26473
26475
26475
26475
26495
26495
26496
26496
26496
26497
26497
26497
26497
26498
26498
26498
26498
26498
26499
26499
26499
26499
26503
26503
26509
26509
26509
26511
26511
26516
26516
26516
26516
26516
26516
26516
26521
26521
26521
26526
26526
26526
26532
26532
26542
26542
26542
26542
26545
26545
26559
26559
26562
26562
26563
26563
26564
26575
26575
26577
26577
26577
26577
26577
26589
26589
26589
26591
26591
26591
2659

29983
29983
29983
29983
29986
29986
29995
29995
29995
30004
30004
30013
30013
30033
30033
30035
30035
30035
30035
30035
30035
30035
30035
30035
30035
30043
30043
30045
30045
30045
30047
30047
30049
30049
30051
30051
30066
30066
30069
30069
30070
30070
30079
30079
30079
30079
30079
30079
30079
30079
30079
30079
30079
30079
30079
30079
30082
30082
30085
30085
30085
30085
30088
30088
30088
30088
30088
30088
30095
30095
30096
30096
30096
30096
30100
30104
30104
30104
30104
30104
30104
30106
30106
30109
30109
30120
30120
30127
30127
30134
30136
30136
30138
30138
30138
30143
30143
30147
30147
30147
30156
30156
30156
30163
30163
30168
30168
30168
30168
30172
30172
30174
30174
30174
30184
30184
30185
30185
30188
30188
30188
30192
30192
30192
30192
30192
30192
30194
30194
30194
30196
30196
30196
30206
30206
30206
30216
30216
30216
30216
30219
30219
30225
30225
30225
30231
30231
30237
30237
30237
30242
30242
30248
30254
30254
30259
30259
30264
30264
30264
30264
30271
30271
30275
30275
30275
3027

33335
33339
33339
33341
33341
33341
33349
33349
33351
33351
33352
33352
33355
33355
33356
33356
33356
33364
33364
33366
33366
33366
33366
33366
33379
33379
33379
33388
33388
33388
33389
33389
33398
33403
33403
33404
33405
33405
33405
33409
33409
33409
33409
33410
33410
33410
33415
33415
33421
33421
33421
33428
33428
33431
33431
33435
33435
33436
33436
33436
33436
33436
33437
33437
33438
33438
33442
33442
33449
33449
33457
33457
33466
33466
33471
33471
33471
33475
33475
33482
33482
33487
33487
33487
33493
33493
33507
33507
33507
33514
33514
33515
33515
33515
33515
33515
33515
33521
33521
33521
33521
33521
33521
33523
33523
33523
33526
33528
33533
33533
33533
33540
33540
33546
33546
33546
33548
33548
33549
33549
33550
33550
33555
33555
33559
33559
33560
33560
33560
33560
33560
33566
33566
33571
33571
33574
33574
33574
33596
33596
33596
33605
33605
33605
33605
33609
33609
33614
33614
33614
33615
33615
33618
33618
33621
33621
33623
33623
33634
33634
33643
33643
33645
33645
33645
33645
3364

36963
36963
36969
36969
36969
36969
36969
36969
36970
36970
36973
36973
36978
36978
36983
36983
36987
36987
36997
36997
37014
37014
37023
37023
37048
37048
37059
37059
37065
37065
37065
37066
37066
37071
37071
37071
37071
37071
37071
37071
37077
37077
37081
37082
37082
37082
37086
37086
37087
37087
37107
37107
37107
37117
37117
37118
37124
37124
37124
37124
37126
37126
37126
37135
37135
37137
37137
37137
37143
37143
37161
37161
37165
37165
37165
37165
37165
37165
37169
37169
37169
37169
37171
37171
37179
37179
37179
37179
37179
37179
37179
37179
37185
37185
37198
37198
37202
37202
37204
37204
37223
37223
37227
37227
37227
37236
37236
37248
37248
37249
37259
37259
37259
37259
37260
37261
37261
37261
37261
37274
37274
37274
37274
37274
37274
37274
37280
37280
37283
37283
37291
37291
37292
37292
37292
37292
37294
37294
37294
37298
37298
37298
37298
37298
37298
37299
37299
37299
37299
37303
37303
37303
37306
37306
37309
37309
37309
37312
37312
37312
37317
37317
37317
37317
37317
37317
3732

40428
40428
40428
40429
40429
40432
40432
40432
40437
40437
40437
40437
40442
40442
40442
40449
40449
40449
40452
40452
40456
40456
40466
40466
40466
40466
40466
40471
40472
40472
40472
40472
40480
40480
40481
40481
40484
40484
40484
40484
40484
40499
40499
40499
40502
40505
40505
40507
40507
40510
40510
40515
40515
40524
40524
40524
40524
40524
40524
40524
40524
40525
40525
40525
40534
40534
40535
40535
40535
40546
40550
40551
40551
40551
40552
40552
40565
40565
40565
40565
40569
40569
40570
40570
40585
40585
40585
40585
40585
40585
40585
40585
40586
40586
40586
40586
40586
40586
40586
40586
40588
40588
40588
40588
40591
40591
40594
40596
40596
40598
40598
40598
40598
40601
40601
40610
40610
40618
40618
40623
40623
40629
40629
40629
40629
40631
40631
40635
40635
40635
40637
40637
40651
40651
40661
40661
40667
40667
40677
40677
40677
40688
40688
40691
40691
40692
40692
40692
40692
40692
40692
40692
40698
40701
40701
40726
40726
40726
40726
40726
40738
40738
40738
40741
40741
40744
4074

43459
43459
43472
43472
43478
43478
43482
43482
43485
43485
43486
43486
43486
43488
43488
43497
43497
43503
43503
43503
43503
43508
43508
43514
43514
43521
43521
43522
43522
43529
43529
43533
43533
43533
43533
43533
43534
43534
43535
43535
43535
43535
43535
43537
43537
43537
43538
43538
43538
43538
43540
43540
43542
43542
43542
43544
43544
43544
43544
43544
43544
43545
43545
43555
43555
43556
43556
43557
43557
43557
43557
43557
43562
43562
43562
43562
43562
43562
43562
43568
43568
43568
43569
43569
43575
43575
43575
43575
43575
43575
43575
43581
43581
43597
43597
43597
43597
43600
43600
43602
43602
43603
43603
43603
43603
43603
43604
43604
43604
43606
43606
43606
43611
43611
43613
43613
43613
43613
43613
43624
43624
43629
43649
43649
43649
43649
43650
43650
43664
43664
43679
43679
43681
43681
43688
43688
43700
43700
43700
43700
43705
43705
43712
43712
43712
43712
43712
43712
43716
43716
43716
43720
43720
43720
43720
43722
43722
43723
43723
43723
43729
43729
43747
43747
43747
43747
4374

47048
47061
47061
47062
47062
47062
47068
47068
47068
47068
47070
47070
47070
47070
47071
47071
47073
47073
47075
47075
47085
47085
47085
47085
47085
47085
47086
47086
47086
47086
47086
47086
47093
47093
47096
47096
47096
47102
47104
47104
47108
47108
47108
47108
47108
47108
47108
47119
47119
47122
47122
47126
47126
47154
47154
47154
47154
47155
47163
47173
47173
47177
47177
47182
47182
47195
47195
47208
47208
47209
47209
47209
47212
47212
47219
47219
47227
47227
47233
47233
47234
47234
47234
47234
47234
47234
47234
47234
47239
47239
47244
47244
47246
47246
47248
47248
47248
47248
47254
47254
47254
47255
47256
47256
47256
47256
47266
47266
47275
47275
47289
47289
47297
47297
47302
47302
47302
47302
47302
47304
47304
47318
47318
47318
47325
47325
47328
47328
47343
47343
47343
47349
47349
47361
47361
47361
47361
47372
47372
47376
47376
47392
47392
47393
47393
47393
47393
47411
47411
47411
47411
47428
47428
47431
47431
47431
47434
47434
47443
47443
47444
47444
47453
47453
47453
47453
4745

50845
50845
50845
50849
50849
50855
50855
50855
50857
50857
50857
50857
50857
50857
50857
50857
50865
50865
50874
50874
50893
50893
50893
50895
50895
50895
50895
50903
50903
50905
50905
50907
50907
50912
50912
50918
50927
50927
50935
50935
50935
50936
50936
50938
50938
50946
50946
50950
50958
50958
50959
50959
50976
50976
50981
50984
50986
50993
50993
50999
50999
50999
51002
51002
51003
51003
51005
51005
51005
51008
51008
51009
51009
51009
51013
51013
51014
51014
51018
51020
51020
51020
51020
51020
51022
51022
51037
51037
51040
51042
51042
51050
51050
51053
51053
51053
51055
51055
51058
51058
51058
51058
51058
51059
51060
51060
51067
51067
51067
51067
51069
51069
51071
51071
51071
51071
51071
51071
51076
51080
51080
51081
51081
51084
51084
51084
51089
51090
51090
51090
51092
51092
51092
51092
51096
51103
51103
51114
51114
51115
51115
51117
51117
51119
51119
51121
51121
51121
51121
51121
51121
51130
51130
51130
51136
51136
51148
51148
51148
51150
51150
51162
51162
51162
51162
51169
5116

53717
53717
53731
53731
53731
53731
53731
53732
53732
53743
53743
53743
53743
53757
53757
53757
53757
53757
53768
53768
53768
53769
53769
53769
53769
53769
53769
53769
53790
53790
53792
53792
53794
53794
53794
53802
53802
53803
53807
53807
53807
53807
53807
53814
53814
53814
53814
53821
53821
53824
53824
53824
53824
53824
53824
53824
53824
53824
53824
53832
53832
53832
53834
53834
53836
53836
53837
53837
53837
53840
53840
53840
53841
53841
53844
53849
53849
53849
53849
53849
53852
53852
53860
53860
53860
53860
53869
53869
53871
53871
53871
53873
53873
53875
53875
53875
53875
53888
53888
53888
53908
53908
53914
53914
53914
53916
53916
53923
53923
53923
53923
53935
53935
53935
53937
53937
53944
53944
53944
53944
53944
53944
53944
53947
53947
53947
53947
53949
53949
53951
53951
53951
53951
53951
53951
53951
53952
53952
53955
53955
53958
53959
53959
53959
53959
53959
53960
53960
53960
53960
53960
53960
53964
53964
53970
53970
53970
53970
53970
53975
53975
53975
53975
53980
53980
53986
5398

56755
56760
56760
56768
56768
56768
56776
56776
56778
56778
56778
56778
56818
56818
56818
56818
56818
56818
56818
56820
56820
56822
56822
56822
56822
56822
56822
56822
56824
56824
56825
56825
56829
56829
56832
56832
56832
56832
56839
56839
56855
56855
56855
56855
56855
56855
56855
56855
56858
56858
56858
56858
56858
56858
56858
56858
56858
56858
56860
56860
56862
56862
56862
56869
56869
56875
56875
56876
56876
56880
56880
56882
56882
56883
56883
56884
56884
56898
56898
56900
56903
56903
56905
56905
56908
56908
56908
56908
56910
56910
56920
56920
56923
56923
56929
56929
56932
56932
56937
56937
56937
56939
56939
56942
56942
56942
56942
56947
56954
56954
56960
56960
56965
56965
56977
56977
56980
56980
56980
56980
56980
56981
56981
56981
56981
56981
56981
56981
56982
56982
56983
56983
56986
56986
56990
56990
56994
56994
56996
56996
56996
56996
56999
56999
57005
57005
57007
57007
57015
57015
57018
57021
57021
57021
57021
57022
57022
57040
57040
57040
57040
57040
57040
57043
57043
57046
5704

60044
60044
60049
60049
60056
60056
60056
60056
60056
60056
60056
60056
60056
60056
60056
60056
60056
60056
60057
60057
60057
60057
60057
60057
60060
60060
60060
60060
60063
60063
60063
60063
60063
60069
60069
60077
60077
60088
60088
60095
60096
60096
60096
60096
60096
60096
60096
60104
60104
60104
60104
60106
60106
60106
60106
60116
60116
60116
60116
60116
60120
60120
60120
60128
60128
60128
60128
60129
60129
60129
60129
60129
60132
60132
60134
60134
60145
60145
60157
60157
60157
60160
60160
60160
60160
60185
60185
60185
60190
60190
60192
60192
60197
60197
60200
60200
60200
60200
60200
60200
60200
60200
60200
60209
60209
60209
60211
60211
60214
60214
60225
60225
60225
60228
60228
60236
60236
60236
60236
60241
60241
60259
60259
60259
60259
60259
60259
60259
60259
60259
60259
60259
60263
60263
60263
60263
60263
60263
60263
60263
60263
60263
60263
60263
60270
60270
60270
60270
60273
60273
60273
60274
60274
60274
60275
60275
60282
60282
60282
60282
60282
60282
60284
60284
60284
60284
6028

63782
63782
63782
63782
63782
63782
63782
63782
63782
63782
63782
63784
63790
63790
63790
63790
63790
63793
63793
63804
63804
63806
63806
63807
63807
63807
63807
63807
63807
63807
63810
63810
63819
63819
63819
63819
63819
63822
63822
63834
63834
63834
63841
63841
63842
63842
63842
63842
63842
63855
63859
63859
63859
63859
63859
63868
63868
63870
63870
63870
63875
63875
63875
63875
63878
63878
63880
63880
63880
63880
63880
63880
63880
63889
63889
63889
63889
63900
63900
63902
63902
63913
63913
63919
63919
63919
63922
63922
63922
63924
63924
63927
63927
63928
63928
63933
63933
63933
63933
63936
63936
63939
63939
63948
63948
63948
63948
63948
63948
63949
63949
63949
63949
63963
63963
63963
63963
63966
63966
63975
63978
63980
63980
63980
63987
63987
64014
64014
64014
64014
64014
64014
64023
64023
64037
64037
64058
64058
64064
64064
64064
64065
64065
64065
64065
64067
64067
64067
64070
64070
64072
64072
64072
64072
64072
64072
64074
64074
64082
64082
64082
64082
64082
64082
64082
64102
6411

67883
67883
67883
67883
67896
67896
67896
67896
67896
67896
67907
67907
67908
67908
67910
67910
67911
67911
67911
67929
67929
67934
67934
67939
67939
67941
67947
67947
67958
67958
67959
67959
67959
67977
67981
67981
67981
67989
67989
67989
67989
67989
67989
67989
67989
67993
67993
67993
67993
67994
67994
67994
67995
67995
67995
67995
67995
67998
67998
68001
68001
68003
68003
68011
68011
68012
68012
68012
68012
68012
68012
68014
68014
68015
68015
68016
68016
68016
68024
68024
68026
68026
68027
68027
68032
68032
68037
68037
68039
68039
68046
68046
68046
68046
68046
68047
68057
68057
68057
68057
68058
68058
68062
68062
68069
68069
68071
68071
68075
68075
68075
68076
68076
68078
68078
68080
68080
68080
68082
68082
68094
68094
68100
68103
68110
68110
68110
68124
68124
68128
68128
68128
68128
68128
68128
68128
68128
68128
68128
68128
68128
68128
68128
68132
68132
68133
68133
68140
68140
68149
68156
68156
68169
68169
68169
68169
68170
68170
68171
68171
68175
68175
68180
68180
68180
68180
6818

In [123]:
word = 2
dim = 3
embd.weight[:,dim]

Variable containing:
 0.0000
 0.0955
-0.0327
   ⋮   
-0.0173
-0.0467
 0.0217
[torch.FloatTensor of size 147760]

In [75]:
dataset.data_tensor


 3.0000e+00  5.4000e+01  5.7364e+04  ...   0.0000e+00  0.0000e+00  0.0000e+00
 3.0000e+00  2.7000e+01  4.3000e+01  ...   0.0000e+00  0.0000e+00  0.0000e+00
 2.3000e+01  5.4800e+02  1.0875e+05  ...   0.0000e+00  0.0000e+00  0.0000e+00
                ...                   ⋱                   ...                
 1.3660e+03  2.2170e+03  2.4590e+03  ...   0.0000e+00  0.0000e+00  0.0000e+00
 1.6400e+02  1.6300e+02  7.3000e+01  ...   0.0000e+00  0.0000e+00  0.0000e+00
 8.0000e+00  5.6900e+02  1.6600e+02  ...   0.0000e+00  0.0000e+00  0.0000e+00
[torch.LongTensor of size 70000x1014]

In [1]:
def data_iterator(data_path):
    with gzip.open(data_path, 'rt') as f:
        for line in f:
            target, sep, words = line.partition("\t")
            words, target = words.split(), target.split()
            if len(words):
                target = torch.Tensor([float(v) for v in target])
                yield words, target

In [3]:
def make_tensor_dataset(data_path, word_to_ix, max_set_size=0, save_path=None):
        
    if not max_set_size:
        for (review, target) in data_iterator(data_path):
            review = [(word in word_to_ix) for word in review]
            max_set_size = max(sum(review),max_set_size)
            
    reviews, targets = [], []

    for (review, target) in data_iterator(data_path):
        review = [word_to_ix[word] + 1 for word in review if word in word_to_ix]
        review = torch.LongTensor(review)
        review = pad_tensor(review, 0, 0, max_set_size)
        reviews.append(review)
        targets.append(target)
    
    reviews = torch.stack(reviews)
    targets = torch.stack(targets)

    dataset = TensorDataset(reviews, targets)

    if save_path:
        torch.save(dataset, save_path)    
    else:
        return dataset

In [None]:
root = '/Users/Max/data/beer_reviews'
data_file = 'reviews.aspect3.train.txt.gz'
embd_file = 'review+wiki.filtered.200.txt.gz'
save_path = os.path.join(root,'pytorch/aspect3_train.pt')
data_path = os.path.join(root, data_file)
embd_path = os.path.join(root, embd_file)

embd_path = 

In [160]:
import torch
import torch.nn as nn
batch_size = 20
input_dim = 12
hidden_dim = 18
output_dim = 2
layer1 = nn.Linear(input_dim, hidden_dim)
batch_norm = nn.BatchNorm1d(hidden_dim)
layer2 = nn.Linear(hidden_dim, output_dim)
model = nn.Sequential(layer1, batch_norm, layer2)

In [206]:
layer1.train()
x = Variable(torch.randn(batch_size, input_dim))
#model2 = nn.Sequential(layer1, layer2)
#y1 = model(x)
#y2 = model2(x)
y = layer2(batch_norm(layer1(x)))
batch_norm(layer1(x))

Variable containing:

Columns 0 to 9 
 1.3708  0.7502  0.0854  0.4266  0.0059 -0.1973 -0.7639  1.1644 -0.0341  0.4740
-2.1605 -0.4336 -0.0277  0.0453 -0.0324  0.2395  0.7752 -0.8815 -0.3543  0.0030
 0.0925 -0.2283  0.0528  0.1211  0.0424 -0.0172  0.3966 -0.0855 -0.1579 -0.7209
 1.5900  0.3261 -0.0083 -0.0961  0.0140 -0.3643 -0.7462 -0.1769  0.0418 -0.5172
 1.0453 -0.6380  0.0516 -1.0351  0.0140  0.1042  0.1187  0.1786  0.0123  0.3371
-0.6965  0.5642  0.0196  0.7898 -0.0998  0.2735  0.3917 -0.5570  0.6402  0.7307
-0.1948 -0.3369  0.0310 -0.0530  0.0530  0.0680  0.0928 -0.6833 -0.4211 -0.6430
 0.0567  0.6595  0.0190 -0.5103 -0.0512  0.1156  1.5776  0.5863  0.5003 -0.1485
-0.5632  0.0859 -0.0170  0.0738 -0.0524 -0.0546 -0.0976 -0.2792  0.2557  0.0196
 0.8496  0.6132 -0.0932  0.0716 -0.0030 -0.1303 -0.3261 -0.4732  0.2063  0.0446
 0.5098  0.0896 -0.0007  0.0563 -0.0718 -0.3349 -0.4187  0.3595  0.0789 -0.1252
 0.6247 -0.1313  0.0327 -0.5086  0.0768 -0.2185  0.8786  1.5757  0.0865 -0.7573
 0

In [18]:
import argparse
import os
import shutil

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader

from dpp_nets.utils.io import make_embd, make_tensor_dataset
from dpp_nets.layers.layers import KernelVar, ReinforceSampler, PredNet, ReinforceTrainer


parser = argparse.ArgumentParser(description='REINFORCE VIMCO Trainer')

parser.add_argument('-a', '--aspect', type=str, choices=['aspect1', 'aspect2', 'aspect3', 'all'],
                    help='what is the target?', required=True)

parser.add_argument('-b', '--batch-size', default=50, type=int,
                    metavar='N', help='mini-batch size (default: 50)')
parser.add_argument('--epochs', default=30, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--lr_k', '--learning_rate_k', default=1e-3, type=float,
                    metavar='LRk', help='initial learning rate for kernel net')
parser.add_argument('--lr_p', '--learning_rate_p', default=1e-4, type=float,
                    metavar='LRp', help='initial learning rate for pred net')
parser.add_argument('--reg', type=float, required=True,
                    metavar='reg', help='regularization constant')
parser.add_argument('--reg_mean', type=float, required=True,
                    metavar='reg_mean', help='regularization_mean')
parser.add_argument('--alpha_iter', type=int, required=True,
                    metavar='alpha_iter', help='How many subsets to sample from DPP? At least 2!')

# Pre-training
parser.add_argument('--pretrain_kernel', type=str, default="",
                    metavar='pretrain_kernel', help='Give name of pretrain_kernel')
parser.add_argument('--pretrain_pred', type=str, default="",
                    metavar='pretrain_pred', help='Give name of pretrain_pred')

# Train locally or remotely?
parser.add_argument('--remote', type=int,
                    help='training locally or on cluster?', required=True)

# Burnt in Paths..
parser.add_argument('--data_path_local', type=str, default='/Users/Max/data/beer_reviews',
                    help='where is the data folder locally?')
parser.add_argument('--data_path_remote', type=str, default='/cluster/home/paulusm/data/beer_reviews',
                    help='where is the data folder remotely?')
parser.add_argument('--ckp_path_local', type=str, default='/Users/Max/checkpoints/beer_reviews',
                    help='where is the checkpoints folder locally?')
parser.add_argument('--ckp_path_remote', type=str, default='/cluster/home/paulusm/checkpoints/beer_reviews',
                    help='where is the data folder remotely?')

parser.add_argument('--pretrain_path_local', type=str, default='/Users/Max/checkpoints/beer_reviews',
                    help='where is the pre_trained model? locally')
parser.add_argument('--pretrain_path_remote', type=str, default='/cluster/home/paulusm/pretrain/beer_reviews',
                    help='where is the data folder? remotely')


def train(loader, trainer, optimizer):

    trainer.train()

    for t, (review, target) in enumerate(loader):
        review = Variable(review)

        if args.aspect == 'all':
            target = Variable(target[:,:3]).type(dtype)
        else:
            target = Variable(target[:,int(args.aspect[-1])]).type(dtype)

        loss  = trainer(review, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Weight mean is: ', trainer.kernel_net.layer1.weight.mean())
        print('Weight max is: ', trainer.kernel_net.layer1.weight.max())
        print('Weight min is: ', trainer.kernel_net.layer1.weight.min())
        print('Grad max is: ', trainer.kernel_net.layer1.weight.grad.max())
        print('Grad min is: ', trainer.kernel_net.layer1.weight.grad.min())
        print("trained one batch")

def validate(loader, trainer):
    """
    Note, we keep the sampling as before. 
    i.e what ever alpha_iter is, we take it. 
    """
    trainer.eval()
    total_loss = 0.0
    total_pred_loss = 0.0
    total_reg_loss = 0.0

    for i, (review, target) in enumerate(loader, 1):
        review = Variable(review, volatile=True)

        if args.aspect == 'all':
            target = Variable(target[:,:3], volatile=True).type(dtype)
        else:
            target = Variable(target[:,int(args.aspect[-1])], volatile=True).type(dtype)

        trainer(review, target)
        loss = trainer.loss.data[0]
        pred_loss = trainer.pred_loss.data[0]
        reg_loss = trainer.reg_loss.data[0]

        delta = loss - total_loss
        total_loss += (delta / i)
        delta = pred_loss - total_pred_loss 
        total_pred_loss += (delta / i)
        delta = reg_loss - total_reg_loss
        total_reg_loss += (delta / i)

# print("validated one batch")

    return total_loss, total_pred_loss, total_reg_loss

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR multiplied by factor 0.1 for every 10 epochs"""
    if not ((epoch + 1) % 10):
        factor = 0.1
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * factor

def log(epoch, loss, pred_loss, reg_loss):

    string = str.join(" | ", ['Epoch: %d' % (epoch), 'V Loss: %.5f' % (loss), 
                              'V Pred Loss: %.5f' % (pred_loss), 'V Reg Loss: %.5f' % (reg_loss)])

    if args.remote:
        destination = os.path.join(args.ckp_path_remote, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) + 
          'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + 'reinforce_log.txt')
    else:
        destination = os.path.join(args.ckp_path_local, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) + 
          'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + 'reinforce_log.txt')

    with open(destination, 'a') as log:
        log.write(string + '\n')

def save_checkpoint(state, is_best, filename='reinforce_checkpoint.pth.tar'):
    """
    State is a dictionary that cotains valuable information to be saved.
    """
    if args.remote:
        destination = os.path.join(args.ckp_path_remote, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) + 
          'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + str(args.filename))
    else:
        destination = os.path.join(args.ckp_path_local, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) + 
          'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + str(args.filename))
    
    torch.save(state, destination)

    if is_best:
        if args.remote:
            best_destination = os.path.join(args.ckp_path_remote, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) + 
               'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + 'reinforce_best.pth.tar')
        else:
            best_destination = os.path.join(args.ckp_path_local, args.aspect + 'reg' + str(args.reg) + 'reg_mean' + str(args.reg_mean) +  
               'alpha_iter' + str(args.alpha_iter) + str(args.pretrain_kernel) + str(args.pretrain_pred) + 'reinforce_best.pth.tar')
        
        shutil.copyfile(destination, best_destination)

In [19]:
global args, lowest_loss, dtype

args = parser.parse_args("-a aspect3 --remote 0 --reg 0.1 --reg_mean 10 --alpha_iter 4 --lr_k 1e-4".split())
lowest_loss = 100 # arbitrary high number as upper bound for loss
dtype = torch.DoubleTensor

### Load data
if args.remote:
    # print('training remotely')
    train_path = os.path.join(args.data_path_remote, str.join(".",['reviews', args.aspect, 'train.txt.gz']))
    val_path   = os.path.join(args.data_path_remote, str.join(".",['reviews', args.aspect, 'heldout.txt.gz']))
    embd_path = os.path.join(args.data_path_remote, 'review+wiki.filtered.200.txt.gz')

else:
    # print('training locally')
    train_path = os.path.join(args.data_path_local, str.join(".",['reviews', args.aspect, 'train.txt.gz']))
    val_path   = os.path.join(args.data_path_local, str.join(".",['reviews', args.aspect, 'heldout.txt.gz']))
    embd_path = os.path.join(args.data_path_local, 'review+wiki.filtered.200.txt.gz')

embd, word_to_ix = make_embd(embd_path)
train_set = make_tensor_dataset(train_path, word_to_ix)
val_set = make_tensor_dataset(val_path, word_to_ix)
print("loaded data")

torch.manual_seed(0)
train_loader = DataLoader(train_set, args.batch_size, shuffle=True)
val_loader = DataLoader(val_set, args.batch_size)
print("loader defined")

### Build model
# Network parameters
embd_dim = embd.weight.size(1)
kernel_dim = 200
hidden_dim = 500
enc_dim = 200
if args.aspect == 'all':
    target_dim = 3
else: 
    target_dim = 1

# Model
torch.manual_seed(1)


# Add pre-training here...
kernel_net = KernelVar(embd_dim, hidden_dim, kernel_dim)
sampler = ReinforceSampler(args.alpha_iter)
pred_net = PredNet(embd_dim, hidden_dim, enc_dim, target_dim)

if args.pretrain_kernel:
    if args.remote:
        state_dict = torch.load(args.pretrain_path_remote + args.pretrain_kernel)
    else:
        state_dict = torch.load(args.pretrain_path_local + args.pretrain_kernel)
    kernel_net.load_state_dict(state_dict)

if args.pretrain_pred:
    if args.remote:
        state_dict = torch.load(args.pretrain_path_remote + args.pretrain_pred)
    else:
        state_dict = torch.load(args.pretrain_path_local + args.pretrain_pred)
    pred_net.load_state_dict(state_dict)

# continue with trainer
trainer = ReinforceTrainer(embd, kernel_net, sampler, pred_net)
trainer.reg = args.reg
trainer.reg_mean = args.reg_mean
trainer.activation = nn.Sigmoid()
trainer.type(dtype)

print("created trainer")

params = [{'params': trainer.kernel_net.parameters(), 'lr': args.lr_k},
          {'params': trainer.pred_net.parameters(), 'lr': args.lr_p}]
optimizer = torch.optim.Adam(params)
print('set-up optimizer')

### Loop
l = []
torch.manual_seed(0)
print("started loop")
for epoch in range(args.epochs):

    adjust_learning_rate(optimizer, epoch)

    trainer.train()

    for t, (review, target) in enumerate(train_loader):
        review = Variable(review)

        if args.aspect == 'all':
            target = Variable(target[:,:3]).type(dtype)
        else:
            target = Variable(target[:,int(args.aspect[-1])]).type(dtype)

        loss  = trainer(review, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Weight mean is: ', trainer.kernel_net.layer1.weight.mean())
        print('Weight max is: ', trainer.kernel_net.layer1.weight.max())
        print('Weight min is: ', trainer.kernel_net.layer1.weight.min())
        print('Grad max is: ', trainer.kernel_net.layer1.weight.grad.max())
        print('Grad min is: ', trainer.kernel_net.layer1.weight.grad.min())
        print("trained one batch")
        
    loss, pred_loss, reg_loss = validate(val_loader, trainer)

    log(epoch, loss, pred_loss, reg_loss)
    print("logged")

    is_best = pred_loss < lowest_loss
    lowest_loss = min(pred_loss, lowest_loss)    
    save = {'epoch:': epoch + 1, 
            'model': 'Marginal Trainer',
            'state_dict': trainer.state_dict(),
            'lowest_loss': lowest_loss,
            'optimizer': optimizer.state_dict()} 

    save_checkpoint(save, is_best)
    print("saved a checkpoint")

print('*'*20, 'SUCCESS','*'*20)

loaded data
loader defined
created trainer
set-up optimizer
started loop
1
torch.Size([200, 1])
Zero Subset was produced. Re-sample
Weight mean is:  Variable containing:
1.00000e-05 *
 -8.7600
[torch.DoubleTensor of size 1]

Weight max is:  Variable containing:
1.00000e-02 *
  5.0096
[torch.DoubleTensor of size 1]

Weight min is:  Variable containing:
1.00000e-02 *
 -5.0100
[torch.DoubleTensor of size 1]

Grad max is:  Variable containing:
1.00000e-02 *
  5.2791
[torch.DoubleTensor of size 1]

Grad min is:  Variable containing:
1.00000e-02 *
 -4.9328
[torch.DoubleTensor of size 1]

trained one batch
1
torch.Size([200, 1])
Zero Subset was produced. Re-sample
Weight mean is:  Variable containing:
1.00000e-05 *
 -8.7408
[torch.DoubleTensor of size 1]

Weight max is:  Variable containing:
1.00000e-02 *
  5.0190
[torch.DoubleTensor of size 1]

Weight min is:  Variable containing:
1.00000e-02 *
 -5.0198
[torch.DoubleTensor of size 1]

Grad max is:  Variable containing:
1.00000e-02 *
  7.6660

RuntimeError: dimension 0 out of range of 0D tensor at /Users/soumith/miniconda2/conda-bld/pytorch_1493757319118/work/torch/lib/TH/generic/THTensor.c:24

In [25]:
trainer.kernel_net.layer1.weight.min()

RuntimeError: dimension 0 out of range of 0D tensor at /Users/soumith/miniconda2/conda-bld/pytorch_1493757319118/work/torch/lib/TH/generic/THTensor.c:24

NameError: name 'path' is not defined