In [1]:
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from model import *

seed = 0 
seed_everything(seed)

In [2]:
data_num = 10000

In [3]:
loader_train,loader_validate = make_loader_train(data_num)

In [4]:
model = MLP()
model.cuda()
model

MLP(
  (layers): Sequential(
    (0): Linear(in_features=3, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1000, out_features=1, bias=True)
  )
)

In [5]:
optimizer = optim.Adam(model.parameters(),lr=1e-3)
criterion = nn.MSELoss()
scheduler = ReduceLROnPlateau(optimizer,min_lr=1e-5,patience=100)
early_stopping = EarlyStopping(patience=500, verbose=False)

In [6]:
def loop_train(loader_train,loader_validate) :
    
    losses_train = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        optimizer.zero_grad()
        y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_train.append( loss.item() )
        
        loss.backward()
        optimizer.step()
        
    losses_validate = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_validate.append( loss.item() )
        
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
        
    scheduler.step(lv)
    
    return (lt,lv)

def loop_test(loader) :
    
    losses = []
    for x,y in loader :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses.append( loss.item() )
            
    return np.mean(losses)

In [8]:
%%time

total_epoch = 100000
outer_unit = np.maximum(data_num//10,1) 
inner_unit = total_epoch//outer_unit 

for outer_loop in range(outer_unit):  
    
    losses_train = []; losses_validate = []
    for inner_loop in range(inner_unit) :
        epoch = inner_unit*outer_loop + inner_loop
        loss_train,loss_validate = loop_train(loader_train,loader_validate)
        losses_train.append(loss_train)
        losses_validate.append(loss_validate)
        
        early_stopping(loss_validate.item(), model)
        if early_stopping.early_stop :
            break 
    
    if early_stopping.early_stop :
        print("Early stopping")
        break 
        
    lr = optimizer.param_groups[0]['lr']
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
    print(f'{epoch:5d} {lr:5.2e} {lt:8.5e} {lv:8.5e}')

   99 1.00e-04 9.57234e-07 9.95441e-07
  199 1.00e-04 9.50845e-07 9.90286e-07
  299 1.00e-04 9.44947e-07 9.88168e-07
  399 1.00e-05 8.98357e-07 9.15707e-07
  499 1.00e-05 8.69331e-07 8.72549e-07
  599 1.00e-05 8.65981e-07 8.69321e-07
  699 1.00e-05 8.63472e-07 8.66623e-07
  799 1.00e-05 8.61413e-07 8.64952e-07
  899 1.00e-05 8.59559e-07 8.62566e-07
  999 1.00e-05 8.57910e-07 8.60933e-07
 1099 1.00e-05 8.56255e-07 8.59582e-07
 1199 1.00e-05 8.55096e-07 8.58242e-07
 1299 1.00e-05 8.53681e-07 8.57017e-07
 1399 1.00e-05 8.52528e-07 8.55390e-07
 1499 1.00e-05 8.51522e-07 8.54750e-07
 1599 1.00e-05 8.50330e-07 8.53573e-07
 1699 1.00e-05 8.49474e-07 8.52205e-07
 1799 1.00e-05 8.48571e-07 8.51793e-07
 1899 1.00e-05 8.47283e-07 8.50376e-07
 1999 1.00e-05 8.46292e-07 8.49825e-07
 2099 1.00e-05 8.45255e-07 8.48812e-07
 2199 1.00e-05 8.44144e-07 8.47973e-07
 2299 1.00e-05 8.43351e-07 8.46770e-07
 2399 1.00e-05 8.42523e-07 8.46425e-07
 2499 1.00e-05 8.41741e-07 8.45517e-07
 2599 1.00e-05 8.40797e-0

21199 1.00e-05 7.31809e-07 7.42163e-07
21299 1.00e-05 7.31299e-07 7.41543e-07
21399 1.00e-05 7.30802e-07 7.41491e-07
21499 1.00e-05 7.30305e-07 7.40997e-07
21599 1.00e-05 7.29896e-07 7.40211e-07
21699 1.00e-05 7.29172e-07 7.39758e-07
21799 1.00e-05 7.28780e-07 7.39340e-07
21899 1.00e-05 7.28285e-07 7.38816e-07
21999 1.00e-05 7.28054e-07 7.37912e-07
22099 1.00e-05 7.27325e-07 7.37564e-07
22199 1.00e-05 7.26813e-07 7.37302e-07
22299 1.00e-05 7.26329e-07 7.36499e-07
22399 1.00e-05 7.25909e-07 7.36528e-07
22499 1.00e-05 7.25348e-07 7.35991e-07
22599 1.00e-05 7.24855e-07 7.35539e-07
22699 1.00e-05 7.24255e-07 7.34827e-07
22799 1.00e-05 7.23759e-07 7.34611e-07
22899 1.00e-05 7.23347e-07 7.34113e-07
22999 1.00e-05 7.22924e-07 7.33737e-07
23099 1.00e-05 7.22312e-07 7.33391e-07
23199 1.00e-05 7.21845e-07 7.32767e-07
23299 1.00e-05 7.21558e-07 7.32299e-07
23399 1.00e-05 7.20925e-07 7.31560e-07
23499 1.00e-05 7.20574e-07 7.31612e-07
23599 1.00e-05 7.20164e-07 7.31207e-07
23699 1.00e-05 7.19588e-0

42299 1.00e-05 6.36506e-07 6.48488e-07
42399 1.00e-05 6.36147e-07 6.47958e-07
42499 1.00e-05 6.35694e-07 6.47088e-07
42599 1.00e-05 6.34965e-07 6.46537e-07
42699 1.00e-05 6.34681e-07 6.46472e-07
42799 1.00e-05 6.34488e-07 6.45490e-07
42899 1.00e-05 6.33842e-07 6.44998e-07
42999 1.00e-05 6.33470e-07 6.45004e-07
43099 1.00e-05 6.33047e-07 6.44623e-07
43199 1.00e-05 6.32644e-07 6.43932e-07
43299 1.00e-05 6.32354e-07 6.44148e-07
43399 1.00e-05 6.32054e-07 6.43797e-07
43499 1.00e-05 6.31564e-07 6.42809e-07
43599 1.00e-05 6.31034e-07 6.42633e-07
43699 1.00e-05 6.30871e-07 6.42076e-07
43799 1.00e-05 6.30465e-07 6.41550e-07
43899 1.00e-05 6.30054e-07 6.41243e-07
43999 1.00e-05 6.29817e-07 6.41538e-07
44099 1.00e-05 6.29326e-07 6.40778e-07
44199 1.00e-05 6.29030e-07 6.39969e-07
44299 1.00e-05 6.28710e-07 6.40087e-07
44399 1.00e-05 6.28338e-07 6.39537e-07
44499 1.00e-05 6.28008e-07 6.39011e-07
44599 1.00e-05 6.27602e-07 6.38727e-07
44699 1.00e-05 6.27170e-07 6.38068e-07
44799 1.00e-05 6.26779e-0

63399 1.00e-05 5.61327e-07 5.74086e-07
63499 1.00e-05 5.60939e-07 5.73815e-07
63599 1.00e-05 5.60740e-07 5.73470e-07
63699 1.00e-05 5.60459e-07 5.73666e-07
63799 1.00e-05 5.60080e-07 5.72699e-07
63899 1.00e-05 5.59774e-07 5.72737e-07
63999 1.00e-05 5.59375e-07 5.72860e-07
64099 1.00e-05 5.59183e-07 5.72344e-07
64199 1.00e-05 5.58817e-07 5.71983e-07
64299 1.00e-05 5.58225e-07 5.71110e-07
64399 1.00e-05 5.58050e-07 5.70022e-07
64499 1.00e-05 5.57831e-07 5.69814e-07
64599 1.00e-05 5.57359e-07 5.69912e-07
64699 1.00e-05 5.56955e-07 5.70312e-07
64799 1.00e-05 5.56603e-07 5.70311e-07
Early stopping
CPU times: user 3h 33min 48s, sys: 30min 8s, total: 4h 3min 57s
Wall time: 4h 5min 6s


In [9]:
model.load_state_dict( torch.load('tmp/checkpoint.pt') )

<All keys matched successfully>

In [10]:
net_dir = 'net'
if not os.path.exists(net_dir) :
    os.mkdir(net_dir)
torch.save(model.state_dict(),f'net/bs_net_{data_num:.0e}.pt')