In [10]:
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from model import *

seed = 0 
seed_everything(seed)

In [11]:
data_num = 100000

In [12]:
loader_train,loader_validate = make_loader_train(data_num)

In [13]:
model = MLP()
model.cuda()
model

MLP(
  (layers): Sequential(
    (0): Linear(in_features=3, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1000, out_features=1, bias=True)
  )
)

In [14]:
optimizer = optim.Adam(model.parameters(),lr=1e-3)
criterion = nn.MSELoss()
scheduler = ReduceLROnPlateau(optimizer,min_lr=1e-5,patience=100)
early_stopping = EarlyStopping(patience=500, verbose=False)

In [15]:
def loop_train(loader_train,loader_validate) :
    
    losses_train = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        optimizer.zero_grad()
        y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_train.append( loss.item() )
        
        loss.backward()
        optimizer.step()
        
    losses_validate = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_validate.append( loss.item() )
        
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
        
    scheduler.step(lv)
    
    return (lt,lv)

def loop_test(loader) :
    
    losses = []
    for x,y in loader :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses.append( loss.item() )
            
    return np.mean(losses)

In [16]:
total_epoch = 100000
outer_unit = np.maximum(data_num//10,1)inner_unit = total_epoch//outer_unit
for outer_loop in range(outer_unit):  
    
    losses_train = []; losses_validate = []
    for inner_loop in range(inner_unit) :
        epoch = inner_unit*outer_loop + inner_loop
        loss_train,loss_validate = loop_train(loader_train,loader_validate)
        losses_train.append(loss_train)
        losses_validate.append(loss_validate)
        
        early_stopping(loss_validate.item(), model)
        if early_stopping.early_stop :
            break 
    
    if early_stopping.early_stop :
        print("Early stopping")
        break 
        
    lr = optimizer.param_groups[0]['lr']
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
    print(f'{epoch:5d} {lr:5.2e} {lt:8.5e} {lv:8.5e}')

   99 1.00e-03 1.91474e-05 2.16792e-06
  199 1.00e-03 1.93902e-06 1.42639e-06
  299 1.00e-03 1.81359e-06 1.36208e-06
  399 1.00e-04 1.42263e-06 1.26793e-06
  499 1.00e-04 1.24566e-06 1.21816e-06
  599 1.00e-04 1.24037e-06 1.21338e-06
  699 1.00e-04 1.23715e-06 1.21081e-06
  799 1.00e-05 1.19112e-06 1.18971e-06
  899 1.00e-05 1.17978e-06 1.18401e-06
  999 1.00e-05 1.17950e-06 1.18368e-06
 1099 1.00e-05 1.17923e-06 1.18336e-06
 1199 1.00e-05 1.17889e-06 1.18271e-06
 1299 1.00e-05 1.17863e-06 1.18243e-06
 1399 1.00e-05 1.17832e-06 1.18206e-06
 1499 1.00e-05 1.17811e-06 1.18153e-06
 1599 1.00e-05 1.17787e-06 1.18164e-06
 1699 1.00e-05 1.17771e-06 1.18119e-06
 1799 1.00e-05 1.17750e-06 1.18082e-06
 1899 1.00e-05 1.17726e-06 1.18032e-06
 1999 1.00e-05 1.17708e-06 1.18066e-06
 2099 1.00e-05 1.17684e-06 1.18050e-06
 2199 1.00e-05 1.17673e-06 1.17983e-06
 2299 1.00e-05 1.17643e-06 1.17964e-06
 2399 1.00e-05 1.17619e-06 1.17936e-06
 2499 1.00e-05 1.17590e-06 1.17942e-06
 2599 1.00e-05 1.17590e-0

KeyboardInterrupt: 

In [None]:
model.load_state_dict( torch.load('tmp/checkpoint.pt') )

In [None]:
net_dir = 'net'
if not os.path.exists(net_dir) :
    os.mkdir(net_dir)
torch.save(model.state_dict(),f'net/bs_net_{data_num:.0e}.pt')