In [1]:
import sys, os, time
import numpy as np
from tqdm import tqdm_notebook as tqdm
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from foundation import util, models, train

In [2]:
args = util.NS()

args.dim = 300
args.device = 'cpu'

args.batch_size = 128
args.num_workers = 4
args.val_per = 0.1

args.lr = 1e-2
args.weight_decay = 1e-3

In [3]:
class Predict_Relations(Dataset):
    def __init__(self, path='../../fast_table.pth.tar'):
        super().__init__()
        
        data = torch.load(path)
        
        self.tuples = data['rows']
        self.table = dict(zip(data['elements'], data['vecs']))
        
    def convert(self, word):
        return torch.from_numpy(self.table[word])
        
    def __len__(self):
        return len(self.tuples)
    
    def __getitem__(self, idx):
        sub, rel, obj = self.tuples[idx]
        
        x = torch.cat([self.convert(sub), self.convert(obj)],-1)
        y = self.convert(rel)
        
        return x,y
        
dataset = Predict_Relations()
traindata, valdata = train.split_dataset(dataset,shuffle=True, split1=1-args.val_per)
len(traindata), len(valdata)

(254334, 28260)

In [4]:
trainloader, valloader = DataLoader(traindata, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True), \
                        DataLoader(valdata, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
len(trainloader), len(valloader)

(1987, 221)

In [5]:
model = models.make_MLP(args.dim*2, args.dim, hidden_dims=[512], nonlin='prelu')
args.epoch = 0
model.to(args.device)
print(model)

Sequential(
  (0): Linear(in_features=600, out_features=512, bias=True)
  (1): PReLU(num_parameters=1)
  (2): Linear(in_features=512, out_features=300, bias=True)
)


In [6]:
optim = util.get_optimizer('adam', model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
criterion = nn.MSELoss()

In [11]:
def iterate(mode, model, dataloader, print_freq=None):
    if mode == 'train':
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    
    loader = tqdm(enumerate(dataloader),total=len(dataloader))
    
#     if print_freq is None:
#         print_freq = len(loader) // 100
        
    stats = util.StatsMeter()
    stats.new('loss')
        
    for i, (x,y) in loader:
        
        pred = model(x)
        
        loss = criterion(pred, y)
        
        if mode == 'train':
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        stats.update('loss', loss)
        
        loader.set_description("loss {:.4f} ({:.4f})".format(stats['loss'].val.item(), stats['loss'].smooth.item()))
    
    del loader
    
    torch.set_grad_enabled(True)
    
    return stats

In [12]:
for e in range(1):
    train_stats = iterate('train', model, trainloader)
    
    val_stats = iterate('val', model, valloader)
    
    print('Epoch {} results: Train: {:.4f} Val: {:.4f}'.format(args.epoch+1,train_stats['loss'].avg.item(), val_stats['loss'].avg.item()))
    
    args.epoch += 1

A Jupyter Widget

  self.val = torch.tensor(val).float()


KeyboardInterrupt: 